Specific IMNN classes

SimulatorIMNN

SimulatorIMNN

class imnn.SimulatorIMNN(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, simulator)

Information maximising neural network fit with simulations on-the-fly

Defines the function to get simulations and compress them using an XLA compilable simulator.

The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) random number generators are generated and used to generate a set of \(n_s\) simulations, \({\bf d}^i={\rm simulator}({\rm seed}^i, \theta^\rm{fid})\) at the fiducial model parameters, \(\theta^\rm{fid}\), and these are passed direrectly through a network \(f_{{\bf w}}({\bf d})\) with network parameters \({\bf w}\) to obtain network outputs \({\bf x}^i\) and autodifferentiation is used to get the derivative of \(n_d\) of these outputs with respect to the physical model parameters, \(\partial{{\bf x}^i}/\partial\theta_\alpha\), where \(\alpha\) labels the physical parameter. With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance

\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]

and the derivative of the mean of the network outputs with respect to the model parameters

\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]

can be calculated and used form the Fisher information matrix

\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]

The loss function is then defined as

\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]

Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as

\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]

Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale

\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]

Once the loss function is calculated the automatic gradient is then calculated and used to update the network parameters via the optimiser function.

simulator:

A function for generating a simulation on-the-fly (XLA compilable)

Public Methods:

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

get_summaries(w, key[, validate])

Gets all network outputs and derivatives wrt model parameters

Inherited from _IMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

fit(λ, ε[, rng, patience, min_iterations, …])

Fitting routine for the IMNN

get_α(λ, ε)

Calculate rate parameter for regularisation from closeness criterion

set_F_statistics([w, key, validate])

Set necessary attributes for calculating score compressed summaries

get_summaries(w, key[, validate])

Gets all network outputs and derivatives wrt model parameters

get_estimate(d)

Calculate score compressed parameter estimates from network outputs

plot([ax, expected_detF, colour, figsize, …])

Plot fitting history

Private Methods:

_get_fitting_keys(rng)

Generates random numbers for simulation

Inherited from _IMNN

_initialise_parameters(n_s, n_d, n_params, …)

Performs type checking and initialisation of class attributes

_initialise_model(model, optimiser, key_or_state)

Initialises neural network parameters or loads optimiser state

_initialise_history()

Initialises history dictionary attribute

_set_history(results)

Places results from fitting into the history dictionary

_set_inputs(rng, max_iterations)

Builds list of inputs for the XLA compilable fitting routine

_get_fitting_keys(rng)

Generates random numbers for simulation

_fit(inputs, λ=None, α=None[, min_iterations])

Single iteration fitting algorithm

_fit_cond(inputs, patience, max_iterations)

Stopping condition for the fitting loop

_update_loop_vars(inputs)

Updates input parameters if max_detF is increased

_check_loop_vars(inputs, min_iterations)

Updates patience_counter if max_detF not increased

_update_history(inputs, history, counter, ind)

Puts current fitting statistics into history arrays

_slogdet(matrix)

Combined summed logarithmic determinant

_construct_derivatives(derivatives)

Builds derivatives of the network outputs wrt model parameters

_get_F_statistics([w, key, validate])

Calculates the Fisher information and returns all statistics used

_calculate_F_statistics(summaries, derivatives)

Calculates the Fisher information matrix from network outputs

_get_regularisation_strength(Λ2, λ, α)

Coupling strength of the regularisation (amplified sigmoid)

_get_regularisation(C, invC)

Difference of the covariance (and its inverse) from identity

_get_loss(w, λ, α[, key])

Calculates the loss function and returns auxillary variables

_calculate_loss(summaries, derivatives, λ, α)

Calculates the loss function from network summaries and derivatives

_setup_plot([ax, expected_detF, figsize])

Builds axes for history plot


_get_fitting_keys(rng)

Generates random numbers for simulation

Parameters

rng (int(2,)) – A random number generator

Returns

A new random number generator and random number generators for fitting (and validation)

Return type

int(2,), int(2,), int(2,)

get_summaries(w, key, validate=False)

Gets all network outputs and derivatives wrt model parameters

A random seed for each simulation is obtained and n_d of them are used to calculate the network outputs of each of these simulations as well as the derivative of these network outputs with respect to the model parameters as calculated using jax autodifferentiation. The remaining n_s - n_d network outputs are then calculated and concatenated to those already calculated.

Parameters
  • w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters

  • key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly

  • validate (bool, default=False) – Whether to get summaries of the validation set

Returns

  • float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance

  • float(n_d, n_summaries, n_params) – The set of all network output derivatives wrt model parameters

get_summary:

Return a single network output

get_derivatives:

Return the Jacobian of the network outputs wrt model parameters

AggregatedSimulatorIMNN

class imnn.AggregatedSimulatorIMNN(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, simulator, host, devices, n_per_device)

Information maximising neural network fit with simulations on-the-fly

Defines the function to get simulations and compress them using an XLA compilable simulator.

The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) random number generators are generated and used to generate a set of \(n_s\) simulations, \({\bf d}^i={\rm simulator}({\rm seed}^i, \theta^\rm{fid})\) at the fiducial model parameters, \(\theta^\rm{fid}\), and these are passed direrectly through a network \(f_{{\bf w}}({\bf d})\) with network parameters \({\bf w}\) to obtain network outputs \({\bf x}^i\) and autodifferentiation is used to get the derivative of \(n_d\) of these outputs with respect to the physical model parameters, \(\partial{{\bf x}^i}/\partial\theta_\alpha\), where \(\alpha\) labels the physical parameter. With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance

\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]

and the derivative of the mean of the network outputs with respect to the model parameters

\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]

can be calculated and used form the Fisher information matrix

\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]

The loss function is then defined as

\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]

Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as

\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]

Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale

\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]

