Training
This page acts as the technical reference for the training subpackage.
The training subpackage provides a training loop function for the ICRF models. In addition to the training loop, it also provides the various functions for loss computation used in the training loop.
combined_gaussian_pair_weights(image_stack, i_idx, j_idx, scale=10.0)
Compute combined Gaussian weights for image pairs by summing weights from each image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_stack
|
Tensor
|
stack of images, ndim >= 2. Usual shape (N, C, H, W) |
required |
i_idx
|
Tensor
|
Index tensor of shape (P,) for first image in each pair, ndim=1. |
required |
j_idx
|
Tensor
|
Index tensor of shape (P,) for second image in each pair, ndim=1. |
required |
scale
|
Optional[float]
|
Sharpness of Gaussian weight. |
10.0
|
Returns:
Name | Type | Description |
---|---|---|
combined_weights |
Tensor
|
Tensor of shape (P, C, H, W) for input of (N, C, H, W). |
Source code in clair_torch/training/losses.py
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
|
compute_endpoint_penalty(curve, per_channel=False)
Get a penalty term for a function (tensor), whose endpoints are not exactly 0 and 1. The penalty is defined as the sum of the squares of the deviations at start and end from 0 and 1 respectively. Args: curve: curve: the function to determine a penalty term for. Shape (D, C) with D and C representing the number of datapoints in the function and C representing the number of channels. per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor. Returns: Penalty terms as scalar tensor, or one element per channel.
Source code in clair_torch/training/losses.py
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
|
compute_monotonicity_penalty(curve, squared=True, per_channel=False)
Get a penalty term for a function (tensor) if it is not monotonically increasing. Args: curve: the function to determine a penalty term for. squared: whether to square the penalty terms. per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor. Returns: Penalty terms as scalar tensor, or one element per channel.
Source code in clair_torch/training/losses.py
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
|
compute_range_penalty(curve, epsilon=1e-06, per_channel=False)
Get a penalty term for a function (tensor) if it breaks [0, 1] value range. Args: curve: the function to determine a penalty term for. epsilon: small epsilon to avoid vanishing gradients. per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor.
Returns:
Type | Description |
---|---|
Tensor
|
Penalty terms as scalar tensor, or one element per channel. |
Source code in clair_torch/training/losses.py
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|
compute_smoothness_penalty(curve, per_channel=False)
Get a penalty term for a function (tensor) if it is not sufficiently smooth. Args: curve: per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor. Returns: Penalty terms as scalar tensor, or one element per channel.
Source code in clair_torch/training/losses.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
|
compute_spatial_linearity_loss(pixelwise_losses, pixelwise_errors=None, external_weights=None, valid_mask=None, use_uncertainty_weighting=True)
Computes the spatial mean, standard deviation and uncertainty based on the given pixelwise losses and possible pixelwise uncertainties. External weights and a valid mask can be used to modify the computation as needed. Args: pixelwise_losses: tensor containing the pixelwise linearity losses. pixelwise_errors: tensor containing the uncertainty of the pixelwise linearity losses. external_weights: tensor containing external weight values for mean computation. valid_mask: boolean tensor representing pixel positions that are valid for the computation. use_uncertainty_weighting: whether to utilize inverse uncertainty based weights.
Returns:
Type | Description |
---|---|
Tensor
|
Tuple representing (the spatial mean of linearity loss, the spatial standard deviation of linearity loss, |
Tensor
|
spatial uncertainty of the linearity loss). |
Source code in clair_torch/training/losses.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
|
gaussian_value_weights(image, scale=30.0)
Compute Gaussian weights for an image based on intensity proximity to 0.5.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image
|
Tensor
|
torch.Tensor of shape (N, C, H, W), values in [0, 1] |
required |
scale
|
Optional[float]
|
float, controls sharpness of the Gaussian (default 30) |
30.0
|
Returns:
Name | Type | Description |
---|---|---|
weight |
Tensor
|
torch.Tensor of shape (N, C, H, W) |
Source code in clair_torch/training/losses.py
193 194 195 196 197 198 199 200 201 202 203 204 205 |
|
pixelwise_linearity_loss(image_value_stack, i_idx, j_idx, ratio_pairs, image_std_stack=None, use_relative=True)
Compute a differentiable linearity loss for a stack of images taken at different exposures. Args: image_value_stack: stack of images, ndim >= 2, shape (N, C, H, W), values in [0, 1] i_idx: Index tensor of shape (P,) for first image in each pair, ndim=1. j_idx: Index tensor of shape (P,) for second image in each pair, ndim=1. ratio_pairs: Tensor of shape (P,) for exposure ratios exposure[i] / exposure[j] image_std_stack: optional standard deviation estimates, same shape as image_value_stack use_relative: compute relative difference instead of absolute
Returns:
Type | Description |
---|---|
tuple[Tensor, Tensor | None]
|
A scalar loss (mean deviation from expected linearity across valid pixels and channel pairs) |
Source code in clair_torch/training/losses.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
train_icrf(dataloader, batch_size, device, icrf_model, optimizers=None, schedulers=None, use_relative_linearity_loss=True, use_uncertainty_weighting=True, epochs=150, patience=300, alpha=1.0, beta=1.0, gamma=1.0, delta=1.0, lower_valid_threshold=1 / 255, upper_valid_threshold=254 / 255, exposure_ratio_threshold=0.1)
Training loop for the ICRF model. Requires a dataloader to yield batches of images for linearization and loss computation, an initialized ICRF model with a matching number of channels and pixel values. Args: dataloader: ImageDataset object for the value images. batch_size: the size of the batches associated with the dataloader. device: the device to perform the training on. icrf_model: initialized ICRFModel object. optimizers: list of PyTorch optimizers to use. Must either be a single optimizer or one for each channel. If None is given, then initializes default optimizers for each channel. schedulers: list of PyTorch learning rate schedulers. None, or list of Schedulers and Nones of equal length to optimizers. use_relative_linearity_loss: whether to use relative (True) or absolute values (False) for linearity loss computation. use_uncertainty_weighting: whether to utilize inverse uncertainty based weights in loss computation. epochs: number of epochs to train the model. patience: number of epochs until early stopping when loss does not improve. alpha: coefficient for monotonicity penalty term. beta: coefficient for range penalty term. gamma: coefficient for endpoint penalty term. delta: coefficient for smoothness penalty term. lower_valid_threshold: lower exclusive bound for considering a pixel as valid. upper_valid_threshold: upper exclusive bound for considering a pixel as valid. exposure_ratio_threshold: threshold for rejecting image pairs based on the ratio of the shorter exposure time against the longer exposure time of a pair of images. Should be between [0.0, 1.0]. Pairs with values lower than the threshold value are rejected from the linearity loss computation.
Returns:
Type | Description |
---|---|
ICRFModelBase
|
The trained model. |
Source code in clair_torch/training/icrf_training.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
|