Specific IMNN classes¶
The available modules are:
SimulatorIMNN¶
SimulatorIMNN¶
-
class
imnn.
SimulatorIMNN
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, simulator)¶ Information maximising neural network fit with simulations on-the-fly
Defines the function to get simulations and compress them using an XLA compilable simulator.
The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) random number generators are generated and used to generate a set of \(n_s\) simulations, \({\bf d}^i={\rm simulator}({\rm seed}^i, \theta^\rm{fid})\) at the fiducial model parameters, \(\theta^\rm{fid}\), and these are passed direrectly through a network \(f_{{\bf w}}({\bf d})\) with network parameters \({\bf w}\) to obtain network outputs \({\bf x}^i\) and autodifferentiation is used to get the derivative of \(n_d\) of these outputs with respect to the physical model parameters, \(\partial{{\bf x}^i}/\partial\theta_\alpha\), where \(\alpha\) labels the physical parameter. With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance
\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]and the derivative of the mean of the network outputs with respect to the model parameters
\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]can be calculated and used form the Fisher information matrix
\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]The loss function is then defined as
\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as
\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale
\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]Once the loss function is calculated the automatic gradient is then calculated and used to update the network parameters via the optimiser function.
-
simulator:
A function for generating a simulation on-the-fly (XLA compilable)
Public Methods:
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
get_summaries
(w, key[, validate])Gets all network outputs and derivatives wrt model parameters
Inherited from
_IMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_α
(λ, ε)Calculate rate parameter for regularisation from closeness criterion
set_F_statistics
([w, key, validate])Set necessary attributes for calculating score compressed summaries
get_summaries
(w, key[, validate])Gets all network outputs and derivatives wrt model parameters
get_estimate
(d)Calculate score compressed parameter estimates from network outputs
plot
([ax, expected_detF, colour, figsize, …])Plot fitting history
Private Methods:
_get_fitting_keys
(rng)Generates random numbers for simulation
Inherited from
_IMNN
_initialise_parameters
(n_s, n_d, n_params, …)Performs type checking and initialisation of class attributes
_initialise_model
(model, optimiser, key_or_state)Initialises neural network parameters or loads optimiser state
_initialise_history
()Initialises history dictionary attribute
_set_history
(results)Places results from fitting into the history dictionary
_set_inputs
(rng, max_iterations)Builds list of inputs for the XLA compilable fitting routine
_get_fitting_keys
(rng)Generates random numbers for simulation
_fit
(inputs, λ=None, α=None[, min_iterations])Single iteration fitting algorithm
_fit_cond
(inputs, patience, max_iterations)Stopping condition for the fitting loop
_update_loop_vars
(inputs)Updates input parameters if
max_detF
is increased_check_loop_vars
(inputs, min_iterations)Updates
patience_counter
ifmax_detF
not increased_update_history
(inputs, history, counter, ind)Puts current fitting statistics into history arrays
_slogdet
(matrix)Combined summed logarithmic determinant
_construct_derivatives
(derivatives)Builds derivatives of the network outputs wrt model parameters
_get_F_statistics
([w, key, validate])Calculates the Fisher information and returns all statistics used
_calculate_F_statistics
(summaries, derivatives)Calculates the Fisher information matrix from network outputs
_get_regularisation_strength
(Λ2, λ, α)Coupling strength of the regularisation (amplified sigmoid)
_get_regularisation
(C, invC)Difference of the covariance (and its inverse) from identity
_get_loss
(w, λ, α[, key])Calculates the loss function and returns auxillary variables
_calculate_loss
(summaries, derivatives, λ, α)Calculates the loss function from network summaries and derivatives
_setup_plot
([ax, expected_detF, figsize])Builds axes for history plot
-
_get_fitting_keys
(rng)¶ Generates random numbers for simulation
- Parameters
rng (int(2,)) – A random number generator
- Returns
A new random number generator and random number generators for fitting (and validation)
- Return type
int(2,), int(2,), int(2,)
-
get_summaries
(w, key, validate=False)¶ Gets all network outputs and derivatives wrt model parameters
A random seed for each simulation is obtained and
n_d
of them are used to calculate the network outputs of each of these simulations as well as the derivative of these network outputs with respect to the model parameters as calculated using jax autodifferentiation. The remainingn_s - n_d
network outputs are then calculated and concatenated to those already calculated.- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=False) – Whether to get summaries of the validation set
- Returns
float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance
float(n_d, n_summaries, n_params) – The set of all network output derivatives wrt model parameters
-
get_summary:
Return a single network output
-
get_derivatives:
Return the Jacobian of the network outputs wrt model parameters
-
AggregatedSimulatorIMNN¶
-
class
imnn.
AggregatedSimulatorIMNN
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, simulator, host, devices, n_per_device)¶ Information maximising neural network fit with simulations on-the-fly
Defines the function to get simulations and compress them using an XLA compilable simulator.
The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) random number generators are generated and used to generate a set of \(n_s\) simulations, \({\bf d}^i={\rm simulator}({\rm seed}^i, \theta^\rm{fid})\) at the fiducial model parameters, \(\theta^\rm{fid}\), and these are passed direrectly through a network \(f_{{\bf w}}({\bf d})\) with network parameters \({\bf w}\) to obtain network outputs \({\bf x}^i\) and autodifferentiation is used to get the derivative of \(n_d\) of these outputs with respect to the physical model parameters, \(\partial{{\bf x}^i}/\partial\theta_\alpha\), where \(\alpha\) labels the physical parameter. With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance
\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]and the derivative of the mean of the network outputs with respect to the model parameters
\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]can be calculated and used form the Fisher information matrix
\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]The loss function is then defined as
\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as
\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale
\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]To enable the use of large data (or networks) the whole procedure is aggregated. This means that the generation and passing of the simulations through the network is farmed out to the desired XLA devices, and recollected,
n_per_device
inputs at a time. These are then used to calculate the automatic gradient of the loss function with respect to the calculated summaries and derivatives, \(\partial\Lambda/\partial{\bf x}^i\) (which is a fairly small computation as long asn_summaries
andn_s
{andn_d
} are not huge). Once this is calculated, the simulations are passed through the network AGAIN this time calculating the Jacobian of the network output with respect to the network parameters \(\partial{\bf x}^i/\partial{\bf w}\) which is then combined via the chain rule to get\[\frac{\partial\Lambda}{\partial{\bf w}} = \frac{\partial\Lambda}{\partial{\bf x}^i} \frac{\partial{\bf x}^i}{\partial{\bf w}}\]This can then be passed to the optimiser.
- Parameters
n_remaining (int) – The number simulations where only the fiducial simulations are calculated. This is zero if
n_s
is equal ton_d
.n_iterations (int) – Number of iterations through the main summarising loop
n_remaining_iterations (int) – Number of iterations through the remaining simulations used for quick loops with no derivatives
batch_shape (tuple) – The shape which
n_d
should be reshaped to for aggregating.n_d // (n_devices * n_per_device), n_devices, n_per_device
remaining_batch_shape (tuple) – The shape which
n_s - n_d
should be reshaped to for aggregating.(n_s - n_d) // (n_devices * n_per_device), n_devices, n_per_device
-
simulator:
A function for generating a simulation on-the-fly (XLA compilable)
Public Methods:
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
get_summary
(input, w, θ[, derivative, gradient])Returns a single summary of a simulation or its gradient
get_summaries
(w, key[, validate])Gets all network outputs and derivatives wrt model parameters
get_gradient
(dΛ_dx, w[, key])Aggregates gradients together to update the network parameters
Inherited from
_AggregatedIMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_summaries
(w, key[, validate])Gets all network outputs and derivatives wrt model parameters
get_gradient
(dΛ_dx, w[, key])Aggregates gradients together to update the network parameters
Inherited from
SimulatorIMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
get_summaries
(w, key[, validate])Gets all network outputs and derivatives wrt model parameters
Inherited from
_IMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_α
(λ, ε)Calculate rate parameter for regularisation from closeness criterion
set_F_statistics
([w, key, validate])Set necessary attributes for calculating score compressed summaries
get_summaries
(w, key[, validate])Gets all network outputs and derivatives wrt model parameters
get_estimate
(d)Calculate score compressed parameter estimates from network outputs
plot
([ax, expected_detF, colour, figsize, …])Plot fitting history
Private Methods:
Calculates the shapes for batching over different devices
_collect_input
(key[, validate])Returns the keys for generating simulations on-the-fly
_split_dΛ_dx
(dΛ_dx)Returns the gradient of loss function wrt summaries (derivatives)
Inherited from
_AggregatedIMNN
_set_devices
(devices, n_per_device)Checks that devices exist and that reshaping onto devices can occur
_set_batch_functions
()Creates jitted functions placed on desired XLA devices
Calculates the shapes for batching over different devices
_setup_progress_bar
(print_rate, max_iterations)Construct progress bar
_update_progress_bar
(pbar, counter, …[, close])Updates (and closes) progress bar
_collect_input
(key[, validate])Returns the keys for generating simulations on-the-fly
_get_batch_summaries
(inputs, w, θ[, …])Vectorised batch calculation of summaries or gradients
_split_dΛ_dx
(dΛ_dx)Returns the gradient of loss function wrt summaries (derivatives)
_construct_gradient
(layers[, aux, func])Multiuse function to iterate over tuple of network parameters
Inherited from
SimulatorIMNN
_get_fitting_keys
(rng)Generates random numbers for simulation
Inherited from
_IMNN
_initialise_parameters
(n_s, n_d, n_params, …)Performs type checking and initialisation of class attributes
_initialise_model
(model, optimiser, key_or_state)Initialises neural network parameters or loads optimiser state
_initialise_history
()Initialises history dictionary attribute
_set_history
(results)Places results from fitting into the history dictionary
_set_inputs
(rng, max_iterations)Builds list of inputs for the XLA compilable fitting routine
_get_fitting_keys
(rng)Generates random numbers for simulation
_fit
(inputs, λ=None, α=None[, min_iterations])Single iteration fitting algorithm
_fit_cond
(inputs, patience, max_iterations)Stopping condition for the fitting loop
_update_loop_vars
(inputs)Updates input parameters if
max_detF
is increased_check_loop_vars
(inputs, min_iterations)Updates
patience_counter
ifmax_detF
not increased_update_history
(inputs, history, counter, ind)Puts current fitting statistics into history arrays
_slogdet
(matrix)Combined summed logarithmic determinant
_construct_derivatives
(derivatives)Builds derivatives of the network outputs wrt model parameters
_get_F_statistics
([w, key, validate])Calculates the Fisher information and returns all statistics used
_calculate_F_statistics
(summaries, derivatives)Calculates the Fisher information matrix from network outputs
_get_regularisation_strength
(Λ2, λ, α)Coupling strength of the regularisation (amplified sigmoid)
_get_regularisation
(C, invC)Difference of the covariance (and its inverse) from identity
_get_loss
(w, λ, α[, key])Calculates the loss function and returns auxillary variables
_calculate_loss
(summaries, derivatives, λ, α)Calculates the loss function from network summaries and derivatives
_setup_plot
([ax, expected_detF, figsize])Builds axes for history plot
-
_collect_input
(key, validate=False)¶ Returns the keys for generating simulations on-the-fly
- Parameters
key (None or int(2,)) – Random number generators for generating simulations
validate (bool, default=False) – Whether to return the set for validation or for fitting (always False)
-
_set_shapes
()¶ Calculates the shapes for batching over different devices
- Raises
ValueError – If the difference between n_s and n_d won’t scale over xla devices
-
_split_dΛ_dx
(dΛ_dx)¶ Returns the gradient of loss function wrt summaries (derivatives)
The gradient of loss function with respect to network outputs and their derivatives with respect to model parameters has to be reshaped and aggregated onto each XLA device matching the format keys are generated for generating simulations.
- Parameters
dΛ_dx (tuple) –
dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries
d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters
-
get_gradient
(dΛ_dx, w, key=None)¶ Aggregates gradients together to update the network parameters
To avoid having to calculate the gradient with respect to all the simulations at once we aggregate by addition the gradient calculation by looping over the keys for generating the simulations again which we then combine with the derivative of the loss function with respect to the network outputs (and their derivatives with respect to the model parameters). Whilst this is expensive, it is necessary since we cannot make a stochastic estimate of the Fisher information accurately and therefore we need to use all the simulations available - which is probably too large to fit in memory.
- Parameters
dΛ_dx (tuple) –
dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries
d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters
w (list) – Network parameters
key (None or int(2,)) – Random number generator
- Returns
The gradient of the loss function with respect to the network parameters calculated by aggregating
- Return type
list
-
get_summaries
(w, key, validate=False)¶ Gets all network outputs and derivatives wrt model parameters
Loops through the generated random number generators on each XLA device to generate a simulation and then pass them through the network to get the network outputs. These are then pushed back to the host for the computation of the loss function.
The fiducial simulations which have a derivative with respect to the model parameters counterpart are processed first and then the remaining fiducial simulations are processed and concatenated.
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,)) – A random number generator for generating simulations on-the-fly
validate (bool, default=False) – Whether to get summaries of the validation set (always False)
- Returns
float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance
float(n_d, n_summaries, n_params) – The set of all the derivatives of the network outputs with respect to model parameters
-
get_summary
(input, w, θ, derivative=False, gradient=False)¶ Returns a single summary of a simulation or its gradient
- Parameters
input (float(input_shape) or tuple) –
A random number generator for a single simulation to pass through the network or a tuple of either (if
gradient
andnot derivative
)dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary
key int(2,) – A random number generator for a single simulation
or (if
gradient
andderivative
)- tuple (gradients)
dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary
d2Λ_dxdθ float(input_shape, n_params) – the derivative of the loss function with respect to the derivative of a network summary with respect to model parameters
key int(2,) – A random number generator for a single simulation
w (list) – The network parameters
θ (float(n_params,)) – The value of the parameters to generate the simulation at (fiducial), unused if not simulating on the fly
derivative (bool, default=False) – Whether a derivative of the simulation with respect to model parameters is also passed.
gradient (bool, default=False) – Whether to calculate the gradient with respect to model parameters
- Returns
tuple (if
gradient
) – The gradient of the loss function with respect to model parametersfloat(n_summaries,) (if
not gradient
) – The output of the networkfloat(n_summaries, n_params) (if
derivative
andnot gradient
) – The derivative of the output of the network wrt model parameters
GradientIMNN¶
GradientIMNN¶
-
class
imnn.
GradientIMNN
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, fiducial, derivative, validation_fiducial=None, validation_derivative=None)¶ Information maximising neural network fit using known derivatives
The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations \({\bf d}^i\) originally generated at fiducial model parameter \({\bf\theta}^\rm{fid}\), and their derivatives \(\partial{\bf d}^i/\partial\theta_\alpha\) with respect to model parameters are used. The fiducial simulations, \({\bf d}^i\), are passed through a network to obtain summaries, \({\bf x}^i\), and the jax automatic derivative of these summaries with respect to the inputs are calculated \(\partial{\bf x}^i\partial{\bf d}^j\delta_{ij}\). The chain rule is then used to calculate
\[\frac{\partial{\bf x}^i}{\partial\theta_\alpha} = \frac{\partial{\bf x}^i}{\partial{\bf d}^j} \frac{\partial{\bf d}^j}{\partial\theta_\alpha}\]With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance
\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]and the derivative of the mean of the network outputs with respect to the model parameters
\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]can be calculated and used form the Fisher information matrix
\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]The loss function is then defined as
\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as
\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale
\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]Once the loss function is calculated the automatic gradient is then calculated and used to update the network parameters via the optimiser function.
- Parameters
fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)
derivative (float(n_d, input_shape, n_params)) – The derivative of the simulations with respect to the model parameters (for fitting)
validation_fiducial (float(n_s, input_shape) or None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation)
validation_derivative (float(n_d, input_shape, n_params) or None) – The derivative of the simulations with respect to the model parameters (for validation)
Public Methods:
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
Inherited from
_IMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_α
(λ, ε)Calculate rate parameter for regularisation from closeness criterion
set_F_statistics
([w, key, validate])Set necessary attributes for calculating score compressed summaries
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
get_estimate
(d)Calculate score compressed parameter estimates from network outputs
plot
([ax, expected_detF, colour, figsize, …])Plot fitting history
Private Methods:
_set_data
(fiducial, derivative, …)Checks and sets data attributes with the correct shape
Inherited from
_IMNN
_initialise_parameters
(n_s, n_d, n_params, …)Performs type checking and initialisation of class attributes
_initialise_model
(model, optimiser, key_or_state)Initialises neural network parameters or loads optimiser state
_initialise_history
()Initialises history dictionary attribute
_set_history
(results)Places results from fitting into the history dictionary
_set_inputs
(rng, max_iterations)Builds list of inputs for the XLA compilable fitting routine
_get_fitting_keys
(rng)Generates random numbers for simulation generation if needed
_fit
(inputs, λ=None, α=None[, min_iterations])Single iteration fitting algorithm
_fit_cond
(inputs, patience, max_iterations)Stopping condition for the fitting loop
_update_loop_vars
(inputs)Updates input parameters if
max_detF
is increased_check_loop_vars
(inputs, min_iterations)Updates
patience_counter
ifmax_detF
not increased_update_history
(inputs, history, counter, ind)Puts current fitting statistics into history arrays
_slogdet
(matrix)Combined summed logarithmic determinant
_construct_derivatives
(derivatives)Builds derivatives of the network outputs wrt model parameters
_get_F_statistics
([w, key, validate])Calculates the Fisher information and returns all statistics used
_calculate_F_statistics
(summaries, derivatives)Calculates the Fisher information matrix from network outputs
_get_regularisation_strength
(Λ2, λ, α)Coupling strength of the regularisation (amplified sigmoid)
_get_regularisation
(C, invC)Difference of the covariance (and its inverse) from identity
_get_loss
(w, λ, α[, key])Calculates the loss function and returns auxillary variables
_calculate_loss
(summaries, derivatives, λ, α)Calculates the loss function from network summaries and derivatives
_setup_plot
([ax, expected_detF, figsize])Builds axes for history plot
-
_set_data
(fiducial, derivative, validation_fiducial, validation_derivative)¶ Checks and sets data attributes with the correct shape
- Parameters
fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)
derivative (float(n_d, input_shape, n_params)) – The derivative of the simulations with respect to the model parameters (for fitting)
validation_fiducial (float(n_s, input_shape) or None, default=None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation). Sets
validate = True
attribute if providedvalidation_derivative (float(n_d, input_shape, n_params) or None) – The derivative of the simulations with respect to the model parameters (for validation). Sets
validate = True
attribute if provided
-
get_summaries
(w, key=None, validate=False)¶ Gets all network outputs and derivatives wrt model parameters
Selects either the fitting or validation sets and passes them through the network to get the network outputs.
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=False) – Whether to get summaries of the validation set
- Returns
float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance
float(n_d, n_summaries, n_params) – The derivative of the network ouputs wrt the model parameters
-
get_derivatives:
Calculates the jacobian of the network output and its value
AggregatedGradientIMNN¶
-
class
imnn.
AggregatedGradientIMNN
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, fiducial, derivative, host, devices, n_per_device, validation_fiducial=None, validation_derivative=None, prefetch=None, cache=False)¶ Information maximising neural network fit using known derivatives
The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations \({\bf d}^i\) originally generated at fiducial model parameter \({\bf\theta}^\rm{fid}\), and their derivatives \(\partial{\bf d}^i/\partial\theta_\alpha\) with respect to model parameters are used. The fiducial simulations, \({\bf d}^i\), are passed through a network to obtain summaries, \({\bf x}^i\), and the jax automatic derivative of these summaries with respect to the inputs are calculated \(\partial{\bf x}^i\partial{\bf d}^j\delta_{ij}\). The chain rule is then used to calculate
\[\frac{\partial{\bf x}^i}{\partial\theta_\alpha} = \frac{\partial{\bf x}^i}{\partial{\bf d}^j} \frac{\partial{\bf d}^j}{\partial\theta_\alpha}\]With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance
\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]and the derivative of the mean of the network outputs with respect to the model parameters
\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]can be calculated and used form the Fisher information matrix
\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]The loss function is then defined as
\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as
\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale
\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]To enable the use of large data (or networks) the whole procedure is aggregated. This means that the passing of the simulations through the network is farmed out to the desired XLA devices, and recollected,
n_per_device
inputs at a time. These are then used to calculate the automatic gradient of the loss function with respect to the calculated summaries and derivatives, \(\partial\Lambda/\partial{\bf x}^i\) (which is a fairly small computation as long asn_summaries
andn_s
{andn_d
} are not huge). Once this is calculated, the simulations are passed through the network AGAIN this time calculating the Jacobian of the network output with respect to the network parameters \(\partial{\bf x}^i/\partial{\bf w}\) which is then combined via the chain rule to get\[\frac{\partial\Lambda}{\partial{\bf w}} = \frac{\partial\Lambda}{\partial{\bf x}^i} \frac{\partial{\bf x}^i}{\partial{\bf w}}\]This can then be passed to the optimiser.
- Parameters
fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)
derivative (float(n_d, input_shape, n_params)) – The derivative of the simulations with respect to the model parameters (for fitting)
validation_fiducial (float(n_s, input_shape) or None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation)
validation_derivative (float(n_d, input_shape, n_params) or None) – The derivative of the simulations with respect to the model parameters (for validation)
main (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for fitting). These are served
n_per_device
at a time as a numpy iterator from a TensorFlow dataset.remaining (list of tf.data.Dataset()as_numpy_iterators()) – The
n_s - n_d
simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs with a derivative counterpart (for fitting). These are servedn_per_device
at a time as a numpy iterator from a TensorFlow dataset.validation_main (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for validation). These are served
n_per_device
at a time as a numpy iterator from a TensorFlow dataset.validation_remaining (list of tf.data.Dataset()as_numpy_iterators()) – The
n_s - n_d
simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs with a derivative counterpart (for validation). Servedn_per_device
at time as a numpy iterator from a TensorFlow dataset.n_remaining (int) – The number simulations where only the fiducial simulations are calculated. This is zero if
n_s
is equal ton_d
.n_iterations (int) – Number of iterations through the main summarising loop
n_remaining_iterations (int) – Number of iterations through the remaining simulations used for quick loops with no derivatives
batch_shape (tuple) – The shape which
n_d
should be reshaped to for aggregating.n_d // (n_devices * n_per_device), n_devices, n_per_device
remaining_batch_shape (tuple) – The shape which
n_s - n_d
should be reshaped to for aggregating.(n_s - n_d) // (n_devices * n_per_device), n_devices, n_per_device
Public Methods:
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
get_summary
(input, w, θ[, derivative, gradient])Returns a single summary of a simulation or its gradient
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
get_gradient
(dΛ_dx, w[, key])Aggregates gradients together to update the network parameters
Inherited from
_AggregatedIMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
get_gradient
(dΛ_dx, w[, key])Aggregates gradients together to update the network parameters
Inherited from
GradientIMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
Inherited from
_IMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_α
(λ, ε)Calculate rate parameter for regularisation from closeness criterion
set_F_statistics
([w, key, validate])Set necessary attributes for calculating score compressed summaries
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
get_estimate
(d)Calculate score compressed parameter estimates from network outputs
plot
([ax, expected_detF, colour, figsize, …])Plot fitting history
Private Methods:
Calculates the shapes for batching over different devices
_set_dataset
(prefetch, cache)Collates data into loopable tensorflow dataset iterations
_collect_input
(key[, validate])Returns validation or fitting sets
_split_dΛ_dx
(dΛ_dx)Returns the gradient of loss function wrt summaries (derivatives)
Inherited from
_AggregatedIMNN
_set_devices
(devices, n_per_device)Checks that devices exist and that reshaping onto devices can occur
_set_batch_functions
()Creates jitted functions placed on desired XLA devices
Calculates the shapes for batching over different devices
_setup_progress_bar
(print_rate, max_iterations)Construct progress bar
_update_progress_bar
(pbar, counter, …[, close])Updates (and closes) progress bar
_collect_input
(key[, validate])Returns validation or fitting sets
_get_batch_summaries
(inputs, w, θ[, …])Vectorised batch calculation of summaries or gradients
_split_dΛ_dx
(dΛ_dx)Returns the gradient of loss function wrt summaries (derivatives)
_construct_gradient
(layers[, aux, func])Multiuse function to iterate over tuple of network parameters
Inherited from
GradientIMNN
_set_data
(fiducial, derivative, …)Checks and sets data attributes with the correct shape
Inherited from
_IMNN
_initialise_parameters
(n_s, n_d, n_params, …)Performs type checking and initialisation of class attributes
_initialise_model
(model, optimiser, key_or_state)Initialises neural network parameters or loads optimiser state
_initialise_history
()Initialises history dictionary attribute
_set_history
(results)Places results from fitting into the history dictionary
_set_inputs
(rng, max_iterations)Builds list of inputs for the XLA compilable fitting routine
_get_fitting_keys
(rng)Generates random numbers for simulation generation if needed
_fit
(inputs, λ=None, α=None[, min_iterations])Single iteration fitting algorithm
_fit_cond
(inputs, patience, max_iterations)Stopping condition for the fitting loop
_update_loop_vars
(inputs)Updates input parameters if
max_detF
is increased_check_loop_vars
(inputs, min_iterations)Updates
patience_counter
ifmax_detF
not increased_update_history
(inputs, history, counter, ind)Puts current fitting statistics into history arrays
_slogdet
(matrix)Combined summed logarithmic determinant
_construct_derivatives
(derivatives)Builds derivatives of the network outputs wrt model parameters
_get_F_statistics
([w, key, validate])Calculates the Fisher information and returns all statistics used
_calculate_F_statistics
(summaries, derivatives)Calculates the Fisher information matrix from network outputs
_get_regularisation_strength
(Λ2, λ, α)Coupling strength of the regularisation (amplified sigmoid)
_get_regularisation
(C, invC)Difference of the covariance (and its inverse) from identity
_get_loss
(w, λ, α[, key])Calculates the loss function and returns auxillary variables
_calculate_loss
(summaries, derivatives, λ, α)Calculates the loss function from network summaries and derivatives
_setup_plot
([ax, expected_detF, figsize])Builds axes for history plot
-
_collect_input
(key, validate=False)¶ Returns validation or fitting sets
- Parameters
key (None or int(2,)) – Random number generators not used in this case
validate (bool) – Whether to return the set for validation or for fitting
- Returns
list of tf.data.Dataset().as_numpy_iterators – The iterators for the main loop including simulations and their derivatives for fitting or validation
list of tf.data.Dataset().as_numpy_iterators – The iterators for the remaining loop simulations for fitting or validation
-
_set_dataset
(prefetch, cache)¶ Collates data into loopable tensorflow dataset iterations
- Parameters
prefetch (tf.data.AUTOTUNE or int or None) – How many simulation to prefetch in the tensorflow dataset
cache (bool) – Whether to cache simulations in the tensorflow datasets
- Raises
ValueError – If cache is None
TypeError – If cache is wrong type
-
_set_shapes
()¶ Calculates the shapes for batching over different devices
- Raises
ValueError – If the difference between n_s and n_d won’t scale over xla devices
-
_split_dΛ_dx
(dΛ_dx)¶ Returns the gradient of loss function wrt summaries (derivatives)
The gradient of loss function with respect to network outputs and their derivatives with respect to model parameters has to be reshaped and aggregated onto each XLA device matching the format that the tensorflow dataset feeds simulations.
- Parameters
dΛ_dx (tuple) –
dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries
d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters
-
get_gradient
(dΛ_dx, w, key=None)¶ Aggregates gradients together to update the network parameters
To avoid having to calculate the gradient with respect to all the simulations at once we aggregate by addition the gradient calculation by looping over the simulations again and combining them with the derivative of the loss function with respect to the network outputs (and their derivatives with respect to the model parameters). Whilst this is expensive, it is necessary since we cannot make a stochastic estimate of the Fisher information accurately and therefore we need to use all the simulations available - which is probably too large to fit in memory.
- Parameters
dΛ_dx (tuple) –
dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries
d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters
w (list) – Network parameters
key (None or int(2,)) – Random number generator used in SimulatorIMNN
- Returns
The gradient of the loss function with respect to the network parameters calculated by aggregating
- Return type
list
-
get_summaries
(w, key=None, validate=False)¶ Gets all network outputs and derivatives wrt model parameters
Selects either the fitting or validation sets and loops through the iterator on each XLA device to pass them through the network to get the network outputs. These are then pushed back to the host for the computation of the loss function.
The fiducial simulations which have a derivative with respect to the model parameters counterpart are processed first and then the remaining fiducial simulations are processed and concatenated.
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=False) – Whether to get summaries of the validation set
- Returns
float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance
float(n_d, n_summaries, n_params) – The set of all the derivatives of the network outputs with respect to model parameters
-
get_summary
(input, w, θ, derivative=False, gradient=False)¶ Returns a single summary of a simulation or its gradient
- Parameters
input (float(input_shape) or tuple) –
A single simulation to pass through the network or a tuple of either (if
gradient
andnot derivative
)dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary
d float(input_shape) – a simulation to compress with the network
or (if
gradient
andderivative
)tuple (gradients)
dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary
d2Λ_dxdθ float(input_shape, n_params) – the derivative of the loss function with respect to the derivative of a network summary with respect to model parameters
tuple (simulations)
d float(input_shape) – a simulation to compress with the network
dd_dθ float(input_shape, n_params) – the derivative of a simulation with respect to model parameters
or (if
derivative
andnot gradient
)d float(input_shape) – a simulation to compress with the network
dd_dθ float(input_shape, n_params) – the derivative of a simulation with respect to model parameters
w (list) – The network parameters
θ (float(n_params,)) – The value of the parameters to generate the simulation at (fiducial), unused if not simulating on the fly
derivative (bool, default=False) – Whether a derivative of the simulation with respect to model parameters is also passed.
gradient (bool, default=False) – Whether to calculate the gradient with respect to model parameters
- Returns
tuple (if
gradient
) – The gradient of the loss function with respect to model parametersfloat(n_summaries,) (if
not gradient
) – The output of the networkfloat(n_summaries, n_params) (if
derivative
andnot gradient
) – The derivative of the output of the network wrt model parameters
NumericalGradientIMNN¶
NumericalGradientIMNN¶
-
class
imnn.
NumericalGradientIMNN
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, fiducial, derivative, δθ, validation_fiducial=None, validation_derivative=None)¶ Information maximising neural network fit using numerical derivatives
The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations \({\bf d}^i\) originally generated at fiducial model parameter \({\bf\theta}^\rm{fid}\), and a set of \(i\in[1, n_d]\) simulations, \(\{{\bf d}_{\alpha^-}^i, {\bf d}_{\alpha^+}^i\}\), generated with the same seed at each \(i\) generated at \({\bf\theta}^\rm{fid}\) apart from at parameter label \(\alpha\) with values
\[\theta_{\alpha^-} = \theta_\alpha^\rm{fid}-\delta\theta_\alpha\]and
\[\theta_{\alpha^+} = \theta_\alpha^\rm{fid}+\delta\theta_\alpha\]where \(\delta\theta_\alpha\) is a \(n_{params}\) length vector with the \(\alpha\) element having a value which perturbs the parameter \(\theta^{\rm fid}_\alpha\). This means there are \(2\times n_{params}\times n_d\) simulations used to calculate the numerical derivatives (this is extremely cheap compared to other machine learning methods). All these simulations are passed through a network \(f_{{\bf w}}({\bf d})\) with network parameters \({\bf w}\) to obtain network outputs \({\bf x}^i\) and \(\{{\bf x}_{\alpha^-}^i,{\bf x}_{\alpha^+}^i\}\). These perturbed values are combined to obtain
\[\frac{\partial{{\bf x}^i}}{\partial\theta_\alpha} = \frac{{\bf x}_{\alpha^+}^i - {\bf x}_{\alpha^-}^i} {\delta\theta_\alpha}\]With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance
\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]and the derivative of the mean of the network outputs with respect to the model parameters
\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]can be calculated and used form the Fisher information matrix
\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]The loss function is then defined as
\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as
\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale
\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]Once the loss function is calculated the automatic gradient is then calculated and used to update the network parameters via the optimiser function.
- Parameters
δθ (float(n_params,)) – Size of perturbation to model parameters for the numerical derivative
fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)
derivative (float(n_d, 2, n_params, input_shape)) – The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for fitting)
validation_fiducial (float(n_s, input_shape) or None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation)
validation_derivative (float(n_d, 2, n_params, input_shape) or None) – The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for validation)
Public Methods:
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
Inherited from
_IMNN
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_α
(λ, ε)Calculate rate parameter for regularisation from closeness criterion
set_F_statistics
([w, key, validate])Set necessary attributes for calculating score compressed summaries
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
get_estimate
(d)Calculate score compressed parameter estimates from network outputs
plot
([ax, expected_detF, colour, figsize, …])Plot fitting history
Private Methods:
_set_data
(δθ, fiducial, derivative, …)Checks and sets data attributes with the correct shape
_collect_input
(key[, validate])Returns validation or fitting sets
_construct_derivatives
(x_mp)Builds derivatives of the network outputs wrt model parameters
Inherited from
_IMNN
_initialise_parameters
(n_s, n_d, n_params, …)Performs type checking and initialisation of class attributes
_initialise_model
(model, optimiser, key_or_state)Initialises neural network parameters or loads optimiser state
_initialise_history
()Initialises history dictionary attribute
_set_history
(results)Places results from fitting into the history dictionary
_set_inputs
(rng, max_iterations)Builds list of inputs for the XLA compilable fitting routine
_get_fitting_keys
(rng)Generates random numbers for simulation generation if needed
_fit
(inputs, λ=None, α=None[, min_iterations])Single iteration fitting algorithm
_fit_cond
(inputs, patience, max_iterations)Stopping condition for the fitting loop
_update_loop_vars
(inputs)Updates input parameters if
max_detF
is increased_check_loop_vars
(inputs, min_iterations)Updates
patience_counter
ifmax_detF
not increased_update_history
(inputs, history, counter, ind)Puts current fitting statistics into history arrays
_slogdet
(matrix)Combined summed logarithmic determinant
_construct_derivatives
(x_mp)Builds derivatives of the network outputs wrt model parameters
_get_F_statistics
([w, key, validate])Calculates the Fisher information and returns all statistics used
_calculate_F_statistics
(summaries, derivatives)Calculates the Fisher information matrix from network outputs
_get_regularisation_strength
(Λ2, λ, α)Coupling strength of the regularisation (amplified sigmoid)
_get_regularisation
(C, invC)Difference of the covariance (and its inverse) from identity
_get_loss
(w, λ, α[, key])Calculates the loss function and returns auxillary variables
_calculate_loss
(summaries, derivatives, λ, α)Calculates the loss function from network summaries and derivatives
_setup_plot
([ax, expected_detF, figsize])Builds axes for history plot
-
_collect_input
(key, validate=False)¶ Returns validation or fitting sets
- Parameters
key (None or int(2,)) – Random number generators not used in this case
validate (bool) – Whether to return the set for validation or for fitting
- Returns
float(n_s, input_shape) – The fiducial simulations for fitting or validation
float(n_d, 2, n_params, input_shape) – The derivative simulations for fitting or validation
-
_construct_derivatives
(x_mp)¶ Builds derivatives of the network outputs wrt model parameters
The network outputs from the simulations generated with model parameter values above and below the fiducial are subtracted from each other and divided by the perturbation size in each model parameter value. The axes are swapped such that the derivatives with respect to parameters are in the last axis.
\[\frac{\partial{\bf x}^i}{\partial\theta_\alpha} = \frac{{\bf x}^i_{\alpha^+}-{\bf x}^i_{\alpha^+}}{ \delta\theta_\alpha}\]- Parameters
derivatives (float(n_d, 2, n_params, n_summaries)) – The outputs of the network of simulations made at perturbed parameter values to construct the derivative of the network outputs with respect to the model parameters numerically
- Returns
The numerical derivatives of the network ouputs with respect to the model parameters
- Return type
float(n_d, n_summaries, n_params)
-
_set_data
(δθ, fiducial, derivative, validation_fiducial, validation_derivative)¶ Checks and sets data attributes with the correct shape
- Parameters
δθ (float(n_params,)) – Size of perturbation to model parameters for the numerical derivative
fiducial (float(n_s, input_shape)) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)
derivative (float(n_d, input_shape, n_params)) – The derivative of the simulations with respect to the model parameters (for fitting)
validation_fiducial (float(n_s, input_shape) or None, default=None) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation). Sets
validate = True
attribute if providedvalidation_derivative (float(n_d, input_shape, n_params) or None) – The derivative of the simulations with respect to the model parameters (for validation). Sets
validate = True
attribute if provided
-
get_summaries
(w, key=None, validate=False)¶ Gets all network outputs and derivatives wrt model parameters
Selects either the fitting or validation sets and passes them through the network to get the network outputs. For the numerical derivatives, the array is first flattened along the batch axis before being passed through the model.
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=False) – Whether to get summaries of the validation set
- Returns
float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance
float(n_d, 2, n_params, n_summaries) – The outputs of the network of simulations made at perturbed parameter values to construct the derivative of the network outputs with respect to the model parameters numerically
AggregatedNumericalGradientIMNN¶
-
class
imnn.
AggregatedNumericalGradientIMNN
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, fiducial, derivative, δθ, host, devices, n_per_device, validation_fiducial=None, validation_derivative=None, prefetch=None, cache=False)¶ Information maximising neural network fit using numerical derivatives
The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations \({\bf d}^i\) originally generated at fiducial model parameter \({\bf\theta}^\rm{fid}\), and a set of \(i\in[1, n_d]\) simulations, \(\{{\bf d}_{\alpha^-}^i, {\bf d}_{\alpha^+}^i\}\), generated with the same seed at each \(i\) generated at \({\bf\theta}^\rm{fid}\) apart from at parameter label \(\alpha\) with values
\[\theta_{\alpha^-} = \theta_\alpha^\rm{fid}-\delta\theta_\alpha\]and
\[\theta_{\alpha^+} = \theta_\alpha^\rm{fid}+\delta\theta_\alpha\]where \(\delta\theta_\alpha\) is a \(n_{params}\) length vector with the \(\alpha\) element having a value which perturbs the parameter \(\theta^{\rm fid}_\alpha\). This means there are \(2\times n_{params}\times n_d\) simulations used to calculate the numerical derivatives (this is extremely cheap compared to other machine learning methods). All these simulations are passed through a network \(f_{{\bf w}}({\bf d})\) with network parameters \({\bf w}\) to obtain network outputs \({\bf x}^i\) and \(\{{\bf x}_{\alpha^-}^i,{\bf x}_{\alpha^+}^i\}\). These perturbed values are combined to obtain
\[\frac{\partial{{\bf x}^i}}{\partial\theta_\alpha} = \frac{{\bf x}_{\alpha^+}^i - {\bf x}_{\alpha^-}^i} {\delta\theta_\alpha}\]With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance
\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]and the derivative of the mean of the network outputs with respect to the model parameters
\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]can be calculated and used form the Fisher information matrix
\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]The loss function is then defined as
\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as
\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale
\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]To enable the use of large data (or networks) the whole procedure is aggregated. This means that the passing of the simulations through the network is farmed out to the desired XLA devices, and recollected,
n_per_device
inputs at a time. These are then used to calculate the automatic gradient of the loss function with respect to the calculated summaries and derivatives, \(\partial\Lambda/\partial{\bf x}^i\) (which is a fairly small computation as long asn_summaries
andn_s
{andn_d
} are not huge). Once this is calculated, the simulations are passed through the network AGAIN this time calculating the Jacobian of the network output with respect to the network parameters \(\partial{\bf x}^i/\partial{\bf w}\) which is then combined via the chain rule to get\[\frac{\partial\Lambda}{\partial{\bf w}} = \frac{\partial\Lambda}{\partial{\bf x}^i} \frac{\partial{\bf x}^i}{\partial{\bf w}}\]This can then be passed to the optimiser.
- Parameters
δθ (float(n_params,)) – Size of perturbation to model parameters for the numerical derivative
fiducial (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting). These are served
n_per_device
at a time as a numpy iterator from a TensorFlow dataset.derivative (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for fitting). These are served
n_per_device
at a time as a numpy iterator from a TensorFlow dataset.validation_fiducial (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation). These are served
n_per_device
at a time as a numpy iterator from a TensorFlow dataset.validation_derivative (list of tf.data.Dataset()as_numpy_iterators()) – The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for validation). These are served
n_per_device
at a time as a numpy iterator from a TensorFlow dataset.fiducial_iterations (int) – The number of iterations over the fiducial dataset
derivative_iterations (int) – The number of iterations over the derivative dataset
derivative_output_shape (tuple) – The shape of the output of the derivatives from the network
fiducial_batch_shape (tuple) – The shape of each batch of fiducial simulations (without input or summary shape)
derivative_batch_shape (tuple) – The shape of each batch of derivative simulations (without input or summary shape)
-
_collect_input
(key, validate=False)¶ Returns validation or fitting sets
- Parameters
key (None or int(2,)) – Random number generators not used in this case
validate (bool) – Whether to return the set for validation or for fitting
- Returns
list of tf.data.Dataset().as_numpy_iterators – The iterators for fiducial simulations for fitting or validation
list of tf.data.Dataset().as_numpy_iterators – The iterators for derivative simulations for fitting or validation
-
_set_batch_functions
()¶ Creates jitted functions placed on desired XLA devices
For each set of summaries to correctly be calculated on a particular device we predefine the jitted functions on each of these devices
-
_set_dataset
(prefetch, cache)¶ Transforms the data into lists of tensorflow dataset iterators
- Parameters
prefetch (tf.data.AUTOTUNE or int or None) – How many simulation to prefetch in the tensorflow dataset
cache (bool) – Whether to cache simulations in the tensorflow datasets
- Raises
ValueError – If
cache
and/orprefetch
is NoneTypeError – If
cache
and/orprefetch
is wrong type
-
_set_shapes
()¶ Calculates the shapes for batching over different devices
-
_split_dΛ_dx
(dΛ_dx)¶ Returns the gradient of loss function wrt summaries (derivatives)
The gradient of loss function with respect to network outputs and their derivatives with respect to model parameters has to be reshaped and aggregated onto each XLA device matching the format that the tensorflow dataset feeds simulations.
- Parameters
dΛ_dx (tuple) –
dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries
d2Λ_dxdθ float(n_d, 2, n_params, n_summaries) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters
- Returns
list – a list of sets of derivatives of the loss function with respect to network summaries placed on each XLA device
list – a list of sets of derivatives of the loss function with respect to the derivative of network summaries with respect to model parameters
-
get_gradient
(dΛ_dx, w, key=None)¶ Aggregates gradients together to update the network parameters
To avoid having to calculate the gradient with respect to all the simulations at once we aggregated by addition the gradient calculation by looping over the simulations again and combining them with the derivative of the loss function with respect to the network outputs (and their derivatives with respect to the model parameters). Whilst this is expensive, it is necessary since we cannot make a stochastic estimate of the Fisher information accurately and therefore we need to use all the simulations available - which is probably too large to fit in memory.
- Parameters
dΛ_dx (tuple) –
dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries
d2Λ_dxdθ float(n_d, 2, n_params, n_summaries) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters
w (list) – Network parameters
key (None or int(2,)) – Random number generator used in SimulatorIMNN
- Returns
The gradient of the loss function with respect to the network parameters calculated by aggregating
- Return type
list
-
get_summaries
(w, key=None, validate=False)¶ Gets all network outputs and derivatives wrt model parameters
Selects either the fitting or validation sets and loops through the iterator on each XLA device to pass them through the network to get the network outputs. These are then pushed back to the host for the computation of the loss function.
The fiducial simulations are processed first and then the simulations which are varied with respect to model parameters for the derivatives.
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=False) – Whether to get summaries of the validation set
- Returns
float(n_s, n_summaries) – The set of all network outputs used to calculate the covariance
float(n_d, 2, n_params, n_summaries) – The outputs of the network of simulations made at perturbed parameter values to construct the derivative of the network outputs with respect to the model parameters numerically
-
get_summary
(inputs, w, θ, derivative=False, gradient=False)¶ Returns a single summary of a simulation or its gradient
- Parameters
inputs (float(input_shape) or tuple) –
- A single simulation to pass through the network or a tuple of
dΛ_dx float(input_shape, n_params) – the derivative of the loss function with respect to a network summary
d float(input_shape) – a simulation to compress with the network
w (list) – The network parameters
θ (float(n_params,)) – The value of the parameters to generate the simulation at (fiducial), unused if not simulating on the fly
derivative (bool, default=False) – Whether a derivative of the simulation with respect to model parameters is also passed. This must be False for NumericalGradientIMNN
gradient (bool, default=False) – Whether to calculate the gradient with respect to model parameters