To enable the use of large data (or networks) the whole procedure is aggregated. This means that the generation and passing of the simulations through the network is farmed out to the desired XLA devices, and recollected, n_per_device inputs at a time. These are then used to calculate the automatic gradient of the loss function with respect to the calculated summaries and derivatives, \(\partial\Lambda/\partial{\bf x}^i\) (which is a fairly small computation as long as n_summaries and n_s {and n_d} are not huge). Once this is calculated, the simulations are passed through the network AGAIN this time calculating the Jacobian of the network output with respect to the network parameters \(\partial{\bf x}^i/\partial{\bf w}\) which is then combined via the chain rule to get

\[\frac{\partial\Lambda}{\partial{\bf w}} = \frac{\partial\Lambda}{\partial{\bf x}^i} \frac{\partial{\bf x}^i}{\partial{\bf w}}\]

This can then be passed to the optimiser.

Parameters
  • n_remaining (int) – The number simulations where only the fiducial simulations are calculated. This is zero if n_s is equal to n_d.

  • n_iterations (int) – Number of iterations through the main summarising loop

  • n_remaining_iterations (int) – Number of iterations through the remaining simulations used for quick loops with no derivatives

  • batch_shape (tuple) – The shape which n_d should be reshaped to for aggregating. n_d // (n_devices * n_per_device), n_devices, n_per_device

  • remaining_batch_shape (tuple) – The shape which n_s - n_d should be reshaped to for aggregating. (n_s - n_d) // (n_devices * n_per_device), n_devices, n_per_device

simulator:

A function for generating a simulation on-the-fly (XLA compilable)

Public Methods:

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

get_summary(input, w, θ[, derivative, gradient])

Returns a single summary of a simulation or its gradient

get_summaries(w, key[, validate])

Gets all network outputs and derivatives wrt model parameters

get_gradient(dΛ_dx, w[, key])

Aggregates gradients together to update the network parameters

Inherited from _AggregatedIMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

fit(λ, ε[, rng, patience, min_iterations, …])

Fitting routine for the IMNN

get_summaries(w, key[, validate])

Gets all network outputs and derivatives wrt model parameters

get_gradient(dΛ_dx, w[, key])

Aggregates gradients together to update the network parameters

Inherited from SimulatorIMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

get_summaries(w, key[, validate])

Gets all network outputs and derivatives wrt model parameters

Inherited from _IMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

fit(λ, ε[, rng, patience, min_iterations, …])

Fitting routine for the IMNN

get_α(λ, ε)

Calculate rate parameter for regularisation from closeness criterion

set_F_statistics([w, key, validate])

Set necessary attributes for calculating score compressed summaries

get_summaries(w, key[, validate])

Gets all network outputs and derivatives wrt model parameters

get_estimate(d)

Calculate score compressed parameter estimates from network outputs

plot([ax, expected_detF, colour, figsize, …])

Plot fitting history

Private Methods:

_set_shapes()

Calculates the shapes for batching over different devices

_collect_input(key[, validate])

Returns the keys for generating simulations on-the-fly

_split_dΛ_dx(dΛ_dx)

Returns the gradient of loss function wrt summaries (derivatives)

Inherited from _AggregatedIMNN

_set_devices(devices, n_per_device)

Checks that devices exist and that reshaping onto devices can occur

_set_batch_functions()

Creates jitted functions placed on desired XLA devices

_set_shapes()

Calculates the shapes for batching over different devices

_setup_progress_bar(print_rate, max_iterations)

Construct progress bar

_update_progress_bar(pbar, counter, …[, close])

Updates (and closes) progress bar

_collect_input(key[, validate])

Returns the keys for generating simulations on-the-fly

_get_batch_summaries(inputs, w, θ[, …])

Vectorised batch calculation of summaries or gradients

_split_dΛ_dx(dΛ_dx)

Returns the gradient of loss function wrt summaries (derivatives)

_construct_gradient(layers[, aux, func])

Multiuse function to iterate over tuple of network parameters

Inherited from SimulatorIMNN

_get_fitting_keys(rng)

Generates random numbers for simulation

Inherited from _IMNN

_initialise_parameters(n_s, n_d, n_params, …)

Performs type checking and initialisation of class attributes

_initialise_model(model, optimiser, key_or_state)

Initialises neural network parameters or loads optimiser state

_initialise_history()

Initialises history dictionary attribute

_set_history(results)

Places results from fitting into the history dictionary

_set_inputs(rng, max_iterations)

Builds list of inputs for the XLA compilable fitting routine

_get_fitting_keys(rng)

Generates random numbers for simulation

_fit(inputs, λ=None, α=None[, min_iterations])

Single iteration fitting algorithm

_fit_cond(inputs, patience, max_iterations)

Stopping condition for the fitting loop

_update_loop_vars(inputs)

Updates input parameters if max_detF is increased

_check_loop_vars(inputs, min_iterations)

Updates patience_counter if max_detF not increased

_update_history(inputs, history, counter, ind)

Puts current fitting statistics into history arrays

_slogdet(matrix)

Combined summed logarithmic determinant

_construct_derivatives(derivatives)

Builds derivatives of the network outputs wrt model parameters

_get_F_statistics([w, key, validate])

Calculates the Fisher information and returns all statistics used

_calculate_F_statistics(summaries, derivatives)

Calculates the Fisher information matrix from network outputs

_get_regularisation_strength(Λ2, λ, α)

Coupling strength of the regularisation (amplified sigmoid)

_get_regularisation(C, invC)

Difference of the covariance (and its inverse) from identity

_get_loss(w, λ, α[, key])

Calculates the loss function and returns auxillary variables

_calculate_loss(summaries, derivatives, λ, α)

Calculates the loss function from network summaries and derivatives

_setup_plot([ax, expected_detF, figsize])

Builds axes for history plot


_collect_input(key, validate=False)

Returns the keys for generating simulations on-the-fly

Parameters
  • key (None or int(2,)) – Random number generators for generating simulations

  • validate (bool, default=False) – Whether to return the set for validation or for fitting (always False)

