Skip to content

Common

This page acts as the technical reference for the common subpackage.

The common package provides typical utilities used with all the other packages in this project, including - Classes and functions for IO operations. - Enums used with various classes. - General mathematical functions and transformation operations in both function and Class forms.

BaseFileSettings

Bases: ABC

Base class for implementing classes that manage IO operations.

BaseFileSettings acts as the guideline for designing classes for camera_linearity_torch's IO operations.

Attributes

input_path: Path a filepath from which the data will be read. output_path: Path a filepath to which data can be saved. Based on the optional output_path init argument. If None is given, then the parent directory of input_path is used as a root in which a new directory called based on default_output_root is created, in which the same filename as the input_path's name will be used to create a new file upon saving. default_output_root: Path a default dirpath to utilize as the directory in which the output file is created upon saving. Based on the default_output_root init argument. If None is given, then defaults to a new dirpath called 'clair_torch_output' in the directory of the input file. cpu_transforms: Transform | Sequence[Transform] | None Optional collection of Transforms that will be performed on the data right after reading it from a file.

Source code in clair_torch/common/base.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class BaseFileSettings(ABC):
    """
    Base class for implementing classes that manage IO operations.

    BaseFileSettings acts as the guideline for designing classes for camera_linearity_torch's IO operations.

    Attributes
    ----------
    input_path: Path
        a filepath from which the data will be read.
    output_path: Path
        a filepath to which data can be saved. Based on the optional output_path init argument. If None is given, then
        the parent directory of input_path is used as a root in which a new directory called based on default_output_root
        is created, in which the same filename as the input_path's name will be used to create a new file upon saving.
    default_output_root: Path
        a default dirpath to utilize as the directory in which the output file is created upon saving. Based on the
        default_output_root init argument. If None is given, then defaults to a new dirpath called
        'clair_torch_output' in the directory of the input file.
    cpu_transforms: Transform | Sequence[Transform] | None
        Optional collection of Transforms that will be performed on the data right after reading it from a file.
    """
    @typechecked
    def __init__(self, input_path: str | Path, output_path: Optional[str | Path] = None,
                 default_output_root: Optional[str | Path] = None,
                 cpu_transforms: Optional[BaseTransform | Sequence[BaseTransform]] = None):
        """
        Initializes the instance with the given paths. Output and default roots are optional and defer to a default
        output root if None is given. output_path overrides any possible default_output_root values. Upon loading data
        the given cpu_transforms are performed on the data sequentially.
        Args:
            input_path: the path at which the file is found.
            output_path: the path to which output a modified file.
            default_output_root: a root directory path to utilize if no output_path is given.
            cpu_transforms: Transform(s) to be performed on the data on the cpu-side upon loading the data.
        """
        input_path = Path(input_path)
        output_path = Path(output_path) if output_path is not None else None
        default_output_root = Path(default_output_root) if default_output_root is not None else None

        validate_input_file_path(input_path, None)
        self.input_path = input_path
        if default_output_root is None:
            self.default_output_root = self.input_path.parent.joinpath("clair_torch_output")
        else:
            self.default_output_root = default_output_root

        if output_path is not None:
            self.output_path = output_path
        else:
            self.output_path = self.default_output_root.joinpath(self.input_path.name)

        if not is_potentially_valid_file_path(self.output_path):
            raise IOError(f"Invalid path for your OS: {self.output_path}")

        if cpu_transforms is not None and not isinstance(cpu_transforms, (list, tuple)):
            cpu_transforms = [cpu_transforms]

        if cpu_transforms is not None and not all(isinstance(t, BaseTransform) for t in cpu_transforms):
            raise TypeError(f"At least one item in transforms is of incorrect type: {cpu_transforms}")

        self.cpu_transforms = cpu_transforms

    @abstractmethod
    def get_input_paths(self) -> Path | tuple[Path, ...]:
        """
        Method for getting the input path(s) from a FileSettings class.
        Returns:
            A single Path or tuple of Paths.
        """
        pass

    @abstractmethod
    def get_output_paths(self) -> Path | tuple[Path, ...]:
        """
        Method for getting the output path(s) from a FileSettings class.
        Returns:
            A single Path or tuple of Paths.
        """
        pass

    @abstractmethod
    def get_transforms(self) -> List[BaseTransform] | None:
        """
        Method for getting the possible Transform operations from a FileSettings class.
        Returns:
            A list of Transforms or None if no Transforms are given.
        """
        pass

    @abstractmethod
    def get_numeric_metadata(self) -> dict:
        """
        Method for getting numeric metadata associated with a file. Should always return at least an empty dict.
        Returns:
            dict[str, int | float]
        """
        pass

    @abstractmethod
    def get_text_metadata(self) -> dict:
        """
        Method for getting the text metadata associated with a file. Should always return at least an empty dict.
        Returns:
            dict[str, str]
        """
        pass

    @abstractmethod
    def get_all_metadata(self) -> dict:
        """
        Method for getting all the metadata associated with a file. Should always return at least an empty dict.
        Returns:
            dict[str, int | float | str]
        """
        pass

__init__(input_path, output_path=None, default_output_root=None, cpu_transforms=None)

Initializes the instance with the given paths. Output and default roots are optional and defer to a default output root if None is given. output_path overrides any possible default_output_root values. Upon loading data the given cpu_transforms are performed on the data sequentially. Args: input_path: the path at which the file is found. output_path: the path to which output a modified file. default_output_root: a root directory path to utilize if no output_path is given. cpu_transforms: Transform(s) to be performed on the data on the cpu-side upon loading the data.

Source code in clair_torch/common/base.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@typechecked
def __init__(self, input_path: str | Path, output_path: Optional[str | Path] = None,
             default_output_root: Optional[str | Path] = None,
             cpu_transforms: Optional[BaseTransform | Sequence[BaseTransform]] = None):
    """
    Initializes the instance with the given paths. Output and default roots are optional and defer to a default
    output root if None is given. output_path overrides any possible default_output_root values. Upon loading data
    the given cpu_transforms are performed on the data sequentially.
    Args:
        input_path: the path at which the file is found.
        output_path: the path to which output a modified file.
        default_output_root: a root directory path to utilize if no output_path is given.
        cpu_transforms: Transform(s) to be performed on the data on the cpu-side upon loading the data.
    """
    input_path = Path(input_path)
    output_path = Path(output_path) if output_path is not None else None
    default_output_root = Path(default_output_root) if default_output_root is not None else None

    validate_input_file_path(input_path, None)
    self.input_path = input_path
    if default_output_root is None:
        self.default_output_root = self.input_path.parent.joinpath("clair_torch_output")
    else:
        self.default_output_root = default_output_root

    if output_path is not None:
        self.output_path = output_path
    else:
        self.output_path = self.default_output_root.joinpath(self.input_path.name)

    if not is_potentially_valid_file_path(self.output_path):
        raise IOError(f"Invalid path for your OS: {self.output_path}")

    if cpu_transforms is not None and not isinstance(cpu_transforms, (list, tuple)):
        cpu_transforms = [cpu_transforms]

    if cpu_transforms is not None and not all(isinstance(t, BaseTransform) for t in cpu_transforms):
        raise TypeError(f"At least one item in transforms is of incorrect type: {cpu_transforms}")

    self.cpu_transforms = cpu_transforms

get_all_metadata() abstractmethod

Method for getting all the metadata associated with a file. Should always return at least an empty dict. Returns: dict[str, int | float | str]

Source code in clair_torch/common/base.py
122
123
124
125
126
127
128
129
@abstractmethod
def get_all_metadata(self) -> dict:
    """
    Method for getting all the metadata associated with a file. Should always return at least an empty dict.
    Returns:
        dict[str, int | float | str]
    """
    pass

get_input_paths() abstractmethod

Method for getting the input path(s) from a FileSettings class. Returns: A single Path or tuple of Paths.

Source code in clair_torch/common/base.py
77
78
79
80
81
82
83
84
@abstractmethod
def get_input_paths(self) -> Path | tuple[Path, ...]:
    """
    Method for getting the input path(s) from a FileSettings class.
    Returns:
        A single Path or tuple of Paths.
    """
    pass

get_numeric_metadata() abstractmethod

Method for getting numeric metadata associated with a file. Should always return at least an empty dict. Returns: dict[str, int | float]

Source code in clair_torch/common/base.py
104
105
106
107
108
109
110
111
@abstractmethod
def get_numeric_metadata(self) -> dict:
    """
    Method for getting numeric metadata associated with a file. Should always return at least an empty dict.
    Returns:
        dict[str, int | float]
    """
    pass

get_output_paths() abstractmethod

Method for getting the output path(s) from a FileSettings class. Returns: A single Path or tuple of Paths.

Source code in clair_torch/common/base.py
86
87
88
89
90
91
92
93
@abstractmethod
def get_output_paths(self) -> Path | tuple[Path, ...]:
    """
    Method for getting the output path(s) from a FileSettings class.
    Returns:
        A single Path or tuple of Paths.
    """
    pass

get_text_metadata() abstractmethod

Method for getting the text metadata associated with a file. Should always return at least an empty dict. Returns: dict[str, str]

Source code in clair_torch/common/base.py
113
114
115
116
117
118
119
120
@abstractmethod
def get_text_metadata(self) -> dict:
    """
    Method for getting the text metadata associated with a file. Should always return at least an empty dict.
    Returns:
        dict[str, str]
    """
    pass

get_transforms() abstractmethod

Method for getting the possible Transform operations from a FileSettings class. Returns: A list of Transforms or None if no Transforms are given.

Source code in clair_torch/common/base.py
 95
 96
 97
 98
 99
100
101
102
@abstractmethod
def get_transforms(self) -> List[BaseTransform] | None:
    """
    Method for getting the possible Transform operations from a FileSettings class.
    Returns:
        A list of Transforms or None if no Transforms are given.
    """
    pass

BaseTransform

Bases: ABC

Base class for Transform classes, which typically wrap a function from general functions. Must be callable, taking a torch.Tensor as input and return a torch.Tensor.

Source code in clair_torch/common/transforms.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class BaseTransform(ABC):
    """
    Base class for Transform classes, which typically wrap a function from general functions. Must be callable, taking
    a torch.Tensor as input and return a torch.Tensor.
    """
    def __init__(self, *args, **kwargs) -> None:
        ...

    @abstractmethod
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        ...

    @abstractmethod
    def to_config(self) -> dict[str, Any]:
        """
        Serialization function, which dumps the instance into a dictionary representation.
        Returns: dictionary that can be used to deserialize back into an equivalent instance.
        """
        ...

    @classmethod
    def from_config(cls, cfg: dict[str, Any]):
        """
        Deserialization function, which initializes a new instance from a dictionary representation
        of the class.
        Args:
            cfg: the configuration dictionary.

        Returns:
            A new instance of this class.
        """
        return cls(**cfg)

from_config(cfg) classmethod

Deserialization function, which initializes a new instance from a dictionary representation of the class. Args: cfg: the configuration dictionary.

Returns:

Type Description

A new instance of this class.

Source code in clair_torch/common/transforms.py
53
54
55
56
57
58
59
60
61
62
63
64
@classmethod
def from_config(cls, cfg: dict[str, Any]):
    """
    Deserialization function, which initializes a new instance from a dictionary representation
    of the class.
    Args:
        cfg: the configuration dictionary.

    Returns:
        A new instance of this class.
    """
    return cls(**cfg)

to_config() abstractmethod

Serialization function, which dumps the instance into a dictionary representation. Returns: dictionary that can be used to deserialize back into an equivalent instance.

Source code in clair_torch/common/transforms.py
45
46
47
48
49
50
51
@abstractmethod
def to_config(self) -> dict[str, Any]:
    """
    Serialization function, which dumps the instance into a dictionary representation.
    Returns: dictionary that can be used to deserialize back into an equivalent instance.
    """
    ...

CastTo

Bases: BaseTransform

Transform for casting the tensor to the given datatype and device. If data_type or device are not give, then maintain them as they are.

Source code in clair_torch/common/transforms.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
@_register_transform
class CastTo(BaseTransform):
    """
    Transform for casting the tensor to the given datatype and device. If data_type or device are not give, then
    maintain them as they are.
    """
    @typechecked()
    def __init__(self, data_type: Optional[str | torch.dtype] = None, device: Optional[str | torch.device] = None):
        super().__init__()
        if isinstance(data_type, str):
            data_type = DTYPE_MAP[data_type]
        if isinstance(device, str):
            device = torch.device(device)

        self.data_type = data_type
        self.device = device

    @typechecked
    def __call__(self, x: torch.Tensor):

        data_type = self.data_type if self.data_type is not None else x.dtype
        device = self.device if self.device is not None else x.device

        return x.to(dtype=data_type, device=device)

    def to_config(self) -> dict[str, Any]:
        dtype = REVERSE_DTYPE_MAP[self.data_type]
        return {
            "data_type": dtype,
            "device": str(self.device) if self.device is not None else None
        }

ClampAlongDims

