Skip to content

Inference

This page acts as the technical reference for the inference subpackage.

The inference subpackage provides functionality that can be used to measure a camera's linearity, create HDR images, linearize single images and compute quantitatively well-defined mean and uncertainty images from a stack of images or a video.

compute_hdr_image(dataloader, device, icrf_model=None, weight_fn=None, flat_field_dataset=None, gpu_transforms=None, dark_field_dataset=None)

Function for computing HDR merging of a set of images at different exposure times under stationary conditions. Uncertainty can be computed if the dataset is provided with PairedFrameSettings instances that include a path to an uncertainty image. Flat field correction can be performed if an FlatFieldArtefactMapDataset is provided and a matching artefact image is found. Similarly, a dark field correction can be performed with a DarkFieldArtefactMapDataset. Args: dataloader: DataLoader containing an ImageMapDataset instance, which should contain FrameSettings or PairedFrameSettings instances. device: the device to run the computations on. icrf_model: an optional ICRF model to use to linearize the images before merging. weight_fn: a weighting function that takes as input the batch of images. flat_field_dataset: a FlatFieldArtefactMapDataset for flat field correction. gpu_transforms: Optional transform operations to be performed on the image batch after moving the data to the desired device. dark_field_dataset: a DarkFieldArtefactMapDataset for dark field correction.

Returns:

Type Description
tuple[Tensor, Tensor | None]

A tuple representing - the HDR image (tensor) - uncertainty of the HDR image (tensor or None).

Source code in clair_torch/inference/hdr_merge.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@typechecked
def compute_hdr_image(dataloader: DataLoader, device: str | torch.device,
                      icrf_model: Optional[ICRFModelBase] = None, weight_fn: Optional[Callable] = None,
                      flat_field_dataset: Optional[FlatFieldArtefactMapDataset] = None,
                      gpu_transforms: Optional[tr.BaseTransform | Iterable[tr.BaseTransform | None]] = None,
                      dark_field_dataset: Optional[DarkFieldArtefactMapDataset] = None) \
        -> tuple[torch.Tensor, torch.Tensor | None]:
    """
    Function for computing HDR merging of a set of images at different exposure times under stationary conditions.
    Uncertainty can be computed if the dataset is provided with PairedFrameSettings instances that include a path
    to an uncertainty image. Flat field correction can be performed if an FlatFieldArtefactMapDataset is provided and a
    matching artefact image is found. Similarly, a dark field correction can be performed with a
    DarkFieldArtefactMapDataset.
    Args:
        dataloader: DataLoader containing an ImageMapDataset instance, which should contain FrameSettings or
            PairedFrameSettings instances.
        device: the device to run the computations on.
        icrf_model: an optional ICRF model to use to linearize the images before merging.
        weight_fn: a weighting function that takes as input the batch of images.
        flat_field_dataset: a FlatFieldArtefactMapDataset for flat field correction.
        gpu_transforms: Optional transform operations to be performed on the image batch after moving the data to the
            desired device.
        dark_field_dataset: a DarkFieldArtefactMapDataset for dark field correction.

    Returns:
        A tuple representing
            - the HDR image (tensor)
            - uncertainty of the HDR image (tensor or None).
    """
    expected_number_of_iterations = len(dataloader)
    main_dataset: ImageMapDataset = dataloader.dataset

    if isinstance(gpu_transforms, Iterable):
        gpu_transforms = gpu_transforms
    else:
        gpu_transforms = [gpu_transforms]

    running_average = None
    running_variance = None

    mean_handler = WBOMean(dim=0)

    # HDR val and std image computations.
    for index_batch, val_batch, std_batch, meta_batch in tqdm(dataloader, desc="Batches processed", total=expected_number_of_iterations):

        # Stage tensors on correct device.
        images = val_batch.to(device=device)
        stds = std_batch.to(device) if std_batch is not None else None
        exposures = meta_batch['exposure_time'].to(device=device)

        # Run GPU transforms.
        for transform in gpu_transforms:
            if transform is not None:
                images = transform(images)

        # Set gradient requirements.
        images.requires_grad_(stds is not None)

        dark_field_val, dark_field_std = None, None
        if dark_field_dataset is not None:

            frame_settings_in_batch = []
            for index in index_batch:
                frame_settings_in_batch.append(main_dataset.files[index])

            _, dark_field_val, dark_field_std, _ = dark_field_dataset.get_matching_artefact_images(frame_settings_in_batch)

            if dark_field_val is not None:
                dark_field_val, dark_field_std = dark_field_val.to(device=device), dark_field_std.to(device=device)

                dark_field_val.requires_grad_(dark_field_std is not None)

                with torch.set_grad_enabled(dark_field_std is not None):
                    images = gf.conditional_gaussian_blur(images, dark_field_val, threshold=0.05, kernel_size=3,
                                                          differentiable=True)

        # Compute weights
        batch_weight = torch.ones_like(images) if weight_fn is None else gaussian_value_weights(images)

        exposures_view = exposures.view(-1, 1, 1, 1)

        def linearize(imgs):
            return icrf_model(imgs) if icrf_model else imgs

        with torch.set_grad_enabled(stds is not None):
            linearized = linearize(images) / exposures_view

        running_average = mean_handler.update_values(linearized, batch_weight)

        if stds is not None:
            running_gradient = torch.autograd.grad(
                outputs=running_average,
                inputs=images,
                grad_outputs=torch.ones_like(running_average),
                retain_graph=True
            )[0]
            variance_update = torch.sum((running_gradient * stds) ** 2, dim=0, keepdim=True)
            running_variance = variance_update if running_variance is None else running_variance + variance_update

        if dark_field_std is not None:
            running_gradient = torch.autograd.grad(
                outputs=running_average,
                inputs=dark_field_val,
                grad_outputs=torch.ones_like(running_average),
                retain_graph=False
            )[0]

            variance_update = torch.sum((running_gradient * dark_field_std) ** 2, dim=0, keepdim=True)
            running_variance = running_variance + variance_update

        mean_handler.internal_detach()

    # Flatfield corrections.
    if flat_field_dataset is not None:

        _, flatfield_val, flatfield_std, _ = flat_field_dataset.get_matching_artefact_images([main_dataset.files[0]])
        flatfield_val, flatfield_std = flatfield_val.to(device=device), flatfield_std.to(device=device)

        if flatfield_std is not None:
            flatfield_val.requires_grad_(True)

        flatfield_mean = gf.flat_field_mean(flatfield_val, 1.0)

        flatfield_corrected_running_average = gf.flatfield_correction(running_average, flatfield_val, flatfield_mean)

        if flatfield_std is not None:
            running_gradient = torch.autograd.grad(
                outputs=flatfield_corrected_running_average,
                inputs=flatfield_val,
                grad_outputs=torch.ones_like(flatfield_corrected_running_average),
                retain_graph=False
            )[0]

            running_variance = running_variance + (running_gradient * flatfield_std) ** 2

        running_average = flatfield_corrected_running_average

    return running_average.squeeze(), torch.sqrt(running_variance.squeeze()) if running_variance is not None else None