_set_shapes()

Calculates the shapes for batching over different devices

Raises

ValueError – If the difference between n_s and n_d won’t scale over xla devices

_split_dΛ_dx(dΛ_dx)

Returns the gradient of loss function wrt summaries (derivatives)

The gradient of loss function with respect to network outputs and their derivatives with respect to model parameters has to be reshaped and aggregated onto each XLA device matching the format keys are generated for generating simulations.

Parameters

dΛ_dx (tuple) –

  • dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries

  • d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters

get_gradient(dΛ_dx, w, key=None)

Aggregates gradients together to update the network parameters

To avoid having to calculate the gradient with respect to all the simulations at once we aggregate by addition the gradient calculation by looping over the keys for generating the simulations again which we then combine with the derivative of the loss function with respect to the network outputs (and their derivatives with respect to the model parameters). Whilst this is expensive, it is necessary since we cannot make a stochastic estimate of the Fisher information accurately and therefore we need to use all the simulations available - which is probably too large to fit in memory.

Parameters
  • dΛ_dx (tuple) –

    • dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries

    • d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters

  • w (list) – Network parameters

  • key (None or int(2,)) – Random number generator

Returns

The gradient of the loss function with respect to the network parameters calculated by aggregating

Return type

list

get_summaries(w, key, validate=False)

Gets all network outputs and derivatives wrt model parameters

Loops through the generated random number generators on each XLA device to generate a simulation and then pass them through the network to get the network outputs. These are then pushed back to the host for the computation of the loss function.

The fiducial simulations which have a derivative with respect to the model parameters counterpart are processed first and then the remaining fiducial simulations are processed and concatenated.

Parameters
  • w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters

  • key (int(2,)) – A random number generator for generating simulations on-the-fly

  • validate (bool, default=False) – Whether to get summaries of the validation set (always False)

Returns

  • float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance

  • float(n_d, n_summaries, n_params) – The set of all the derivatives of the network outputs with respect to model parameters

get_summary(input, w, θ, derivative=False, gradient=False)

Returns a single summary of a simulation or its gradient

Parameters
  • input (float(input_shape) or tuple) –

    A random number generator for a single simulation to pass through the network or a tuple of either (if gradient and not derivative)

    • dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary

    • key int(2,) – A random number generator for a single simulation

    or (if gradient and derivative)

    • tuple (gradients)
      • dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary

      • d2Λ_dxdθ float(input_shape, n_params) – the derivative of the loss function with respect to the derivative of a network summary with respect to model parameters

    • key int(2,) – A random number generator for a single simulation

  • w (list) – The network parameters

  • θ (float(n_params,)) – The value of the parameters to generate the simulation at (fiducial), unused if not simulating on the fly

  • derivative (bool, default=False) – Whether a derivative of the simulation with respect to model parameters is also passed.

  • gradient (bool, default=False) – Whether to calculate the gradient with respect to model parameters

Returns

  • tuple (if gradient) – The gradient of the loss function with respect to model parameters

  • float(n_summaries,) (if not gradient) – The output of the network

  • float(n_summaries, n_params) (if derivative and not gradient) – The derivative of the output of the network wrt model parameters

GradientIMNN

GradientIMNN

class imnn.GradientIMNN(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, fiducial, derivative, validation_fiducial=None, validation_derivative=None)

Information maximising neural network fit using known derivatives

The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations \({\bf d}^i\) originally generated at fiducial model parameter \({\bf\theta}^\rm{fid}\), and their derivatives \(\partial{\bf d}^i/\partial\theta_\alpha\) with respect to model parameters are used. The fiducial simulations, \({\bf d}^i\), are passed through a network to obtain summaries, \({\bf x}^i\), and the jax automatic derivative of these summaries with respect to the inputs are calculated \(\partial{\bf x}^i\partial{\bf d}^j\delta_{ij}\). The chain rule is then used to calculate

\[\frac{\partial{\bf x}^i}{\partial\theta_\alpha} = \frac{\partial{\bf x}^i}{\partial{\bf d}^j} \frac{\partial{\bf d}^j}{\partial\theta_\alpha}\]

With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance

\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]

and the derivative of the mean of the network outputs with respect to the model parameters

\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]

can be calculated and used form the Fisher information matrix

\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]

The loss function is then defined as

\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]

Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as

\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]

Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale

\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]

Once the loss function is calculated the automatic gradient is then calculated and used to update the network parameters via the optimiser function.

Parameters
  • fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)

  • derivative (float(n_d, input_shape, n_params)) – The derivative of the simulations with respect to the model parameters (for fitting)

  • validation_fiducial (float(n_s, input_shape) or None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation)

  • validation_derivative (float(n_d, input_shape, n_params) or None) – The derivative of the simulations with respect to the model parameters (for validation)

Public Methods:

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

get_summaries(w[, key, validate])

Gets all network outputs and derivatives wrt model parameters

Inherited from _IMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

fit(λ, ε[, rng, patience, min_iterations, …])

Fitting routine for the IMNN

get_α(λ, ε)

Calculate rate parameter for regularisation from closeness criterion

set_F_statistics([w, key, validate])

Set necessary attributes for calculating score compressed summaries

get_summaries(w[, key, validate])

Gets all network outputs and derivatives wrt model parameters

get_estimate(d)

Calculate score compressed parameter estimates from network outputs

plot([ax, expected_detF, colour, figsize, …])

Plot fitting history

Private Methods:

_set_data(fiducial, derivative, …)

Checks and sets data attributes with the correct shape

Inherited from _IMNN

_initialise_parameters(n_s, n_d, n_params, …)

Performs type checking and initialisation of class attributes

_initialise_model(model, optimiser, key_or_state)

Initialises neural network parameters or loads optimiser state

_initialise_history()

Initialises history dictionary attribute

_set_history(results)

Places results from fitting into the history dictionary

