Examples¶
Whilst there are many specific ways to use the IMNN depending on the size of the dataset and speed at which a simulation can be generated, the easiest way to interact with the IMNN is via IMNN()
described at the end of this page
IMNN¶
-
imnn.
IMNN
(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, simulator=None, fiducial=None, derivative=None, main=None, remaining=None, δθ=None, validation_fiducial=None, validation_derivative=None, validation_main=None, validation_remaining=None, host=None, devices=None, n_per_device=None, cache=None, prefetch=None, verbose=True)¶ Selection function to return correct submodule based on inputs
Because there are many different subclasses to work with specific types of simulations (or a simulator) and how their gradients are calculated this function provides a way to try and return the desired one based on the data passed.
- Parameters
n_s (int) – Number of simulations used to calculate summary covariance
n_d (int) – Number of simulations used to calculate mean of summary derivative
n_params (int) – Number of model parameters
n_summaries (int) – Number of summaries, i.e. outputs of the network
input_shape (tuple) – The shape of a single input to the network
θ_fid (float(n_params,)) – The value of the fiducial parameter values used to generate inputs
model (tuple, len=2) – Tuple containing functions to initialise neural network
fn(rng: int(2), input_shape: tuple) -> tuple, list
and the neural network as a function of network parameters and inputsfn(w: list, d: float([None], input_shape)) -> float([None], n_summaries)
. (Essentibly stax-like, see jax.experimental.stax))optimiser (tuple, len=3) – Tuple containing functions to generate the optimiser state
fn(x0: list) -> :obj:state
, to update the state from a list of gradientsfn(i: int, g: list, state: :obj:state) -> :obj:state
and to extract network parameters from the statefn(state: :obj:state) -> list
. (See jax.experimental.optimizers)key_or_state (int(2) or :obj:state) – Either a stateless random number generator or the state object of an preinitialised optimiser
simulator (fn, optional requirement) – (
SimulatorIMNN()
,AggregatedSimulatorIMNN()
) A function that generates a single simulation from a random number generator and a tuple (or array) of parameter values at which to generate the simulations. For the purposes of use in LFI/ABC afterwards it is also useful for the simulator to be able to broadcast to a batch of simulations on the zeroth axisfn(int(2,), float([None], n_params)) -> float([None], input_shape)
fiducial (float or list, optional requirement) –
The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)
(float(n_s, input_shape)) –
GradientIMNN()
,NumericalGradientIMNN()
,AggregatedGradientIMNN()
,AggregatedNumericalGradientIMNN()
(list of numpy iterators) –
DatasetNumericalGradientIMNN()
derivative (float or list, optional requirement) –
The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for fitting)
(float(n_d, input_shape, n_params)) –
GradientIMNN()
,AggregatedGradientIMNN()
(float(n_d, 2, n_params, input_shape)) –
NumericalGradientIMNN()
,AggregatedNumericalGradientIMNN()
(list of numpy iterators) –
DatasetNumericalGradientIMNN()
main (list of numpy iterators, optional requirement) – (
DatasetGradientIMNN()
) The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for fitting). These are servedn_per_device
at a time as a numpy iterator from a TensorFlow dataset.remaining (list of numpy iterators, optional requirement) – (
DatasetGradientIMNN()
) Then_s - n_d
simulations generated at the fiducial model parameter values used for calculating the covariance ofnetwork outputs with a derivative counterpart (for fitting). These are servedn_per_device
at a time as a numpy iterator from a TensorFlow dataset.δθ (float(n_params,), optional requirement) – (
NumericalGradientIMNN()
,AggregatedNumericalGradientIMNN()
,DatasetNumericalGradientIMNN()
) Size of perturbation to model parameters for the numerical derivativevalidation_fiducial (float or list, optional requirement) –
The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation)
(float(n_s, input_shape)) –
GradientIMNN()
,NumericalGradientIMNN()
,AggregatedGradientIMNN()
,AggregatedNumericalGradientIMNN()
(list of numpy iterators) –
DatasetNumericalGradientIMNN()
validation_derivative (float or list, optional requirement) –
The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for validation)
(float(n_d, input_shape, n_params)) –
GradientIMNN()
,AggregatedGradientIMNN()
(float(n_d, 2, n_params, input_shape)) –
NumericalGradientIMNN()
,AggregatedNumericalGradientIMNN()
(list of numpy iterators) –
DatasetNumericalGradientIMNN()
validation_main (list of numpy iterators, optional requirement) – (
DatasetGradientIMNN()
) The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for validation). These are servedn_per_device
at a time as a numpy iterator from a TensorFlow dataset.validation_remaining (list of numpy iterators, optional requirement) – (
DatasetGradientIMNN()
) Then_s - n_d
simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs with a derivative counterpart (for validation). These are servedn_per_device
at a time as a numpy iterator from a TensorFlow dataset.host (jax.device, optional requirement) – (
AggregatedSimulatorIMNN()
,AggregatedGradientIMNN()
,AggregatedNumericalGradientIMNN()
,DatasetGradientIMNN()
,DatasetNumericalGradientIMNN()
) The main device where the Fisher calculation is performed (something likejax.devices("cpu")[0]
)devices (list, optional requirement) – (
AggregatedSimulatorIMNN()
,AggregatedGradientIMNN()
,AggregatedNumericalGradientIMNN()
,DatasetGradientIMNN()
,DatasetNumericalGradientIMNN()
) A list of the available jax devices (fromjax.devices()
)n_per_device (int, optional requirement) – (
AggregatedSimulatorIMNN()
,AggregatedGradientIMNN()
,AggregatedNumericalGradientIMNN()
,DatasetGradientIMNN()
,DatasetNumericalGradientIMNN()
) Number of simulations to handle at once, this should be as large as possible without letting the memory overflow for the best performanceprefetch (tf.data.AUTOTUNE or int or None, optional, default=None) – How many simulation to prefetch in the tensorflow dataset (could be used in
AggregatedGradientIMNN()
andAggregatedNumericalGradientIMNN()
)cache (bool, optional, default=None) – Whether to cache simulations in the tensorflow datasets (could be used in
AggregatedGradientIMNN()
andAggregatedNumericalGradientIMNN()
)