Skip to content

Training

This page acts as the technical reference for the training subpackage.

The training subpackage provides a training loop function for the ICRF models. In addition to the training loop, it also provides the various functions for loss computation used in the training loop.

combined_gaussian_pair_weights(image_stack, i_idx, j_idx, scale=10.0)

Compute combined Gaussian weights for image pairs by summing weights from each image.

Parameters:

Name Type Description Default
image_stack Tensor

stack of images, ndim >= 2. Usual shape (N, C, H, W)

required
i_idx Tensor

Index tensor of shape (P,) for first image in each pair, ndim=1.

required
j_idx Tensor

Index tensor of shape (P,) for second image in each pair, ndim=1.

required
scale Optional[float]

Sharpness of Gaussian weight.

10.0

Returns:

Name Type Description
combined_weights Tensor

Tensor of shape (P, C, H, W) for input of (N, C, H, W).

Source code in clair_torch/training/losses.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def combined_gaussian_pair_weights(
    image_stack: torch.Tensor,
    i_idx: torch.Tensor,
    j_idx: torch.Tensor,
    scale: Optional[float] = 10.0
) -> torch.Tensor:
    """
    Compute combined Gaussian weights for image pairs by summing weights from each image.

    Args:
        image_stack: stack of images, ndim >= 2. Usual shape (N, C, H, W)
        i_idx: Index tensor of shape (P,) for first image in each pair, ndim=1.
        j_idx: Index tensor of shape (P,) for second image in each pair, ndim=1.
        scale: Sharpness of Gaussian weight.

    Returns:
        combined_weights: Tensor of shape (P, C, H, W) for input of (N, C, H, W).
    """
    validate_multiple_dimensions([i_idx, j_idx], [1, 1])

    image_i = image_stack[i_idx]  # (P, C, H, W)
    image_j = image_stack[j_idx]  # (P, C, H, W)

    weights_i = gaussian_value_weights(image_i, scale)
    weights_j = gaussian_value_weights(image_j, scale)

    combined_weights = weights_i + weights_j
    return combined_weights

compute_endpoint_penalty(curve, per_channel=False)

Get a penalty term for a function (tensor), whose endpoints are not exactly 0 and 1. The penalty is defined as the sum of the squares of the deviations at start and end from 0 and 1 respectively. Args: curve: curve: the function to determine a penalty term for. Shape (D, C) with D and C representing the number of datapoints in the function and C representing the number of channels. per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor. Returns: Penalty terms as scalar tensor, or one element per channel.

Source code in clair_torch/training/losses.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def compute_endpoint_penalty(curve: torch.Tensor, per_channel: Optional[bool] = False) -> torch.Tensor:
    """
    Get a penalty term for a function (tensor), whose endpoints are not exactly 0 and 1. The penalty is defined as the
    sum of the squares of the deviations at start and end from 0 and 1 respectively.
    Args:
        curve: curve: the function to determine a penalty term for. Shape (D, C) with D and C representing the number
            of datapoints in the function and C representing the number of channels.
        per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor.
    Returns:
        Penalty terms as scalar tensor, or one element per channel.
    """
    if curve.ndim == 1:
        curve = curve.unsqueeze(1)
    validate_dimensions(curve, (1, 2), raise_error=True)

    penalty = (curve[:, 0] - 0) ** 2 + (curve[:, -1] - 1) ** 2  # - 0 to emphasize the definition of the loss.
    if per_channel:
        return penalty
    return torch.sum(penalty)

compute_monotonicity_penalty(curve, squared=True, per_channel=False)

Get a penalty term for a function (tensor) if it is not monotonically increasing. Args: curve: the function to determine a penalty term for. squared: whether to square the penalty terms. per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor. Returns: Penalty terms as scalar tensor, or one element per channel.

Source code in clair_torch/training/losses.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def compute_monotonicity_penalty(curve: torch.Tensor, squared=True, per_channel: bool = False) -> torch.Tensor:
    """
    Get a penalty term for a function (tensor) if it is not monotonically increasing.
    Args:
        curve: the function to determine a penalty term for.
        squared: whether to square the penalty terms.
        per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor.
    Returns:
        Penalty terms as scalar tensor, or one element per channel.
    """
    df = curve[:, 1:] - curve[:, :-1]  # Shape: (N-1, C)
    mask = (df <= 0).float()  # 1.0 where non-strictly increasing
    if squared:
        relu_neg = mask * df.pow(2)  # penalize squared difference
    else:
        relu_neg = mask * (-df)  # penalize linearly

    penalty = relu_neg.sum(dim=1)  # Shape: (C,)

    if per_channel:
        return penalty
    return torch.sum(penalty)

