transforms

Module containing custom transforms to be used with torchvision.transforms.

class AutoNorm(dataset: RasterDataset, length: int = 128, roi: BoundingBox | None = None, inplace=False)

Transform that will automatically calculate the mean and standard deviation of the dataset to normalise the data with.

Uses torchvision.transforms.Normalize for the normalisation.

dataset

Dataset to calculate the mean and standard deviation of.

Type:

RasterDataset

sampler

Sampler used to create valid queries for the dataset to find data files.

Type:

RandomGeoSampler

Parameters:
  • dataset (RasterDataset) – Dataset to calculate the mean and standard deviation of.

  • length (int) – Optional; Number of samples from the dataset to calculate the mean and standard deviation of.

  • roi (BoundingBox) – Optional; Region of interest for sampler to sample from.

  • inplace (bool) – Optional; Performs the normalisation transform inplace on the tensor. Default False.

Added in version 0.26.

class ClassTransform(transform: dict[int, int])

Transform to be applied to a mask to convert from one labelling schema to another.

transform

Mapping from one labelling schema to another.

Type:

dict[int, int]

Parameters:

transform (dict[int, int]) – Mapping from one labelling schema to another.

forward(mask: LongTensor) LongTensor

Transforms the given mask from the original label schema to the new.

Parameters:

mask (LongTensor) – Mask in the original label schema.

Returns:

Mask transformed into new label schema.

Return type:

LongTensor

class DetachedColorJitter(*args, **kwargs)

Sends RGB channels of multi-spectral images to be transformed by ColorJitter.

forward(img: Tensor) Tensor

Detaches RGB channels of input image to be sent to ColorJitter.

All other channels bypass ColorJitter and are concatenated onto the colour jittered RGB channels.

Parameters:

img (Tensor) – Input image.

Raises:

ValueError – If number of channels of input img is 2.

Returns:

Color jittered image.

Return type:

Tensor

class MinervaCompose(transforms: list[Callable[[...], Any]] | Callable[[...], Any] | dict[str, list[Callable[[...], Any]] | Callable[[...], Any]], change_detection: bool = False)

Adaption of torchvision.transforms.Compose. Composes several transforms together.

Designed to work with both Tensor and torchgeo sample dict.

This transform does not support torchscript.

transforms

List of composed transforms.

Type:

list[Callable[…, Any]] | Callable[…, Any]

Parameters:
  • transforms (Sequence[Callable[..., Any]] | Callable[..., Any]) – List of transforms to compose.

  • change_detection (bool) – Flag for if transforming a change detection dataset which has "image1" and "image2" keys rather than "image".

Example

>>> transforms.MinervaCompose([
>>>     transforms.CenterCrop(10),
>>>     transforms.PILToTensor(),
>>>     transforms.ConvertImageDtype(torch.float),
>>> ])
class Normalise(norm_value: int)

Transform that normalises an image tensor based on the bit size.

norm_value

Value to normalise image with.

Type:

int

Parameters:

norm_value (int) – Value to normalise image with.

forward(img: Tensor) Tensor

Normalises inputted image using norm_value.

Parameters:

img (Tensor) – Image tensor to be normalised. Should have a bit size that relates to norm_value.

Returns:

Input image tensor normalised by norm_value.

Return type:

Tensor

class PairCreate

Transform that takes a sample and returns a pair of the same sample.

static forward(sample: Any) tuple[Any, Any]

Takes a sample and returns it and a copy as a tuple pair.

Parameters:

sample (Any) – Sample to duplicate.

Returns:

tuple of two copies of the sample.

Return type:

tuple[Any, Any]

class SeasonTransform(season: str = 'random')

Configure what seasons from a patch are parsed through to the model.

Adapted from source: https://github.com/zhu-xlab/SSL4EO-S12/tree/main

Parameters:
  • season (str) – How to handle what seasons to return:

  • pair (*) – Randomly pick 2 seasons to return that will form a pair.

  • random (*) – Randomly pick a single season to return.

Added in version 0.28.

class SelectChannels(channels: list[int])

Transform to select which channels to keep by passing a list of indices

channels

Channel indices to keep.

Type:

list[int]

Parameters:

channels (list[int]) – Channel indices to keep.

forward(img: Tensor) Tensor

Select the desired channels from the input image and return.

Parameters:

