Skip to content

Datasets

This page acts as the technical reference for the datasets subpackage.

The datasets subpackage provides dataset classes, which can be used with the PyTorch Dataloader class for managing the data loading process in functions.

FlatFieldArtefactMapDataset

Bases: MultiFileArtefactMapDataset

Source code in clair_torch/datasets/image_dataset.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class FlatFieldArtefactMapDataset(MultiFileArtefactMapDataset):
    @typechecked
    def __init__(self, files: tuple[FrameSettings | PairedFrameSettings, ...], copy_preloaded_data: bool = True,
                 missing_std_mode: MissingStdMode = MissingStdMode.CONSTANT, missing_std_value: float = 0.0,
                 attributes_to_match: dict[str, None | int | float] = None,
                 cache_size: int = 0, missing_val_mode: MissingValMode = MissingValMode.ERROR,
                 default_get_item_key: str = "raw"):
        """
        Dataset class for handling calibration images. Currently, mainly used for flat-field correction.
        Args:
            attributes_to_match:
            copy_preloaded_data:
            files: list of FrameSettings objects composing the dataset of calibration images.
            missing_std_mode: how missing uncertainty images should be dealt with. Read more in .enums.MissingStdMode.
            missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing
                uncertainty images.
        """
        # Workaround for mutable default argument.
        if attributes_to_match is None:
            attributes_to_match = {"magnification": None, "illumination": None}

        super().__init__(files=files, copy_preloaded_data=copy_preloaded_data, missing_std_mode=missing_std_mode,
                         missing_std_value=missing_std_value, attributes_to_match=attributes_to_match,
                         cache_size=cache_size, missing_val_mode=missing_val_mode,
                         default_get_item_key=default_get_item_key)

    def _get_matching_image_settings_idx(self, reference_image_settings: FrameSettings,
                                         matching_attributes: dict[str, int | float | None]) -> int | None:
        """
        Internal helper method for getting a matching FrameSettings object. If no matches are found, returns None.
        Args:
            reference_image_settings: the FrameSettings object for which to find a match.
            matching_attributes: the attributes that should match between the reference and one of the FrameSettings
                contained.

        Returns:
            If a match is found, returns the index of that FrameSettings object. If no matches are found, returns None.
        """
        for i, image_settings in enumerate(self.files):

            if image_settings.is_match(reference_image_settings, attributes=matching_attributes):
                return i

        return None

__init__(files, copy_preloaded_data=True, missing_std_mode=MissingStdMode.CONSTANT, missing_std_value=0.0, attributes_to_match=None, cache_size=0, missing_val_mode=MissingValMode.ERROR, default_get_item_key='raw')

Dataset class for handling calibration images. Currently, mainly used for flat-field correction. Args: attributes_to_match: copy_preloaded_data: files: list of FrameSettings objects composing the dataset of calibration images. missing_std_mode: how missing uncertainty images should be dealt with. Read more in .enums.MissingStdMode. missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing uncertainty images.

Source code in clair_torch/datasets/image_dataset.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@typechecked
def __init__(self, files: tuple[FrameSettings | PairedFrameSettings, ...], copy_preloaded_data: bool = True,
             missing_std_mode: MissingStdMode = MissingStdMode.CONSTANT, missing_std_value: float = 0.0,
             attributes_to_match: dict[str, None | int | float] = None,
             cache_size: int = 0, missing_val_mode: MissingValMode = MissingValMode.ERROR,
             default_get_item_key: str = "raw"):
    """
    Dataset class for handling calibration images. Currently, mainly used for flat-field correction.
    Args:
        attributes_to_match:
        copy_preloaded_data:
        files: list of FrameSettings objects composing the dataset of calibration images.
        missing_std_mode: how missing uncertainty images should be dealt with. Read more in .enums.MissingStdMode.
        missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing
            uncertainty images.
    """
    # Workaround for mutable default argument.
    if attributes_to_match is None:
        attributes_to_match = {"magnification": None, "illumination": None}

    super().__init__(files=files, copy_preloaded_data=copy_preloaded_data, missing_std_mode=missing_std_mode,
                     missing_std_value=missing_std_value, attributes_to_match=attributes_to_match,
                     cache_size=cache_size, missing_val_mode=missing_val_mode,
                     default_get_item_key=default_get_item_key)