compute_range_penalty(curve, epsilon=1e-06, per_channel=False)

Get a penalty term for a function (tensor) if it breaks [0, 1] value range. Args: curve: the function to determine a penalty term for. epsilon: small epsilon to avoid vanishing gradients. per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor.

Returns:

Type Description
Tensor

Penalty terms as scalar tensor, or one element per channel.

Source code in clair_torch/training/losses.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def compute_range_penalty(curve: torch.Tensor, epsilon: float = 1e-6, per_channel: bool = False) -> torch.Tensor:
    """
    Get a penalty term for a function (tensor) if it breaks [0, 1] value range.
    Args:
        curve: the function to determine a penalty term for.
        epsilon: small epsilon to avoid vanishing gradients.
        per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor.

    Returns:
        Penalty terms as scalar tensor, or one element per channel.
    """

    lower = torch.relu(-curve)      # Penalize values under 0
    upper = torch.relu(curve - 1)   # Penalize values over 1
    penalty = (lower + upper).sum(dim=1)  # Sum over 256, shape: [3]

    if per_channel:
        return penalty
    return torch.sum(penalty)

compute_smoothness_penalty(curve, per_channel=False)

Get a penalty term for a function (tensor) if it is not sufficiently smooth. Args: curve: per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor. Returns: Penalty terms as scalar tensor, or one element per channel.

Source code in clair_torch/training/losses.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def compute_smoothness_penalty(curve: torch.Tensor, per_channel: bool = False) -> torch.Tensor:
    """
    Get a penalty term for a function (tensor) if it is not sufficiently smooth.
    Args:
        curve:
        per_channel: whether to return a per-channel loss, or sum the channel losses into a scalar tensor.
    Returns:
        Penalty terms as scalar tensor, or one element per channel.
    """
    second_diff = curve[:, :-2] - 2 * curve[:, 1:-1] + curve[:, 2:]
    penalty = second_diff.pow(2).sum(dim=1)  # Shape: [3]
    if per_channel:
        return penalty
    return torch.sum(penalty)

compute_spatial_linearity_loss(pixelwise_losses, pixelwise_errors=None, external_weights=None, valid_mask=None, use_uncertainty_weighting=True)

Computes the spatial mean, standard deviation and uncertainty based on the given pixelwise losses and possible pixelwise uncertainties. External weights and a valid mask can be used to modify the computation as needed. Args: pixelwise_losses: tensor containing the pixelwise linearity losses. pixelwise_errors: tensor containing the uncertainty of the pixelwise linearity losses. external_weights: tensor containing external weight values for mean computation. valid_mask: boolean tensor representing pixel positions that are valid for the computation. use_uncertainty_weighting: whether to utilize inverse uncertainty based weights.

Returns:

Type Description
Tensor

Tuple representing (the spatial mean of linearity loss, the spatial standard deviation of linearity loss,

Tensor

spatial uncertainty of the linearity loss).

Source code in clair_torch/training/losses.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def compute_spatial_linearity_loss(pixelwise_losses: torch.Tensor, pixelwise_errors: Optional[torch.Tensor] = None,
                                   external_weights: Optional[torch.Tensor] = None,
                                   valid_mask: Optional[torch.Tensor] = None, use_uncertainty_weighting: bool = True) \
        -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
    """
    Computes the spatial mean, standard deviation and uncertainty based on the given pixelwise losses and
    possible pixelwise uncertainties. External weights and a valid mask can be used to modify the computation
    as needed.
    Args:
        pixelwise_losses: tensor containing the pixelwise linearity losses.
        pixelwise_errors: tensor containing the uncertainty of the pixelwise linearity losses.
        external_weights: tensor containing external weight values for mean computation.
        valid_mask: boolean tensor representing pixel positions that are valid for the computation.
        use_uncertainty_weighting: whether to utilize inverse uncertainty based weights.

    Returns:
        Tuple representing (the spatial mean of linearity loss, the spatial standard deviation of linearity loss,
        spatial uncertainty of the linearity loss).
    """

    # Stage the use of weighting. Uncertainty weighting utilizes the computed uncertainties as inverse weights
    # prioritize values with smaller uncertainties. Possible external weights are added to the possible
    # uncertainty-based weights.
    if pixelwise_errors is not None or external_weights is not None:
        weights = torch.zeros_like(pixelwise_losses)
        if pixelwise_errors is not None and use_uncertainty_weighting:
            weights = weights + (1 / (pixelwise_errors + 1e-6))
        if external_weights is not None:
            weights = weights + external_weights
    else:
        weights = None

    spatial_linearity_loss, spatial_linearity_loss_std = weighted_mean_and_std(pixelwise_losses, weights=weights, mask=valid_mask, dim=(2, 3))
    if pixelwise_errors is not None:
        spatial_linearity_loss_error, _ = weighted_mean_and_std(pixelwise_errors, mask=valid_mask, dim=(2, 3))
    else:
        spatial_linearity_loss_error = None

    return spatial_linearity_loss, spatial_linearity_loss_std, spatial_linearity_loss_error

