Skip to content

Models

This page acts as the technical reference for the models subpackage.

The models package provides the inverse camera response function model classes, which are built on the torch.nn.module class. The base model provides a guideline for creating a new ICRF model, while the concrete implementations provide different approaches to modelling an ICRF.

ICRFModelBase

Bases: Module, ABC

Base class for the ICRF model classes. Implements common functionality and acts as a guideline for implementing the model interface.

Attributes

_n_points: int how many datapoints the range [0, 1] is split into in modelling the ICRF. _channels: int how many color channels are managed by the model. One ICRF curve for each channel. interpolation_mode: InterpMode enum determining how the forward call of the model is handled. See InterpMode doc for more. _initial_power: float a guess at the initial form of the ICRF curve, represented by raising the linear range [0, 1] to this power. _fig: Figure a matplotlib Figure used for visualizing the model. _axs: List[Axes] a list of matplotlib axes used for model visualization. _lines_curve: List[Line2D] a list of the matplotlib Line2D objects for the plotted curves used in model visualization. _lines_deriv: List[Line2D] a list of the matplotlib Line2D objects for the plotted curves' derivatives used in model visualization.

Source code in clair_torch/models/base.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class ICRFModelBase(nn.Module, ABC):
    """
    Base class for the ICRF model classes. Implements common functionality and acts as a guideline for implementing the
    model interface.

    Attributes
    ----------
    _n_points: int
        how many datapoints the range [0, 1] is split into in modelling the ICRF.
    _channels: int
        how many color channels are managed by the model. One ICRF curve for each channel.
    interpolation_mode: InterpMode
        enum determining how the forward call of the model is handled. See InterpMode doc for more.
    _initial_power: float
        a guess at the initial form of the ICRF curve, represented by raising the linear range [0, 1] to this power.
    _fig: Figure
        a matplotlib Figure used for visualizing the model.
    _axs: List[Axes]
        a list of matplotlib axes used for model visualization.
    _lines_curve: List[Line2D]
        a list of the matplotlib Line2D objects for the plotted curves used in model visualization.
    _lines_deriv: List[Line2D]
        a list of the matplotlib Line2D objects for the plotted curves' derivatives used in model visualization.
    """
    @typechecked
    def __init__(self, n_points: Optional[int] = 256, channels: Optional[int] = 3,
                 interpolation_mode: InterpMode = InterpMode.LINEAR, initial_power: float = 2.5,
                 icrf: Optional[torch.Tensor] = None):
        """
        Initializes the ICRF model instance with the given parameters. The icrf argument overrides n_points and channels
        by its shape if given.
        Args:
            n_points: how many datapoints the range [0, 1] is split into in modelling the ICRF.
            channels: how many datapoints the range [0, 1] is split into in modelling the ICRF.
            interpolation_mode: enum determining how the forward call of the model is handled. See InterpMode doc for
                more.
            initial_power: a guess at the initial form of the ICRF curve, represented by raising the linear range [0, 1]
                to this power.
            icrf: an optional initial form of the ICRF curves, overrides n_points and channels.
        """
        super().__init__()

        # If an ICRF curve is given, its shape will override the given n_points and channels parameters.
        if icrf is not None:
            channels, n_points = icrf.shape

        self._channels = channels
        self._initial_power = initial_power
        self._n_points = n_points
        self.register_buffer("_x_axis_datapoints", torch.linspace(0, 1, n_points))

        if icrf is None:
            icrf = self._initialize_default_icrf()
        self.register_buffer("_icrf", icrf)

        self.interpolation_mode = interpolation_mode

        self._dispatch = {
            InterpMode.LOOKUP: self._forward_lookup,
            InterpMode.LINEAR: self._forward_linear,
            InterpMode.CATMULL: self._forward_catmull,
        }

        if self.interpolation_mode not in self._dispatch:
            raise ValueError(f'Unknown interpolation mode {self.interpolation_mode}')

        self._fig = None
        self._axs = None
        self._lines_curve = None
        self._lines_deriv = None

    @property
    def icrf(self):
        return self._icrf

    @property
    def channels(self):
        return self._channels

    @property
    def n_points(self):
        return self._n_points

    @property
    def initial_power(self):
        return self._initial_power

    @property
    def x_axis_datapoints(self):
        return self._x_axis_datapoints

    @abstractmethod
    def channel_params(self, c: int) -> list[nn.Parameter]:
        """
        Method for getting the model parameters for the channel of the given index. Subclasses implement the logic
        based on their model parameters. This should be the main method of accessing the optimization parameters for
        feeding them to a torch.Optimizer.
        Args:
            c: channel index.

        Returns:
            list of nn.Parameters.
        """

    @abstractmethod
    def update_icrf(self) -> None:
        """
        Method for constructing a new ICRF curve from the model parameters and updating the curve to the
        self._icrf attribute. Subclasses implement the logic based their model parameters.
        """

    def _initialize_default_icrf(self) -> torch.Tensor:
        """
        Initializes a default ICRF curve if None is given, based on the number of datapoints, channels and the value
        of the initial power.
        """
        return torch.transpose(torch.linspace(0, 1, self.n_points).unsqueeze(1).repeat(1, self.channels) ** self.initial_power, 0, 1)

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        return self._dispatch[self.interpolation_mode](image)

    def _forward_lookup(self, image: torch.Tensor) -> torch.Tensor:
        """
        Nearest-neighbour table lookup (no gradients).
        image: (N, C, H, W) with values in [0, 1]
        self.icrf: (L, C) first dim = curve sample, second dim = channel
        """
        N, C, H, W = image.shape
        L = self.icrf.shape[1]

        # integer sample index per pixel
        idx = (image * (L - 1)).round().clamp(0, L - 1).long()  # (N, C, H, W)

        # matching channel index tensor
        chan = (
            torch.arange(C, device=image.device)
            .view(1, C, 1, 1)
            .expand(N, C, H, W)  # same shape as idx
        )

        # advanced indexing: returns (N, C, H, W)
        return self.icrf[chan, idx]

    def _forward_linear(self, image: torch.Tensor) -> torch.Tensor:
        # image : (N, C, H, W) in [0, 1]
        N, C, H, W = image.shape
        L = self.icrf.size(1)

        # scale pixel values to LUT index range [0, L-1]
        x = (image * (L - 1)).clamp_(0, L - 1)

        x0 = x.floor().long()  # lower index
        x1 = (x0 + 1).clamp_(0, L - 1)  # upper index
        w = (x - x0.float())  # weight  (N,C,H,W)

        # helper that gathers LUT values for arbitrary index tensor
        def gather(ix):
            flat_ix = ix.reshape(-1)  # safe flatten
            channels = torch.arange(C, device=ix.device).repeat(N * H * W)
            return self.icrf[channels, flat_ix].reshape(N, C, H, W)  # safe reshape

        g0 = gather(x0)
        g1 = gather(x1)

        # linear interpolation
        return g0 * (1.0 - w) + g1 * w

    def _forward_catmull(self, image: torch.Tensor) -> torch.Tensor:
        # Assume image shape: (N, C, H, W), values in [0, 1]
        N, C, H, W = image.shape
        L = self.icrf.shape[1]  # number of ICRF samples

        # Scale input to [0, L - 1]
        x = (image * (L - 1)).clamp(0, L - 1)

        # Compute indices for Catmull-Rom interpolation
        x0 = x.floor().long()
        x_indices = [  # Collect 4 neighboring indices: x-1, x0, x0+1, x0+2
            (x0 - 1).clamp(0, L - 1),
            x0.clamp(0, L - 1),
            (x0 + 1).clamp(0, L - 1),
            (x0 + 2).clamp(0, L - 1),
        ]

        # Compute fractional part
        t = (x - x0.float()).clamp(0, 1)

        # Catmull-Rom basis weights
        t2 = t * t
        t3 = t2 * t

        w0 = -0.5 * t3 + t2 - 0.5 * t
        w1 = 1.5 * t3 - 2.5 * t2 + 1.0
        w2 = -1.5 * t3 + 2.0 * t2 + 0.5 * t
        w3 = 0.5 * t3 - 0.5 * t2

        weights = [w0, w1, w2, w3]  # each (N, C, H, W)

        # Broadcast gather: returns (N, C, H, W) for each index tensor
        def gather_idx(ix):
            flat_idx = ix.view(-1)
            channels = torch.arange(C).repeat(N * H * W)
            return self.icrf[channels, flat_idx].view(N, C, H, W)

        g = [gather_idx(ix) for ix in x_indices]

        # Final smooth interpolated result
        result = torch.stack([w * gi for w, gi in zip(weights, g)], dim=0).sum(dim=0)

        return result

    def _prepare_icrf_plot_data(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Internal utility function to prepare the model data for plotting.
        Returns:
            Tuple of NumPy arrays representing [x_values, y_values, dy/dx_values].
        """
        icrf_cpu = self.icrf.detach().cpu().numpy()
        x = np.linspace(0, 1, self.n_points)
        dy = np.diff(icrf_cpu, axis=1)
        dx = x[1] - x[0]
        dydx = dy / dx

        return x, icrf_cpu, dydx

    def plot_icrf(self) -> None:
        """
        Model utility function for live-plotting. The model class manages the state of the plot, while utilizing the
        general plotting function of the plotting module.
        """
        x, icrf_cpu, dydx = self._prepare_icrf_plot_data()

        if self._fig is None:
            plt.ion()  # Enable interactive mode for live plotting.
            self._fig, self._axs, self._lines_curve, self._lines_deriv = plot_data_and_diff(x, icrf_cpu, dydx)
        else:
            for c in range(self.channels):
                self._lines_curve[c].set_ydata(icrf_cpu[c, :])
                self._lines_deriv[c].set_ydata(dydx[c, :])

        self._fig.canvas.draw()
        self._fig.canvas.flush_events()
        plt.pause(0.01)

__init__(n_points=256, channels=3, interpolation_mode=InterpMode.LINEAR, initial_power=2.5, icrf=None)

Initializes the ICRF model instance with the given parameters. The icrf argument overrides n_points and channels by its shape if given. Args: n_points: how many datapoints the range [0, 1] is split into in modelling the ICRF. channels: how many datapoints the range [0, 1] is split into in modelling the ICRF. interpolation_mode: enum determining how the forward call of the model is handled. See InterpMode doc for more. initial_power: a guess at the initial form of the ICRF curve, represented by raising the linear range [0, 1] to this power. icrf: an optional initial form of the ICRF curves, overrides n_points and channels.

Source code in clair_torch/models/base.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@typechecked
def __init__(self, n_points: Optional[int] = 256, channels: Optional[int] = 3,
             interpolation_mode: InterpMode = InterpMode.LINEAR, initial_power: float = 2.5,
             icrf: Optional[torch.Tensor] = None):
    """
    Initializes the ICRF model instance with the given parameters. The icrf argument overrides n_points and channels
    by its shape if given.
    Args:
        n_points: how many datapoints the range [0, 1] is split into in modelling the ICRF.
        channels: how many datapoints the range [0, 1] is split into in modelling the ICRF.
        interpolation_mode: enum determining how the forward call of the model is handled. See InterpMode doc for
            more.
        initial_power: a guess at the initial form of the ICRF curve, represented by raising the linear range [0, 1]
            to this power.
        icrf: an optional initial form of the ICRF curves, overrides n_points and channels.
    """
    super().__init__()

    # If an ICRF curve is given, its shape will override the given n_points and channels parameters.
    if icrf is not None:
        channels, n_points = icrf.shape

    self._channels = channels
    self._initial_power = initial_power
    self._n_points = n_points
    self.register_buffer("_x_axis_datapoints", torch.linspace(0, 1, n_points))

    if icrf is None:
        icrf = self._initialize_default_icrf()
    self.register_buffer("_icrf", icrf)

    self.interpolation_mode = interpolation_mode

    self._dispatch = {
        InterpMode.LOOKUP: self._forward_lookup,
        InterpMode.LINEAR: self._forward_linear,
        InterpMode.CATMULL: self._forward_catmull,
    }

    if self.interpolation_mode not in self._dispatch:
        raise ValueError(f'Unknown interpolation mode {self.interpolation_mode}')

    self._fig = None
    self._axs = None
    self._lines_curve = None
    self._lines_deriv = None

channel_params(c) abstractmethod

Method for getting the model parameters for the channel of the given index. Subclasses implement the logic based on their model parameters. This should be the main method of accessing the optimization parameters for feeding them to a torch.Optimizer. Args: c: channel index.

Returns:

Type Description
list[Parameter]

list of nn.Parameters.

Source code in clair_torch/models/base.py
108
109
110
111
112
113
114
115
116
117
118
119
@abstractmethod
def channel_params(self, c: int) -> list[nn.Parameter]:
    """
    Method for getting the model parameters for the channel of the given index. Subclasses implement the logic
    based on their model parameters. This should be the main method of accessing the optimization parameters for
    feeding them to a torch.Optimizer.
    Args:
        c: channel index.

    Returns:
        list of nn.Parameters.
    """

plot_icrf()

Model utility function for live-plotting. The model class manages the state of the plot, while utilizing the general plotting function of the plotting module.

Source code in clair_torch/models/base.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def plot_icrf(self) -> None:
    """
    Model utility function for live-plotting. The model class manages the state of the plot, while utilizing the
    general plotting function of the plotting module.
    """
    x, icrf_cpu, dydx = self._prepare_icrf_plot_data()

    if self._fig is None:
        plt.ion()  # Enable interactive mode for live plotting.
        self._fig, self._axs, self._lines_curve, self._lines_deriv = plot_data_and_diff(x, icrf_cpu, dydx)
    else:
        for c in range(self.channels):
            self._lines_curve[c].set_ydata(icrf_cpu[c, :])
            self._lines_deriv[c].set_ydata(dydx[c, :])

    self._fig.canvas.draw()
    self._fig.canvas.flush_events()
    plt.pause(0.01)

update_icrf() abstractmethod

Method for constructing a new ICRF curve from the model parameters and updating the curve to the self._icrf attribute. Subclasses implement the logic based their model parameters.

Source code in clair_torch/models/base.py
121
122
123
124
125
126
@abstractmethod
def update_icrf(self) -> None:
    """
    Method for constructing a new ICRF curve from the model parameters and updating the curve to the
    self._icrf attribute. Subclasses implement the logic based their model parameters.
    """

ICRFModelDirect

Bases: ICRFModelBase

An ICRF model class that utilizes the datapoints directly as optimization parameters. Total number of parameters is therefore n_points * channels.

Attributes:

Inherits attributes from ICRFModelBase.

nn.ParameterList

a list of nn parameters, each parameter corresponds to an actual datapoint in the ICRF curve.

Source code in clair_torch/models/icrf_model.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class ICRFModelDirect(ICRFModelBase):
    """
    An ICRF model class that utilizes the datapoints directly as optimization parameters. Total number of parameters is
    therefore n_points * channels.

    Attributes:
    ----------
    Inherits attributes from ICRFModelBase.

    direct_params: nn.ParameterList
        a list of nn parameters, each parameter corresponds to an actual datapoint in the ICRF curve.
    """
    @typechecked
    def __init__(self, n_points: Optional[int] = 256, channels: Optional[int] = 3,
                 interpolation_mode: InterpMode = InterpMode.LINEAR, initial_power: float = 2.5,
                 icrf: Optional[torch.Tensor] = None):

        super().__init__(n_points, channels, interpolation_mode, initial_power, icrf)

        self.direct_params = nn.ParameterList([
            nn.Parameter(torch.linspace(0, 1, n_points) ** initial_power) for _ in range(channels)
        ])

    def channel_params(self, c: int):
        """
        Main method for accessing the optimization parameters of the model.
        Args:
            c: channel index.

        Returns:
            A list of nn.Parameters.
        """
        return [self.direct_params[c]]

    def update_icrf(self):
        """
        Directly stack the per-channel parameters to form the ICRF.
        """
        self._icrf = torch.stack([p for p in self.direct_params], dim=0)  # shape: (L, C)

channel_params(c)

Main method for accessing the optimization parameters of the model. Args: c: channel index.

Returns:

Type Description

A list of nn.Parameters.

Source code in clair_torch/models/icrf_model.py
112
113
114
115
116
117
118
119
120
121
def channel_params(self, c: int):
    """
    Main method for accessing the optimization parameters of the model.
    Args:
        c: channel index.

    Returns:
        A list of nn.Parameters.
    """
    return [self.direct_params[c]]

update_icrf()

Directly stack the per-channel parameters to form the ICRF.

Source code in clair_torch/models/icrf_model.py
123
124
125
126
127
def update_icrf(self):
    """
    Directly stack the per-channel parameters to form the ICRF.
    """
    self._icrf = torch.stack([p for p in self.direct_params], dim=0)  # shape: (L, C)

ICRFModelPCA

Bases: ICRFModelBase

ICRF model class that utilizes a set of principal components for the optimization process. For each principal component a single scalar optimization parameter is utilized, e.g. for 5 components there is a total of 5 parameters to optimize in the model. The shape of the given principal components is used to determine the number of datapoints, number of components and number of channels. In addition to these an exponent for the linear range [0, 1] is used as a parameter for each channel.

Attributes:

Inherits attributes from ICRFModelBase

nn.ParameterList

a list of nn parameters representing the exponent of the base curve. One element for each channel.

coefficients: nn.ParameterList a list of nn parameters representing the PCA parameters, for each channel a number equal to the number of components.

Source code in clair_torch/models/icrf_model.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class ICRFModelPCA(ICRFModelBase):
    """
    ICRF model class that utilizes a set of principal components for the optimization process. For each principal
    component a single scalar optimization parameter is utilized, e.g. for 5 components there is a total of 5 parameters
    to optimize in the model. The shape of the given principal components is used to determine the number of datapoints,
    number of components and number of channels. In addition to these an exponent for the linear range [0, 1] is used
    as a parameter for each channel.

    Attributes:
    -----------
    Inherits attributes from ICRFModelBase

    p: nn.ParameterList
        a list of nn parameters representing the exponent of the base curve. One element for each channel.
    coefficients: nn.ParameterList
        a list of nn parameters representing the PCA parameters, for each channel a number equal to the number of
            components.
    """
    @typechecked
    def __init__(self, pca_basis: torch.Tensor, interpolation_mode: InterpMode = InterpMode.LINEAR,
                 initial_power: float = 2.5, icrf: Optional[torch.Tensor] = None) -> None:
        """
        Initializes a ICRFModelPCA instance. The pca_basis is used to determine the shape of the actual ICRF.
        Args:
            pca_basis: a torch.Tensor representing the principal components. The shape is expected to be
                (n_points, num_components, channels).
            interpolation_mode: a InterpMode determining how the ICRF is used in a forward call.
            initial_power: a guess at the initial form of the ICRF curve, represented by raising the linear range [0, 1]
                to this power.
            icrf: an optional initial form of the ICRF curves, overrides n_points and channels.
        """
        n_points, num_components, channels = pca_basis.shape

        super().__init__(n_points, channels, interpolation_mode, initial_power, icrf)

        # Base curve power (learnable scalar), one parameter per channel.
        self.p = nn.ParameterList([nn.Parameter(torch.tensor(2.0)) for _ in range(channels)])

        # Model weights, num_components number per channel.
        self.coefficients = nn.ParameterList([nn.Parameter(torch.zeros(num_components)) for _ in range(channels)])

        self.register_buffer("pca_basis", pca_basis)
        self.register_buffer("x_values", torch.linspace(0, 1, n_points))  # (L,)

    def channel_params(self, c: int) -> list[nn.Parameter]:
        """
        Main method for accessing the optimization parameters of the model.
        Args:
            c: channel index.

        Returns:
            A list of nn.Parameters.
        """
        return [self.p[c], self.coefficients[c]]

    def update_icrf(self):
        """
        Builds the full ICRF curve (n_points, channels) from base + PCA.
        """
        p_tensor = torch.stack(list(self.p))
        coefficient_tensor = torch.stack(list(self.coefficients))

        # x_values: (L,)
        x_safe = self.x_values.clamp(min=1e-6).unsqueeze(1)  # (L, 1)

        # Compute per-channel base curve: (L, C)
        base_curve = x_safe.pow(p_tensor.unsqueeze(0))  # broadcasted over L

        # PCA component contribution: (L, C)
        pca_curve = (self.pca_basis * coefficient_tensor.T.unsqueeze(0)).sum(dim=1)

        # Combine base and PCA to update self.icrf
        self._icrf = base_curve + pca_curve  # shape: (L, C)

__init__(pca_basis, interpolation_mode=InterpMode.LINEAR, initial_power=2.5, icrf=None)

Initializes a ICRFModelPCA instance. The pca_basis is used to determine the shape of the actual ICRF. Args: pca_basis: a torch.Tensor representing the principal components. The shape is expected to be (n_points, num_components, channels). interpolation_mode: a InterpMode determining how the ICRF is used in a forward call. initial_power: a guess at the initial form of the ICRF curve, represented by raising the linear range [0, 1] to this power. icrf: an optional initial form of the ICRF curves, overrides n_points and channels.

Source code in clair_torch/models/icrf_model.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@typechecked
def __init__(self, pca_basis: torch.Tensor, interpolation_mode: InterpMode = InterpMode.LINEAR,
             initial_power: float = 2.5, icrf: Optional[torch.Tensor] = None) -> None:
    """
    Initializes a ICRFModelPCA instance. The pca_basis is used to determine the shape of the actual ICRF.
    Args:
        pca_basis: a torch.Tensor representing the principal components. The shape is expected to be
            (n_points, num_components, channels).
        interpolation_mode: a InterpMode determining how the ICRF is used in a forward call.
        initial_power: a guess at the initial form of the ICRF curve, represented by raising the linear range [0, 1]
            to this power.
        icrf: an optional initial form of the ICRF curves, overrides n_points and channels.
    """
    n_points, num_components, channels = pca_basis.shape

    super().__init__(n_points, channels, interpolation_mode, initial_power, icrf)

    # Base curve power (learnable scalar), one parameter per channel.
    self.p = nn.ParameterList([nn.Parameter(torch.tensor(2.0)) for _ in range(channels)])

    # Model weights, num_components number per channel.
    self.coefficients = nn.ParameterList([nn.Parameter(torch.zeros(num_components)) for _ in range(channels)])

    self.register_buffer("pca_basis", pca_basis)
    self.register_buffer("x_values", torch.linspace(0, 1, n_points))  # (L,)

channel_params(c)

Main method for accessing the optimization parameters of the model. Args: c: channel index.

Returns:

Type Description
list[Parameter]

A list of nn.Parameters.

Source code in clair_torch/models/icrf_model.py
58
59
60
61
62
63
64
65
66
67
def channel_params(self, c: int) -> list[nn.Parameter]:
    """
    Main method for accessing the optimization parameters of the model.
    Args:
        c: channel index.

    Returns:
        A list of nn.Parameters.
    """
    return [self.p[c], self.coefficients[c]]

update_icrf()

Builds the full ICRF curve (n_points, channels) from base + PCA.

Source code in clair_torch/models/icrf_model.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def update_icrf(self):
    """
    Builds the full ICRF curve (n_points, channels) from base + PCA.
    """
    p_tensor = torch.stack(list(self.p))
    coefficient_tensor = torch.stack(list(self.coefficients))

    # x_values: (L,)
    x_safe = self.x_values.clamp(min=1e-6).unsqueeze(1)  # (L, 1)

    # Compute per-channel base curve: (L, C)
    base_curve = x_safe.pow(p_tensor.unsqueeze(0))  # broadcasted over L

    # PCA component contribution: (L, C)
    pca_curve = (self.pca_basis * coefficient_tensor.T.unsqueeze(0)).sum(dim=1)

    # Combine base and PCA to update self.icrf
    self._icrf = base_curve + pca_curve  # shape: (L, C)