_set_inputs(rng, max_iterations)

Builds list of inputs for the XLA compilable fitting routine

_get_fitting_keys(rng)

Generates random numbers for simulation generation if needed

_fit(inputs, λ=None, α=None[, min_iterations])

Single iteration fitting algorithm

_fit_cond(inputs, patience, max_iterations)

Stopping condition for the fitting loop

_update_loop_vars(inputs)

Updates input parameters if max_detF is increased

_check_loop_vars(inputs, min_iterations)

Updates patience_counter if max_detF not increased

_update_history(inputs, history, counter, ind)

Puts current fitting statistics into history arrays

_slogdet(matrix)

Combined summed logarithmic determinant

_construct_derivatives(derivatives)

Builds derivatives of the network outputs wrt model parameters

_get_F_statistics([w, key, validate])

Calculates the Fisher information and returns all statistics used

_calculate_F_statistics(summaries, derivatives)

Calculates the Fisher information matrix from network outputs

_get_regularisation_strength(Λ2, λ, α)

Coupling strength of the regularisation (amplified sigmoid)

_get_regularisation(C, invC)

Difference of the covariance (and its inverse) from identity

_get_loss(w, λ, α[, key])

Calculates the loss function and returns auxillary variables

_calculate_loss(summaries, derivatives, λ, α)

Calculates the loss function from network summaries and derivatives

_setup_plot([ax, expected_detF, figsize])

Builds axes for history plot


_set_data(fiducial, derivative, validation_fiducial, validation_derivative)

Checks and sets data attributes with the correct shape

Parameters
  • fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)

  • derivative (float(n_d, input_shape, n_params)) – The derivative of the simulations with respect to the model parameters (for fitting)

  • validation_fiducial (float(n_s, input_shape) or None, default=None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation). Sets validate = True attribute if provided

  • validation_derivative (float(n_d, input_shape, n_params) or None) – The derivative of the simulations with respect to the model parameters (for validation). Sets validate = True attribute if provided

get_summaries(w, key=None, validate=False)

Gets all network outputs and derivatives wrt model parameters

Selects either the fitting or validation sets and passes them through the network to get the network outputs.

Parameters
  • w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters

  • key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly

  • validate (bool, default=False) – Whether to get summaries of the validation set

Returns

  • float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance

  • float(n_d, n_summaries, n_params) – The derivative of the network ouputs wrt the model parameters

get_derivatives:

Calculates the jacobian of the network output and its value

AggregatedGradientIMNN

class imnn.AggregatedGradientIMNN(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, fiducial, derivative, host, devices, n_per_device, validation_fiducial=None, validation_derivative=None, prefetch=None, cache=False)

Information maximising neural network fit using known derivatives

The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations \({\bf d}^i\) originally generated at fiducial model parameter \({\bf\theta}^\rm{fid}\), and their derivatives \(\partial{\bf d}^i/\partial\theta_\alpha\) with respect to model parameters are used. The fiducial simulations, \({\bf d}^i\), are passed through a network to obtain summaries, \({\bf x}^i\), and the jax automatic derivative of these summaries with respect to the inputs are calculated \(\partial{\bf x}^i\partial{\bf d}^j\delta_{ij}\). The chain rule is then used to calculate

\[\frac{\partial{\bf x}^i}{\partial\theta_\alpha} = \frac{\partial{\bf x}^i}{\partial{\bf d}^j} \frac{\partial{\bf d}^j}{\partial\theta_\alpha}\]

With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance

\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]

and the derivative of the mean of the network outputs with respect to the model parameters

\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]

can be calculated and used form the Fisher information matrix

\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]

The loss function is then defined as

\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]

Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as

\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]

Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale

\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]

To enable the use of large data (or networks) the whole procedure is aggregated. This means that the passing of the simulations through the network is farmed out to the desired XLA devices, and recollected, n_per_device inputs at a time. These are then used to calculate the automatic gradient of the loss function with respect to the calculated summaries and derivatives, \(\partial\Lambda/\partial{\bf x}^i\) (which is a fairly small computation as long as n_summaries and n_s {and n_d} are not huge). Once this is calculated, the simulations are passed through the network AGAIN this time calculating the Jacobian of the network output with respect to the network parameters \(\partial{\bf x}^i/\partial{\bf w}\) which is then combined via the chain rule to get

\[\frac{\partial\Lambda}{\partial{\bf w}} = \frac{\partial\Lambda}{\partial{\bf x}^i} \frac{\partial{\bf x}^i}{\partial{\bf w}}\]

This can then be passed to the optimiser.

Parameters
  • fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)

  • derivative (float(n_d, input_shape, n_params)) – The derivative of the simulations with respect to the model parameters (for fitting)

  • validation_fiducial (float(n_s, input_shape) or None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation)

  • validation_derivative (float(n_d, input_shape, n_params) or None) – The derivative of the simulations with respect to the model parameters (for validation)

  • main (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for fitting). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • remaining (list of tf.data.Dataset()as_numpy_iterators()) – The n_s - n_d simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs with a derivative counterpart (for fitting). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • validation_main (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for validation). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • validation_remaining (list of tf.data.Dataset()as_numpy_iterators()) – The n_s - n_d simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs with a derivative counterpart (for validation). Served n_per_device at time as a numpy iterator from a TensorFlow dataset.

  • n_remaining (int) – The number simulations where only the fiducial simulations are calculated. This is zero if n_s is equal to n_d.

  • n_iterations (int) – Number of iterations through the main summarising loop

  • n_remaining_iterations (int) – Number of iterations through the remaining simulations used for quick loops with no derivatives

  • batch_shape (tuple) – The shape which n_d should be reshaped to for aggregating. n_d // (n_devices * n_per_device), n_devices, n_per_device

  • remaining_batch_shape (tuple) – The shape which n_s - n_d should be reshaped to for aggregating. (n_s - n_d) // (n_devices * n_per_device), n_devices, n_per_device