ImageMapDataset

Bases: MultiFileMapDataset

Source code in clair_torch/datasets/image_dataset.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class ImageMapDataset(MultiFileMapDataset):
    @typechecked
    def __init__(self, files: tuple[FrameSettings | PairedFrameSettings, ...], copy_preloaded_data: bool = True,
                 missing_std_mode: MissingStdMode = MissingStdMode.CONSTANT,
                 missing_std_value: float = 0.0, default_get_item_key: str = "raw",
                 missing_val_mode: MissingValMode = MissingValMode.ERROR):
        """
        ImageDataset is the master image data object. The files attribute holds a list of FileSettings-based objects.
        The image tensors shapes are (C, H, W), that is number of channels, height and width. Through a DataLoader the
        shape is expanded into (N, C, H, W) with N standing for the number of images in the batch.
        Args:
            files: list of the FileSettings-based objects composing the dataset.
            copy_preloaded_data: whether preloaded data should be returned as a new copy or as a reference to the
                preloaded data contained in self._preloaded_dataset.
            missing_std_mode: how missing uncertainty images should be dealt with. Read more in .enums.MissingStdMode.
            missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing
                uncertainty images.
        """
        super().__init__(files=files, copy_preloaded_data=copy_preloaded_data, missing_std_mode=missing_std_mode,
                         missing_std_value=missing_std_value, default_getitem_key=default_get_item_key,
                         missing_val_mode=missing_val_mode)

__init__(files, copy_preloaded_data=True, missing_std_mode=MissingStdMode.CONSTANT, missing_std_value=0.0, default_get_item_key='raw', missing_val_mode=MissingValMode.ERROR)

ImageDataset is the master image data object. The files attribute holds a list of FileSettings-based objects. The image tensors shapes are (C, H, W), that is number of channels, height and width. Through a DataLoader the shape is expanded into (N, C, H, W) with N standing for the number of images in the batch. Args: files: list of the FileSettings-based objects composing the dataset. copy_preloaded_data: whether preloaded data should be returned as a new copy or as a reference to the preloaded data contained in self._preloaded_dataset. missing_std_mode: how missing uncertainty images should be dealt with. Read more in .enums.MissingStdMode. missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing uncertainty images.

Source code in clair_torch/datasets/image_dataset.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@typechecked
def __init__(self, files: tuple[FrameSettings | PairedFrameSettings, ...], copy_preloaded_data: bool = True,
             missing_std_mode: MissingStdMode = MissingStdMode.CONSTANT,
             missing_std_value: float = 0.0, default_get_item_key: str = "raw",
             missing_val_mode: MissingValMode = MissingValMode.ERROR):
    """
    ImageDataset is the master image data object. The files attribute holds a list of FileSettings-based objects.
    The image tensors shapes are (C, H, W), that is number of channels, height and width. Through a DataLoader the
    shape is expanded into (N, C, H, W) with N standing for the number of images in the batch.
    Args:
        files: list of the FileSettings-based objects composing the dataset.
        copy_preloaded_data: whether preloaded data should be returned as a new copy or as a reference to the
            preloaded data contained in self._preloaded_dataset.
        missing_std_mode: how missing uncertainty images should be dealt with. Read more in .enums.MissingStdMode.
        missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing
            uncertainty images.
    """
    super().__init__(files=files, copy_preloaded_data=copy_preloaded_data, missing_std_mode=missing_std_mode,
                     missing_std_value=missing_std_value, default_getitem_key=default_get_item_key,
                     missing_val_mode=missing_val_mode)

MultiFileIterDataset

Bases: IterableDataset, ABC

A generic base class for iterable-style Dataset classes. Dataset classes must manage files via a concrete implementation of the generic base FileSettings class.

Source code in clair_torch/datasets/base.py
299
300
301
302
303
304
305
306
307
308
309
310
311
class MultiFileIterDataset(IterableDataset, ABC):
    """
    A generic base class for iterable-style Dataset classes. Dataset classes must manage files via a concrete
    implementation of the generic base FileSettings class.
    """

    def __init__(self, files: List[FrameSettings | PairedFrameSettings]):

        self.files = files

    @abstractmethod
    def __iter__(self):
        pass