Bases: BaseTransform

Transform for clamping the tensor values between a min and max value, along the given dimension(s).

Source code in clair_torch/common/transforms.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
@_register_transform
class ClampAlongDims(BaseTransform):
    """
    Transform for clamping the tensor values between a min and max value, along the given dimension(s).
    """
    @typechecked
    def __init__(self, dim: int | tuple[int, ...], min_max_pairs: tuple[float, float] | list[tuple[float, float]]):

        super().__init__()
        self.dim = dim
        self.min_max_pairs = min_max_pairs

    @typechecked
    def __call__(self, x: torch.Tensor) -> torch.Tensor:

        return clamp_along_dims(x, self.dim, self.min_max_pairs)

    def to_config(self) -> dict[str, Any]:
        return {
            "dim": self.dim,
            "min_max_pairs": self.min_max_pairs,
        }

CvToTorch

Bases: BaseTransform

Transform for modifying a tensor from OpenCV dimensions and channel ordering to a PyTorch dimensionality and ordering.

Source code in clair_torch/common/transforms.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@_register_transform
class CvToTorch(BaseTransform):
    """
    Transform for modifying a tensor from OpenCV dimensions and channel ordering to a PyTorch dimensionality and
    ordering.
    """

    def __init__(self):
        super().__init__()
        pass

    @typechecked
    def __call__(self, x: torch.Tensor) -> torch.Tensor:

        return cv_to_torch(x)

    def to_config(self) -> dict[str, Any]:
        return {}

FileSettings

Bases: BaseFileSettings

Class for managing input and output paths related to an arbitrary file. Main use is to manage the IO settings for use inside a PyTorch Dataset class.

Attributes:

Inherits attributes from BaseFileSettings.

Source code in clair_torch/common/file_settings.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class FileSettings(BaseFileSettings):
    """
    Class for managing input and output paths related to an arbitrary file. Main use is to manage the IO settings for
    use inside a PyTorch Dataset class.

    Attributes:
    -----------
    Inherits attributes from BaseFileSettings.
    """

    @typechecked
    def __init__(self, input_path: str | Path, output_path: Optional[str | Path] = None,
                 default_output_root: Optional[str | Path] = None,
                 cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None):
        """
        Initializes the instance with the given paths. Output and default roots are optional and defer to a default
        output root if None is given. output_path overrides any possible default_output_root values. Upon loading data
        the given cpu_transforms are performed on the data sequentially.
        Args:
            input_path: the path at which the file is found.
            output_path: the path to which output a modified file.
            default_output_root: a root directory path to utilize if no output_path is given.
            cpu_transforms: Transform(s) to be performed on the data on the cpu-side upon loading the data.
        """
        super().__init__(input_path, output_path, default_output_root, cpu_transforms)

    def get_input_paths(self) -> Path:
        """
        Method for getting the input path.
        Returns:
            Path.
        """
        return self.input_path

    def get_output_paths(self) -> Path:
        """
        Method for getting the output path.
        Returns:
            Path.
        """
        return self.output_path

    def get_candidate_std_output_path(self) -> Path:
        """
        Method for getting a candidate output path for an uncertainty file.
        Returns:
            Constructed uncertainty file output path based on the main input file and its filetype suffix.
        """
        candidate_name = self.input_path.stem + " STD" + self.input_path.suffix
        std_output_path = self.output_path.parent / candidate_name

        return std_output_path

    def get_transforms(self) -> List[BaseTransform] | None:
        """
        Method for getting the possible Transform operations to be performed on the data upon reading it.
        Returns:
            List[Transform] or None, if no Transform operations were given on init.
        """
        return self.cpu_transforms

    def get_numeric_metadata(self) -> dict:
        """
        Unused method stub inherited from the base class. This is used in classes further down the subclass tree.
        Returns:
            Empty dict.
        """
        return {}

    def get_text_metadata(self) -> dict:
        """
        Unused method stub inherited from the base class. This is used in classes further down the subclass tree.
        Returns:
            Empty dict.
        """
        return {}

    def get_all_metadata(self) -> dict:
        """
        Unused method stub inherited from the base class. This is used in classes further down the subclass tree.
        Returns:
            Empty dict.
        """
        return {}

__init__(input_path, output_path=None, default_output_root=None, cpu_transforms=None)

Initializes the instance with the given paths. Output and default roots are optional and defer to a default output root if None is given. output_path overrides any possible default_output_root values. Upon loading data the given cpu_transforms are performed on the data sequentially. Args: input_path: the path at which the file is found. output_path: the path to which output a modified file. default_output_root: a root directory path to utilize if no output_path is given. cpu_transforms: Transform(s) to be performed on the data on the cpu-side upon loading the data.

Source code in clair_torch/common/file_settings.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@typechecked
def __init__(self, input_path: str | Path, output_path: Optional[str | Path] = None,
             default_output_root: Optional[str | Path] = None,
             cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None):
    """
    Initializes the instance with the given paths. Output and default roots are optional and defer to a default
    output root if None is given. output_path overrides any possible default_output_root values. Upon loading data
    the given cpu_transforms are performed on the data sequentially.
    Args:
        input_path: the path at which the file is found.
        output_path: the path to which output a modified file.
        default_output_root: a root directory path to utilize if no output_path is given.
        cpu_transforms: Transform(s) to be performed on the data on the cpu-side upon loading the data.
    """
    super().__init__(input_path, output_path, default_output_root, cpu_transforms)

get_all_metadata()

Unused method stub inherited from the base class. This is used in classes further down the subclass tree. Returns: Empty dict.

Source code in clair_torch/common/file_settings.py
 94
 95
 96
 97
 98
 99
100
def get_all_metadata(self) -> dict:
    """
    Unused method stub inherited from the base class. This is used in classes further down the subclass tree.
    Returns:
        Empty dict.
    """
    return {}

get_candidate_std_output_path()

Method for getting a candidate output path for an uncertainty file. Returns: Constructed uncertainty file output path based on the main input file and its filetype suffix.

Source code in clair_torch/common/file_settings.py
59
60
61
62
63
64
65
66
67
68
def get_candidate_std_output_path(self) -> Path:
    """
    Method for getting a candidate output path for an uncertainty file.
    Returns:
        Constructed uncertainty file output path based on the main input file and its filetype suffix.
    """
    candidate_name = self.input_path.stem + " STD" + self.input_path.suffix
    std_output_path = self.output_path.parent / candidate_name

    return std_output_path

get_input_paths()

Method for getting the input path. Returns: Path.

Source code in clair_torch/common/file_settings.py
43
44
45
46
47
48
49
def get_input_paths(self) -> Path:
    """
    Method for getting the input path.
    Returns:
        Path.
    """
    return self.input_path

get_numeric_metadata()

Unused method stub inherited from the base class. This is used in classes further down the subclass tree. Returns: Empty dict.

Source code in clair_torch/common/file_settings.py
78
79
80
81
82
83
84
def get_numeric_metadata(self) -> dict:
    """
    Unused method stub inherited from the base class. This is used in classes further down the subclass tree.
    Returns:
        Empty dict.
    """
    return {}

get_output_paths()

Method for getting the output path. Returns: Path.

Source code in clair_torch/common/file_settings.py
51
52
53
54
55
56
57
def get_output_paths(self) -> Path:
    """
    Method for getting the output path.
    Returns:
        Path.
    """
    return self.output_path

get_text_metadata()

Unused method stub inherited from the base class. This is used in classes further down the subclass tree. Returns: Empty dict.

Source code in clair_torch/common/file_settings.py
86
87
88
89
90
91
92
def get_text_metadata(self) -> dict:
    """
    Unused method stub inherited from the base class. This is used in classes further down the subclass tree.
    Returns:
        Empty dict.
    """
    return {}

get_transforms()

Method for getting the possible Transform operations to be performed on the data upon reading it. Returns: List[Transform] or None, if no Transform operations were given on init.

Source code in clair_torch/common/file_settings.py
70
71
72
73
74
75
76
def get_transforms(self) -> List[BaseTransform] | None:
    """
    Method for getting the possible Transform operations to be performed on the data upon reading it.
    Returns:
        List[Transform] or None, if no Transform operations were given on init.
    """
    return self.cpu_transforms

FrameSettings

Bases: FileSettings

Class for managing input and output paths related to an arbitrary file, with the addition of managing image related metadata. Main use is to manage the IO settings for use inside a PyTorch Dataset class.

Attributes:

Inherits attributes from FileSettings. metadata: BaseMetadata an encapsulated instance of a BaseMetadata subclass for managing the metadata related to images.

Source code in clair_torch/common/file_settings.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class FrameSettings(FileSettings):
    """
    Class for managing input and output paths related to an arbitrary file, with the addition of managing image related
    metadata. Main use is to manage the IO settings for use inside a PyTorch Dataset class.

    Attributes:
    -----------
    Inherits attributes from FileSettings.
    metadata: BaseMetadata
        an encapsulated instance of a BaseMetadata subclass for managing the metadata related to images.
    """
    @typechecked
    def __init__(self, input_path: str | Path, output_path: Optional[str | Path] = None,
                 default_output_root: Optional[str | Path] = None,
                 cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
                 metadata_cls: Type[BaseMetadata] = ImagingMetadata,
                 *metadata_args, **metadata_kwargs):
        super().__init__(input_path, output_path, default_output_root, cpu_transforms)

        self.metadata = metadata_cls(input_path, *metadata_args, **metadata_kwargs)

    def get_numeric_metadata(self) -> dict[str, float | int | None]:
        """
        Method for getting the numeric metadata managed by the encapsulated Metadata class.
        Returns:
            dict[str, int | float].
        """
        return self.metadata.get_numeric_metadata()

    def get_text_metadata(self) -> dict[str, str | None]:
        """
        Method for getting the text metadata managed by the encapsulated Metadata class.
        Returns:
            dict[str, str].
        """
        return self.metadata.get_text_metadata()

    def get_all_metadata(self) -> dict[str, str | int | float | None]:
        """
        Method for getting all the metadata managed by the encapsulated Metadata class.
        Returns:
            dict[str, str | int | float | None].
        """
        return self.metadata.get_all_metadata()

    @typechecked
    def is_match(self, reference: FrameSettings | PairedFrameSettings | BaseMetadata,
                 attributes: dict[str, Optional[int | float]]) -> bool:
        """
        Method for evaluating whether the metadata contained in a given FramSettings instance or Metadata instance are a
        match based on the given sequence of attributes, which act as keys to the metadata dictionary in a Metadata
        instance.
        Args:
            reference: a FrameSettings instance or a BaseMetadata subclass instance.
            attributes: a sequence of string, which define the dictionary keys, whose associated values must be equal
                for a successful match.

        Returns:
            bool, True for a successful match, False for failed.
        """
        if isinstance(reference, (FrameSettings, PairedFrameSettings)):
            reference_metadata = reference.metadata
        elif isinstance(reference, BaseMetadata):
            reference_metadata = reference
        else:
            return False

        return self.metadata.is_match(reference_metadata, attributes)

get_all_metadata()

Method for getting all the metadata managed by the encapsulated Metadata class. Returns: dict[str, str | int | float | None].

Source code in clair_torch/common/file_settings.py
140
141
142
143
144
145
146
def get_all_metadata(self) -> dict[str, str | int | float | None]:
    """
    Method for getting all the metadata managed by the encapsulated Metadata class.
    Returns:
        dict[str, str | int | float | None].
    """
    return self.metadata.get_all_metadata()

get_numeric_metadata()

Method for getting the numeric metadata managed by the encapsulated Metadata class. Returns: dict[str, int | float].

Source code in clair_torch/common/file_settings.py
124
125
126
127
128
129
130
def get_numeric_metadata(self) -> dict[str, float | int | None]:
    """
    Method for getting the numeric metadata managed by the encapsulated Metadata class.
    Returns:
        dict[str, int | float].
    """
    return self.metadata.get_numeric_metadata()

get_text_metadata()

Method for getting the text metadata managed by the encapsulated Metadata class. Returns: dict[str, str].

Source code in clair_torch/common/file_settings.py
132
133
134
135
136
137
138
def get_text_metadata(self) -> dict[str, str | None]:
    """
    Method for getting the text metadata managed by the encapsulated Metadata class.
    Returns:
        dict[str, str].
    """
    return self.metadata.get_text_metadata()

is_match(reference, attributes)

Method for evaluating whether the metadata contained in a given FramSettings instance or Metadata instance are a match based on the given sequence of attributes, which act as keys to the metadata dictionary in a Metadata instance. Args: reference: a FrameSettings instance or a BaseMetadata subclass instance. attributes: a sequence of string, which define the dictionary keys, whose associated values must be equal for a successful match.

Returns:

Type Description
bool

bool, True for a successful match, False for failed.