gaussian_value_weights(image, scale=30.0)

Compute Gaussian weights for an image based on intensity proximity to 0.5.

Parameters:

Name Type Description Default
image Tensor

torch.Tensor of shape (N, C, H, W), values in [0, 1]

required
scale Optional[float]

float, controls sharpness of the Gaussian (default 30)

30.0

Returns:

Name Type Description
weight Tensor

torch.Tensor of shape (N, C, H, W)

Source code in clair_torch/training/losses.py
193
194
195
196
197
198
199
200
201
202
203
204
205
def gaussian_value_weights(image: torch.Tensor, scale: Optional[float] = 30.0) -> torch.Tensor:
    """
    Compute Gaussian weights for an image based on intensity proximity to 0.5.

    Args:
        image: torch.Tensor of shape (N, C, H, W), values in [0, 1]
        scale: float, controls sharpness of the Gaussian (default 30)

    Returns:
        weight: torch.Tensor of shape (N, C, H, W)
    """

    return torch.exp(-scale * (image - 0.5) ** 2)

pixelwise_linearity_loss(image_value_stack, i_idx, j_idx, ratio_pairs, image_std_stack=None, use_relative=True)

Compute a differentiable linearity loss for a stack of images taken at different exposures. Args: image_value_stack: stack of images, ndim >= 2, shape (N, C, H, W), values in [0, 1] i_idx: Index tensor of shape (P,) for first image in each pair, ndim=1. j_idx: Index tensor of shape (P,) for second image in each pair, ndim=1. ratio_pairs: Tensor of shape (P,) for exposure ratios exposure[i] / exposure[j] image_std_stack: optional standard deviation estimates, same shape as image_value_stack use_relative: compute relative difference instead of absolute

Returns:

Type Description
tuple[Tensor, Tensor | None]

A scalar loss (mean deviation from expected linearity across valid pixels and channel pairs)