MultiFileMapDataset

Bases: Dataset, ABC

A generic base class for map-style Dataset classes. Dataset classes must manage files via a concrete implementation of the generic base FileSettings class.

Source code in clair_torch/datasets/base.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class MultiFileMapDataset(Dataset, ABC):
    """
    A generic base class for map-style Dataset classes. Dataset classes must manage files via a concrete implementation
    of the generic base FileSettings class.
    """
    @typechecked
    def __init__(self, files: tuple[FrameSettings | PairedFrameSettings, ...], copy_preloaded_data: bool = True,
                 missing_std_mode: MissingStdMode = MissingStdMode.CONSTANT, missing_std_value: float = 0.0,
                 default_getitem_key: str = "raw", missing_val_mode: MissingValMode = MissingValMode.ERROR):

        self.files = files
        self.preloaded_dataset = None
        self.copy_preloaded_data = copy_preloaded_data
        self.missing_std_mode = missing_std_mode
        self.missing_std_value = missing_std_value
        self.shared_std_tensor = torch.tensor(missing_std_value)
        self.missing_val_mode = missing_val_mode

        # Create a dictionary of numeric metadata keys, based on which to create lists of indices of the files in
        # self.file based on the sorting of each key.
        self.sorted_indices = {
            key: sorted(range(len(self.files)), key=lambda i: self.files[i].get_numeric_metadata()[key]) for key in files[0].get_numeric_metadata().keys()
        }

        # Insert a raw sorting order in the indices.
        self.sorted_indices["raw"] = list(range(len(self.files)))

        if default_getitem_key not in self.sorted_indices.keys():
            raise ValueError(f"The default access key {default_getitem_key} is not found in the dataset.")
        self.default_getitem_key = default_getitem_key

    def __len__(self) -> int:
        """
        The length of the dataset is defined as the number of files in it manages.
        Returns:
            int representing the number of files.
        """
        return len(self.files)

    def __getitem__(self, key) -> tuple[int, torch.Tensor, torch.Tensor | None, dict[str, float | int]]:
        """
        This method loads images from disk with OpenCV, converts them to PyTorch tensors, runs them through the given
        transformations, finally returning the image tensor and a scalar tensor of the exposure time. It also falls
        back on the preloaded tensors if they are available. This method should be used as the main way to access the
        tensors.
        Args:
            key: index of the item to get.

        Returns:
            A tuple (tensor, tensor | None, dict[str, float | int]), representing the value image, optional uncertainty
            image and a numeric metadata dictionary.
        """
        # Support (idx, meta_key) indexing
        if isinstance(key, tuple):
            if len(key) != 2:
                raise ValueError("Expected (index, meta_key) for tuple indexing.")
            idx, meta_key = key
        else:
            idx, meta_key = key, self.default_getitem_key

        # Resolve sorted index
        sorted_idx = self.sorted_indices[meta_key][idx]

        # Utilize the preloaded dataset if it exists.
        if self.preloaded_dataset is not None:
            indices, val_tensors, std_tensors, numeric_metadatas = self.preloaded_dataset
            if self.copy_preloaded_data:
                return indices[sorted_idx], val_tensors[sorted_idx].clone(), std_tensors[sorted_idx].clone(), deepcopy(numeric_metadatas[sorted_idx])
            else:
                return indices[sorted_idx], val_tensors[sorted_idx], std_tensors[sorted_idx], numeric_metadatas[sorted_idx]

        # Load images normally
        val_image, std_image, numeric_metadatas = self._load_value_and_std_image(sorted_idx)

        return sorted_idx, val_image, std_image, numeric_metadatas

    def _load_value_and_std_image(self, idx: int) \
            -> tuple[torch.Tensor, torch.Tensor | None, dict[str, float | int]]:
        """
        Shared image loading function that handles the loading of both the value images and uncertainty images and the
        numeric metadata.
        Args:
            idx: index of the managed file to load images off of.

        Returns:
            Tuple of (tensor, tensor | None, dict) representing the value image, possible uncertainty image and numeric
            metadata dictionary.
        """
        file_settings = self.files[idx]
        input_paths = file_settings.get_input_paths()
        transforms = file_settings.get_transforms()

        if isinstance(transforms, tuple):
            val_transforms, std_transforms = transforms
        else:
            val_transforms = transforms
            std_transforms = None

        if isinstance(input_paths, tuple):
            val_path, std_path = input_paths
        else:
            val_path = input_paths
            std_path = None

        val_image = load_image(val_path, val_transforms)
        if std_path is not None:
            std_image = load_image(std_path, std_transforms)
        else:
            if self.missing_std_mode == MissingStdMode.NONE:
                std_image = None
            elif self.missing_std_mode == MissingStdMode.CONSTANT:
                std_image = self.shared_std_tensor.expand_as(val_image)
            elif self.missing_std_mode == MissingStdMode.MULTIPLIER:
                std_image = val_image * self.shared_std_tensor
            else:
                raise ValueError(f"Unsupported MissingStdMode: {self.missing_std_mode}")

        numeric_metadata = file_settings.get_numeric_metadata()

        return val_image, std_image, numeric_metadata

    def preload_dataset(self) -> None:
        """
        Loads all data from disk into memory and stores them as a tuple of lists of tensors in self._preloaded_dataset.
        This method utilizes the __getitem__ method.
        """
        indices = []
        val_tensors = []
        std_tensors = []
        numeric_metadata = []

        for i in range(len(self)):
            index, val_tensor, std_tensor, numeric_metadata_tensor = self[i]
            indices.append(index)
            val_tensors.append(val_tensor)
            std_tensors.append(std_tensor)
            numeric_metadata.append(numeric_metadata_tensor)

        self.preloaded_dataset = (indices, val_tensors, std_tensors, numeric_metadata)

    def _get_closest_frame_settings_idx(self, reference_frame_settings: FrameSettings, attribute: str) -> int | None:
        """
        Method for getting the closest possible match from the dataset for a given reference FrameSettings instance as
        determined by the given attribute.
        Args:
            reference_frame_settings: the FrameSettings instance for which to find the closest match.
            attribute: the attribute based on which the match is searched for.

        Returns:
            The index of the closest matching FrameSettings instance in this dataset, or None if dataset is empty.
        """

        ...