compute_video_mean_and_std(dataloader, device, icrf_model=None)

Function for computing the mean and standard deviation of the frames in a given dataset of video files. All frames in all the videos are treated as belonging in the same dataset. Args: dataloader: a DataLoader containing a VideoIterableDataset, representing the dataset. device: torch device to run the computation on. icrf_model: an ICRF model to optionally linearize the pixel values before computing the mean and std.

Returns:

Type Description
tuple[Tensor, Tensor]

The mean and standard deviations of the (possibly linearized) video frames in the dataset.

Source code in clair_torch/inference/inferential_statistics.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@typechecked
def compute_video_mean_and_std(dataloader: DataLoader, device: str | torch.device,
                               icrf_model: Optional[ICRFModelBase] = None) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Function for computing the mean and standard deviation of the frames in a given dataset of video files. All frames
    in all the videos are treated as belonging in the same dataset.
    Args:
        dataloader: a DataLoader containing a VideoIterableDataset, representing the dataset.
        device: torch device to run the computation on.
        icrf_model: an ICRF model to optionally linearize the pixel values before computing the mean and std.

    Returns:
        The mean and standard deviations of the (possibly linearized) video frames in the dataset.
    """
    total_iterations = len(dataloader)

    mean_handler = WBOMeanVar(dim=0, variance_mode=VarianceMode.SAMPLE_FREQUENCY)
    number_of_frames = 0

    with torch.inference_mode():
        for idx_batch, val_batch, std_batch, meta_batch in tqdm(dataloader, desc="Number of batches processed",
                                                                total=total_iterations):

            frames = val_batch.to(device=device)
            number_of_frames += frames.shape[0]

            if icrf_model:
                frames = icrf_model(frames)

            mean_handler.update_values(frames, None)

    return mean_handler.mean.squeeze(), torch.sqrt(mean_handler.variance().squeeze()) / math.sqrt(number_of_frames)

linearize_dataset_generator(dataloader, device, icrf_model, flatfield_dataset=None, gpu_transforms=None, dark_field_dataset=None)

Generator function to yield single linearized image and its possible associated uncertainty. Args: dataloader: Torch dataloader with custom collate function. device: the device on which to run the linearization. icrf_model: the ICRF model used to linearize the images. flatfield_dataset: An FlatFieldArtefactMapDataset, used to select the appropriate flat field correction image for each linearized image. gpu_transforms: transform operations to run on each image as the first thing in the process. dark_field_dataset: An DarkFieldArtefactMapDataset, used to select the appropriate dark field correction image for each linearized image.

Returns:

Type Description
None

A generator object yielding a tuple of image, uncertainty image and metadata dictionary.

Source code in clair_torch/inference/linearization.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@typechecked
def linearize_dataset_generator(
        dataloader: DataLoader,
        device: str | torch.device,
        icrf_model: ICRFModelBase,
        flatfield_dataset: Optional[FlatFieldArtefactMapDataset] = None,
        gpu_transforms: Optional[BaseTransform | Sequence[BaseTransform]] = None,
        dark_field_dataset: Optional[DarkFieldArtefactMapDataset] = None
) -> Generator[tuple[torch.Tensor, torch.Tensor, dict], None, None]:
    """
    Generator function to yield single linearized image and its possible associated uncertainty.
    Args:
        dataloader: Torch dataloader with custom collate function.
        device: the device on which to run the linearization.
        icrf_model: the ICRF model used to linearize the images.
        flatfield_dataset: An FlatFieldArtefactMapDataset, used to select the appropriate flat field correction image
            for each linearized image.
        gpu_transforms: transform operations to run on each image as the first thing in the process.
        dark_field_dataset: An DarkFieldArtefactMapDataset, used to select the appropriate dark field correction image
            for each linearized image.

    Returns:
        A generator object yielding a tuple of image, uncertainty image and metadata dictionary.
    """
    main_dataset: ImageMapDataset = dataloader.dataset

    if not dataloader.batch_size == 1:
        raise ValueError("For linearization only batch_size of 1 is allowed.")

    if isinstance(gpu_transforms, Iterable):
        gpu_transforms = list(gpu_transforms)
    else:
        gpu_transforms = [gpu_transforms] if gpu_transforms else []

    flatfield_val, flatfield_std = None, None
    if flatfield_dataset is not None:

        _, flatfield_val, flatfield_std, _ = flatfield_dataset.get_matching_artefact_images([main_dataset.files[0]])
        flatfield_val = flatfield_val.to(device=device)
        flatfield_mean = gf.flat_field_mean(flatfield_val, 1.0)
        flatfield_std = flatfield_std.to(device=device) if flatfield_std is not None else None
        flatfield_val.requires_grad_(True)

    for i, (index_batch, val_batch, std_batch, meta_batch) in enumerate(dataloader):

        # Stage tensors.
        images = val_batch.to(device)
        stds = std_batch.to(device) if std_batch is not None else None

        # Run GPU transforms
        for transform in gpu_transforms:
            if transform is not None:
                images = transform(images)

        # Stage gradient usage.
        images.requires_grad_(stds is not None)

        # Apply dark field correction.
        dark_field_val, dark_field_std = None, None
        if dark_field_dataset is not None:

            frame_settings_in_batch = []
            for index in index_batch:
                frame_settings_in_batch.append(main_dataset.files[index])

            _, dark_field_val, dark_field_std, _ = dark_field_dataset.get_matching_artefact_images(
                frame_settings_in_batch)

            if dark_field_val is not None:
                dark_field_val, dark_field_std = dark_field_val.to(device=device), dark_field_std.to(device=device)

                dark_field_val.requires_grad_(dark_field_std is not None)

                with torch.set_grad_enabled(dark_field_std is not None):
                    images = gf.conditional_gaussian_blur(images, dark_field_val, threshold=0.05, kernel_size=3,
                                                          differentiable=True)

        # Linearize.
        with torch.set_grad_enabled(stds is not None):
            linearized = icrf_model(images)

        # Compute uncertainty of the linearization.
        running_variance = torch.zeros_like(images)
        if stds is not None:
            running_gradient = torch.autograd.grad(
                outputs=linearized,
                inputs=images,
                grad_outputs=torch.ones_like(linearized),
                retain_graph=True
            )[0]
            running_variance = running_variance + (running_gradient * stds) ** 2

        # Compute uncertainty of the dark field correction.
        if dark_field_std is not None:
            running_gradient = torch.autograd.grad(
                outputs=linearized,
                inputs=dark_field_val,
                grad_outputs=torch.ones_like(linearized),
                retain_graph=False
            )[0]
            running_variance = running_variance + (running_gradient * dark_field_std) ** 2

        # Apply flat field correction to the linearized images.
        if flatfield_dataset is not None:
            linearized = gf.flatfield_correction(linearized, flatfield_val, flatfield_mean)

        # Compute the uncertainty of the flat field correction.
        if flatfield_std is not None:
            running_gradient = torch.autograd.grad(
                outputs=linearized,
                inputs=flatfield_val,
                grad_outputs=torch.ones_like(linearized),
                retain_graph=False
            )[0]
            running_variance = running_variance + (running_gradient * flatfield_std) ** 2

        yield linearized.squeeze().detach().cpu(), torch.sqrt(running_variance).squeeze().detach().cpu(), meta_batch