Public Methods:

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

get_summary(input, w, θ[, derivative, gradient])

Returns a single summary of a simulation or its gradient

get_summaries(w[, key, validate])

Gets all network outputs and derivatives wrt model parameters

get_gradient(dΛ_dx, w[, key])

Aggregates gradients together to update the network parameters

Inherited from _AggregatedIMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

fit(λ, ε[, rng, patience, min_iterations, …])

Fitting routine for the IMNN

get_summaries(w[, key, validate])

Gets all network outputs and derivatives wrt model parameters

get_gradient(dΛ_dx, w[, key])

Aggregates gradients together to update the network parameters

Inherited from GradientIMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

get_summaries(w[, key, validate])

Gets all network outputs and derivatives wrt model parameters

Inherited from _IMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

fit(λ, ε[, rng, patience, min_iterations, …])

Fitting routine for the IMNN

get_α(λ, ε)

Calculate rate parameter for regularisation from closeness criterion

set_F_statistics([w, key, validate])

Set necessary attributes for calculating score compressed summaries

get_summaries(w[, key, validate])

Gets all network outputs and derivatives wrt model parameters

get_estimate(d)

Calculate score compressed parameter estimates from network outputs

plot([ax, expected_detF, colour, figsize, …])

Plot fitting history

Private Methods:

_set_shapes()

Calculates the shapes for batching over different devices

_set_dataset(prefetch, cache)

Collates data into loopable tensorflow dataset iterations

_collect_input(key[, validate])

Returns validation or fitting sets

_split_dΛ_dx(dΛ_dx)

Returns the gradient of loss function wrt summaries (derivatives)

Inherited from _AggregatedIMNN

_set_devices(devices, n_per_device)

Checks that devices exist and that reshaping onto devices can occur

_set_batch_functions()

Creates jitted functions placed on desired XLA devices

_set_shapes()

Calculates the shapes for batching over different devices

_setup_progress_bar(print_rate, max_iterations)

Construct progress bar

_update_progress_bar(pbar, counter, …[, close])

Updates (and closes) progress bar

_collect_input(key[, validate])

Returns validation or fitting sets

_get_batch_summaries(inputs, w, θ[, …])

Vectorised batch calculation of summaries or gradients

_split_dΛ_dx(dΛ_dx)

Returns the gradient of loss function wrt summaries (derivatives)

_construct_gradient(layers[, aux, func])

Multiuse function to iterate over tuple of network parameters

Inherited from GradientIMNN

_set_data(fiducial, derivative, …)

Checks and sets data attributes with the correct shape

Inherited from _IMNN

_initialise_parameters(n_s, n_d, n_params, …)

Performs type checking and initialisation of class attributes

_initialise_model(model, optimiser, key_or_state)

Initialises neural network parameters or loads optimiser state

_initialise_history()

Initialises history dictionary attribute

_set_history(results)

Places results from fitting into the history dictionary

_set_inputs(rng, max_iterations)

Builds list of inputs for the XLA compilable fitting routine

_get_fitting_keys(rng)

Generates random numbers for simulation generation if needed

_fit(inputs, λ=None, α=None[, min_iterations])

Single iteration fitting algorithm

_fit_cond(inputs, patience, max_iterations)

Stopping condition for the fitting loop

_update_loop_vars(inputs)

Updates input parameters if max_detF is increased

_check_loop_vars(inputs, min_iterations)

Updates patience_counter if max_detF not increased

_update_history(inputs, history, counter, ind)

Puts current fitting statistics into history arrays

_slogdet(matrix)

Combined summed logarithmic determinant

_construct_derivatives(derivatives)

Builds derivatives of the network outputs wrt model parameters

_get_F_statistics([w, key, validate])

Calculates the Fisher information and returns all statistics used

_calculate_F_statistics(summaries, derivatives)

Calculates the Fisher information matrix from network outputs

_get_regularisation_strength(Λ2, λ, α)

Coupling strength of the regularisation (amplified sigmoid)

_get_regularisation(C, invC)

Difference of the covariance (and its inverse) from identity

_get_loss(w, λ, α[, key])

Calculates the loss function and returns auxillary variables

_calculate_loss(summaries, derivatives, λ, α)

Calculates the loss function from network summaries and derivatives

_setup_plot([ax, expected_detF, figsize])

Builds axes for history plot


_collect_input(key, validate=False)

Returns validation or fitting sets

Parameters
  • key (None or int(2,)) – Random number generators not used in this case

  • validate (bool) – Whether to return the set for validation or for fitting

Returns

  • list of tf.data.Dataset().as_numpy_iterators – The iterators for the main loop including simulations and their derivatives for fitting or validation

  • list of tf.data.Dataset().as_numpy_iterators – The iterators for the remaining loop simulations for fitting or validation

_set_dataset(prefetch, cache)

Collates data into loopable tensorflow dataset iterations

Parameters
  • prefetch (tf.data.AUTOTUNE or int or None) – How many simulation to prefetch in the tensorflow dataset

  • cache (bool) – Whether to cache simulations in the tensorflow datasets

Raises
  • ValueError – If cache is None

  • TypeError – If cache is wrong type

_set_shapes()

Calculates the shapes for batching over different devices

Raises

ValueError – If the difference between n_s and n_d won’t scale over xla devices

_split_dΛ_dx(dΛ_dx)

Returns the gradient of loss function wrt summaries (derivatives)

The gradient of loss function with respect to network outputs and their derivatives with respect to model parameters has to be reshaped and aggregated onto each XLA device matching the format that the tensorflow dataset feeds simulations.

Parameters

dΛ_dx (tuple) –

  • dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries

  • d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters

get_gradient(dΛ_dx, w, key=None)

Aggregates gradients together to update the network parameters