__getitem__(key)

This method loads images from disk with OpenCV, converts them to PyTorch tensors, runs them through the given transformations, finally returning the image tensor and a scalar tensor of the exposure time. It also falls back on the preloaded tensors if they are available. This method should be used as the main way to access the tensors. Args: key: index of the item to get.

Returns:

Type Description
int

A tuple (tensor, tensor | None, dict[str, float | int]), representing the value image, optional uncertainty

Tensor

image and a numeric metadata dictionary.

Source code in clair_torch/datasets/base.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __getitem__(self, key) -> tuple[int, torch.Tensor, torch.Tensor | None, dict[str, float | int]]:
    """
    This method loads images from disk with OpenCV, converts them to PyTorch tensors, runs them through the given
    transformations, finally returning the image tensor and a scalar tensor of the exposure time. It also falls
    back on the preloaded tensors if they are available. This method should be used as the main way to access the
    tensors.
    Args:
        key: index of the item to get.

    Returns:
        A tuple (tensor, tensor | None, dict[str, float | int]), representing the value image, optional uncertainty
        image and a numeric metadata dictionary.
    """
    # Support (idx, meta_key) indexing
    if isinstance(key, tuple):
        if len(key) != 2:
            raise ValueError("Expected (index, meta_key) for tuple indexing.")
        idx, meta_key = key
    else:
        idx, meta_key = key, self.default_getitem_key

    # Resolve sorted index
    sorted_idx = self.sorted_indices[meta_key][idx]

    # Utilize the preloaded dataset if it exists.
    if self.preloaded_dataset is not None:
        indices, val_tensors, std_tensors, numeric_metadatas = self.preloaded_dataset
        if self.copy_preloaded_data:
            return indices[sorted_idx], val_tensors[sorted_idx].clone(), std_tensors[sorted_idx].clone(), deepcopy(numeric_metadatas[sorted_idx])
        else:
            return indices[sorted_idx], val_tensors[sorted_idx], std_tensors[sorted_idx], numeric_metadatas[sorted_idx]

    # Load images normally
    val_image, std_image, numeric_metadatas = self._load_value_and_std_image(sorted_idx)

    return sorted_idx, val_image, std_image, numeric_metadatas