Source code in clair_torch/common/file_settings.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@typechecked
def is_match(self, reference: FrameSettings | PairedFrameSettings | BaseMetadata,
             attributes: dict[str, Optional[int | float]]) -> bool:
    """
    Method for evaluating whether the metadata contained in a given FramSettings instance or Metadata instance are a
    match based on the given sequence of attributes, which act as keys to the metadata dictionary in a Metadata
    instance.
    Args:
        reference: a FrameSettings instance or a BaseMetadata subclass instance.
        attributes: a sequence of string, which define the dictionary keys, whose associated values must be equal
            for a successful match.

    Returns:
        bool, True for a successful match, False for failed.
    """
    if isinstance(reference, (FrameSettings, PairedFrameSettings)):
        reference_metadata = reference.metadata
    elif isinstance(reference, BaseMetadata):
        reference_metadata = reference
    else:
        return False

    return self.metadata.is_match(reference_metadata, attributes)

InterpMode

Bases: Enum

Manages the interpolation modes used in ICRF model classes.

Source code in clair_torch/common/enums.py
10
11
12
13
14
15
16
class InterpMode(Enum):
    """
    Manages the interpolation modes used in ICRF model classes.
    """
    LOOKUP = auto()  # nearest-neighbour LUT, no-grad fast path
    LINEAR = auto()
    CATMULL = auto()

MissingStdMode

Bases: Enum

Manages how missing uncertainty images are dealt with in ImageDataset classes.

Source code in clair_torch/common/enums.py
33
34
35
36
37
38
39
class MissingStdMode(Enum):
    """
    Manages how missing uncertainty images are dealt with in ImageDataset classes.
    """
    NONE = auto()
    CONSTANT = auto()
    MULTIPLIER = auto()

Normalize

Bases: BaseTransform

Transform for normalizing a tensor by a minimum and maximum value. If no values are given, dynamically use min and max of the given tensor.

Source code in clair_torch/common/transforms.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@_register_transform
class Normalize(BaseTransform):
    """
    Transform for normalizing a tensor by a minimum and maximum value.
    If no values are given, dynamically use min and max of the given tensor.
    """

    @typechecked
    def __init__(self, max_val: Optional[int | float] = None, min_val: Optional[int | float] = None,
                 target_range: Iterable[int | float] = (0.0, 1.0)):
        super().__init__()
        target_range = normalize_container(target_range, tuple, convert_if_iterable=True)

        self.max_val = max_val
        self.min_val = min_val
        self.target_range = target_range

    @typechecked
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return normalize_tensor(x, max_val=self.max_val, min_val=self.min_val, target_range=self.target_range)

    def to_config(self) -> dict[str, Any]:
        return {
            "max_val": self.max_val,
            "min_val": self.min_val,
            "target_range": tuple(self.target_range),
        }

PairedFileSettings

Bases: FileSettings

Class for managing input and output paths of a pair of arbitrary files. Main use is to handle the paired IO operations of a value image and its associated uncertainty image. Composed of two instances of FileSettings.

Attributes:

Inherits attributes from FileSettings.

Source code in clair_torch/common/file_settings.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class PairedFileSettings(FileSettings):
    """
    Class for managing input and output paths of a pair of arbitrary files. Main use is to handle the paired IO
    operations of a value image and its associated uncertainty image. Composed of two instances of FileSettings.

    Attributes:
    -----------
    Inherits attributes from FileSettings.
    """
    @typechecked
    def __init__(self, val_input_path: str | Path, std_input_path: Optional[str | Path] = None,
                 val_output_path: Optional[str | Path] = None,
                 std_output_path: Optional[str | Path] = None,
                 default_output_root: Optional[str | Path] = None,
                 val_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
                 std_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None):
        """
        Initializes a PairedFileSettings object with the given paths. The class both inherits from FileSettings and is
        a composition of two instances of FileSettings. Referring to self.val_input_path is the same as referring to
        self.val_settings.input_path, same follows for other attributes. PairedFileSettings should only be used when an
        uncertainty input file does exist, even though the std_input_path parameter is Optional. The optionality is left
        to enable easy implicit seeking of std files. Use regular FileSettings if no uncertainty files are to be used.
        Args:
            val_input_path: input path for the main value file.
            std_input_path: input path for the associated uncertainty file.
            val_output_path: output path for the modified value file.
            std_output_path: output path for the modified uncertainty file.
            default_output_root: a root directory path to utilize if no output_path is given.
            val_cpu_transforms: transform operations for the value file.
            std_cpu_transforms transform operations for the uncertainty file.
        """
        super().__init__(val_input_path, val_output_path, default_output_root, val_cpu_transforms)

        self.val_settings = FileSettings(
            val_input_path,
            output_path=val_output_path,
            default_output_root=default_output_root,
            cpu_transforms=val_cpu_transforms
        )

        # This enables implicit seeking of STD files, when no STD input file is directly given.
        if std_input_path is None:
            candidate_name = self.val_settings.input_path.stem + " STD" + self.val_settings.input_path.suffix
            std_input_path = self.val_settings.input_path.parent / candidate_name

        self.std_settings = FileSettings(
            std_input_path,
            output_path=std_output_path,
            default_output_root=self.val_settings.default_output_root,
            cpu_transforms=std_cpu_transforms
        )

    def get_input_paths(self) -> Tuple[Path, Path]:
        """
        Method for getting the input paths of a PairedFileSettings instance. Overrides the inherited method by deferring
        the process for the two encapsulated instances.
        Returns:
            Tuple of Paths, first for value file, second for uncertainty file.
        """
        return self.val_settings.get_input_paths(), self.std_settings.get_input_paths()

    def get_output_paths(self) -> Tuple[Path, Path]:
        """
        Method for getting the output paths of a PairedFileSettings instance. Overrides the inherited method by deferring
        the process for the two encapsulated instances.
        Returns:
            Tuple of paths, first for value file, second for uncertainty file.
        """
        return self.val_settings.get_output_paths(), self.std_settings.get_output_paths()

    def get_transforms(self) -> Tuple[List[BaseTransform], List[BaseTransform]]:
        """
        Method for getting the transformation operations. Overrides the inherited method by deferring the process for
        the two encapsulated instances.
        Returns:
            Tuple of Lists of Transforms, first for value file Transforms, second for uncertainty file Transforms.
        """
        return self.val_settings.get_transforms(), self.std_settings.get_transforms()

__init__(val_input_path, std_input_path=None, val_output_path=None, std_output_path=None, default_output_root=None, val_cpu_transforms=None, std_cpu_transforms=None)

Initializes a PairedFileSettings object with the given paths. The class both inherits from FileSettings and is a composition of two instances of FileSettings. Referring to self.val_input_path is the same as referring to self.val_settings.input_path, same follows for other attributes. PairedFileSettings should only be used when an uncertainty input file does exist, even though the std_input_path parameter is Optional. The optionality is left to enable easy implicit seeking of std files. Use regular FileSettings if no uncertainty files are to be used. Args: val_input_path: input path for the main value file. std_input_path: input path for the associated uncertainty file. val_output_path: output path for the modified value file. std_output_path: output path for the modified uncertainty file. default_output_root: a root directory path to utilize if no output_path is given. val_cpu_transforms: transform operations for the value file. std_cpu_transforms transform operations for the uncertainty file.

Source code in clair_torch/common/file_settings.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@typechecked
def __init__(self, val_input_path: str | Path, std_input_path: Optional[str | Path] = None,
             val_output_path: Optional[str | Path] = None,
             std_output_path: Optional[str | Path] = None,
             default_output_root: Optional[str | Path] = None,
             val_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
             std_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None):
    """
    Initializes a PairedFileSettings object with the given paths. The class both inherits from FileSettings and is
    a composition of two instances of FileSettings. Referring to self.val_input_path is the same as referring to
    self.val_settings.input_path, same follows for other attributes. PairedFileSettings should only be used when an
    uncertainty input file does exist, even though the std_input_path parameter is Optional. The optionality is left
    to enable easy implicit seeking of std files. Use regular FileSettings if no uncertainty files are to be used.
    Args:
        val_input_path: input path for the main value file.
        std_input_path: input path for the associated uncertainty file.
        val_output_path: output path for the modified value file.
        std_output_path: output path for the modified uncertainty file.
        default_output_root: a root directory path to utilize if no output_path is given.
        val_cpu_transforms: transform operations for the value file.
        std_cpu_transforms transform operations for the uncertainty file.
    """
    super().__init__(val_input_path, val_output_path, default_output_root, val_cpu_transforms)

    self.val_settings = FileSettings(
        val_input_path,
        output_path=val_output_path,
        default_output_root=default_output_root,
        cpu_transforms=val_cpu_transforms
    )

    # This enables implicit seeking of STD files, when no STD input file is directly given.
    if std_input_path is None:
        candidate_name = self.val_settings.input_path.stem + " STD" + self.val_settings.input_path.suffix
        std_input_path = self.val_settings.input_path.parent / candidate_name

    self.std_settings = FileSettings(
        std_input_path,
        output_path=std_output_path,
        default_output_root=self.val_settings.default_output_root,
        cpu_transforms=std_cpu_transforms
    )

get_input_paths()

Method for getting the input paths of a PairedFileSettings instance. Overrides the inherited method by deferring the process for the two encapsulated instances. Returns: Tuple of Paths, first for value file, second for uncertainty file.

Source code in clair_torch/common/file_settings.py
225
226
227
228
229
230
231
232
def get_input_paths(self) -> Tuple[Path, Path]:
    """
    Method for getting the input paths of a PairedFileSettings instance. Overrides the inherited method by deferring
    the process for the two encapsulated instances.
    Returns:
        Tuple of Paths, first for value file, second for uncertainty file.
    """
    return self.val_settings.get_input_paths(), self.std_settings.get_input_paths()

get_output_paths()

Method for getting the output paths of a PairedFileSettings instance. Overrides the inherited method by deferring the process for the two encapsulated instances. Returns: Tuple of paths, first for value file, second for uncertainty file.

Source code in clair_torch/common/file_settings.py
234
235
236
237
238
239
240
241
def get_output_paths(self) -> Tuple[Path, Path]:
    """
    Method for getting the output paths of a PairedFileSettings instance. Overrides the inherited method by deferring
    the process for the two encapsulated instances.
    Returns:
        Tuple of paths, first for value file, second for uncertainty file.
    """
    return self.val_settings.get_output_paths(), self.std_settings.get_output_paths()

get_transforms()

Method for getting the transformation operations. Overrides the inherited method by deferring the process for the two encapsulated instances. Returns: Tuple of Lists of Transforms, first for value file Transforms, second for uncertainty file Transforms.

Source code in clair_torch/common/file_settings.py
243
244
245
246
247
248
249
250
def get_transforms(self) -> Tuple[List[BaseTransform], List[BaseTransform]]:
    """
    Method for getting the transformation operations. Overrides the inherited method by deferring the process for
    the two encapsulated instances.
    Returns:
        Tuple of Lists of Transforms, first for value file Transforms, second for uncertainty file Transforms.
    """
    return self.val_settings.get_transforms(), self.std_settings.get_transforms()

PairedFrameSettings

Bases: PairedFileSettings

Class for managing paired files with their associated metadatas, based on the PairedFileSettings class.

Attributes:

Inherits attributes from PairedFileSettings. val_metadata: BaseMetadata an encapsulated instance of a BaseMetadata subclass for managing the metadata related to the value image. std_metadata: BaseMetadata an encapsulated instance of a BaseMetadata subclass for managing the metadata related to the possible uncertainty image. Typically, these are equal to the val_metadata values, so by default this is not parsed. The parse_std_meta init argument is used to determine whether the std_metadata is parsed or not.

