transformsο
Module containing custom transforms to be used with torchvision.transforms
.
- class AutoNorm(dataset: RasterDataset, length: int = 128, roi: BoundingBox | None = None, inplace=False)ο
Transform that will automatically calculate the mean and standard deviation of the dataset to normalise the data with.
Uses
torchvision.transforms.Normalize
for the normalisation.- datasetο
Dataset to calculate the mean and standard deviation of.
- Type:
RasterDataset
- samplerο
Sampler used to create valid queries for the dataset to find data files.
- Type:
RandomGeoSampler
- Parameters:
dataset (RasterDataset) β Dataset to calculate the mean and standard deviation of.
length (int) β Optional; Number of samples from the dataset to calculate the mean and standard deviation of.
roi (BoundingBox) β Optional; Region of interest for sampler to sample from.
inplace (bool) β Optional; Performs the normalisation transform inplace on the tensor. Default False.
Added in version 0.26.
- class ClassTransform(transform: dict[int, int])ο
Transform to be applied to a mask to convert from one labelling schema to another.
- forward(mask: LongTensor) LongTensor ο
Transforms the given mask from the original label schema to the new.
- Parameters:
mask (LongTensor) β Mask in the original label schema.
- Returns:
Mask transformed into new label schema.
- Return type:
LongTensor
- class DetachedColorJitter(*args, **kwargs)ο
Sends RGB channels of multi-spectral images to be transformed by
ColorJitter
.- forward(img: Tensor) Tensor ο
Detaches RGB channels of input image to be sent to
ColorJitter
.All other channels bypass
ColorJitter
and are concatenated onto the colour jittered RGB channels.- Parameters:
img (Tensor) β Input image.
- Raises:
ValueError β If number of channels of input
img
is 2.- Returns:
Color jittered image.
- Return type:
- class MinervaCompose(transforms: list[Callable[[...], Any]] | Callable[[...], Any] | dict[str, list[Callable[[...], Any]] | Callable[[...], Any]], change_detection: bool = False)ο
Adaption of
torchvision.transforms.Compose
. Composes several transforms together.Designed to work with both
Tensor
andtorchgeo
sampledict
.This transform does not support torchscript.
- Parameters:
Example
>>> transforms.MinervaCompose([ >>> transforms.CenterCrop(10), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> ])
- class Normalise(norm_value: int)ο
Transform that normalises an image tensor based on the bit size.
- Parameters:
norm_value (int) β Value to normalise image with.
- class PairCreateο
Transform that takes a sample and returns a pair of the same sample.
- class SeasonTransform(season: str = 'random')ο
Configure what seasons from a patch are parsed through to the model.
Adapted from source: https://github.com/zhu-xlab/SSL4EO-S12/tree/main
- Parameters:
season (str) β How to handle what seasons to return:
pair (*) β Randomly pick 2 seasons to return that will form a pair.
random (*) β Randomly pick a single season to return.
Added in version 0.28.
- class SelectChannels(channels: list[int])ο
Transform to select which channels to keep by passing a list of indices
- class SingleLabel(mode: str = 'modal')ο
Reduces a mask to a single label using transform mode provided.
- Parameters:
mode (str) β Mode of operation. Currently only supports
"modal"
or"centre"
.
Added in version 0.22.
- forward(mask: LongTensor) LongTensor ο
Forward pass of the transform, reducing the input mask to a single label.
- Parameters:
mask (LongTensor) β Input mask to reduce to a single label.
- Raises:
NotImplementedError β If
mode
is not"modal"
or"centre"
.- Returns:
The single label as a 0D, 1-element tensor.
- Return type:
LongTensor
- class SwapKeys(from_key: str, to_key: str)ο
Transform to set one key in a
torchgeo
sampledict
to another.Useful for testing autoencoders to predict their input.
- Parameters:
Added in version 0.22.
- class ToRGB(channels: tuple[int, int, int] | None = None)ο
Reduces the number of channels down to RGB.
- channelsο
Optional; Tuple defining which channels in expected input images contain the RGB bands. If
None
, it is assumed that the RGB bands are in the first 3 channels.
- Parameters:
channels (tuple[int, int, int]) β Optional; Tuple defining which channels in expected input images contain the RGB bands. If
None
, it is assumed that the RGB bands are in the first 3 channels.
Added in version 0.22.
- forward(img: Tensor) Tensor ο
Performs a forward pass of the transform, returning an RGB image.
- Parameters:
img (Tensor) β Image to convert to RGB.
- Returns:
Image of only the RGB channels of
img
.- Return type:
- Raises:
ValueError β If
img
has less channels than specified inchannels
.ValueError β If
img
has less than 3 channels andchannels
isNone
.
- get_transform(transform_params: dict[str, Any]) Callable[[...], Any] ο
Creates a transform object based on config parameters.
- Parameters:
transform_params (dict[str, Any]) β Arguements to construct transform with. Should also include
"_target_"
key defining the import path to the transform object.- Returns:
Initialised transform object specified by config parameters.
Example
>>> params = {"_target": "torchvision.transforms.RandomResizedCrop", "size": 128} >>> transform = get_transform(params)
- Raises:
TypeError β If created transform object is itself not
Callable
.
- init_auto_norm(dataset: RasterDataset, length: int = 128, roi: BoundingBox | None = None, inplace=False) RasterDataset ο
Uses :class:~`minerva.transforms.AutoNorm` to automatically find the mean and standard deviation of dataset to create a normalisation transform that is then added to the existing transforms of dataset.
- Parameters:
- Returns:
dataset with an additional :class:~`minerva.transforms.AutoNorm` transform added to itβs :attr:~`torchgeo.datasets.RasterDataset.transforms` attribute.
- Return type:
RasterDataset
- make_transformations(transform_params: dict[str, Any] | Literal[False], change_detection: bool = False) MinervaCompose | None ο
Constructs a transform or series of transforms based on parameters provided.
- Parameters:
transform_params (dict[str, Any] | Literal[False]) β Parameters defining transforms desired. The name of each transform should be the key, while the kwargs for the transform should be the value of that key as a dict.
change_detection (bool) β Flag for if transforming a change detection dataset which has
"image1"
and"image2"
keys rather than"image"
.
Example
>>> transform_params = { >>> "crop": {"_target_": "torchvision.transforms.CenterCrop", "size": 128}, >>> "flip": {"_target_": "torchvision.transforms.RandomHorizontalFlip", "p": 0.7} >>> } >>> transforms = make_transformations(transform_params)
- Returns:
If no parameters are parsed, None is returned. If only one transform is defined by the parameters, returns a Transforms object. If multiple transforms are defined, a Compose object of Transform objects is returned.