__len__()

The length of the dataset is defined as the number of files in it manages. Returns: int representing the number of files.

Source code in clair_torch/datasets/base.py
51
52
53
54
55
56
57
def __len__(self) -> int:
    """
    The length of the dataset is defined as the number of files in it manages.
    Returns:
        int representing the number of files.
    """
    return len(self.files)

preload_dataset()

Loads all data from disk into memory and stores them as a tuple of lists of tensors in self._preloaded_dataset. This method utilizes the getitem method.

Source code in clair_torch/datasets/base.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def preload_dataset(self) -> None:
    """
    Loads all data from disk into memory and stores them as a tuple of lists of tensors in self._preloaded_dataset.
    This method utilizes the __getitem__ method.
    """
    indices = []
    val_tensors = []
    std_tensors = []
    numeric_metadata = []

    for i in range(len(self)):
        index, val_tensor, std_tensor, numeric_metadata_tensor = self[i]
        indices.append(index)
        val_tensors.append(val_tensor)
        std_tensors.append(std_tensor)
        numeric_metadata.append(numeric_metadata_tensor)

    self.preloaded_dataset = (indices, val_tensors, std_tensors, numeric_metadata)

VideoIterableDataset

Bases: IterableDataset

Source code in clair_torch/datasets/video_frame_dataset.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class VideoIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, frame_settings: List[FrameSettings], missing_std_mode=MissingStdMode.CONSTANT,
                 missing_std_value=0.0):
        """
        # TODO: add handling for std files along the main value files.
        Dataset class for video files. Treats all encompassed files as a single dataset, jumping smoothly from one file
        to the next upon exhausting the frames from one file.
        Args:
            frame_settings: list of FrameSettings objects composing the dataset.
            missing_std_mode: enum flag determining how missing uncertainty images should be handled.
            missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing
                uncertainty images.
        """
        self.frame_settings = frame_settings
        self.missing_std_mode = missing_std_mode
        self.shared_std_tensor = torch.tensor(missing_std_value)
        self._running_index = 0

    def __len__(self):

        number_of_frames = 0
        for frame_setting in self.frame_settings:
            number_of_frames += frame_setting.get_numeric_metadata()["number_of_frames"]
        return number_of_frames

    def __iter__(self) -> Tuple[int, torch.Tensor, torch.Tensor | None, dict[str, float | int]]:
        """
        Access method for the frames of this dataset. Iterates through the files and frames, moving on to the next file
        upon exhausting a file.
        Returns:

        """
        for settings in self.frame_settings:

            input_paths = settings.get_input_paths()
            transforms = settings.get_transforms()

            if isinstance(input_paths, Tuple):
                val_path, std_path = input_paths
            else:
                val_path = input_paths
                std_path = None

            if isinstance(transforms, Tuple):
                val_transforms, std_transforms = transforms
            else:
                val_transforms = transforms
                std_transforms = None

            metadata = settings.get_numeric_metadata()

            for frame in load_video_frames_generator(val_path, val_transforms):
                if self.missing_std_mode == MissingStdMode.NONE:
                    std_image = None
                elif self.missing_std_mode == MissingStdMode.CONSTANT:
                    std_image = self.shared_std_tensor.expand_as(frame)
                elif self.missing_std_mode == MissingStdMode.MULTIPLIER:
                    std_image = frame * self.shared_std_tensor
                else:
                    raise ValueError(f"Unsupported MissingStdMode: {self.missing_std_mode}")

                self._running_index += 1
                yield self._running_index, frame, std_image, metadata

__init__(frame_settings, missing_std_mode=MissingStdMode.CONSTANT, missing_std_value=0.0)

TODO: add handling for std files along the main value files.

Dataset class for video files. Treats all encompassed files as a single dataset, jumping smoothly from one file to the next upon exhausting the frames from one file. Args: frame_settings: list of FrameSettings objects composing the dataset. missing_std_mode: enum flag determining how missing uncertainty images should be handled. missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing uncertainty images.

