Parent classes¶
There are many different cases which might be desirable for fitting an IMNN, i.e. generating simulations on-the-fly with SimulatorIMNN()
, or with a fixed set of analytic gradients with GradientIMNN()
or performing numerical gradients for the derivatives with NumericalGradientIMNN()
, or when the datasets are very large (either in number of elements - n_s
- or in shape - input_shape
) then the gradients can be manually aggregated using AggregatedSimulatorIMNN()
, imnn.AggregatedGradientIMNN()
or AggregatedNumericalGradientIMNN()
. These are all wrappers around a base class _IMNN()
and optionally an aggregation class _AggregatedIMNN()
. For completeness these are documented here.
The available modules are:
Base class¶
-
class
imnn.imnn._imnn.
_IMNN
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state)¶ Information maximising neural network parent class
This class defines the general fitting framework for information maximising neural networks. It includes the generic calculations of the Fisher information matrix from the outputs of a neural network as well as an XLA compilable fitting routine (with and without a progress bar). This class also provides a plotting routine for fitting history and a function to calculate the score compression of network outputs to quasi-maximum likelihood estimates of model parameter values.
The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations and \(n_d\) derivatives with respect to physical model parameters are used to calculate network outputs and their derivatives with respect to the physical model parameters, \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\), where \(\alpha\) labels the physical parameter. The exact details of how these are calculated depend on the type of available data (see list of different IMNN below). With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance
\[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]and the derivative of the mean of the network outputs with respect to the model parameters
\[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]can be calculated and used form the Fisher information matrix
\[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]The loss function is then defined as
\[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as
\[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale
\[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]Once the loss function is calculated the automatic gradient is then calculated and used to update the network parameters via the optimiser function. Note for large input data-sizes, large \(n_s\) or massive networks the gradients may need manually accumulating via the
_AggregatedIMNN()
._IMNN
is designed as the parent class for a range of specific case IMNNs. There is a helper function (IMNN) which should return the correct case when provided with the correct data. These different subclasses are:Fit an IMNN using simulations generated on-the-fly from a jax (XLA compilable) simulator
Fit an IMNN using a precalculated set of fiducial simulations and their derivatives with respect to model parameters
Fit an IMNN using a precalculated set of fiducial simulations and simulations generated using parameter values just above and below the fiducial parameter values to make a numerical estimate of the derivatives of the network outputs. Best stability is achieved when seeds of the simulations are matched between all parameter directions for the numerical derivative
SimulatorIMNN
distributed over multiple jax devices and gradients aggregated manually. This might be necessary for very large input sizes as batching cannot be done when calculating the Fisher information matrixGradientIMNN
distributed over multiple jax devices and gradients aggregated manually. This might be necessary for very large input sizes as batching cannot be done when calculating the Fisher information matrixAggregatedNumericalGradientIMNN()
:NumericalGradientIMNN
distributed over multiple jax devices and gradients aggregated manually. This might be necessary for very large input sizes as batching cannot be done when calculating the Fisher information matrixAggregatedGradientIMNN
with prebuilt TensorFlow datasetsDatasetNumericalGradientIMNN()
:AggregatedNumericalGradientIMNN
with prebuilt TensorFlow datasetsThere are currently two other parent classes
AggregatedIMNN()
:This is the parent class which provides the fitting routine when the gradients of the network parameters are aggregated manually rather than automatically by jax. This is necessary if the size of an entire batch of simulations (and their derivatives with respect to model parameters) and the network parameters and their calculated gradients is too large to fit into memory. Note there is a significant performance loss from using the aggregation so it should only be used for these large data cases
- Parameters
n_s (int) – Number of simulations used to calculate network output covariance
n_d (int) – Number of simulations used to calculate mean of network output derivative with respect to the model parameters
n_params (int) – Number of model parameters
n_summaries (int) – Number of summaries, i.e. outputs of the network
input_shape (tuple) – The shape of a single input to the network
θ_fid (float(n_params,)) – The value of the fiducial parameter values used to generate inputs
validate (bool) – Whether a validation set is being used
simulate (bool) – Whether input simulations are generated on the fly
_run_with_pbar (bool) – Book keeping parameter noting that a progress bar is used when fitting (induces a performance hit). If
run_with_pbar = True
andrun_without_pbar = True
then a jit compilation error will occur and so it is prevented_run_without_pbar (bool) – Book keeping parameter noting that a progress bar is not used when fitting. If
run_with_pbar = True
andrun_without_pbar = True
then a jit compilation error will occur and so it is preventedF (float(n_params, n_params)) – Fisher information matrix calculated from the network outputs
invF (float(n_params, n_params)) – Inverse Fisher information matrix calculated from the network outputs
C (float(n_summaries, n_summaries)) – Covariance of the network outputs
invC (float(n_summaries, n_summaries)) – Inverse covariance of the network outputs
μ (float(n_summaries,)) – Mean of the network outputs
dμ_dθ (float(n_summaries, n_params)) – Derivative of the mean of the network outputs with respect to model parameters
state (:obj:state) – The optimiser state used for updating the network parameters and optimisation algorithm
initial_w (list) – List of the network parameters values at initialisation (to restart)
final_w (list) – List of the network parameters values at the end of fitting
best_w (list) – List of the network parameters values which provide the maxmimum value of the determinant of the Fisher matrix
w (list) – List of the network parameters values (either final or best depending on setting when calling fit(…))
history (dict) –
- A dictionary containing the fitting history. Keys are
detF – determinant of the Fisher information at the end of each iteration
detC – determinant of the covariance of network outputs at the end of each iteration
detinvC – determinant of the inverse covariance of network outputs at the end of each iteration
Λ2 – value of the covariance regularisation at the end of each iteration
r – value of the regularisation coupling at the end of each iteration
val_detF – determinant of the Fisher information of the validation data at the end of each iteration
val_detC – determinant of the covariance of network outputs given the validation data at the end of each iteration
val_detinvC – determinant of the inverse covariance of network outputs given the validation data at the end of each iteration
val_Λ2 – value of the covariance regularisation given the validation data at the end of each iteration
val_r – value of the regularisation coupling given the validation data at the end of each iteration
max_detF – maximum value of the determinant of the Fisher information on the validation data (if available)
-
model:
Neural network as a function of network parameters and inputs
-
_get_parameters:
Function which extracts the network parameters from the state
-
_model_initialiser:
Function to initialise neural network weights from RNG and shape tuple
-
_opt_initialiser:
Function which generates the optimiser state from network parameters
-
_update:
Function which updates the state from a gradient
Public Methods:
__init__
(n_s, n_d, n_params, n_summaries, …)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_α
(λ, ε)Calculate rate parameter for regularisation from closeness criterion
set_F_statistics
([w, key, validate])Set necessary attributes for calculating score compressed summaries
get_summaries
([w, key, validate])Gets all network outputs and derivatives wrt model parameters
get_estimate
(d)Calculate score compressed parameter estimates from network outputs
plot
([ax, expected_detF, colour, figsize, …])Plot fitting history
Private Methods:
_initialise_parameters
(n_s, n_d, n_params, …)Performs type checking and initialisation of class attributes
_initialise_model
(model, optimiser, key_or_state)Initialises neural network parameters or loads optimiser state
Initialises history dictionary attribute
_set_history
(results)Places results from fitting into the history dictionary
_set_inputs
(rng, max_iterations)Builds list of inputs for the XLA compilable fitting routine
_get_fitting_keys
(rng)Generates random numbers for simulation generation if needed
_fit
(inputs, λ=None, α=None[, min_iterations])Single iteration fitting algorithm
_fit_cond
(inputs, patience, max_iterations)Stopping condition for the fitting loop
_update_loop_vars
(inputs)Updates input parameters if
max_detF
is increased_check_loop_vars
(inputs, min_iterations)Updates
patience_counter
ifmax_detF
not increased_update_history
(inputs, history, counter, ind)Puts current fitting statistics into history arrays
_slogdet
(matrix)Combined summed logarithmic determinant
_construct_derivatives
(derivatives)Builds derivatives of the network outputs wrt model parameters
_get_F_statistics
([w, key, validate])Calculates the Fisher information and returns all statistics used
_calculate_F_statistics
(summaries, derivatives)Calculates the Fisher information matrix from network outputs
_get_regularisation_strength
(Λ2, λ, α)Coupling strength of the regularisation (amplified sigmoid)
_get_regularisation
(C, invC)Difference of the covariance (and its inverse) from identity
_get_loss
(w, λ, α[, key])Calculates the loss function and returns auxillary variables
_calculate_loss
(summaries, derivatives, λ, α)Calculates the loss function from network summaries and derivatives
_setup_plot
([ax, expected_detF, figsize])Builds axes for history plot
-
_calculate_F_statistics
(summaries, derivatives)¶ Calculates the Fisher information matrix from network outputs
If the numerical derivative is being calculated then the derivatives are first constructed. If the mean is to be returned (for use in score compression), this is calculated and pushed to the results tuple. Then the covariance of the summaries is taken and inverted and the mean of the derivative of network summaries with respect to the model parameters is found and these are used to calculate the Gaussian form of the Fisher information matrix.
- Parameters
summaries (float(n_s, n_summaries)) – The network outputs
derivatives (float(n_d, n_summaries, n_params)) – The derivative of the network outputs wrt the model parameters. Note that when
NumericalGradientIMNN
is being used the shape isfloat(n_d, 2, n_params, n_summaries)
which is then constructed into the the numerical derivative in_construct_derivatives
.
- Returns
F (float(n_params, n_params)) – Fisher information matrix
C (float(n_summaries, n_summaries)) – Covariance of network outputs
invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs
dμ_dθ (float(n_summaries, n_params)) – The derivative of the mean of network outputs with respect to model parameters
μ (float(n_summaries)) – The mean of the network outputs
- Return type
tuple
-
_calculate_loss
(summaries, derivatives, λ, α)¶ Calculates the loss function from network summaries and derivatives
- Parameters
summaries (float(n_s, n_summaries)) – The network outputs
derivatives (float(n_d, n_summaries, n_params)) – The derivative of the network outputs wrt the model parameters. Note that when
NumericalGradientIMNN
is being used the shape isfloat(n_d, 2, n_params, n_summaries)
which is then constructed into the the numerical derivative in_construct_derivatives
.λ (float) – Coupling strength of the regularisation
α (float) – Calculate rate parameter for regularisation from ϵ criterion
- Returns
float – Value of the regularised loss function
tuple –
- Fitting statistics calculated on a single iteration
F (float(n_params, n_params)) – Fisher information matrix
C (float(n_summaries, n_summaries)) – Covariance of network outputs
invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs
Λ2 (float) – Covariance regularisation
r (float) – Regularisation coupling strength
-
_check_loop_vars
(inputs, min_iterations)¶ Updates
patience_counter
ifmax_detF
not increasedIf the determinant of the Fisher information matrix calculated in a given iteration is not larger than the
max_detF
calculated so far then thepatience_counter
is increased by one as long as the number of iterations is greater than the minimum number of iterations that should be run.- Parameters
inputs (tuple) –
patience_counter (int) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix
counter (int) – While loop iteration counter
detF (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the Fisher information matrix
max_detF (float) – Maximum value of the determinant of the Fisher information matrix calculated so far
w (list) – Value of the network parameters which in current iteration
best_w (list) – Value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix
min_iterations (int) – Number of iterations that should be run before considering early stopping using the patience counter
- Returns
(described in Parameters)
- Return type
tuple
-
_construct_derivatives
(derivatives)¶ Builds derivatives of the network outputs wrt model parameters
An empty directive in
_IMNN
,SimulatorIMNN
andGradientIMNN
but necessary to construct correct shaped derivatives when usingNumericalGradientIMNN
.- Parameters
derivatives (float(n_d, n_summaries, n_params)) – The derivatives of the network ouputs with respect to the model parameters
- Returns
The derivatives of the network ouputs with respect to the model parameters
- Return type
float(n_d, n_summaries, n_params)
-
_fit
(inputs, λ=None, α=None, min_iterations=None)¶ Single iteration fitting algorithm
This function performs the network parameter updates first getting any necessary random number generators for simulators and then extracting the network parameters from the state. These parameters are used to calculate the gradient with respect to the network parameters of the loss function (see _IMNN class docstrings). Once the loss function is calculated the gradient is then used to update the network parameters via the optimiser function and the current iterations statistics are saved to the history arrays. If validation is used (recommended for
GradientIMNN
andNumericalGradientIMNN
) then all necessary statistics to calculate the loss function are calculated and pushed to the history arrays.The
patience_counter
is increased if the value of determinant of the Fisher information matrix does not increase over the previous iterations uptopatience
number of iterations at which point early stopping occurs, but only if the number of iterations so far performed is greater than a specifiedmin_iterations
.- Parameters
inputs (tuple) –
max_detF (float) – Maximum value of the determinant of the Fisher information matrix calculated so far
best_w (list) – Value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix
detF (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the Fisher information matrix
detC (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the covariance of network outputs
detinvC (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the inverse covariance of network outputs
Λ2 (float(max_iterations, 1) or float(max_iterations, 2)) – History of the covariance regularisation
r (float(max_iterations, 1) or float(max_iterations, 2)) – History of the regularisation coupling strength
counter (int) – While loop iteration counter
patience_counter (int) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix
state (:obj: state) – Optimiser state used for updating the network parameters and optimisation algorithm
rng (int(2,)) – Stateless random number generator
λ (float) – Coupling strength of the regularisation
α (float) – Rate parameter for regularisation coupling
min_iterations (int) – Number of iterations that should be run before considering early stopping using the patience counter
- Returns
loop variables (described in Parameters)
- Return type
tuple
-
_fit_cond
(inputs, patience, max_iterations)¶ Stopping condition for the fitting loop
The stopping conditions due to reaching
max_iterations
or the patience counter reachingpatience
due topatience_counter
number of iterations without increasing the determinant of the Fisher information matrix.- Parameters
inputs (tuple) –
max_detF (float) – Maximum value of the determinant of the Fisher information matrix calculated so far
best_w (list) – Value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix
detF (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the Fisher information matrix
detC (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the covariance of network outputs
detinvC (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the inverse covariance of network outputs
Λ2 (float(max_iterations, 1) or float(max_iterations, 2)) – History of the covariance regularisation
r (float(max_iterations, 1) or float(max_iterations, 2)) – History of the regularisation coupling strength
counter (int) – While loop iteration counter
patience_counter (int) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix
state (:obj: state) – Optimiser state used for updating the network parameters and optimisation algorithm
rng (int(2,)) – Stateless random number generator
patience (int) – Number of iterations to stop the fitting when there is no increase in the value of the determinant of the Fisher information matrix
max_iterations (int) –
number of iterations to run the fitting procedure for (Maximum) –
- Returns
True if either the
patience_counter
has not reached thepatience
criterion or if thecounter
has not reachedmax_iterations
- Return type
bool
-
_get_F_statistics
(w=None, key=None, validate=False)¶ Calculates the Fisher information and returns all statistics used
First gets the summaries and derivatives and then uses them to calculate the Fisher information matrix from the outputs and return all the necessary constituents to calculate the Fisher information (which) are needed for the score compression or the regularisation of the loss function.
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=True) – Whether to calculate Fisher information using the validation set
- Returns
F (float(n_params, n_params)) – Fisher information matrix
C (float(n_summaries, n_summaries)) – Covariance of network outputs
invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs
dμ_dθ (float(n_summaries, n_params)) – The derivative of the mean of network outputs with respect to model parameters
μ (float(n_summaries,)) – The mean of the network outputs
- Return type
tuple
-
_get_fitting_keys
(rng)¶ Generates random numbers for simulation generation if needed
- Parameters
rng (int(2,) or None) – A random number generator
- Returns
A new random number generator and random number generators for training and validation, or empty values
- Return type
int(2,), int(2,), int(2,) or None, None, None
-
_get_loss
(w, λ, α, key=None)¶ Calculates the loss function and returns auxillary variables
First gets the summaries and derivatives and then uses them to calculate the loss function. This function is separated to be able to use
jax.grad
directly rather than calculating the derivative of the summaries as is done with_AggregatedIMNN
.- Parameters
w (list) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
λ (float) – Coupling strength of the regularisation
α (float) – Calculate rate parameter for regularisation from ϵ criterion
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
- Returns
float – Value of the regularised loss function
tuple –
- Fitting statistics calculated on a single iteration
F (float(n_params, n_params)) – Fisher information matrix
C (float(n_summaries, n_summaries)) – Covariance of network outputs
invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs
Λ2 (float) – Covariance regularisation
r (float) – Regularisation coupling strength
-
_get_regularisation
(C, invC)¶ Difference of the covariance (and its inverse) from identity
The negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs since any linear rescaling of a sufficient statistic is also a sufficient statistic. We choose to fix this scale by constraining the covariance of network outputs as
\[\Lambda_2 = ||\bf{C}-\bf{I}|| + ||\bf{C}^{-1}-\bf{I}||\]One benefit of choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information.
- Parameters
C (float(n_summaries, n_summaries)) – Covariance of the network ouputs
invC (float(n_summaries, n_summaries)) – Inverse covariance of the network ouputs
- Returns
Regularisation loss terms for the distance of the covariance and its determinant from the identity matrix
- Return type
float
-
_get_regularisation_strength
(Λ2, λ, α)¶ Coupling strength of the regularisation (amplified sigmoid)
To dynamically turn off the regularisation when the scale of the covariance is set to approximately the identity matrix, a smooth sigmoid conditional on the value of the regularisation is used. The rate, α, is calculated from a closeness condition of the covariance (and the inverse covariance) to the identity matrix using
get_α
.- Parameters
Λ2 (float) – Covariance regularisation
λ (float) – Coupling strength of the regularisation
α (float) – Calculate rate parameter for regularisation from ϵ criterion
- Returns
Smooth, dynamic regularisation strength
- Return type
float
-
_initialise_history
()¶ Initialises history dictionary attribute
Notes
- The contents of the history dictionary are
detF – determinant of the Fisher information at the end of each iteration
detC – determinant of the covariance of network outputs at the end of each iteration
detinvC – determinant of the inverse covariance of network outputs at the end of each iteration
Λ2 – value of the covariance regularisation at the end of each iteration
r – value of the regularisation coupling at the end of each iteration
val_detF – determinant of the Fisher information of the validation data at the end of each iteration
val_detC – determinant of the covariance of network outputs given the validation data at the end of each iteration
val_detinvC – determinant of the inverse covariance of network outputs given the validation data at the end of each iteration
val_Λ2 – value of the covariance regularisation given the validation data at the end of each iteration
val_r – value of the regularisation coupling given the validation data at the end of each iteration
max_detF – maximum value of the determinant of the Fisher information on the validation data (if available)
-
_initialise_model
(model, optimiser, key_or_state)¶ Initialises neural network parameters or loads optimiser state
- Parameters
model (tuple, len=2) – Tuple containing functions to initialise neural network
fn(rng: int(2), input_shape: tuple) -> tuple, list
and the neural network as a function of network parameters and inputsfn(w: list, d: float([None], input_shape)) -> float([None], n_summaries)
. (Essentibly stax-like, see jax.experimental.stax ))optimiser (tuple or obj, len=3) – Tuple containing functions to generate the optimiser state
fn(x0: list) -> :obj:state
, to update the state from a list of gradientsfn(i: int, g: list, state: :obj:state) -> :obj:state
and to extract network parameters from the statefn(state: :obj:state) -> list
. (See jax.experimental.optimizers)key_or_state (int(2) or :obj:state) – Either a stateless random number generator or the state object of an preinitialised optimiser
Notes
The design of the model follows jax’s stax module in that the model is encapsulated by two functions, one to initialise the network and one to call the model, i.e.:
import jax from jax.experimental import stax rng = jax.random.PRNGKey(0) data_key, model_key = jax.random.split(rng) input_shape = (10,) inputs = jax.random.normal(data_key, shape=input_shape) model = stax.serial( stax.Dense(10), stax.LeakyRelu, stax.Dense(10), stax.LeakyRelu, stax.Dense(2)) output_shape, initial_params = model[0](model_key, input_shape) outputs = model[1](initial_params, inputs)
Note that the model used in the IMNN is assumed to be totally broadcastable, i.e. any batch shape can be used for inputs. This might require having a layer which reshapes all batch dimensions into a single dimension and then unwraps it at the last layer. A model such as that above is already fully broadcastable.
The optimiser should follow jax’s experimental optimiser module in that the optimiser is encapsulated by three functions, one to initialise the state, one to update the state from a list of gradients and one to extract the network parameters from the state, .i.e
from jax.experimental import optimizers import jax.numpy as np optimiser = optimizers.adam(step_size=1e-3) initial_state = optimiser[0](initial_params) params = optimiser[2](initial_state) def scalar_output(params, inputs): return np.sum(model[1](params, inputs)) counter = 0 grad = jax.grad(scalar_output, argnums=0)(params, inputs) state = optimiser[1](counter, grad, state)
This function either initialises the neural network or the state if passed a stateless random number generator in
key_or_state
or loads a predefined state if the state is passed tokey_or_state
. The functions get mapped to the class functionsself.model = model[1] self._model_initialiser = model[0] self._opt_initialiser = optimiser[0] self._update = optimiser[1] self._get_parameters = optimiser[2]
The state is made into the
state
class attribute and the parameters are assigned toinitial_w
,final_w
,best_w
andw
class attributes (wherew
stands for weights).There is some type checking done, but for freedom of choice of model there will be very few raised warnings.
- Raises
TypeError – If the random number generator is not correct, or if there is no possible way to construct a model or an optimiser from the passed parameters
ValueError – If any input is
None
or if the functions for the model or optimiser do not conform to the necessary specifications
-
_initialise_parameters
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid)¶ Performs type checking and initialisation of class attributes
- Parameters
n_s (int) – Number of simulations used to calculate summary covariance
n_d (int) – Number of simulations used to calculate mean of summary derivative
n_params (int) – Number of model parameters
n_summaries (int) – Number of summaries, i.e. outputs of the network
input_shape (tuple) – The shape of a single input to the network
θ_fid (float(n_params,)) – The value of the fiducial parameter values used to generate inputs
- Raises
TypeError – Any of the parameters are not correct type
ValueError – Any of the parameters are
None
Θ_fid
has the wrong shape
-
_set_history
(results)¶ Places results from fitting into the history dictionary
- Parameters
results (list) –
- List of results from fitting procedure. These are:
detF (float(n_iterations, 2)) – determinant of the Fisher information,
detF[:, 0]
for training anddetF[:, 1]
for validationdetC (float(n_iterations, 2)) – determinant of the covariance of network outputs,
detC[:, 0]
for training anddetC[:, 1]
for validationdetinvC (float(n_iterations, 2)) – determinant of the inverse covariance of network outputs,
detinvC[:, 0]
for training anddetinvC[:, 1]
for validationΛ2 (float(n_iterations, 2)) – value of the covariance regularisation,
Λ2[:, 0]
for training andΛ2[:, 1]
for validationr (float(n_iterations, 2)) – value of the regularisation coupling,
r[:, 0]
for training andr[:, 1]
for validation
-
_set_inputs
(rng, max_iterations)¶ Builds list of inputs for the XLA compilable fitting routine
- Parameters
rng (int(2,) or None) – A stateless random number generator
max_iterations (int) – Maximum number of iterations to run the fitting procedure for
Notes
- The list of inputs to the routine are
max_detF (float) – The maximum value of the determinant of the Fisher information matrix calculated so far. This is zero if not run before or the value from previous calls to
fit
best_w (list) – The value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix. This is the initial network parameter values if not run before
detF (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the determinant of the Fisher information matrix during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of
(max_iterations, 1)
, otherwise validation values are stored indetF[:, 1]
.detC (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the determinant of the covariance of network outputs during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of
(max_iterations, 1)
, otherwise validation values are stored indetC[:, 1]
.detF (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the determinant of the inverse covariance of network outputs during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of
(max_iterations, 1)
, otherwise validation values are stored indetinvC[:, 1]
.Λ2 (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the covariance regularisation during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of
(max_iterations, 1)
, otherwise validation values are stored inΛ2[:, 1]
.r (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the regularisation coupling strength during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of
(max_iterations, 1),
otherwise validation values are stored inr[:, 1]
.counter (int) – Iteration counter used to note whether the while loop reaches
max_iterations
. If not, the history objects (above) get truncated to lengthcounter
. This starts with value zeropatience_counter (int) – Counts the number of iterations where there is no increase in the value of the determinant of the Fisher information matrix, used for early stopping. This starts with value zero
state (:obj:state) – The current optimiser state used for updating the network parameters and optimisation algorithm
rng (int(2,)) – A stateless random number generator which gets updated on each iteration
-
_setup_plot
(ax=None, expected_detF=None, figsize=(5, 15))¶ Builds axes for history plot
- Parameters
ax (mpl.axes or None, default=None) – An axes object of predefined axes to be labelled
expected_detF (float or None, default=None) – Value of the expected determinant of the Fisher information to plot a horizontal line at to check fitting progress
figsize (tuple, default=(5, 15)) – The size of the figure to be produced
- Returns
An axes object of labelled axes
- Return type
mpl.axes
-
_slogdet
(matrix)¶ Combined summed logarithmic determinant
- Parameters
matrix (float(n, n)) – An n x n matrix to calculate the summed logarithmic determinant of
- Returns
The summed logarithmic determinant multiplied by its sign
- Return type
float
-
_update_history
(inputs, history, counter, ind)¶ Puts current fitting statistics into history arrays
- Parameters
inputs (tuple) –
- Fitting statistics calculated on a single iteration
F (float(n_params, n_params)) – Fisher information matrix
C (float(n_summaries, n_summaries)) – Covariance of network outputs
invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs
_Λ2 (float) – Covariance regularisation
_r (float) – Regularisation coupling strength
history (tuple) –
- History arrays containing fitting statistics for each iteration
detF (float(max_iterations, 1) or float(max_iterations, 2)) – Determinant of the Fisher information matrix
detC (float(max_iterations, 1) or float(max_iterations, 2)) – Determinant of the covariance of network outputs
detinvC (float(max_iterations, 1) or float(max_iterations, 2)) – Determinant of the inverse covariance of network outputs
Λ2 (float(max_iterations, 1) or float(max_iterations, 2)) – Covariance regularisation
r (float(max_iterations, 1) or float(max_iterations, 2)) – Regularisation coupling strength
counter (int) – Current iteration to insert a single iteration statistics into the history
ind (int) – Values of either 0 (fitting) or 1 (validation) to separate the fitting and validation historys
- Returns
float(max_iterations, 1) or float(max_iterations, 2) – History of the determinant of the Fisher information matrix
float(max_iterations, 1) or float(max_iterations, 2) – History of the determinant of the covariance of network outputs
float(max_iterations, 1) or float(max_iterations, 2) – History of the determinant of the inverse covariance of network outputs
float(max_iterations, 1) or float(max_iterations, 2) – History of the covariance regularisation
float(max_iterations, 1) or float(max_iterations, 2) – History of the regularisation coupling strength
-
_update_loop_vars
(inputs)¶ Updates input parameters if
max_detF
is increasedIf the determinant of the Fisher information matrix calculated in a given iteration is larger than the
max_detF
calculated so far then thepatience_counter
is reset to zero and themax_detF
is replaced with the current value ofdetF
and the network parameters in this iteration replace the previous parameters which obtained the highest determinant of the Fisher information,best_w
.- Parameters
inputs (tuple) –
patience_counter (int) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix
counter (int) – While loop iteration counter
detF (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the Fisher information matrix
max_detF (float) – Maximum value of the determinant of the Fisher information matrix calculated so far
w (list) – Value of the network parameters which in current iteration
best_w (list) – Value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix
- Returns
(described in Parameters)
- Return type
tuple
-
fit
(λ, ε, rng=None, patience=100, min_iterations=100, max_iterations=100000, print_rate=None, best=True)¶ Fitting routine for the IMNN
- Parameters
λ (float) – Coupling strength of the regularisation
ϵ (float) – Closeness criterion describing how close to the 1 the determinant of the covariance (and inverse covariance) of the network outputs is desired to be
rng (int(2,) or None, default=None) – Stateless random number generator
patience (int, default=10) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix, used for early stopping
min_iterations (int, default=100) – Number of iterations that should be run before considering early stopping using the patience counter
max_iterations (int, default=int(1e5)) – Maximum number of iterations to run the fitting procedure for
print_rate (int or None, default=None,) – Number of iterations before updating the progress bar whilst fitting. There is a performance hit from updating the progress bar more often and there is a large performance hit from using the progress bar at all. (Possible
RET_CHECK
failure ifprint_rate
is notNone
when using GPUs). For this reason it is set to None as defaultbest (bool, default=True) – Whether to set the network parameter attribute
self.w
to the parameter values that obtained the maximum determinant of the Fisher information matrix or the parameter values at the final iteration of fitting
Example
We are going to summarise the mean and variance of some random Gaussian noise with 10 data points per example using a SimulatorIMNN. In this case we are going to generate the simulations on-the-fly with a simulator written in jax (from the examples directory). We will use 1000 simulations to estimate the covariance of the network outputs and the derivative of the mean of the network outputs with respect to the model parameters (Gaussian mean and variance) and generate the simulations at a fiducial μ=0 and Σ=1. The network will be a stax model with hidden layers of
[128, 128, 128]
activated with leaky relu and outputting 2 summaries. Optimisation will be via Adam with a step size of1e-3
. Rather arbitrarily we’ll set the regularisation strength and covariance identity constraint to λ=10 and ϵ=0.1 (these are relatively unimportant for such an easy model).import jax import jax.numpy as np from jax.experimental import stax, optimizers from imnn import SimulatorIMNN rng = jax.random.PRNGKey(0) n_s = 1000 n_d = 1000 n_params = 2 n_summaries = 2 input_shape = (10,) simulator_args = {"input_shape": input_shape} θ_fid = np.array([0., 1.]) def simulator(rng, θ): return θ[0] + jax.random.normal( rng, shape=input_shape) * np.sqrt(θ[1]) model = stax.serial( stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(n_summaries)) optimiser = optimizers.adam(step_size=1e-3) λ = 10. ϵ = 0.1 model_key, fit_key = jax.random.split(rng) imnn = SimulatorIMNN( n_s=n_s, n_d=n_d, n_params=n_params, n_summaries=n_summaries, input_shape=input_shape, θ_fid=θ_fid, model=model, optimiser=optimiser, key_or_state=model_key, simulator=simulator) imnn.fit(λ, ϵ, rng=fit_key, min_iterations=1000, patience=250, print_rate=None)
Notes
A minimum number of interations should be be run before stopping based on a maximum determinant of the Fisher information achieved since the loss function has dual objectives. Since the determinant of the covariance of the network outputs is forced to 1 quickly, this can be at the detriment to the value of the determinant of the Fisher information matrix early in the fitting procedure. For this reason starting early stopping after the covariance has converged is advised. This is not currently implemented but could be considered in the future.
The best fit network parameter values are probably not the most representative set of parameters when simulating on-the-fly since there is a high chance of a statistically overly-informative set of data being generated. Instead, if using
fit()
consider usingbest=False
which setsself.w=self.final_w
which are the network parameter values obtained in the last iteration. Also consider using a largerpatience
value if usingfit()
to overcome the fact that a flukish high value for the determinant might have been obtained due to the realisation of the dataset.Due to some unusual thing, that I can’t work out, there is a massive performance hit when calling
jax.jit(self._fit)
compared with directly decorating_fit
with@partial(jax.jit(static_argnums=0))
. Unfortunately this means having to duplicate_fit
to include a version where the loop condition is decorated with a progress bar because thetqdm
module cannot use a jitted tracer. If the progress bar is not used then the fully decorated jitted_fit
function is used and it is super quick. Otherwise, just the body of the loop is jitted so that the condition function can be decorated by the progress bar (at the expense of a performance hit). I imagine that something can be improved here.There is a chance of a
RET_CHECK
failure when using the progress bar on GPUs (this doesn’t seem to be a problem on CPUs). If this is the case then print_rate=None should be used-
_fit:
Main fitting function implemented as a
jax.lax.while_loop
-
_fit_pbar:
Main fitting function as a
jax.lax.while_loop
with progress bar
- Raises
TypeError – If any input has the wrong type
ValueError – If any input (except
rng
andprint_rate
) areNone
ValueError – If
rng
has the wrong shapeValueError – If
rng
isNone
but simulating on-the-flyValueError – If calling fit with
print_rate=None
after previous call withprint_rate
as an integer valueValueError – If calling fit with
print_rate
as an integer after previous call withprint_rate=None
-
get_estimate
(d)¶ Calculate score compressed parameter estimates from network outputs
Using score compression we can get parameter estimates under the transformation
\[\hat{\boldsymbol{\theta}}_\alpha=\theta^{\rm{fid}}_\alpha+ \bf{F}^{-1}_{\alpha\beta}\frac{\partial\mu_i}{\partial \theta_\beta}\bf{C}^{-1}_{ij}(x(\bf{w}, \bf{d})-\mu)_j\]where \(x_j\) is the \(j\) output of the network with network parameters \(\bf{w}\) and input data \(\bf{d}\).
Examples
Assuming that an IMNN has been fit (as in the example in
imnn.imnn._imnn.IMNN.fit()
) then we can obtain a pseudo-maximum likelihood estimate of some target data (which is generated with parameter values μ=1, Σ=2) usingrng, target_key = jax.random.split(rng) target_data = model_simulator(target_key, np.array([1., 2.])) imnn.get_estimate(target_data) >>> DeviceArray([0.1108716, 1.7881424], dtype=float32)
The one standard deviation uncertainty on these parameter estimates (assuming the fiducial is at the maximum-likelihood estimate - which we know it isn’t here) estimated by the square root of the inverse Fisher information matrix is
np.sqrt(np.diag(imnn.invF)) >>> DeviceArray([0.31980422, 0.47132865], dtype=float32)
Note that we can compare the values estimated by the IMNN to the value of the mean and the variance of the target data itself, which is what the IMNN should be summarising
np.mean(target_data) >>> DeviceArray(0.10693721, dtype=float32) np.var(target_data) >>> DeviceArray(1.70872, dtype=float32)
Note that batches of data can be summarised at once using
get_estimate
. In this example we will draw 10 different values of μ from between \(-10 < \mu < 10\) and 10 different values of Σ from between \(0 < \Sigma < 10\) and generate a batch of 10 different input data which we can summarise using the IMNN.rng, mean_keys, var_keys = jax.random.split(rng, num=3) mean_vals = jax.random.uniform( mean_keys, minval=-10, maxval=10, shape=(10,)) var_vals = jax.random.uniform( var_keys, minval=0, maxval=10, shape=(10,)) np.stack([mean_vals, var_vals], -1) >>> DeviceArray([[ 3.8727236, 1.6727388], [-3.1113386, 8.14554 ], [ 9.87299 , 1.4134324], [ 4.4837523, 1.5812075], [-9.398947 , 3.5737753], [-2.0789695, 9.978279 ], [-6.2622285, 6.828809 ], [ 4.6470118, 6.0823894], [ 5.7369494, 8.856505 ], [ 4.248898 , 5.114669 ]], dtype=float32) batch_target_keys = np.array(jax.random.split(rng, num=10)) batch_target_data = jax.vmap(model_simulator)( batch_target_keys, (mean_vals, var_vals)) imnn.get_estimate(batch_target_data) >>> DeviceArray([[ 4.6041985, 8.344688 ], [-3.5172062, 7.7219954], [13.229679 , 23.668312 ], [ 5.745726 , 10.020965 ], [-9.734651 , 21.076218 ], [-1.8083427, 6.1901293], [-8.626409 , 18.894459 ], [ 5.7684307, 9.482665 ], [ 6.7861238, 14.128591 ], [ 4.900367 , 9.472563 ]], dtype=float32)
- Parameters
d (float(None, input_shape)) – Input data to be compressed to score compressed parameter estimates
- Returns
Score compressed parameter estimates
- Return type
float(None, n_params)
-
single_element:
Returns a single score compressed summary
-
multiple_elements:
Returns a batch of score compressed summaries
- Raises
ValueError – If the Fisher statistics are not set after running
fit
orset_F_statistics
.
-
get_summaries
(w=None, key=None, validate=False)¶ Gets all network outputs and derivatives wrt model parameters
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=False) – Whether to get summaries of the validation set
- Raises
ValueError –
-
get_α
(λ, ε)¶ Calculate rate parameter for regularisation from closeness criterion
- Parameters
λ (float) – coupling strength of the regularisation
ϵ (float) – closeness criterion describing how close to the 1 the determinant of the covariance (and inverse covariance) of the network outputs is desired to be
- Returns
The steepness of the tanh-like function (or rate) which determines how fast the determinant of the covariance of the network outputs should be pushed to 1
- Return type
float
-
plot
(ax=None, expected_detF=None, colour='C0', figsize=(5, 15), label='', filename=None, ncol=1)¶ Plot fitting history
Plots a three panel vertical plot with the determinant of the Fisher information matrix in the first sublot, the covariance and the inverse covariance in the second and the regularisation term and the regularisation coupling strength in the final subplot.
A predefined axes can be passed to fill, and these axes can be decorated via a call to
_setup_plot
(for horizonal plots for example).Example
Assuming that an IMNN has been fit (as in the example in
imnn.imnn._imnn.IMNN.fit()
) then we can make a training plot of the history by simply runningimnn.fit(expected_detF=50, filename="history_plot.png")
Note we know the analytic value of the determinant of the Fisher information for this problem (\(|\bf{F}|=50\)) so we can add this line to the plot too, and save the output as a png named
history_plot
.- Parameters
ax (mpl.axes or None, default=None) – An axes object of predefined axes to be labelled
expected_detF (float or None, default=None) – Value of the expected determinant of the Fisher information to plot a horizontal line at to check fitting progress
colour (str or rgb/a value or list, default="C0") – Colour to plot the lines
figsize (tuple, default=(5, 15)) – The size of the figure to be produced
label (str, default="") – Name to add to description in legend
filename (str or None, default=None) – Filename to save plot to
ncol (int, default=1) – Number of columns to have in the legend
- Returns
An axes object of the filled plot
- Return type
mpl.axes
-
set_F_statistics
(w=None, key=None, validate=True)¶ Set necessary attributes for calculating score compressed summaries
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=True) – Whether to calculate Fisher information using the validation set
Aggregation of gradients¶
-
class
imnn.imnn._aggregated_imnn.
_AggregatedIMNN
(host, devices, n_per_device)¶ Manual aggregation of gradients for the IMNN parent class
This class defines the overriding fitting functions for
_IMNN()
which allows gradients to be aggregated manually. This is necessary if networks or input data are extremely large (or if the number of simulations necessary to estimate the covariance of network outputs,n_s
, is very large) since all operations may not fit in memory.The aggregation is done by calculating
n_per_device
network outputs at once on each availablejax.device()
and then scanning over alln_s
inputs andn_d
simulations necessary to calculate the derivative of the mean of the network outputs with respect to the model parameters. This gives a set of summaries and derivatives from which the loss function\[\Lambda = -\log|\bf{F}| + r(\Lambda_2) \Lambda_2\](See IMNN: Information maximising neural networks) can be calculated and its gradient with respect to these summaries, \(\frac{\partial\Lambda}{\partial x_i^j}\) and derivatives \(\frac{\partial\Lambda}{\partial\partial{x_i^j}/ \partial\theta_\alpha}\) calculated, where \(i\) labels the network output and \(j\) labels the simulation. Note that these are small in comparison to the gradient with respect to the network parameters since their sizes are
n_s * n_summaries
andn_d * n_summaries * n_params
respectively. Once \(\frac{\partial\Lambda}{\partial{x_i^j}}\) and \(\frac{\partial\Lambda}{\partial\partial{x_i^j}/\partial \theta_\alpha}\) are calculated then gradients of the network outputs with respect to the network parameters \(\frac{\partial{x_i^j}}{\partial{w_{ab}^l}}\) and \(\frac{\partial\partial{x_i^j}/\partial\theta_\alpha} {\partial{w_{ab}^l}}\) are calculated and the chain rule is used to get\[\frac{\partial\Lambda}{\partial{w_{ab}^l}} = \frac{\partial \Lambda}{\partial{x_i^j}}\frac{\partial{x_i^j}} {\partial{w_{ab}^l}} + \frac{\partial\Lambda} {\partial\partial{x_i^j}/\partial\theta_\alpha} \frac{\partial\partial{x_i^j}/\partial\theta_\alpha} {\partial{w_{ab}^l}}\]Note that we keep the memory use low because only
n_per_device
simulations are handled at once before being summed into a single gradient list on each device.n_per_devices
should be as large as possible to get the best performance. If everything will fit in memory then this class should be avoided.The AttributedIMNN class doesn’t directly inherit from
_IMNN()
, but is meant to be built within a child class of it. For this reason there are attributes which are not explicitly set here, but are used within the module. These will be noted in Other Parameters below.- Parameters
host (jax.device) – The main device where the Fisher information calculation is performed
devices (list) – A list of the available jax devices (from
jax.devices()
)n_devices (int) – Number of devices to aggregated calculation over
n_per_device (int) – Number of simulations to handle at once, this should be as large as possible without letting the memory overflow for the best performance
-
model:
Neural network as a function of network parameters and inputs
-
_get_parameters:
Function which extracts the network parameters from the state
-
_model_initialiser:
Function to initialise neural network weights from RNG and shape tuple
-
_opt_initialiser:
Function which generates the optimiser state from network parameters
-
_update:
Function which updates the state from a gradient
-
batch_summaries:
Jitted function to calculate summaries on each XLA device
-
batch_summaries_with_derivatives:
Jitted function to calculate summaries from derivative on each device
-
batch_gradients:
Jitted function to calculate gradient on each XLA device
-
batch_gradients_with_derivatives:
Jitted function to calculate gradient from derivative on eachdevice
Public Methods:
__init__
(host, devices, n_per_device)Constructor method
fit
(λ, ε[, rng, patience, min_iterations, …])Fitting routine for the IMNN
get_summaries
(w[, key, validate])Gets all network outputs and derivatives wrt model parameters
get_gradient
(dΛ_dx, w[, key])Aggregates gradients together to update the network parameters
Private Methods:
_set_devices
(devices, n_per_device)Checks that devices exist and that reshaping onto devices can occur
Creates jitted functions placed on desired XLA devices
Calculates the shapes for batching over different devices
_setup_progress_bar
(print_rate, max_iterations)Construct progress bar
_update_progress_bar
(pbar, counter, …[, close])Updates (and closes) progress bar
_collect_input
(key[, validate])Returns the dataset to be interated over
_get_batch_summaries
(inputs, w, θ[, …])Vectorised batch calculation of summaries or gradients
_split_dΛ_dx
(dΛ_dx)Separates dΛ_dx and d2Λ_dxdθ and reshapes them for aggregation
_construct_gradient
(layers[, aux, func])Multiuse function to iterate over tuple of network parameters
-
_collect_input
(key, validate=False)¶ Returns the dataset to be interated over
- Parameters
key (int(2,)) – A random number generator
validate (bool, default=False) – Whether to use the validation set or not
-
_construct_gradient
(layers, aux=None, func='zeros')¶ Multiuse function to iterate over tuple of network parameters
- The options are:
"zeros"
– to create an empty gradient array"einsum"
– to combine tuple ofdx_dw
withdΛ_dx
"derivative_einsum"
– to combine tuple ofd2x_dwdθ
withd2Λ_dxdθ
"sum"
– to reduce sum batches of gradients on the first axis
- Parameters
layers (tuple) – The tuple of tuples of arrays to be iterated over
aux (float(various shapes)) – parameter to pass dΛ_dx and d2Λ_dxdθ to einsum
func (str) –
- Option for the function to apply
"zeros"
– to create an empty gradient array"einsum"
– to combine tuple ofdx_dw
withdΛ_dx
"derivative_einsum"
– to combine tuple ofd2x_dwdθ
withd2Λ_dxdθ
"sum"
– to reduce sum batches of gradients on the first axis
- Returns
Tuple of objects like the gradient of the loss function with respect to the network parameters
- Return type
tuple
- Raises
ValueError – If applied function is not implemented
-
_get_batch_summaries
(inputs, w, θ, gradient=False, derivative=False)¶ Vectorised batch calculation of summaries or gradients
- Parameters
inputs (tuple) –
dΛ_dx if
gradient
(float(n_per_device, n_summaries) or tuple)dΛ_dx (float(n_per_device, n_summaries)) – The gradient of the loss function with respect to network outputs
d2Λ_dxdθ if
derivative
(float(n_per_device, n_summaries, n_params)) – The gradient of the loss function with respect to derivative of network outputs with respect to model parameters
keys if
SimulatorIMNN()
(int(n_per_device, 2)) – The keys for generating simulations on-the-flyd if
NumericalGradientIMNN()
(float(n_per_device, input_shape) or tuple)d (float(n_per_device, input_shape)) – The simulations to be evaluated
dd_dθ if
derivative
(float(n_per_device, input_shape, n_params)) – The derivative of the simulations to be evaluated with respect to model parameters
w (list) – Network model parameters
θ (float(n_params,)) – The value of the model parameters to generate simulations at/to perform the derivative calculation
gradient (bool) – Whether to do the gradient calculation
derivative (bool, default=False) – Whether the gradient of loss function with respect to the derivative of the network outputs with respect to the model parameters is being used
- Returns
x if not
gradient
(float(n_per_device, n_summaries) or tuple)x (float(n_per_device, n_summaries)) – The network outputs
dd_dθ if
derivative
(float(n_per_device, n_summaries, n_params)) – The derivative of the network outputs with respect to model parameters
- if
gradient
(tuple) – The accumlated and aggregated gradient of the loss function with respect to the network parameters
- if
- Return type
float(n_devices, n_per_device, n_summaries) or tuple
-
_set_batch_functions
()¶ Creates jitted functions placed on desired XLA devices
For each set of summaries to correctly be calculated on a particular device we predefine the jitted functions on each of these devices
-
_set_devices
(devices, n_per_device)¶ Checks that devices exist and that reshaping onto devices can occur
Due to the aggregation then balanced splits must be made between the different devices and so these are checked.
- Parameters
devices (list) – A list of the available jax devices (from
jax.devices()
)n_per_device (int) – Number of simulations to handle at once, this should be as large as possible without letting the memory overflow for the best performance
- Raises
ValueError – If
devices
orn_per_device
are NoneValueError – If balanced splitting cannot be done
TypeError – If
devices
is not a list and ifn_per_device
is not an int
-
_set_shapes
()¶ Calculates the shapes for batching over different devices
Not implemented
- Raises
ValueError – Not implemented in _AggregatedIMNN
-
_setup_progress_bar
(print_rate, max_iterations)¶ Construct progress bar
- Parameters
print_rate (int or None) – The rate at which the progress bar is updated (no bar if None)
max_iterations (int) – The maximum number of iterations, used to setup bar upper limit
- Returns
progress bar or None – The TQDM progress bar object
int or None – The print rate (after checking for int or None)
int or None – The difference between the max_iterations and the print rate
- Raises
TypeError: – If
print_rate
is not an integer
-
_split_dΛ_dx
(dΛ_dx)¶ Separates dΛ_dx and d2Λ_dxdθ and reshapes them for aggregation
- Parameters
dΛ_dx (tuple) –
dΛ_dx (float(n_s, n_summaries)) – The derivative of the loss function wrt the network outputs
d2Λ_dxdθ (float(n_d, n_summaries, n_params)) – The derivative of the loss function wrt the derivative of the network outputs wrt the model parameters
- Raises
ValueError – function not implemented in parent class
-
_update_progress_bar
(pbar, counter, patience_counter, max_detF, detF, detC, detinvC, Λ2, r, print_rate, max_iterations, remainder, close=False)¶ Updates (and closes) progress bar
Checks whether a pbar is used and is so checks whether the iteration coincides with the print rate, or is the last set of iterations within the print rate from the last iteration, or if the last iteration has been reached and the bar should be closed.
- Parameters
pbar (progress bar object) – The TQDM progress bar
counter (int) – The value of the current iteration
patience_counter (int) – The number of iterations where the maximum of the determinant of the Fisher information matrix has not increased
max_detF (float) – Maximum of the determinant of the Fisher information matrix
detF (float(n_params, n_params)) – Fisher information matrix
detC (float(n_summaries, n_summaries)) – Covariance of the network summaries
detinvC (float(n_summaries, n_summaries)) – Inverse covariance of the network summaries
Λ2 (float) – Value of the regularisation term
r (float) – Value of the dynamic regularisation coupling strength
print_rate (int or None) – The number of iterations to run before updating the progress bar
max_iterations (int) – The maximum number of iterations to run
remainder (int or None) – The number of iterations before max_iterations to check progress
close (bool, default=False) – Whether to close the progress bar (on final iteration)
-
fit
(λ, ε, rng=None, patience=100, min_iterations=100, max_iterations=100000, print_rate=None, best=True)¶ Fitting routine for the IMNN
- Parameters
λ (float) – Coupling strength of the regularisation
ϵ (float) – Closeness criterion describing how close to the 1 the determinant of the covariance (and inverse covariance) of the network outputs is desired to be
rng (int(2,) or None, default=None) – Stateless random number generator
patience (int, default=10) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix, used for early stopping
min_iterations (int, default=100) – Number of iterations that should be run before considering early stopping using the patience counter
max_iterations (int, default=int(1e5)) – Maximum number of iterations to run the fitting procedure for
print_rate (int or None, default=None,) – Number of iterations before updating the progress bar whilst fitting. There is a performance hit from updating the progress bar more often and there is a large performance hit from using the progress bar at all. (Possible
RET_CHECK
failure ifprint_rate
is notNone
when using GPUs). For this reason it is set to None as defaultbest (bool, default=True) – Whether to set the network parameter attribute
self.w
to the parameter values that obtained the maximum determinant of the Fisher information matrix or the parameter values at the final iteration of fitting
Example
We are going to summarise the mean and variance of some random Gaussian noise with 10 data points per example using an AggregatedSimulatorIMNN. In this case we are going to generate the simulations on-the-fly with a simulator written in jax (from the examples directory). These simulations will be generated on-the-fly and passed through the network on each of the GPUs in
jax.devices("gpu")
and we will make 100 simulations on each device at a time. The main computation will be done on the CPU. We will use 1000 simulations to estimate the covariance of the network outputs and the derivative of the mean of the network outputs with respect to the model parameters (Gaussian mean and variance) and generate the simulations at a fiducial μ=0 and Σ=1. The network will be a stax model with hidden layers of[128, 128, 128]
activated with leaky relu and outputting 2 summaries. Optimisation will be via Adam with a step size of1e-3
. Rather arbitrarily we’ll set the regularisation strength and covariance identity constraint to λ=10 and ϵ=0.1 (these are relatively unimportant for such an easy model).import jax import jax.numpy as np from jax.experimental import stax, optimizers from imnn import AggregatedSimulatorIMNN rng = jax.random.PRNGKey(0) n_s = 1000 n_d = 1000 n_params = 2 n_summaries = 2 input_shape = (10,) θ_fid = np.array([0., 1.]) def simulator(rng, θ): return θ[0] + jax.random.normal( rng, shape=input_shape) * np.sqrt(θ[1]) model = stax.serial( stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(n_summaries)) optimiser = optimizers.adam(step_size=1e-3) λ = 10. ϵ = 0.1 model_key, fit_key = jax.random.split(rng) host = jax.devices("cpu")[0] devices = jax.devices("gpu") n_per_device = 100 imnn = AggregatedSimulatorIMNN( n_s=n_s, n_d=n_d, n_params=n_params, n_summaries=n_summaries, input_shape=input_shape, θ_fid=θ_fid, model=model, optimiser=optimiser, key_or_state=model_key, simulator=simulator, host=host, devices=devices, n_per_device=n_per_device) imnn.fit(λ, ϵ, rng=fit_key, min_iterations=1000, patience=250, print_rate=None)
Notes
A minimum number of interations should be be run before stopping based on a maximum determinant of the Fisher information achieved since the loss function has dual objectives. Since the determinant of the covariance of the network outputs is forced to 1 quickly, this can be at the detriment to the value of the determinant of the Fisher information matrix early in the fitting procedure. For this reason starting early stopping after the covariance has converged is advised. This is not currently implemented but could be considered in the future.
The best fit network parameter values are probably not the most representative set of parameters when simulating on-the-fly since there is a high chance of a statistically overly-informative set of data being generated. Instead, if using
fit()
consider usingbest=False
which setsself.w=self.final_w
which are the network parameter values obtained in the last iteration. Also consider using a largerpatience
value if usingfit()
to overcome the fact that a flukish high value for the determinant might have been obtained due to the realisation of the dataset.- Raises
TypeError – If any input has the wrong type
ValueError – If any input (except
rng
) areNone
ValueError – If
rng
has the wrong shapeValueError – If
rng
isNone
but simulating on-the-fly
-
get_keys_and_params:
Jitted collection of parameters and random numbers
-
calculate_loss:
Returns the jitted gradient of the loss function wrt summaries
-
validation_loss:
Jitted loss and auxillary statistics from validation set
-
get_gradient
(dΛ_dx, w, key=None)¶ Aggregates gradients together to update the network parameters
To avoid having to calculate the gradient with respect to all the simulations at once we aggregate by addition the gradient calculation by looping over the simulations again and combining them with the derivative of the loss function with respect to the network outputs (and their derivatives with respect to the model parameters). Whilst this is expensive, it is necessary since we cannot make a stochastic estimate of the Fisher information accurately and therefore we need to use all the simulations available - which is probably too large to fit in memory.
- Parameters
dΛ_dx (tuple) –
dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries
d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters
w (list) – Network parameters
key (None or int(2,)) – Random number generator used in SimulatorIMNN
- Returns
The gradient of the loss function with respect to the network parameters calculated by aggregating
- Return type
list
- Raises
ValueError – function not implemented in parent class
-
get_summaries
(w, key=None, validate=False)¶ Gets all network outputs and derivatives wrt model parameters
- Parameters
w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters
key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly
validate (bool, default=True) – Whether to get summaries of the validation set
- Returns
float(n_s, n_summaries) – The network outputs
float(n_d, n_summaries, n_params) – The derivative of the network outputs wrt the model parameters
- Raises
ValueError – function not implemented in parent class