Source code in clair_torch/common/file_settings.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class PairedFrameSettings(PairedFileSettings):
    """
    Class for managing paired files with their associated metadatas, based on the PairedFileSettings class.

    Attributes:
    -----------
    Inherits attributes from PairedFileSettings.
    val_metadata: BaseMetadata
        an encapsulated instance of a BaseMetadata subclass for managing the metadata related to the value image.
    std_metadata: BaseMetadata
        an encapsulated instance of a BaseMetadata subclass for managing the metadata related to the possible uncertainty
        image. Typically, these are equal to the val_metadata values, so by default this is not parsed. The parse_std_meta
        init argument is used to determine whether the std_metadata is parsed or not.
    """
    @typechecked
    def __init__(self, val_input_path: str | Path, std_input_path: Optional[str | Path] = None,
                 val_output_path: Optional[str | Path] = None, std_output_path: Optional[str | Path] = None,
                 default_output_root: Optional[str | Path] = None,
                 val_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
                 std_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
                 parse_std_meta: bool = False, metadata_cls: Type[BaseMetadata] = ImagingMetadata,
                 *metadata_args, **metadata_kwargs):
        """
        File settings related aspects are all delegated to the PairedFileSettings super class. Additional responsibility
        for this class is maintaining and allowing access to the metadatas of each of the encapsulated files via
        subclasses of BaseMetadata.
        Args:
            val_input_path: input path for the main value file.
            std_input_path: input path for the associated uncertainty file.
            val_output_path: output path for the modified value file.
            std_output_path: output path for the modified uncertainty file.
            default_output_root: a root directory path to utilize if no output_path is given.
            val_cpu_transforms: transform operations for the value file.
            std_cpu_transforms transform operations for the uncertainty file.
            parse_std_meta: whether to parse a Metadata class instance for the STD file.
            metadata_cls: a subclass of BaseMetadata
            *metadata_args: additional args to pass to instantiating the given metadata_cls.
            **metadata_kwargs: additional kwargs to pass to instantiating the given metadata_cls.
        """
        super().__init__(val_input_path, std_input_path, val_output_path, std_output_path, default_output_root,
                         val_cpu_transforms, std_cpu_transforms)

        self.val_metadata = metadata_cls(self.val_settings.input_path, *metadata_args, **metadata_kwargs)

        self.std_metadata = (
            metadata_cls(self.std_settings.input_path, *metadata_args, **metadata_kwargs)
            if parse_std_meta and self.std_settings is not None
            else None
        )

    def get_numeric_metadata(self) -> dict[str, float | int | None]:
        """
        Method for getting the numeric metadata of the value image.
        Returns:
            dict[str, float | int | None]
        """
        return self.val_metadata.get_numeric_metadata()

    def get_text_metadata(self) -> dict[str, str | None]:
        """
        Method for getting the text metadata of the value image.
        Returns:
            dict[str, str | None]
        """
        return self.val_metadata.get_text_metadata()

    def get_all_metadata(self) -> dict[str, str | int | float]:
        """
        Method for getting all the metadata of the value image.
        Returns:
            dict[str, str | int | float | None]
        """
        return self.val_metadata.get_all_metadata()

    @typechecked
    def is_match(self, reference: FrameSettings | PairedFrameSettings | BaseMetadata,
                 attributes: dict[str, Optional[int | float]]) -> bool:
        """
        Method for evaluating whether the metadata contained in a given FramSettings instance or Metadata instance are a
        match based on the given sequence of attributes, which act as keys to the metadata dictionary in a Metadata
        instance. Utilizes the val_metadata associated with the value image.
        Args:
            reference: a FrameSettings instance or a BaseMetadata subclass instance.
            attributes: a sequence of string, which define the dictionary keys, whose associated values must be equal
                for a successful match.

        Returns:
            bool, True for a successful match, False for failed.
        """
        if isinstance(reference, (FrameSettings, PairedFrameSettings)):
            reference_meta = reference.val_metadata
        elif isinstance(reference, BaseMetadata):
            reference_meta = reference
        else:
            return False

        return self.val_metadata.is_match(reference_meta, attributes)

__init__(val_input_path, std_input_path=None, val_output_path=None, std_output_path=None, default_output_root=None, val_cpu_transforms=None, std_cpu_transforms=None, parse_std_meta=False, metadata_cls=ImagingMetadata, *metadata_args, **metadata_kwargs)

File settings related aspects are all delegated to the PairedFileSettings super class. Additional responsibility for this class is maintaining and allowing access to the metadatas of each of the encapsulated files via subclasses of BaseMetadata. Args: val_input_path: input path for the main value file. std_input_path: input path for the associated uncertainty file. val_output_path: output path for the modified value file. std_output_path: output path for the modified uncertainty file. default_output_root: a root directory path to utilize if no output_path is given. val_cpu_transforms: transform operations for the value file. std_cpu_transforms transform operations for the uncertainty file. parse_std_meta: whether to parse a Metadata class instance for the STD file. metadata_cls: a subclass of BaseMetadata metadata_args: additional args to pass to instantiating the given metadata_cls. *metadata_kwargs: additional kwargs to pass to instantiating the given metadata_cls.

Source code in clair_torch/common/file_settings.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
@typechecked
def __init__(self, val_input_path: str | Path, std_input_path: Optional[str | Path] = None,
             val_output_path: Optional[str | Path] = None, std_output_path: Optional[str | Path] = None,
             default_output_root: Optional[str | Path] = None,
             val_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
             std_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
             parse_std_meta: bool = False, metadata_cls: Type[BaseMetadata] = ImagingMetadata,
             *metadata_args, **metadata_kwargs):
    """
    File settings related aspects are all delegated to the PairedFileSettings super class. Additional responsibility
    for this class is maintaining and allowing access to the metadatas of each of the encapsulated files via
    subclasses of BaseMetadata.
    Args:
        val_input_path: input path for the main value file.
        std_input_path: input path for the associated uncertainty file.
        val_output_path: output path for the modified value file.
        std_output_path: output path for the modified uncertainty file.
        default_output_root: a root directory path to utilize if no output_path is given.
        val_cpu_transforms: transform operations for the value file.
        std_cpu_transforms transform operations for the uncertainty file.
        parse_std_meta: whether to parse a Metadata class instance for the STD file.
        metadata_cls: a subclass of BaseMetadata
        *metadata_args: additional args to pass to instantiating the given metadata_cls.
        **metadata_kwargs: additional kwargs to pass to instantiating the given metadata_cls.
    """
    super().__init__(val_input_path, std_input_path, val_output_path, std_output_path, default_output_root,
                     val_cpu_transforms, std_cpu_transforms)

    self.val_metadata = metadata_cls(self.val_settings.input_path, *metadata_args, **metadata_kwargs)

    self.std_metadata = (
        metadata_cls(self.std_settings.input_path, *metadata_args, **metadata_kwargs)
        if parse_std_meta and self.std_settings is not None
        else None
    )

get_all_metadata()

Method for getting all the metadata of the value image. Returns: dict[str, str | int | float | None]

Source code in clair_torch/common/file_settings.py
319
320
321
322
323
324
325
def get_all_metadata(self) -> dict[str, str | int | float]:
    """
    Method for getting all the metadata of the value image.
    Returns:
        dict[str, str | int | float | None]
    """
    return self.val_metadata.get_all_metadata()

get_numeric_metadata()

Method for getting the numeric metadata of the value image. Returns: dict[str, float | int | None]

Source code in clair_torch/common/file_settings.py
303
304
305
306
307
308
309
def get_numeric_metadata(self) -> dict[str, float | int | None]:
    """
    Method for getting the numeric metadata of the value image.
    Returns:
        dict[str, float | int | None]
    """
    return self.val_metadata.get_numeric_metadata()

get_text_metadata()

Method for getting the text metadata of the value image. Returns: dict[str, str | None]

Source code in clair_torch/common/file_settings.py
311
312
313
314
315
316
317
def get_text_metadata(self) -> dict[str, str | None]:
    """
    Method for getting the text metadata of the value image.
    Returns:
        dict[str, str | None]
    """
    return self.val_metadata.get_text_metadata()

is_match(reference, attributes)

Method for evaluating whether the metadata contained in a given FramSettings instance or Metadata instance are a match based on the given sequence of attributes, which act as keys to the metadata dictionary in a Metadata instance. Utilizes the val_metadata associated with the value image. Args: reference: a FrameSettings instance or a BaseMetadata subclass instance. attributes: a sequence of string, which define the dictionary keys, whose associated values must be equal for a successful match.

Returns:

Type Description
bool

bool, True for a successful match, False for failed.

Source code in clair_torch/common/file_settings.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
@typechecked
def is_match(self, reference: FrameSettings | PairedFrameSettings | BaseMetadata,
             attributes: dict[str, Optional[int | float]]) -> bool:
    """
    Method for evaluating whether the metadata contained in a given FramSettings instance or Metadata instance are a
    match based on the given sequence of attributes, which act as keys to the metadata dictionary in a Metadata
    instance. Utilizes the val_metadata associated with the value image.
    Args:
        reference: a FrameSettings instance or a BaseMetadata subclass instance.
        attributes: a sequence of string, which define the dictionary keys, whose associated values must be equal
            for a successful match.

    Returns:
        bool, True for a successful match, False for failed.
    """
    if isinstance(reference, (FrameSettings, PairedFrameSettings)):
        reference_meta = reference.val_metadata
    elif isinstance(reference, BaseMetadata):
        reference_meta = reference
    else:
        return False

    return self.val_metadata.is_match(reference_meta, attributes)

StridedDownscale

Bases: BaseTransform

Transform for applying a spatial downscale on the input tensor.

Source code in clair_torch/common/transforms.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@_register_transform
class StridedDownscale(BaseTransform):
    """
    Transform for applying a spatial downscale on the input tensor.
    """
    @typechecked()
    def __init__(self, step_size: int):
        super().__init__()
        if step_size < 0:
            raise ValueError(f"step_size must be non-negative.")

        self.step_size = step_size

    @typechecked
    def __call__(self, x: torch.Tensor):

        strided = x[..., ::self.step_size, ::self.step_size]

        return strided

    def to_config(self) -> dict[str, Any]:
        return {
            "step_size": self.step_size
        }

TorchToCv

Bases: BaseTransform

Transform for converting a tensor from PyTorch (C, H, W) format with RGB channels to OpenCV format (H, W, C) with BGR channels.

Source code in clair_torch/common/transforms.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@_register_transform
class TorchToCv(BaseTransform):
    """
    Transform for converting a tensor from PyTorch (C, H, W) format with RGB channels to OpenCV format (H, W, C) with
    BGR channels.
    """

    def __init__(self):
        super().__init__()
        pass

    @typechecked
    def __call__(self, x: torch.Tensor) -> torch.Tensor:

        return torch_to_cv(x)

    def to_config(self) -> dict[str, Any]:
        return {}

VarianceMode

Bases: Enum

Manages how the variance is computed in the WBOMeanVar class.

Source code in clair_torch/common/enums.py
64
65
66
67
68
69
70
class VarianceMode(Enum):
    """
    Manages how the variance is computed in the WBOMeanVar class.
    """
    POPULATION = auto()
    SAMPLE_FREQUENCY = auto()
    RELIABILITY_WEIGHTS = auto()

WBOMean

Bases: object