Source code in clair_torch/datasets/video_frame_dataset.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def __init__(self, frame_settings: List[FrameSettings], missing_std_mode=MissingStdMode.CONSTANT,
             missing_std_value=0.0):
    """
    # TODO: add handling for std files along the main value files.
    Dataset class for video files. Treats all encompassed files as a single dataset, jumping smoothly from one file
    to the next upon exhausting the frames from one file.
    Args:
        frame_settings: list of FrameSettings objects composing the dataset.
        missing_std_mode: enum flag determining how missing uncertainty images should be handled.
        missing_std_value: a constant that is used in a manner defined by the missing_std_mode to deal with missing
            uncertainty images.
    """
    self.frame_settings = frame_settings
    self.missing_std_mode = missing_std_mode
    self.shared_std_tensor = torch.tensor(missing_std_value)
    self._running_index = 0

__iter__()

Access method for the frames of this dataset. Iterates through the files and frames, moving on to the next file upon exhausting a file. Returns:

Source code in clair_torch/datasets/video_frame_dataset.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __iter__(self) -> Tuple[int, torch.Tensor, torch.Tensor | None, dict[str, float | int]]:
    """
    Access method for the frames of this dataset. Iterates through the files and frames, moving on to the next file
    upon exhausting a file.
    Returns:

    """
    for settings in self.frame_settings:

        input_paths = settings.get_input_paths()
        transforms = settings.get_transforms()

        if isinstance(input_paths, Tuple):
            val_path, std_path = input_paths
        else:
            val_path = input_paths
            std_path = None

        if isinstance(transforms, Tuple):
            val_transforms, std_transforms = transforms
        else:
            val_transforms = transforms
            std_transforms = None

        metadata = settings.get_numeric_metadata()

        for frame in load_video_frames_generator(val_path, val_transforms):
            if self.missing_std_mode == MissingStdMode.NONE:
                std_image = None
            elif self.missing_std_mode == MissingStdMode.CONSTANT:
                std_image = self.shared_std_tensor.expand_as(frame)
            elif self.missing_std_mode == MissingStdMode.MULTIPLIER:
                std_image = frame * self.shared_std_tensor
            else:
                raise ValueError(f"Unsupported MissingStdMode: {self.missing_std_mode}")

            self._running_index += 1
            yield self._running_index, frame, std_image, metadata

custom_collate(batch)

Custom collate function for handling possible None std images. If any Nones are found in the batch, the whole batch is set to None. Args: batch: the data batch from a Dataset as a tuple. Expects tuple of four items, similar to the return value.

Returns:

Type Description
Tensor

Batched data in a tuple

Tensor
  • Index tensor containing the indices that were utilized from the dataset
Tensor | None
  • Image tensor
dict[str, Tensor]
  • Possible uncertainty image tensor.
tuple[Tensor, Tensor, Tensor | None, dict[str, Tensor]]
  • Dictionary of metadata keys (str) and numeric values (torch.Tensor).
Source code in clair_torch/datasets/collate.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def custom_collate(batch) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor]]:
    """
    Custom collate function for handling possible None std images. If any Nones are found in the batch, the whole
    batch is set to None.
    Args:
        batch: the data batch from a Dataset as a tuple. Expects tuple of four items, similar to the return value.

    Returns:
        Batched data in a tuple
        - Index tensor containing the indices that were utilized from the dataset
        - Image tensor
        - Possible uncertainty image tensor.
        - Dictionary of metadata keys (str) and numeric values (torch.Tensor).
    """

    sorted_batch = sorted(batch, key=lambda x: x[3]['exposure_time'])
    indices, val_images, std_images, metas = zip(*sorted_batch)

    # Collate val_images and metas normally.
    index_batch = default_collate(indices)
    val_batch = default_collate(val_images)
    meta_batch = default_collate(metas)

    # Std_images are either collated normally or the batch is set to None.
    contains_none = False
    for std in std_images:
        if std is None:
            contains_none = True
            break

    if contains_none:
        std_batch = None
    else:
        std_batch = default_collate(std_images)

    return index_batch, val_batch, std_batch, meta_batch