Source code in clair_torch/training/losses.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def pixelwise_linearity_loss(
    image_value_stack: torch.Tensor,     # (N, C, H, W)
    i_idx: torch.Tensor,                 # (P,)
    j_idx: torch.Tensor,                 # (P,)
    ratio_pairs: torch.Tensor,           # (P,)
    image_std_stack: Optional[torch.Tensor] = None,
    use_relative: bool = True,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """
    Compute a differentiable linearity loss for a stack of images taken at different exposures.
    Args:
        image_value_stack: stack of images, ndim >= 2, shape (N, C, H, W), values in [0, 1]
        i_idx: Index tensor of shape (P,) for first image in each pair, ndim=1.
        j_idx: Index tensor of shape (P,) for second image in each pair, ndim=1.
        ratio_pairs: Tensor of shape (P,) for exposure ratios exposure[i] / exposure[j]
        image_std_stack: optional standard deviation estimates, same shape as image_value_stack
        use_relative: compute relative difference instead of absolute

    Returns:
        A scalar loss (mean deviation from expected linearity across valid pixels and channel pairs)
    """

    # Gather image pairs and reshape to (P, C, H, W)
    I_i = image_value_stack[i_idx]  # (P, C, H, W)
    I_j = image_value_stack[j_idx]  # (P, C, H, W)

    # Apply ratio: expect I_i ≈ I_j * ratio
    expected = I_j * ratio_pairs.view(-1, 1, 1, 1)

    diff = I_i - expected
    if use_relative:
        # Avoid division by zero — small epsilon
        expected_safe = expected + 1e-6
        diff = diff / expected_safe

    abs_diff = diff.abs()

    if image_std_stack is not None:
        std_i = image_std_stack[i_idx]
        std_j = image_std_stack[j_idx]
        if use_relative:
            eps = 1e-6
            # expected_safe = expected.clamp(min=eps)
            I_j_safe = I_j.clamp(min=eps)

            term1 = (std_i / expected_safe) ** 2
            term2 = ((I_i * std_j) / (expected_safe * I_j_safe)) ** 2

            std = torch.sqrt(term1 + term2 + eps)
        else:
            std = torch.sqrt(std_i ** 2 + (ratio_pairs.view(-1, 1, 1, 1) * std_j) ** 2)
    else:
        std = None

    return abs_diff, std

train_icrf(dataloader, batch_size, device, icrf_model, optimizers=None, schedulers=None, use_relative_linearity_loss=True, use_uncertainty_weighting=True, epochs=150, patience=300, alpha=1.0, beta=1.0, gamma=1.0, delta=1.0, lower_valid_threshold=1 / 255, upper_valid_threshold=254 / 255, exposure_ratio_threshold=0.1)

Training loop for the ICRF model. Requires a dataloader to yield batches of images for linearization and loss computation, an initialized ICRF model with a matching number of channels and pixel values. Args: dataloader: ImageDataset object for the value images. batch_size: the size of the batches associated with the dataloader. device: the device to perform the training on. icrf_model: initialized ICRFModel object. optimizers: list of PyTorch optimizers to use. Must either be a single optimizer or one for each channel. If None is given, then initializes default optimizers for each channel. schedulers: list of PyTorch learning rate schedulers. None, or list of Schedulers and Nones of equal length to optimizers. use_relative_linearity_loss: whether to use relative (True) or absolute values (False) for linearity loss computation. use_uncertainty_weighting: whether to utilize inverse uncertainty based weights in loss computation. epochs: number of epochs to train the model. patience: number of epochs until early stopping when loss does not improve. alpha: coefficient for monotonicity penalty term. beta: coefficient for range penalty term. gamma: coefficient for endpoint penalty term. delta: coefficient for smoothness penalty term. lower_valid_threshold: lower exclusive bound for considering a pixel as valid. upper_valid_threshold: upper exclusive bound for considering a pixel as valid. exposure_ratio_threshold: threshold for rejecting image pairs based on the ratio of the shorter exposure time against the longer exposure time of a pair of images. Should be between [0.0, 1.0]. Pairs with values lower than the threshold value are rejected from the linearity loss computation.

Returns:

Type Description
ICRFModelBase

The trained model.

Source code in clair_torch/training/icrf_training.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@typechecked
def train_icrf(
        dataloader: DataLoader,
        batch_size: int,
        device: str | torch.device,
        icrf_model: ICRFModelBase,
        optimizers: Optional[list[Optimizer]] = None,
        schedulers: Optional[list[_LRScheduler | ReduceLROnPlateau | None]] = None,
        use_relative_linearity_loss: bool = True,
        use_uncertainty_weighting: bool = True,
        epochs: int = 150,
        patience: int = 300,
        alpha: float = 1.0,
        beta: float = 1.0,
        gamma: float = 1.0,
        delta: float = 1.0,
        lower_valid_threshold: float = 1/255,
        upper_valid_threshold: float = 254/255,
        exposure_ratio_threshold: float = 0.1
) -> ICRFModelBase:
    """
    Training loop for the ICRF model. Requires a dataloader to yield batches of images for linearization and loss
    computation, an initialized ICRF model with a matching number of channels and pixel values.
    Args:
        dataloader: ImageDataset object for the value images.
        batch_size: the size of the batches associated with the dataloader.
        device: the device to perform the training on.
        icrf_model: initialized ICRFModel object.
        optimizers: list of PyTorch optimizers to use. Must either be a single optimizer or one for each channel.
            If None is given, then initializes default optimizers for each channel.
        schedulers: list of PyTorch learning rate schedulers. None, or list of Schedulers and Nones of equal length to
            optimizers.
        use_relative_linearity_loss: whether to use relative (True) or absolute values (False) for linearity loss
            computation.
        use_uncertainty_weighting: whether to utilize inverse uncertainty based weights in loss computation.
        epochs: number of epochs to train the model.
        patience: number of epochs until early stopping when loss does not improve.
        alpha: coefficient for monotonicity penalty term.
        beta: coefficient for range penalty term.
        gamma: coefficient for endpoint penalty term.
        delta: coefficient for smoothness penalty term.
        lower_valid_threshold: lower exclusive bound for considering a pixel as valid.
        upper_valid_threshold: upper exclusive bound for considering a pixel as valid.
        exposure_ratio_threshold: threshold for rejecting image pairs based on the ratio of the shorter exposure time
            against the longer exposure time of a pair of images. Should be between [0.0, 1.0]. Pairs with values lower
            than the threshold value are rejected from the linearity loss computation.

    Returns:
        The trained model.
    """
    channels = icrf_model.channels

    if batch_size == 1:
        raise ValueError("Batch size must be larger than 1.")

    if optimizers is None:
        optimizers = [
            torch.optim.Adam(icrf_model.channel_params(c), lr=1e-3, amsgrad=False) for c in range(channels)
        ]

    previous_lrs = [pg['lr'] for opt in optimizers for pg in opt.param_groups]

    if schedulers is None:
        schedulers = [None] * len(optimizers)
    if len(schedulers) != len(optimizers):
        raise ValueError(f"Mismatched number of optimizers: {len(optimizers)} and schedulers: {len(schedulers)}.")

    best_losses = [float('inf')] * channels
    epochs_without_improvement = [0] * channels

    icrf_model.train()
    icrf_model.plot_icrf()

    for epoch in range(epochs):

        running_loss = torch.zeros(icrf_model.channels, device=icrf_model.icrf.get_device())

        for index_batch, val_batch, std_batch, meta_batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", total=batch_size):

            images = val_batch.to(device=device)
            stds = std_batch.to(device) if std_batch is not None else None
            exposures = meta_batch['exposure_time'].to(device=device)

            # Skip batch if number of images in batch is only one.
            if images.shape[0] < 2:
                print("Skipped batch due to single image.")
                continue

            i_idx, j_idx, ratio_pairs = get_valid_exposure_pairs(increasing_exposure_values=exposures,
                                                                 exposure_ratio_threshold=exposure_ratio_threshold)
            valid_mask = get_pairwise_valid_pixel_mask(images, i_idx, j_idx, stds,
                                                       val_lower=lower_valid_threshold, val_upper=upper_valid_threshold)
            gaussian_weight = combined_gaussian_pair_weights(images, i_idx, j_idx)

            for optimizer in optimizers:
                optimizer.zero_grad()

            images.requires_grad_(True)
            linearized = icrf_model(images)  # Shape: (N, C, H, W)

            if stds is not None:
                grads = torch.autograd.grad(
                    outputs=linearized,
                    inputs=images,
                    grad_outputs=torch.ones_like(linearized),
                    retain_graph=True
                )[0]
                linearized_stds = (grads * stds).abs()
            else:
                linearized_stds = None

            icrf_curve = icrf_model.icrf

            pixelwise_loss, pixelwise_errors = pixelwise_linearity_loss(linearized, i_idx, j_idx, ratio_pairs,
                                                                        linearized_stds, use_relative_linearity_loss)

            spatial_linearity_loss, _, _ = (
                compute_spatial_linearity_loss(pixelwise_loss, pixelwise_errors, gaussian_weight, valid_mask, use_uncertainty_weighting))

            linearity_loss = torch.sqrt((spatial_linearity_loss ** 2).sum(dim=0))

            monotonicity_loss = compute_monotonicity_penalty(icrf_curve, per_channel=True)
            range_loss = compute_range_penalty(icrf_curve, per_channel=True)
            endpoint_loss = compute_endpoint_penalty(icrf_curve, per_channel=True)
            smoothness_loss = compute_smoothness_penalty(icrf_curve, per_channel=True)

            loss = linearity_loss + alpha * monotonicity_loss + beta * range_loss + gamma * endpoint_loss + delta * smoothness_loss

            if len(optimizers) == 1:
                loss = torch.sum(loss)

            for c, optimizer in enumerate(optimizers):
                loss[c].backward(retain_graph=True)

            for c, optimizer in enumerate(optimizers):
                optimizer.step()

            icrf_model.update_icrf()

            running_loss += loss.detach()

        avg_loss = (running_loss / len(dataloader)).cpu().numpy()
        print(f"Epoch {epoch + 1} Loss: {avg_loss}")
        update_loss_plot(epoch, avg_loss)

        for c in range(channels):
            if avg_loss[c] < best_losses[c]:
                best_losses[c] = avg_loss[c]
                epochs_without_improvement[c] = 0
            else:
                epochs_without_improvement[c] += 1

        if all(epochs_without_improvement[c] >= patience for c in range(icrf_model.channels)):
            print(f"Early stopping triggered for all channels (patience = {patience} epochs).")
            break

        for c, scheduler in enumerate(schedulers):
            if scheduler is not None:
                scheduler.step(avg_loss[c])

        for i, optimizer in enumerate(optimizers):
            current_lr = optimizer.param_groups[0]['lr']
            if current_lr != previous_lrs[i]:
                print(f"Optimizer {i} learning rate changed to: {current_lr}")
            previous_lrs[i] = current_lr

        if (epoch + 1) % 5 == 0:
            icrf_model.plot_icrf()

    return icrf_model