To avoid having to calculate the gradient with respect to all the simulations at once we aggregate by addition the gradient calculation by looping over the simulations again and combining them with the derivative of the loss function with respect to the network outputs (and their derivatives with respect to the model parameters). Whilst this is expensive, it is necessary since we cannot make a stochastic estimate of the Fisher information accurately and therefore we need to use all the simulations available - which is probably too large to fit in memory.

Parameters
  • dΛ_dx (tuple) –

    • dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries

    • d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters

  • w (list) – Network parameters

  • key (None or int(2,)) – Random number generator used in SimulatorIMNN

Returns

The gradient of the loss function with respect to the network parameters calculated by aggregating

Return type

list

get_summaries(w, key=None, validate=False)

Gets all network outputs and derivatives wrt model parameters

Selects either the fitting or validation sets and loops through the iterator on each XLA device to pass them through the network to get the network outputs. These are then pushed back to the host for the computation of the loss function.

The fiducial simulations which have a derivative with respect to the model parameters counterpart are processed first and then the remaining fiducial simulations are processed and concatenated.

Parameters
  • w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters

  • key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly

  • validate (bool, default=False) – Whether to get summaries of the validation set

Returns

  • float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance

  • float(n_d, n_summaries, n_params) – The set of all the derivatives of the network outputs with respect to model parameters

get_summary(input, w, θ, derivative=False, gradient=False)

Returns a single summary of a simulation or its gradient

Parameters
  • input (float(input_shape) or tuple) –

    A single simulation to pass through the network or a tuple of either (if gradient and not derivative)

    • dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary

    • d float(input_shape) – a simulation to compress with the network

    or (if gradient and derivative)

    • tuple (gradients)

      • dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary

      • d2Λ_dxdθ float(input_shape, n_params) – the derivative of the loss function with respect to the derivative of a network summary with respect to model parameters

    • tuple (simulations)

      • d float(input_shape) – a simulation to compress with the network

      • dd_dθ float(input_shape, n_params) – the derivative of a simulation with respect to model parameters

    or (if derivative and not gradient)

    • d float(input_shape) – a simulation to compress with the network

    • dd_dθ float(input_shape, n_params) – the derivative of a simulation with respect to model parameters

  • w (list) – The network parameters

  • θ (float(n_params,)) – The value of the parameters to generate the simulation at (fiducial), unused if not simulating on the fly

  • derivative (bool, default=False) – Whether a derivative of the simulation with respect to model parameters is also passed.

  • gradient (bool, default=False) – Whether to calculate the gradient with respect to model parameters

Returns

  • tuple (if gradient) – The gradient of the loss function with respect to model parameters

  • float(n_summaries,) (if not gradient) – The output of the network

  • float(n_summaries, n_params) (if derivative and not gradient) – The derivative of the output of the network wrt model parameters

NumericalGradientIMNN

NumericalGradientIMNN

class imnn.NumericalGradientIMNN(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, fiducial, derivative, δθ, validation_fiducial=None, validation_derivative=None)

Information maximising neural network fit using numerical derivatives

The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations \({\bf d}^i\) originally generated at fiducial model parameter \({\bf\theta}^\rm{fid}\), and a set of \(i\in[1, n_d]\) simulations, \(\{{\bf d}_{\alpha^-}^i, {\bf d}_{\alpha^+}^i\}\), generated with the same seed at each \(i\) generated at \({\bf\theta}^\rm{fid}\) apart from at parameter label \(\alpha\) with values

\[\theta_{\alpha^-} = \theta_\alpha^\rm{fid}-\delta\theta_\alpha\]

and

\[\theta_{\alpha^+} = \theta_\alpha^\rm{fid}+\delta\theta_\alpha\]

where \(\delta\theta_\alpha\) is a \(n_{params}\) length vector with the \(\alpha\) element having a value which perturbs the parameter \(\theta^{\rm fid}_\alpha\). This means there are \(2\times n_{params}\times n_d\) simulations used to calculate the numerical derivatives (this is extremely cheap compared to other machine learning methods). All these simulations are passed through a network \(f_{{\bf w}}({\bf d})\) with network parameters \({\bf w}\) to obtain network outputs \({\bf x}^i\) and \(\{{\bf x}_{\alpha^-}^i,{\bf x}_{\alpha^+}^i\}\). These perturbed values are combined to obtain

\[\frac{\partial{{\bf x}^i}}{\partial\theta_\alpha} = \frac{{\bf x}_{\alpha^+}^i - {\bf x}_{\alpha^-}^i} {\delta\theta_\alpha}\]

With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance

\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]

and the derivative of the mean of the network outputs with respect to the model parameters

\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]

can be calculated and used form the Fisher information matrix

\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]

The loss function is then defined as

\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]

Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as

\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]

Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale

\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]

Once the loss function is calculated the automatic gradient is then calculated and used to update the network parameters via the optimiser function.

Parameters
  • δθ (float(n_params,)) – Size of perturbation to model parameters for the numerical derivative

  • fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)

  • derivative (float(n_d, 2, n_params, input_shape)) – The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for fitting)

  • validation_fiducial (float(n_s, input_shape) or None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation)

  • validation_derivative (float(n_d, 2, n_params, input_shape) or None) – The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for validation)

Public Methods:

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

get_summaries(w[, key, validate])

Gets all network outputs and derivatives wrt model parameters

Inherited from _IMNN

__init__(n_s, n_d, n_params, n_summaries, …)

Constructor method

fit(λ, ε[, rng, patience, min_iterations, …])

Fitting routine for the IMNN

get_α(λ, ε)

Calculate rate parameter for regularisation from closeness criterion

set_F_statistics([w, key, validate])

Set necessary attributes for calculating score compressed summaries

get_summaries(w[, key, validate])

Gets all network outputs and derivatives wrt model parameters

get_estimate(d)

Calculate score compressed parameter estimates from network outputs

plot([ax, expected_detF, colour, figsize, …])

Plot fitting history

Private Methods:

_set_data(δθ, fiducial, derivative, …)