img (Tensor) – Input image.

Returns:

Selected channels of the input image.

Return type:

Tensor

class SingleLabel(mode: str = 'modal')

Reduces a mask to a single label using transform mode provided.

mode

Mode of operation.

Type:

str

Parameters:

mode (str) – Mode of operation. Currently only supports "modal" or "centre".

Added in version 0.22.

forward(mask: LongTensor) LongTensor

Forward pass of the transform, reducing the input mask to a single label.

Parameters:

mask (LongTensor) – Input mask to reduce to a single label.

Raises:

NotImplementedError – If mode is not "modal" or "centre".

Returns:

The single label as a 0D, 1-element tensor.

Return type:

LongTensor

class SwapKeys(from_key: str, to_key: str)

Transform to set one key in a torchgeo sample dict to another.

Useful for testing autoencoders to predict their input.

from_key

Key for the value to set to to_key.

Type:

str

to_key

Key to set the value from from_key to.

Type:

str

Parameters:
  • from_key (str) – Key for the value to set to to_key.

  • to_key (str) – Key to set the value from from_key to.

Added in version 0.22.

forward(sample: dict[str, Any]) dict[str, Any]

Sets the to_key of sample to the from_key and returns.

Parameters:

sample (dict[str, Any]) – Sample dict from torchgeo containing from_key.

Returns:

Sample with to_key set to the value of from_key.

Return type:

dict[str, Any]

class ToRGB(channels: tuple[int, int, int] | None = None)

Reduces the number of channels down to RGB.

channels

Optional; Tuple defining which channels in expected input images contain the RGB bands. If None, it is assumed that the RGB bands are in the first 3 channels.

Type:

tuple[int, int, int]

Parameters:

channels (tuple[int, int, int]) – Optional; Tuple defining which channels in expected input images contain the RGB bands. If None, it is assumed that the RGB bands are in the first 3 channels.

Added in version 0.22.

forward(img: Tensor) Tensor

Performs a forward pass of the transform, returning an RGB image.

Parameters:

img (Tensor) – Image to convert to RGB.

Returns:

Image of only the RGB channels of img.

Return type:

Tensor

Raises:
get_transform(transform_params: dict[str, Any]) Callable[[...], Any]

Creates a transform object based on config parameters.

Parameters:

transform_params (dict[str, Any]) – Arguements to construct transform with. Should also include "_target_" key defining the import path to the transform object.

Returns:

Initialised transform object specified by config parameters.

Example

>>> params = {"_target": "torchvision.transforms.RandomResizedCrop", "size": 128}
>>> transform = get_transform(params)
Raises:

TypeError – If created transform object is itself not Callable.

init_auto_norm(dataset: RasterDataset, length: int = 128, roi: BoundingBox | None = None, inplace=False) RasterDataset

Uses :class:~`minerva.transforms.AutoNorm` to automatically find the mean and standard deviation of dataset to create a normalisation transform that is then added to the existing transforms of dataset.

Parameters:
  • dataset (RasterDataset) – Dataset to find and apply the normalisation conditions to.

  • params (dict[str, Any]) – Parameters for :class:~`minerva.transforms.AutoNorm`.

Returns:

dataset with an additional :class:~`minerva.transforms.AutoNorm` transform added to it’s :attr:~`torchgeo.datasets.RasterDataset.transforms` attribute.

Return type:

RasterDataset

make_transformations(transform_params: dict[str, Any] | Literal[False], change_detection: bool = False) MinervaCompose | None

Constructs a transform or series of transforms based on parameters provided.

Parameters:
  • transform_params (dict[str, Any] | Literal[False]) – Parameters defining transforms desired. The name of each transform should be the key, while the kwargs for the transform should be the value of that key as a dict.

  • change_detection (bool) – Flag for if transforming a change detection dataset which has "image1" and "image2" keys rather than "image".

Example

>>> transform_params = {
>>>    "crop": {"_target_": "torchvision.transforms.CenterCrop", "size": 128},
>>>     "flip": {"_target_": "torchvision.transforms.RandomHorizontalFlip", "p": 0.7}
>>> }
>>> transforms = make_transformations(transform_params)
Returns:

If no parameters are parsed, None is returned. If only one transform is defined by the parameters, returns a Transforms object. If multiple transforms are defined, a Compose object of Transform objects is returned.