Source code in clair_torch/common/statistics.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class WBOMean(object):

    @typechecked
    def __init__(self, dim: int | tuple[int, ...] = 0):
        """
        Weighted batched online mean (WBOMean). Allows the computation of a weighted mean value of a dataset in a single
        pass. Attributes are read-only properties and the user only needs to assign the dimension(s) over which the
        value is computed and supply the batches of new values and their associated weights.
        Args:
            dim: the dimension(s) along which to compute the values.
        """
        if isinstance(dim, int):
            dim = (dim,)
        else:
            raise TypeError(f"Expected dim as int or tuple of int, got {type(dim)}")

        self._dim = dim
        self._mean = 0.0
        self._sum_of_weights = 0.0

    @property
    def mean(self) -> float | torch.Tensor:
        return self._mean

    @property
    def sum_of_weights(self) -> float | torch.Tensor:
        return self._sum_of_weights

    @property
    def dim(self) -> int | tuple[int, ...]:
        return self._dim

    def internal_detach(self, *, in_place: bool = True):
        """
        Break the autograd graph attached to the internal state.
        Args:
            in_place: Whether to call detach so the tensor is modified in-place or to put into a new tensor.
        """
        if torch.is_tensor(self._mean):
            if in_place:
                self._mean.detach_()
            else:
                self._mean = self._mean.detach()

        if torch.is_tensor(self._sum_of_weights):
            if in_place:
                self._sum_of_weights.detach_()
            else:
                self._sum_of_weights = self._sum_of_weights.detach()

    @typechecked
    def update_values(self, batch_values: torch.Tensor, batch_weights: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        The public method for updating the weighted mean value. Returns the new mean on update.
        Args:
            batch_values: the new batch of values used for updating the collective weighted mean.
            batch_weights: the new batch of weights associated with the given batch_values.

        Returns:
            The new mean value as a float or Tensor.
        """
        batch_size = int(torch.prod(torch.tensor([batch_values.shape[d] for d in self.dim])))

        if batch_weights is not None:
            total_batch_weights = torch.sum(batch_weights, dim=self.dim, keepdim=True)
            total_batch_mean = torch.sum(batch_weights * batch_values, dim=self.dim, keepdim=True) / (
                    total_batch_weights + 1e-6)
        else:
            total_batch_mean = torch.mean(batch_values, dim=self.dim, keepdim=True)
            total_batch_weights = torch.full_like(total_batch_mean, batch_size, dtype=batch_values.dtype)

        self._update_internal_values(total_batch_mean, total_batch_weights,
                                     None, None)
        return self.mean

    def _update_internal_values(self, total_batch_mean: torch.Tensor, total_batch_weights: torch.Tensor,
                                total_squared_batch_weights: torch.Tensor, total_batch_variance: torch.Tensor):
        """
        Implementation of the updating of the weighted mean according to DOI 10.1007/s00180-015-0637-z.
        Args:
            total_batch_mean: the mean of the new batch.
            total_batch_weights: the sum of the weights in the new batch.
            total_squared_batch_weights: unused dummy variable for inheriting class.
            total_batch_variance: unused dummy variable for inheriting class.
        Returns:

        """
        W_A = self.sum_of_weights
        W_B = total_batch_weights
        mean_A = self.mean
        mean_B = total_batch_mean
        W = W_A + W_B

        new_mean = mean_A + (W_B / W) * (mean_B - mean_A)

        self._mean = new_mean
        self._sum_of_weights = W

__init__(dim=0)

Weighted batched online mean (WBOMean). Allows the computation of a weighted mean value of a dataset in a single pass. Attributes are read-only properties and the user only needs to assign the dimension(s) over which the value is computed and supply the batches of new values and their associated weights. Args: dim: the dimension(s) along which to compute the values.

Source code in clair_torch/common/statistics.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@typechecked
def __init__(self, dim: int | tuple[int, ...] = 0):
    """
    Weighted batched online mean (WBOMean). Allows the computation of a weighted mean value of a dataset in a single
    pass. Attributes are read-only properties and the user only needs to assign the dimension(s) over which the
    value is computed and supply the batches of new values and their associated weights.
    Args:
        dim: the dimension(s) along which to compute the values.
    """
    if isinstance(dim, int):
        dim = (dim,)
    else:
        raise TypeError(f"Expected dim as int or tuple of int, got {type(dim)}")

    self._dim = dim
    self._mean = 0.0
    self._sum_of_weights = 0.0

internal_detach(*, in_place=True)

Break the autograd graph attached to the internal state. Args: in_place: Whether to call detach so the tensor is modified in-place or to put into a new tensor.

Source code in clair_torch/common/statistics.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def internal_detach(self, *, in_place: bool = True):
    """
    Break the autograd graph attached to the internal state.
    Args:
        in_place: Whether to call detach so the tensor is modified in-place or to put into a new tensor.
    """
    if torch.is_tensor(self._mean):
        if in_place:
            self._mean.detach_()
        else:
            self._mean = self._mean.detach()

    if torch.is_tensor(self._sum_of_weights):
        if in_place:
            self._sum_of_weights.detach_()
        else:
            self._sum_of_weights = self._sum_of_weights.detach()

update_values(batch_values, batch_weights=None)

The public method for updating the weighted mean value. Returns the new mean on update. Args: batch_values: the new batch of values used for updating the collective weighted mean. batch_weights: the new batch of weights associated with the given batch_values.

Returns:

Type Description
Tensor

The new mean value as a float or Tensor.

Source code in clair_torch/common/statistics.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@typechecked
def update_values(self, batch_values: torch.Tensor, batch_weights: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    The public method for updating the weighted mean value. Returns the new mean on update.
    Args:
        batch_values: the new batch of values used for updating the collective weighted mean.
        batch_weights: the new batch of weights associated with the given batch_values.

    Returns:
        The new mean value as a float or Tensor.
    """
    batch_size = int(torch.prod(torch.tensor([batch_values.shape[d] for d in self.dim])))

    if batch_weights is not None:
        total_batch_weights = torch.sum(batch_weights, dim=self.dim, keepdim=True)
        total_batch_mean = torch.sum(batch_weights * batch_values, dim=self.dim, keepdim=True) / (
                total_batch_weights + 1e-6)
    else:
        total_batch_mean = torch.mean(batch_values, dim=self.dim, keepdim=True)
        total_batch_weights = torch.full_like(total_batch_mean, batch_size, dtype=batch_values.dtype)

    self._update_internal_values(total_batch_mean, total_batch_weights,
                                 None, None)
    return self.mean

WBOMeanVar

Bases: WBOMean

Source code in clair_torch/common/statistics.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class WBOMeanVar(WBOMean):

    @typechecked
    def __init__(self, dim: int | tuple[int, ...] = 0, variance_mode: VarianceMode = VarianceMode.RELIABILITY_WEIGHTS):
        """
        Weighted batched online mean and variance. Allows the computation of a weighted mean and variance of a dataset in a
        single pass. Attributes are read-only properties and the user only needs to assign the dimension(s) over which the
        value is computed and supply the batches of new values and their associated weights.
        Args:
            dim: the dimension(s) along which to compute the values.
            variance_mode: determines the normalization method to compute the variance.
        """
        super().__init__(dim=dim)

        self._variance_mode = variance_mode

        self._dispatch = {
            VarianceMode.POPULATION: self._variance_population,
            VarianceMode.RELIABILITY_WEIGHTS: self._variance_reliability_weights,
            VarianceMode.SAMPLE_FREQUENCY: self._variance_sample_frequency
        }

        if self._variance_mode not in self._dispatch:
            raise ValueError(f"Unknown variance mode {self._variance_mode}")

        self._m2 = 0.0
        self._sum_of_squared_weights = 0.0

    @property
    def sum_of_squared_weights(self) -> float | torch.Tensor:
        return self._sum_of_squared_weights

    @property
    def m2(self) -> float | torch.Tensor:
        return self._m2

    def variance(self) -> float | torch.Tensor:
        return self._dispatch[self._variance_mode]()

    def _variance_sample_frequency(self) -> float | torch.Tensor:
        return self._m2 * self._sample_frequency_scale(self._sum_of_weights)

    def _variance_reliability_weights(self) -> float | torch.Tensor:
        return self._m2 * self._reliability_weights_scale(self._sum_of_weights, self._sum_of_squared_weights)

    def _variance_population(self) -> float | torch.Tensor:
        return self._m2 * self._population_scale(self._sum_of_weights)

    @staticmethod
    def _sample_frequency_scale(sum_of_weights: float | torch.Tensor):
        return 1 / (sum_of_weights - 1)

    @staticmethod
    def _reliability_weights_scale(sum_of_weights: float | torch.Tensor, sum_of_squared_weights: float | torch.Tensor):
        return 1 / (sum_of_weights - sum_of_squared_weights / sum_of_weights)

    @staticmethod
    def _population_scale(sum_of_weights: float | torch.Tensor):
        return 1 / sum_of_weights

    def internal_detach(self, *, in_place: bool = True) -> None:
        """
        Break the autograd graph attached to the internal state.
        Args:
            in_place: Whether to call detach so the tensor is modified in-place or to put into a new tensor.
        """
        if torch.is_tensor(self._mean):
            if in_place:
                self._mean.detach_()
            else:
                self._mean = self._mean.detach()

        if torch.is_tensor(self._m2):
            if in_place:
                self._m2.detach_()
            else:
                self._m2 = self._m2.detach()

        if torch.is_tensor(self._sum_of_weights):
            if in_place:
                self._sum_of_weights.detach_()
            else:
                self._sum_of_weights = self._sum_of_weights.detach()

        if torch.is_tensor(self._sum_of_squared_weights):
            if in_place:
                self._sum_of_squared_weights.detach_()
            else:
                self._sum_of_squared_weights = self._sum_of_squared_weights.detach_()

    @typechecked
    def update_values(self, batch_values: torch.Tensor, batch_weights: Optional[torch.Tensor] = None) -> \
            tuple[float | torch.Tensor, float | torch.Tensor]:
        """
        The public method for updating the weighted mean value. Returns the new mean on update.
        Args:
            batch_values: the new batch of values used for updating the collective weighted mean.
            batch_weights: the new batch of weights associated with the given batch_values.

        Returns:
            The new mean value and variance values as tuple of floats or tensors.
        """
        batch_size = int(torch.prod(torch.tensor([batch_values.shape[d] for d in self.dim])))

        if batch_weights is not None:
            total_batch_weights = torch.sum(batch_weights, dim=self.dim, keepdim=True)
            total_squared_batch_weights = torch.sum(batch_weights ** 2, dim=self.dim, keepdim=True)
            total_batch_mean = torch.sum(batch_weights * batch_values, dim=self.dim, keepdim=True) / (
                    total_batch_weights + 1e-6)
            m2 = torch.sum(batch_weights * (batch_values - total_batch_mean) ** 2, dim=self.dim, keepdim=True)
        else:
            total_batch_mean = torch.mean(batch_values, dim=self.dim, keepdim=True)
            total_batch_weights = torch.full_like(total_batch_mean, batch_size, dtype=batch_values.dtype)
            total_squared_batch_weights = total_batch_weights
            m2 = torch.sum((batch_values - total_batch_mean) ** 2, dim=self.dim, keepdim=True)

        total_batch_variance = m2

        self._update_internal_values(total_batch_mean, total_batch_weights,
                                     total_squared_batch_weights, total_batch_variance)
        return self.mean, self.m2

    def _update_internal_values(self, total_batch_mean: torch.Tensor, total_batch_weights: torch.Tensor,
                                total_squared_batch_weights: torch.Tensor, total_batch_variance: torch.Tensor) -> None:
        """
        Implementation of the updating of the weighted mean according to DOI 10.1007/s00180-015-0637-z.
        Args:
            total_batch_mean: the mean of the new batch.
            total_batch_weights: the sum of the weights in the new batch.

        Returns:
            None
        """
        W_A = self.sum_of_weights
        W_B = total_batch_weights
        M_A = self.m2
        M_B = total_batch_variance
        mean_A = self.mean
        mean_B = total_batch_mean
        W = W_A + W_B

        M_AB = M_A + M_B + (W_A * W_B / W) * (mean_B - mean_A) ** 2
        new_mean = mean_A + (W_B / W) * (mean_B - mean_A)

        self._mean = new_mean
        self._m2 = M_AB
        self._sum_of_weights = W
        self._sum_of_squared_weights += total_squared_batch_weights

__init__(dim=0, variance_mode=VarianceMode.RELIABILITY_WEIGHTS)

Weighted batched online mean and variance. Allows the computation of a weighted mean and variance of a dataset in a single pass. Attributes are read-only properties and the user only needs to assign the dimension(s) over which the value is computed and supply the batches of new values and their associated weights. Args: dim: the dimension(s) along which to compute the values. variance_mode: determines the normalization method to compute the variance.

Source code in clair_torch/common/statistics.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@typechecked
def __init__(self, dim: int | tuple[int, ...] = 0, variance_mode: VarianceMode = VarianceMode.RELIABILITY_WEIGHTS):
    """
    Weighted batched online mean and variance. Allows the computation of a weighted mean and variance of a dataset in a
    single pass. Attributes are read-only properties and the user only needs to assign the dimension(s) over which the
    value is computed and supply the batches of new values and their associated weights.
    Args:
        dim: the dimension(s) along which to compute the values.
        variance_mode: determines the normalization method to compute the variance.
    """
    super().__init__(dim=dim)

    self._variance_mode = variance_mode

    self._dispatch = {
        VarianceMode.POPULATION: self._variance_population,
        VarianceMode.RELIABILITY_WEIGHTS: self._variance_reliability_weights,
        VarianceMode.SAMPLE_FREQUENCY: self._variance_sample_frequency
    }

    if self._variance_mode not in self._dispatch:
        raise ValueError(f"Unknown variance mode {self._variance_mode}")

    self._m2 = 0.0
    self._sum_of_squared_weights = 0.0

internal_detach(*, in_place=True)

Break the autograd graph attached to the internal state. Args: in_place: Whether to call detach so the tensor is modified in-place or to put into a new tensor.

Source code in clair_torch/common/statistics.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def internal_detach(self, *, in_place: bool = True) -> None:
    """
    Break the autograd graph attached to the internal state.
    Args:
        in_place: Whether to call detach so the tensor is modified in-place or to put into a new tensor.
    """
    if torch.is_tensor(self._mean):
        if in_place:
            self._mean.detach_()
        else:
            self._mean = self._mean.detach()

    if torch.is_tensor(self._m2):
        if in_place:
            self._m2.detach_()
        else:
            self._m2 = self._m2.detach()

    if torch.is_tensor(self._sum_of_weights):
        if in_place:
            self._sum_of_weights.detach_()
        else:
            self._sum_of_weights = self._sum_of_weights.detach()

    if torch.is_tensor(self._sum_of_squared_weights):
        if in_place:
            self._sum_of_squared_weights.detach_()
        else:
            self._sum_of_squared_weights = self._sum_of_squared_weights.detach_()

update_values(batch_values, batch_weights=None)

The public method for updating the weighted mean value. Returns the new mean on update. Args: batch_values: the new batch of values used for updating the collective weighted mean. batch_weights: the new batch of weights associated with the given batch_values.

Returns:

Type Description
tuple[float | Tensor, float | Tensor]

The new mean value and variance values as tuple of floats or tensors.

Source code in clair_torch/common/statistics.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
@typechecked
def update_values(self, batch_values: torch.Tensor, batch_weights: Optional[torch.Tensor] = None) -> \
        tuple[float | torch.Tensor, float | torch.Tensor]:
    """
    The public method for updating the weighted mean value. Returns the new mean on update.
    Args:
        batch_values: the new batch of values used for updating the collective weighted mean.
        batch_weights: the new batch of weights associated with the given batch_values.

    Returns:
        The new mean value and variance values as tuple of floats or tensors.
    """
    batch_size = int(torch.prod(torch.tensor([batch_values.shape[d] for d in self.dim])))

    if batch_weights is not None:
        total_batch_weights = torch.sum(batch_weights, dim=self.dim, keepdim=True)
        total_squared_batch_weights = torch.sum(batch_weights ** 2, dim=self.dim, keepdim=True)
        total_batch_mean = torch.sum(batch_weights * batch_values, dim=self.dim, keepdim=True) / (
                total_batch_weights + 1e-6)
        m2 = torch.sum(batch_weights * (batch_values - total_batch_mean) ** 2, dim=self.dim, keepdim=True)
    else:
        total_batch_mean = torch.mean(batch_values, dim=self.dim, keepdim=True)
        total_batch_weights = torch.full_like(total_batch_mean, batch_size, dtype=batch_values.dtype)
        total_squared_batch_weights = total_batch_weights
        m2 = torch.sum((batch_values - total_batch_mean) ** 2, dim=self.dim, keepdim=True)

    total_batch_variance = m2

    self._update_internal_values(total_batch_mean, total_batch_weights,
                                 total_squared_batch_weights, total_batch_variance)
    return self.mean, self.m2

clamp_along_dims(x, dim, min_max_pairs)

Clamp a tensor along specified dimension(s).

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
dim int | tuple[int, ...]

int or tuple of ints, dimensions along which to apply min/max.

required
min_max_pairs tuple[float, float] | list[tuple[float, float]]

Single tuple (min, max) applied to all slices. List of tuples; length must match the number of slices along dims.

required

Returns:

Type Description
Tensor

Clamped tensor of same shape as x.

Source code in clair_torch/common/general_functions.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
@typechecked
def clamp_along_dims(x: torch.Tensor, dim: int | tuple[int, ...],
                     min_max_pairs: tuple[float, float] | list[tuple[float, float]]) -> torch.Tensor:
    """
    Clamp a tensor along specified dimension(s).

    Args:
        x: Input tensor.
        dim: int or tuple of ints, dimensions along which to apply min/max.
        min_max_pairs: Single tuple (min, max) applied to all slices. List of tuples; length must match the number of
            slices along dims.

    Returns:
        Clamped tensor of same shape as x.
    """
    if isinstance(dim, int):
        dim = (dim,)

    dim = tuple(d % x.ndim for d in dim)  # handle negative dims

    # Determine slice shape along specified dims
    slice_shape = tuple(x.shape[d] for d in dim)
    num_slices = torch.tensor(slice_shape).prod().item()

    # Handle single min-max pair
    if isinstance(min_max_pairs, tuple):
        return torch.clamp(x, min=min_max_pairs[0], max=min_max_pairs[1])

    # Otherwise we expect a list of min-max pairs
    if len(min_max_pairs) != num_slices:
        raise ValueError(f"Expected 1 or {num_slices} min/max pairs, got {len(min_max_pairs)}")

    # Convert min/max pairs to tensors of shape slice_shape
    mins = torch.tensor([pair[0] for pair in min_max_pairs], dtype=x.dtype).reshape(slice_shape)
    maxs = torch.tensor([pair[1] for pair in min_max_pairs], dtype=x.dtype).reshape(slice_shape)

    # Create broadcasting shape
    broadcast_shape = [1] * x.ndim
    for i, d in enumerate(dim):
        broadcast_shape[d] = slice_shape[i]

    mins = mins.reshape(broadcast_shape)
    maxs = maxs.reshape(broadcast_shape)

    # Clamp using broadcasting
    return torch.clamp(x, min=mins, max=maxs)

cli_parse_args_from_config()

CLI utility function to read parameter values from a .yaml config file into a dictionary. Returns: dictionary of the parsed keys and values.

Source code in clair_torch/common/general_functions.py
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def cli_parse_args_from_config() -> dict:
    """
    CLI utility function to read parameter values from a .yaml config file into a dictionary.
    Returns:
        dictionary of the parsed keys and values.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to config file")
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.safe_load(f)

    return config

conditional_gaussian_blur(image, mask_map, threshold, kernel_size, differentiable=False, alpha=50.0)

Apply a gaussian blur on input image positions at which the given map has value larger than the given threshold. Optionally use a differentiable soft mask instead of a boolean mask.

Parameters:

Name Type Description Default
image Tensor

input image of shape (..., C, H, W).

required
mask_map Tensor

map for filtering. Shape must be (1, C, H, W) or (N, C, H, W), where N matches the batch dimension of image or is 1 (broadcasted).

required
threshold float

threshold value for filtering.

required
kernel_size int

size of the gaussian blur kernel.

required
differentiable bool

if True, use a soft differentiable mask via sigmoid.

False
alpha float

steepness of sigmoid when differentiable=True.

50.0

Returns:

Type Description
Tensor

Filtered image, same shape as input.

Source code in clair_torch/common/general_functions.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
@typechecked
def conditional_gaussian_blur(image: torch.Tensor, mask_map: torch.Tensor, threshold: float, kernel_size: int,
                              differentiable: bool = False, alpha: float = 50.0) -> torch.Tensor:
    """
    Apply a gaussian blur on input image positions at which the given map has value larger than the given threshold.
    Optionally use a differentiable soft mask instead of a boolean mask.

    Args:
        image: input image of shape (..., C, H, W).
        mask_map: map for filtering. Shape must be (1, C, H, W) or (N, C, H, W),
                  where N matches the batch dimension of image or is 1 (broadcasted).
        threshold: threshold value for filtering.
        kernel_size: size of the gaussian blur kernel.
        differentiable: if True, use a soft differentiable mask via sigmoid.
        alpha: steepness of sigmoid when differentiable=True.

    Returns:
        Filtered image, same shape as input.
    """
    *leading, C, H, W = image.shape
    image_flat = image.reshape(-1, C, H, W)  # flatten leading dims into batch
    N = image_flat.size(0)

    blur_transform = GaussianBlur(kernel_size=kernel_size, sigma=1.0)
    blurred = blur_transform(image_flat)

    # validate mask shape
    if mask_map.shape[0] not in (1, N):
        raise ValueError(
            f"mask_map batch dimension must be 1 or {N}, "
            f"got {mask_map.shape[0]}"
        )

    if differentiable:
        # smooth sigmoid mask ∈ (0, 1)
        mask = torch.sigmoid((mask_map - threshold) * alpha)
    else:
        # boolean mask
        mask = (mask_map > threshold).to(image.dtype)

    # broadcast mask to match image batch if needed
    if mask.shape[0] == 1 and N > 1:
        mask = mask.expand(N, -1, -1, -1)

    # mix blurred and original
    out = mask * blurred + (1 - mask) * image_flat

    return out.reshape(*leading, C, H, W)

cv_to_torch(x)

Function for transforming a tensor with OpenCV channel and dimension ordering into PyTorch channel and dimension ordering. Expects a tensor with two or three dimensions. Args: x: tensor to transform.

Returns:

Type Description
Tensor

torch.Tensor

Source code in clair_torch/common/general_functions.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def cv_to_torch(x: torch.Tensor) -> torch.Tensor:
    """
    Function for transforming a tensor with OpenCV channel and dimension ordering into PyTorch channel and dimension
    ordering. Expects a tensor with two or three dimensions.
    Args:
        x: tensor to transform.

    Returns:
        torch.Tensor
    """
    if x.ndim == 2:
        # Grayscale image: (H, W) -> (1, H, W)
        x = x.unsqueeze(0)
    elif x.ndim == 3 and x.shape[2] == 3:
        # Color image: BGR to RGB, then permute to (C, H, W)
        x = x[:, :, [2, 1, 0]]  # BGR to RGB
        x = x.permute(2, 0, 1)
    else:
        raise ValueError(f"Unexpected image shape: {x.shape}")

    return x

file_settings_constructor(dir_paths, file_pattern, recursive=False, default_output_root=None, val_cpu_transforms=None, std_cpu_transforms=None, metadata_cls=None, strict_exclusive=True, *metadata_args, **metadata_kwargs)

Utility function for creating instances of FileSettings, PairedFileSettings, FrameSettings and PairedFrameSettings classes. FrameSettings classes are created when a metadata class is provided, otherwise FileSettings are created. The created objects are assigned to three categories, each having a list of their own to hold the objects: 1. Paired files, containing either PairedFrameSettings or PairedFileSettings objects. 2. Main files, containing either FrameSettings or FileSettings objects. 3. Uncertainty files, containing either FrameSettings or FileSettings objects. The strict_exclusive parameter controls whether a file can be present in the objects of multiple categories of only one. In strict mode, for example, the main file and uncertainty file present in a PairedSettings object aren't allowed to also be present in a Settings object in the main and uncertainty categories. Args: dir_paths: one or multiple paths from which to collect files for creating objects. file_pattern: a regex pattern for the file search. Use '.png' for example to search for any .png file. recursive: whether to extend the file search recursively to subdirectories of the given paths. default_output_root: optional default output root directory for the created objects. val_cpu_transforms: optional main file transform operations to attach to the created objects. std_cpu_transforms: optional uncertainty file transform operations to attach to the created objects. metadata_cls: a subclass of BaseMetadata to use in creating FrameSettings and PairedFrameSettings objects. strict_exclusive: whether to allow a particular file to exist in multiple object categories, or only in the highest priority one. metadata_args: args for the instantiation of the given metadata_cls. **metadata_kwargs: kwargs for the instantiation of the given metadata_cls.

Returns:

Source code in clair_torch/common/file_settings.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
@typechecked
def file_settings_constructor(
    dir_paths: Path | Sequence[Path],
    file_pattern: str,
    recursive: bool = False,
    default_output_root: Optional[Path] = None,
    val_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
    std_cpu_transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None,
    metadata_cls: Optional[Type[BaseMetadata]] = None,
    strict_exclusive: bool = True,
    *metadata_args,
    **metadata_kwargs
):
    """
    Utility function for creating instances of FileSettings, PairedFileSettings, FrameSettings and PairedFrameSettings
    classes. FrameSettings classes are created when a metadata class is provided, otherwise FileSettings are created.
    The created objects are assigned to three categories, each having a list of their own to hold the objects:
    1. Paired files, containing either PairedFrameSettings or PairedFileSettings objects. 2. Main files, containing
    either FrameSettings or FileSettings objects. 3. Uncertainty files, containing either FrameSettings or FileSettings
    objects. The strict_exclusive parameter controls whether a file can be present in the objects of multiple categories
    of only one. In strict mode, for example, the main file and uncertainty file present in a PairedSettings object
    aren't allowed to also be present in a Settings object in the main and uncertainty categories.
    Args:
        dir_paths: one or multiple paths from which to collect files for creating objects.
        file_pattern: a regex pattern for the file search. Use '*.png' for example to search for any .png file.
        recursive: whether to extend the file search recursively to subdirectories of the given paths.
        default_output_root: optional default output root directory for the created objects.
        val_cpu_transforms: optional main file transform operations to attach to the created objects.
        std_cpu_transforms: optional uncertainty file transform operations to attach to the created objects.
        metadata_cls: a subclass of BaseMetadata to use in creating FrameSettings and PairedFrameSettings objects.
        strict_exclusive: whether to allow a particular file to exist in multiple object categories, or only in the
            highest priority one.
        *metadata_args: args for the instantiation of the given metadata_cls.
        **metadata_kwargs: kwargs for the instantiation of the given metadata_cls.

    Returns:

    """
    def make_single_settings(cls: Callable, path_key: int, cpu_transforms, path_pair):
        return cls(
            input_path=path_pair[path_key],
            default_output_root=default_output_root,
            cpu_transforms=cpu_transforms,
            metadata_cls=metadata_cls if cls is FrameSettings else None,
            *metadata_args,
            **metadata_kwargs
        )

    def make_paired_settings(cls: Callable, pair):
        return cls(
            val_input_path=pair[0],
            std_input_path=pair[1],
            default_output_root=default_output_root,
            val_cpu_transforms=val_cpu_transforms,
            std_cpu_transforms=std_cpu_transforms,
            metadata_cls=metadata_cls if cls is PairedFrameSettings else None,
            *metadata_args,
            **metadata_kwargs
        )

    search_paths = [dir_paths] if isinstance(dir_paths, Path) else dir_paths

    all_paired, all_main, all_std = [], [], []

    for dir_path in search_paths:
        current_files = _get_file_input_paths_by_pattern(dir_path, file_pattern, recursive)
        paired, main, std = _pair_main_and_std_files(current_files, strict_exclusive=strict_exclusive)
        all_paired.extend(paired)
        all_main.extend(main)
        all_std.extend(std)

    if metadata_cls:
        paired_cls, frame_cls = PairedFrameSettings, FrameSettings
    else:
        paired_cls, frame_cls = PairedFileSettings, FileSettings

    paired_settings = tuple(make_paired_settings(paired_cls, p) for p in all_paired)
    main_settings = tuple(make_single_settings(frame_cls, 0, val_cpu_transforms, m) for m in all_main)
    std_settings = tuple(make_single_settings(frame_cls, 1, std_cpu_transforms, s) for s in all_std)

    return paired_settings, main_settings, std_settings

flat_field_mean(flat_field, mid_area_side_fraction)

Computes the spatial mean over a centered square ROI for each image and channel.

Parameters:

Name Type Description Default
flat_field Tensor

Input tensor of shape (N, C, H, W)

required
mid_area_side_fraction float

Fraction of spatial dims to use for the ROI. Must lie in range [0.0, 1.0].

required

Returns:

Name Type Description
Tensor Tensor

Mean over the ROI, shape (...)

Source code in clair_torch/common/general_functions.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
@typechecked
def flat_field_mean(flat_field: torch.Tensor, mid_area_side_fraction: float) -> torch.Tensor:
    """
    Computes the spatial mean over a centered square ROI for each image and channel.

    Args:
        flat_field (Tensor): Input tensor of shape (N, C, H, W)
        mid_area_side_fraction (float): Fraction of spatial dims to use for the ROI. Must lie in range [0.0, 1.0].

    Returns:
        Tensor: Mean over the ROI, shape (...)
    """
    if mid_area_side_fraction > 1.0 or mid_area_side_fraction < 0.0:
        raise ValueError(f"mid_area_side_fraction should be between 0.0 and 1.0")

    N, C, H, W = flat_field.shape

    ROI_dx = math.floor(W * mid_area_side_fraction)
    ROI_dy = math.floor(H * mid_area_side_fraction)

    ROI_start_index = (math.floor(1 / mid_area_side_fraction) - 1) / 2

    x_start = math.floor(ROI_start_index * ROI_dx)
    x_end = math.floor((ROI_start_index + 1) * ROI_dx)
    y_start = math.floor(ROI_start_index * ROI_dy)
    y_end = math.floor((ROI_start_index + 1) * ROI_dy)

    cropped = flat_field[:, :, y_start:y_end, x_start:x_end]  # shape: (N, C, ROI_dy, ROI_dx)

    return cropped.mean(dim=(-1, -2), keepdim=True)  # shape: (N, C, 1, 1)

flatfield_correction(images, flatfield, flatfield_mean_val, epsilon=1e-06)

Computes a flat-field corrected version of input image by utilizing the given flat-field image and a given spatial mean. Ideally expects both images and flatfield in shape (N, C, H, W) but others are allowed. Match argument shapes based on requirements. For example with images (N, C, H, W) use flatfield (1, C, H, W) to apply same flatfield across the batch dimension, (N, 1, H, W) to apply unique flatfields across batch while disregarding channel specific features. Similarly, use flatfield_mean_val (1, C, 1, 1) to apply channel-specific scaling uniformly across the batch.

Parameters:

Name Type Description Default
images Tensor

Image tensor of shape (N, C, H, W).

required
flatfield Tensor

Flat field calibration image, same shape as images or broadcastable to images.

required
flatfield_mean_val Tensor

Values used to scale the image. Match shape based on the given images and flatfield.

required
epsilon float

Small constant to avoid division by zero.

1e-06

Returns:

Name Type Description
Tensor Tensor

Corrected image tensor, same shape as input.

Source code in clair_torch/common/general_functions.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
@typechecked
def flatfield_correction(images: torch.Tensor, flatfield: torch.Tensor, flatfield_mean_val: torch.Tensor,
                         epsilon: float = 1e-6) -> torch.Tensor:
    """
    Computes a flat-field corrected version of input image by utilizing the given flat-field image and a given spatial
    mean. Ideally expects both images and flatfield in shape (N, C, H, W) but others are allowed. Match argument shapes
    based on requirements. For example with images (N, C, H, W) use flatfield (1, C, H, W) to apply same flatfield
    across the batch dimension, (N, 1, H, W) to apply unique flatfields across batch while disregarding channel specific
    features. Similarly, use flatfield_mean_val (1, C, 1, 1) to apply channel-specific scaling uniformly across the batch.

    Args:
        images: Image tensor of shape (N, C, H, W).
        flatfield: Flat field calibration image, same shape as `images` or broadcastable to images.
        flatfield_mean_val: Values used to scale the image. Match shape based on the given images and flatfield.
        epsilon: Small constant to avoid division by zero.

    Returns:
        Tensor: Corrected image tensor, same shape as input.
    """

    is_broadcastable(images.shape, flatfield.shape, raise_error=True)
    is_broadcastable(images.shape, flatfield_mean_val.shape, raise_error=True)

    corrected_images = (images / (flatfield + epsilon)) * flatfield_mean_val

    return corrected_images

get_pairwise_valid_pixel_mask(image_value_stack, i_idx, j_idx, image_std_stack=None, val_lower=0.0, val_upper=1.0, std_lower=None, std_upper=None)

For a batch of images, for all pairs given by indices i_idx and j_idx, compute a pairwise boolean mask by marking invalid pixels as False, if they lie outside the valid range defined by lower and upper, in either one of the images in a given pair. Args: image_value_stack: batch of value images, shape (N, C, H, W). i_idx: the first set of indices to create pairs off of, shape (P,). j_idx: the second set of indices to create pairs off of, shape (P,). image_std_stack: batch of uncertainty images associated with the images in image_value_stack, shape (N, C, H, W). val_lower: lower threshold for marking pixels as invalid in value image. val_upper: upper threshold for marking pixels as invalid in value image. std_lower: lower threshold for marking pixels as invalid in std image. std_upper: upper threshold for marking pixels as invalid in std image.

Returns:

Type Description
Tensor

A boolean tensor that marks invalid pixel positions in a pair with False, shape (P, C, H, W).

Source code in clair_torch/common/general_functions.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
@typechecked
def get_pairwise_valid_pixel_mask(image_value_stack: torch.Tensor, i_idx: torch.Tensor, j_idx: torch.Tensor,
                                  image_std_stack: Optional[torch.Tensor] = None,
                                  val_lower: float = 0.0, val_upper: float = 1.0,
                                  std_lower: Optional[float] = None, std_upper: Optional[float] = None) -> torch.Tensor:
    """
    For a batch of images, for all pairs given by indices i_idx and j_idx, compute a pairwise boolean mask by marking
    invalid pixels as False, if they lie outside the valid range defined by lower and upper, in either one of the images
    in a given pair.
    Args:
        image_value_stack: batch of value images, shape (N, C, H, W).
        i_idx: the first set of indices to create pairs off of, shape (P,).
        j_idx: the second set of indices to create pairs off of, shape (P,).
        image_std_stack: batch of uncertainty images associated with the images in image_value_stack, shape (N, C, H, W).
        val_lower: lower threshold for marking pixels as invalid in value image.
        val_upper: upper threshold for marking pixels as invalid in value image.
        std_lower: lower threshold for marking pixels as invalid in std image.
        std_upper: upper threshold for marking pixels as invalid in std image.

    Returns:
        A boolean tensor that marks invalid pixel positions in a pair with False, shape (P, C, H, W).
    """
    if val_lower > val_upper:
        raise ValueError("Lower threshold cannot be a larger value than upper threshold.")
    if std_lower is not None and std_upper is not None and (std_lower > std_upper):
        raise ValueError("Lower threshold cannot be a larger value than upper threshold.")

    # Gather image pairs and reshape to (P, C, H, W)
    val_i, val_j = image_value_stack[i_idx], image_value_stack[j_idx]

    valid_mask = (val_i >= val_lower) & (val_i <= val_upper) & (val_j >= val_lower) & (val_j <= val_upper)

    if image_std_stack is not None and (std_lower is not None or std_upper is not None):
        std_i, std_j = image_std_stack[i_idx], image_std_stack[j_idx]
        valid_std_mask = (std_i >= std_lower) & (std_i <= std_upper) & (std_j >= std_lower) & (std_j <= std_upper)
        valid_mask = valid_mask & valid_std_mask

    return valid_mask

get_valid_exposure_pairs(increasing_exposure_values, exposure_ratio_threshold=None)

Generate valid (i, j) index pairs for exposure comparison, based on a minimum ratio threshold.

Parameters:

Name Type Description Default
increasing_exposure_values Tensor

Shape (N,) exposure values in an increasing order.

required
exposure_ratio_threshold float

Minimum exposure ratio to accept a pair.

None

Returns:

Type Description
tuple[Tensor, Tensor, Tensor]

Tuple of: - i_idx: (P,) indices of first images in valid pairs (i < j) - j_idx: (P,) indices of second images - ratio_pairs: (P,) exposure ratios exposure[i] / exposure[j]

Source code in clair_torch/common/general_functions.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
@typechecked
def get_valid_exposure_pairs(increasing_exposure_values: torch.Tensor, exposure_ratio_threshold: Optional[float] = None,
                             ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Generate valid (i, j) index pairs for exposure comparison, based on a minimum ratio threshold.

    Args:
        increasing_exposure_values (Tensor): Shape (N,) exposure values in an increasing order.
        exposure_ratio_threshold (float, optional): Minimum exposure ratio to accept a pair.

    Returns:
        Tuple of:
            - i_idx: (P,) indices of first images in valid pairs (i < j)
            - j_idx: (P,) indices of second images
            - ratio_pairs: (P,) exposure ratios exposure[i] / exposure[j]
    """

    N = increasing_exposure_values.shape[0]
    device = increasing_exposure_values.device
    ratios = increasing_exposure_values.view(N, 1) / increasing_exposure_values.view(1, N)  # (N, N)
    i_idx, j_idx = torch.triu_indices(N, N, offset=1)
    i_idx = i_idx.to(device=device)
    j_idx = j_idx.to(device=device)
    ratio_pairs = ratios[i_idx, j_idx]  # (P,)

    if exposure_ratio_threshold is not None:
        mask = ratio_pairs >= exposure_ratio_threshold
        i_idx = i_idx[mask]
        j_idx = j_idx[mask]
        ratio_pairs = ratio_pairs[mask]

    return i_idx, j_idx, ratio_pairs

group_frame_settings_by_attributes(list_of_frame_settings, attributes)

Sort FrameSettings objects into separate groups based on the values of the given attributes. Args: list_of_frame_settings: list of the FrameSettings to sort. attributes: the attributes to base the grouping on.

Returns:

Type Description
List[Tuple[dict[str, str | float | int], List[FrameSettings]]]

List of tuples, the first item in the tuple containing a dictionary of the attributes used to generate that

List[Tuple[dict[str, str | float | int], List[FrameSettings]]]

group. The second item in the tuple contains a list of the FrameSettings objects belonging to that group.

Source code in clair_torch/common/file_settings.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
@typechecked
def group_frame_settings_by_attributes(list_of_frame_settings: List[FrameSettings],
                                       attributes: dict[str, None | int | float]) \
        -> List[Tuple[dict[str, str | float | int], List[FrameSettings]]]:
    """
    Sort FrameSettings objects into separate groups based on the values of the given attributes.
    Args:
        list_of_frame_settings: list of the FrameSettings to sort.
        attributes: the attributes to base the grouping on.

    Returns:
        List of tuples, the first item in the tuple containing a dictionary of the attributes used to generate that
        group. The second item in the tuple contains a list of the FrameSettings objects belonging to that group.
    """

    list_of_grouped_frame_settings = []
    list_of_group_metas = []

    for frame_settings in list_of_frame_settings:

        current_metas = frame_settings.get_all_metadata()

        # Generate the first group automatically.
        if not list_of_grouped_frame_settings:

            group_list = [frame_settings]
            list_of_grouped_frame_settings.append(group_list)
            list_of_group_metas.append({k: current_metas[k] for k in attributes if k in current_metas})
            continue

        # Loop through the current groups and check if the current frame_settings fits into any of them.
        number_of_groups = len(list_of_grouped_frame_settings)
        for i, group_list in enumerate(list_of_grouped_frame_settings):

            current_target_metadata = group_list[0].metadata
            candidate_metadata = frame_settings.metadata

            # Add to existing group.
            if current_target_metadata.is_match(candidate_metadata, attributes):
                group_list.append(frame_settings)
                break
            # Generate a new group and add to it.
            if number_of_groups - 1 - i == 0:
                additional_group_list = [frame_settings]
                list_of_grouped_frame_settings.append(additional_group_list)
                list_of_group_metas.append({k: current_metas[k] for k in attributes if k in current_metas})
                break

    return list(zip(list_of_group_metas, list_of_grouped_frame_settings))

load_icrf_txt(path, source_channel_order=ChannelOrder.BGR, source_dimension_order=DimensionOrder.BSC)

Utility function for loading an inverse camera response function from a .txt file. Expects a 2D NumPy array with shape (N, C), with N representing the number of datapoints Args: path: path to the text file containing the ICRF data. source_channel_order: the order in which the channels are expected to be in the file. source_dimension_order: the order in which the dimensions are expected to be in the file.

Returns:

Type Description
Tensor

torch.Tensor representing the ICRF.

Source code in clair_torch/common/data_io.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@typechecked
def load_icrf_txt(path: str | Path, source_channel_order: ChannelOrder = ChannelOrder.BGR,
                  source_dimension_order: DimensionOrder = DimensionOrder.BSC) -> torch.Tensor:
    """
    Utility function for loading an inverse camera response function from a .txt file. Expects a 2D NumPy array
    with shape (N, C), with N representing the number of datapoints
    Args:
        path: path to the text file containing the ICRF data.
        source_channel_order: the order in which the channels are expected to be in the file.
        source_dimension_order: the order in which the dimensions are expected to be in the file.

    Returns:
        torch.Tensor representing the ICRF.
    """
    if isinstance(path, str):
        path = Path(path)

    validate_input_file_path(path, suffix=".txt")

    try:
        data = torch.from_numpy(np.loadtxt(path)).float()  # shape: (256, 3), BGR order
    except Exception as e:
        raise IOError(f'Failed to load NumPy array from {path}: {e}')

    if source_dimension_order == DimensionOrder.BCS:
        pass
    elif source_dimension_order == DimensionOrder.BSC:
        data = torch.transpose(data, 0, 1)

    # No change for any and RGB ordering.
    if source_channel_order == ChannelOrder.ANY or source_channel_order == ChannelOrder.RGB:
        pass
    # Reverse BGR order into RGB order.
    elif source_channel_order == ChannelOrder.BGR:
        data = data[[2, 1, 0], :]
    # Raise value error for unknown channel ordering.
    else:
        raise ValueError(f"Unknown channel order {source_channel_order}:")

    return data

load_image(file_path, transforms=None)

Generic function to load a single image from the given path. Allows also the definition of transformations to be performed on the image before returning it upstream. Args: file_path: path to the image file. transforms: single Transform or Iterable of Transforms to be performed on the image before returning it.

Returns:

Type Description
Tensor

The loaded and possibly operated image in Tensor format.

Source code in clair_torch/common/data_io.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@typechecked
def load_image(file_path: str | Path,
               transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None) -> torch.Tensor:
    """
    Generic function to load a single image from the given path. Allows also the definition of transformations to be
    performed on the image before returning it upstream.
    Args:
        file_path: path to the image file.
        transforms: single Transform or Iterable of Transforms to be performed on the image before returning it.

    Returns:
        The loaded and possibly operated image in Tensor format.
    """
    file_path = Path(file_path)
    validate_input_file_path(file_path, suffix=None)

    transforms = normalize_container(transforms)

    try:
        image = cv.imread(str(file_path), cv.IMREAD_UNCHANGED)
    except Exception as e:
        raise IOError(f"Failed to load NumPy array from {file_path}: {e}")

    image = torch.from_numpy(image)
    image = cv_to_torch(image)

    if transforms:
        for transform in transforms:
            image = transform(image)

    return image

load_principal_components(file_paths)

Loads principal component data from text files, one file per color channel. The files in the input paths should be ordered in the desired channel order. E.g. for RGB images it should point to the red, green and blue files in order.

Parameters:

Name Type Description Default
file_paths list[str | Path]

List of paths to the .txt files (one per channel).

required

Returns:

Type Description
Tensor

A torch.Tensor of shape (n_points, n_components, channels).

Source code in clair_torch/common/data_io.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@typechecked
def load_principal_components(file_paths: list[str | Path]) -> torch.Tensor:
    """
    Loads principal component data from text files, one file per color channel. The files in the input paths should be
    ordered in the desired channel order. E.g. for RGB images it should point to the red, green and blue files in order.

    Args:
        file_paths: List of paths to the .txt files (one per channel).

    Returns:
        A torch.Tensor of shape (n_points, n_components, channels).
    """
    pcs = []
    for path in file_paths:
        data = np.loadtxt(path)
        pcs.append(torch.tensor(data, dtype=torch.float32))  # Shape: (n_points, n_components)

    pcs_tensor = torch.stack(pcs, dim=2)  # Shape: (n_points, n_components, channels)
    return pcs_tensor

load_video_frames_generator(file_path, transforms=None)

Function for loading frames from a video file through a generator. Args: file_path: path to the video file. transforms: optional list of transform operations to perform on each frame before yielding them.

Returns:

Type Description
None

Generator, which yields torch.Tensors.

Source code in clair_torch/common/data_io.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@typechecked
def load_video_frames_generator(file_path: str | Path,
                                transforms: Optional[BaseTransform | Iterable[BaseTransform]] = None) \
                                -> Generator[torch.Tensor | None, None, None]:
    """
    Function for loading frames from a video file through a generator.
    Args:
        file_path: path to the video file.
        transforms: optional list of transform operations to perform on each frame before yielding them.

    Returns:
        Generator, which yields torch.Tensors.
    """
    validate_input_file_path(file_path, suffix=None)

    transforms = normalize_container(transforms)

    cap = cv.VideoCapture(str(file_path))
    success, frame = cap.read()

    while success:
        tensor_frame = torch.from_numpy(frame)
        tensor_frame = cv_to_torch(tensor_frame)

        if transforms:
            for transform in transforms:
                tensor_frame = transform(tensor_frame)

        yield tensor_frame
        success, frame = cap.read()

    cap.release()

normalize_tensor(x, max_val=None, min_val=None, target_range=(0.0, 1.0))

Normalize a tensor by a given min and max value. If not provided, uses the min and max of the tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
max_val Optional[float]

Optional maximum value for normalization.

None
min_val Optional[float]

Optional minimum value for normalization.

None
target_range tuple[float, float]

Tuple specifying the (min, max) target range.

(0.0, 1.0)

Returns:

Type Description
Tensor

The normalized tensor.

Source code in clair_torch/common/general_functions.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def normalize_tensor(x: torch.Tensor, max_val: Optional[float] = None, min_val: Optional[float] = None,
                     target_range: tuple[float, float] = (0.0, 1.0)) -> torch.Tensor:
    """
    Normalize a tensor by a given min and max value. If not provided, uses the min and max of the tensor.

    Args:
        x: Input tensor.
        max_val: Optional maximum value for normalization.
        min_val: Optional minimum value for normalization.
        target_range: Tuple specifying the (min, max) target range.

    Returns:
        The normalized tensor.
    """
    max_val = x.max() if max_val is None else max_val
    min_val = x.min() if min_val is None else min_val

    denominator = max_val - min_val
    if denominator == 0:
        raise ValueError("Normalization range is zero (min == max); cannot normalize.")

    # Normalize to [0, 1]
    x_normalized = (x - min_val) / denominator

    # Scale to [target_min, target_max]
    min_target, max_target = target_range
    target_span = max_target - min_target
    x_scaled = x_normalized * target_span + min_target

    return x_scaled

save_icrf_txt(icrf, path, target_channel_order=ChannelOrder.BGR, target_dimension_order=DimensionOrder.BSC)

Utility function to save an ICRF into a .txt file of the given filepath. Args: icrf: the ICRF tensor. target_channel_order: the order in which the channels are to be in the saved data. target_dimension_order: the order in which the dimensions are to be in the saved data. path: the filepath where to save the file.

Returns:

Type Description
None

None

Source code in clair_torch/common/data_io.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@typechecked
def save_icrf_txt(icrf: torch.Tensor, path: str | Path, target_channel_order: ChannelOrder = ChannelOrder.BGR,
                  target_dimension_order: DimensionOrder = DimensionOrder.BSC) -> None:
    """
    Utility function to save an ICRF into a .txt file of the given filepath.
    Args:
        icrf: the ICRF tensor.
        target_channel_order: the order in which the channels are to be in the saved data.
        target_dimension_order: the order in which the dimensions are to be in the saved data.
        path: the filepath where to save the file.

    Returns:
        None
    """
    path = Path(path)
    if not is_potentially_valid_file_path(path):
        raise IOError(f"Invalid path for your OS: {path}")

    data = icrf.detach().cpu()
    if target_channel_order == ChannelOrder.RGB:
        pass
    elif target_channel_order == ChannelOrder.BGR:
        data = data[[2, 1, 0], :]

    if target_dimension_order == DimensionOrder.BCS:
        pass
    elif target_dimension_order == DimensionOrder.BSC:
        data = torch.transpose(data, 0, 1)

    data = data.numpy()

    try:
        np.savetxt(path, data)
    except Exception:
        raise IOError(f"Couldn't save data to path {path}")

    return

save_image(tensor, image_save_path, dtype=np.dtype('float64'), params=None)

Save a PyTorch tensor as a 32-bit float per channel TIFF image.

Parameters:

Name Type Description Default
tensor Tensor

A PyTorch tensor of shape (C, H, W) or (H, W), dtype float32.

required
image_save_path str | Path

Path to save the image.

required
dtype dtype

the NumPy datatype to use to save the image.

dtype('float64')
params Optional[Sequence[int]]

Sequence of params to pass to OpenCV imwrite function.

None
Source code in clair_torch/common/data_io.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
@typechecked
def save_image(tensor: torch.Tensor, image_save_path: str | Path, dtype: np.dtype = np.dtype("float64"),
               params: Optional[Sequence[int]] = None) -> None:
    """
    Save a PyTorch tensor as a 32-bit float per channel TIFF image.

    Args:
        tensor: A PyTorch tensor of shape (C, H, W) or (H, W), dtype float32.
        image_save_path: Path to save the image.
        dtype: the NumPy datatype to use to save the image.
        params: Sequence of params to pass to OpenCV imwrite function.
    """
    if not is_potentially_valid_file_path(image_save_path):
        raise IOError(f"Invalid path for your OS: {image_save_path}")

    image_save_path.parent.mkdir(parents=True, exist_ok=True)
    if not image_save_path.parent.exists():
        raise IOError(f"Couldn't create the directory structure for path {image_save_path}")

    if tensor.is_cuda:
        tensor = tensor.cpu()

    array = tensor.detach().numpy().astype(dtype=dtype)

    # Handle 3-channel (C, H, W) case: convert to (H, W, C) and reorder to BGR
    if array.ndim == 3:
        array = np.transpose(array, (1, 2, 0))  # (H, W, C)
        if array.shape[2] == 3:
            array = array[:, :, [2, 1, 0]]  # RGB → BGR

    success = cv.imwrite(str(image_save_path), array, params or [])
    if not success:
        raise IOError(f"Failed to save image to {image_save_path}")

weighted_mean_and_std(values, weights=None, mask=None, dim=None, keepdim=False, eps=1e-08, compute_std=True)

Computes the weighted mean and variance of values, with optional weights and boolean mask. Args: values (Tensor): Input tensor. weights (Tensor or None): Optional weights, broadcastable to values. mask (Tensor or None): Optional boolean mask, where True = valid value. dim (int or tuple of ints): Axis or axes to reduce over. keepdim (bool): Keep reduced dimensions. eps (float): Small value to avoid division by zero. compute_std: whether to compute std or not.

Returns:

Type Description
(mean, variance)

Tuple of tensors, each of the reduced shape.

Source code in clair_torch/common/general_functions.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
@typechecked
def weighted_mean_and_std(
        values: torch.Tensor, weights: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
        dim=None, keepdim=False, eps:  int | float = 1e-8, compute_std: Optional[bool] = True) \
        -> tuple[torch.Tensor, torch.Tensor | None]:
    """
    Computes the weighted mean and variance of `values`, with optional weights and boolean mask.
    Args:
        values (Tensor): Input tensor.
        weights (Tensor or None): Optional weights, broadcastable to `values`.
        mask (Tensor or None): Optional boolean mask, where True = valid value.
        dim (int or tuple of ints): Axis or axes to reduce over.
        keepdim (bool): Keep reduced dimensions.
        eps (float): Small value to avoid division by zero.
        compute_std: whether to compute std or not.

    Returns:
        (mean, variance): Tuple of tensors, each of the reduced shape.
    """
    std = None

    if mask is not None:

        mask = mask.to(dtype=values.dtype)
        values = values * mask
        weights = weights * mask if weights is not None else mask

    if weights is None:

        mean = values.mean(dim=dim, keepdim=True)

        if compute_std:
            sq_diff = (values - mean) ** 2
            std = torch.sqrt(sq_diff.mean(dim=dim, keepdim=True))

    else:

        weighted_sum = (values * weights).sum(dim=dim, keepdim=True)
        total_weight = weights.sum(dim=dim, keepdim=True).clamp(min=eps)

        # Handle edge case: total weight is zero (e.g. all weights are 0 or all masked)
        zero_weight_mask = (total_weight == 0)
        safe_weight = total_weight.clamp(min=eps)

        mean = weighted_sum / safe_weight

        if compute_std:
            sq_diff = (values - mean) ** 2
            weighted_sq_diff = (sq_diff * weights).sum(dim=dim, keepdim=True)
            std = torch.sqrt(weighted_sq_diff / safe_weight)

        # Set mean and std to 0 where total weight is zero
        mean = torch.where(zero_weight_mask, torch.zeros_like(mean), mean)
        if compute_std:
            std = torch.where(zero_weight_mask, torch.zeros_like(std), std)

    if not keepdim:
        mean = mean.squeeze(dim) if dim is not None else mean.squeeze()
        if compute_std:
            std = std.squeeze(dim) if dim is not None else std.squeeze()

    return mean, std