Checks and sets data attributes with the correct shape

_collect_input(key[, validate])

Returns validation or fitting sets

_construct_derivatives(x_mp)

Builds derivatives of the network outputs wrt model parameters

Inherited from _IMNN

_initialise_parameters(n_s, n_d, n_params, …)

Performs type checking and initialisation of class attributes

_initialise_model(model, optimiser, key_or_state)

Initialises neural network parameters or loads optimiser state

_initialise_history()

Initialises history dictionary attribute

_set_history(results)

Places results from fitting into the history dictionary

_set_inputs(rng, max_iterations)

Builds list of inputs for the XLA compilable fitting routine

_get_fitting_keys(rng)

Generates random numbers for simulation generation if needed

_fit(inputs, λ=None, α=None[, min_iterations])

Single iteration fitting algorithm

_fit_cond(inputs, patience, max_iterations)

Stopping condition for the fitting loop

_update_loop_vars(inputs)

Updates input parameters if max_detF is increased

_check_loop_vars(inputs, min_iterations)

Updates patience_counter if max_detF not increased

_update_history(inputs, history, counter, ind)

Puts current fitting statistics into history arrays

_slogdet(matrix)

Combined summed logarithmic determinant

_construct_derivatives(x_mp)

Builds derivatives of the network outputs wrt model parameters

_get_F_statistics([w, key, validate])

Calculates the Fisher information and returns all statistics used

_calculate_F_statistics(summaries, derivatives)

Calculates the Fisher information matrix from network outputs

_get_regularisation_strength(Λ2, λ, α)

Coupling strength of the regularisation (amplified sigmoid)

_get_regularisation(C, invC)

Difference of the covariance (and its inverse) from identity

_get_loss(w, λ, α[, key])

Calculates the loss function and returns auxillary variables

_calculate_loss(summaries, derivatives, λ, α)

Calculates the loss function from network summaries and derivatives

_setup_plot([ax, expected_detF, figsize])

Builds axes for history plot


_collect_input(key, validate=False)

Returns validation or fitting sets

Parameters
  • key (None or int(2,)) – Random number generators not used in this case

  • validate (bool) – Whether to return the set for validation or for fitting

Returns

  • float(n_s, input_shape) – The fiducial simulations for fitting or validation

  • float(n_d, 2, n_params, input_shape) – The derivative simulations for fitting or validation

_construct_derivatives(x_mp)

Builds derivatives of the network outputs wrt model parameters

The network outputs from the simulations generated with model parameter values above and below the fiducial are subtracted from each other and divided by the perturbation size in each model parameter value. The axes are swapped such that the derivatives with respect to parameters are in the last axis.

\[\frac{\partial{\bf x}^i}{\partial\theta_\alpha} = \frac{{\bf x}^i_{\alpha^+}-{\bf x}^i_{\alpha^+}}{ \delta\theta_\alpha}\]
Parameters

derivatives (float(n_d, 2, n_params, n_summaries)) – The outputs of the network of simulations made at perturbed parameter values to construct the derivative of the network outputs with respect to the model parameters numerically

Returns

The numerical derivatives of the network ouputs with respect to the model parameters

Return type

float(n_d, n_summaries, n_params)

_set_data(δθ, fiducial, derivative, validation_fiducial, validation_derivative)

Checks and sets data attributes with the correct shape

Parameters
  • δθ (float(n_params,)) – Size of perturbation to model parameters for the numerical derivative

  • fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)

  • derivative (float(n_d, input_shape, n_params)) – The derivative of the simulations with respect to the model parameters (for fitting)

  • validation_fiducial (float(n_s, input_shape) or None, default=None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation). Sets validate = True attribute if provided

  • validation_derivative (float(n_d, input_shape, n_params) or None) – The derivative of the simulations with respect to the model parameters (for validation). Sets validate = True attribute if provided

get_summaries(w, key=None, validate=False)

Gets all network outputs and derivatives wrt model parameters

Selects either the fitting or validation sets and passes them through the network to get the network outputs. For the numerical derivatives, the array is first flattened along the batch axis before being passed through the model.

Parameters
  • w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters

  • key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly

  • validate (bool, default=False) – Whether to get summaries of the validation set

Returns

  • float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance

  • float(n_d, 2, n_params, n_summaries) – The outputs of the network of simulations made at perturbed parameter values to construct the derivative of the network outputs with respect to the model parameters numerically

AggregatedNumericalGradientIMNN

class imnn.AggregatedNumericalGradientIMNN(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, fiducial, derivative, δθ, host, devices, n_per_device, validation_fiducial=None, validation_derivative=None, prefetch=None, cache=False)

Information maximising neural network fit using numerical derivatives

The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations \({\bf d}^i\) originally generated at fiducial model parameter \({\bf\theta}^\rm{fid}\), and a set of \(i\in[1, n_d]\) simulations, \(\{{\bf d}_{\alpha^-}^i, {\bf d}_{\alpha^+}^i\}\), generated with the same seed at each \(i\) generated at \({\bf\theta}^\rm{fid}\) apart from at parameter label \(\alpha\) with values

\[\theta_{\alpha^-} = \theta_\alpha^\rm{fid}-\delta\theta_\alpha\]

and

\[\theta_{\alpha^+} = \theta_\alpha^\rm{fid}+\delta\theta_\alpha\]

where \(\delta\theta_\alpha\) is a \(n_{params}\) length vector with the \(\alpha\) element having a value which perturbs the parameter \(\theta^{\rm fid}_\alpha\). This means there are \(2\times n_{params}\times n_d\) simulations used to calculate the numerical derivatives (this is extremely cheap compared to other machine learning methods). All these simulations are passed through a network \(f_{{\bf w}}({\bf d})\) with network parameters \({\bf w}\) to obtain network outputs \({\bf x}^i\) and \(\{{\bf x}_{\alpha^-}^i,{\bf x}_{\alpha^+}^i\}\). These perturbed values are combined to obtain

\[\frac{\partial{{\bf x}^i}}{\partial\theta_\alpha} = \frac{{\bf x}_{\alpha^+}^i - {\bf x}_{\alpha^-}^i} {\delta\theta_\alpha}\]

With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance

\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]

and the derivative of the mean of the network outputs with respect to the model parameters

\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]

can be calculated and used form the Fisher information matrix

\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]

The loss function is then defined as

\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]

Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as

\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]

Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale

\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]

To enable the use of large data (or networks) the whole procedure is aggregated. This means that the passing of the simulations through the network is farmed out to the desired XLA devices, and recollected, n_per_device inputs at a time. These are then used to calculate the automatic gradient of the loss function with respect to the calculated summaries and derivatives, \(\partial\Lambda/\partial{\bf x}^i\) (which is a fairly small computation as long as n_summaries and n_s {and n_d} are not huge). Once this is calculated, the simulations are passed through the network AGAIN this time calculating the Jacobian of the network output with respect to the network parameters \(\partial{\bf x}^i/\partial{\bf w}\) which is then combined via the chain rule to get

\[\frac{\partial\Lambda}{\partial{\bf w}} = \frac{\partial\Lambda}{\partial{\bf x}^i} \frac{\partial{\bf x}^i}{\partial{\bf w}}\]

This can then be passed to the optimiser.

Parameters
  • δθ (float(n_params,)) – Size of perturbation to model parameters for the numerical derivative

  • fiducial (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • derivative (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for fitting). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • validation_fiducial (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • validation_derivative (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for validation). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • fiducial_iterations (int) – The number of iterations over the fiducial dataset

  • derivative_iterations (int) – The number of iterations over the derivative dataset

  • derivative_output_shape (tuple) – The shape of the output of the derivatives from the network

  • fiducial_batch_shape (tuple) – The shape of each batch of fiducial simulations (without input or summary shape)

  • derivative_batch_shape (tuple) – The shape of each batch of derivative simulations (without input or summary shape)

_collect_input(key, validate=False)

Returns validation or fitting sets

Parameters
  • key (None or int(2,)) – Random number generators not used in this case

  • validate (bool) – Whether to return the set for validation or for fitting

Returns

  • list of tf.data.Dataset().as_numpy_iterators – The iterators for fiducial simulations for fitting or validation

  • list of tf.data.Dataset().as_numpy_iterators – The iterators for derivative simulations for fitting or validation

_set_batch_functions()

Creates jitted functions placed on desired XLA devices

For each set of summaries to correctly be calculated on a particular device we predefine the jitted functions on each of these devices

_set_dataset(prefetch, cache)

Transforms the data into lists of tensorflow dataset iterators

Parameters
  • prefetch (tf.data.AUTOTUNE or int or None) – How many simulation to prefetch in the tensorflow dataset

  • cache (bool) – Whether to cache simulations in the tensorflow datasets

Raises
  • ValueError – If cache and/or prefetch is None

  • TypeError – If cache and/or prefetch is wrong type

_set_shapes()

Calculates the shapes for batching over different devices

_split_dΛ_dx(dΛ_dx)

Returns the gradient of loss function wrt summaries (derivatives)

The gradient of loss function with respect to network outputs and their derivatives with respect to model parameters has to be reshaped and aggregated onto each XLA device matching the format that the tensorflow dataset feeds simulations.

Parameters

dΛ_dx (tuple) –

  • dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries

  • d2Λ_dxdθ float(n_d, 2, n_params, n_summaries) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters

Returns

  • list – a list of sets of derivatives of the loss function with respect to network summaries placed on each XLA device

  • list – a list of sets of derivatives of the loss function with respect to the derivative of network summaries with respect to model parameters

get_gradient(dΛ_dx, w, key=None)

Aggregates gradients together to update the network parameters

To avoid having to calculate the gradient with respect to all the simulations at once we aggregated by addition the gradient calculation by looping over the simulations again and combining them with the derivative of the loss function with respect to the network outputs (and their derivatives with respect to the model parameters). Whilst this is expensive, it is necessary since we cannot make a stochastic estimate of the Fisher information accurately and therefore we need to use all the simulations available - which is probably too large to fit in memory.

Parameters
  • dΛ_dx (tuple) –

    • dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries

    • d2Λ_dxdθ float(n_d, 2, n_params, n_summaries) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters

  • w (list) – Network parameters

  • key (None or int(2,)) – Random number generator used in SimulatorIMNN

Returns

The gradient of the loss function with respect to the network parameters calculated by aggregating

Return type

list

get_summaries(w, key=None, validate=False)

Gets all network outputs and derivatives wrt model parameters

Selects either the fitting or validation sets and loops through the iterator on each XLA device to pass them through the network to get the network outputs. These are then pushed back to the host for the computation of the loss function.

The fiducial simulations are processed first and then the simulations which are varied with respect to model parameters for the derivatives.

Parameters
  • w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters

  • key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly

  • validate (bool, default=False) – Whether to get summaries of the validation set

Returns

  • float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance

  • float(n_d, 2, n_params, n_summaries) – The outputs of the network of simulations made at perturbed parameter values to construct the derivative of the network outputs with respect to the model parameters numerically

get_summary(inputs, w, θ, derivative=False, gradient=False)

Returns a single summary of a simulation or its gradient

Parameters
  • inputs (float(input_shape) or tuple) –

    A single simulation to pass through the network or a tuple of
    • dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary

    • d float(input_shape) – a simulation to compress with the network

  • w (list) – The network parameters

  • θ (float(n_params,)) – The value of the parameters to generate the simulation at (fiducial), unused if not simulating on the fly

  • derivative (bool, default=False) – Whether a derivative of the simulation with respect to model parameters is also passed. This must be False for NumericalGradientIMNN

  • gradient (bool, default=False) – Whether to calculate the gradient with respect to model parameters