Examples

Whilst there are many specific ways to use the IMNN depending on the size of the dataset and speed at which a simulation can be generated, the easiest way to interact with the IMNN is via IMNN() described at the end of this page

IMNN

imnn.IMNN(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state, simulator=None, fiducial=None, derivative=None, main=None, remaining=None, δθ=None, validation_fiducial=None, validation_derivative=None, validation_main=None, validation_remaining=None, host=None, devices=None, n_per_device=None, cache=None, prefetch=None, verbose=True)

Selection function to return correct submodule based on inputs

Because there are many different subclasses to work with specific types of simulations (or a simulator) and how their gradients are calculated this function provides a way to try and return the desired one based on the data passed.

Parameters
  • n_s (int) – Number of simulations used to calculate summary covariance

  • n_d (int) – Number of simulations used to calculate mean of summary derivative

  • n_params (int) – Number of model parameters

  • n_summaries (int) – Number of summaries, i.e. outputs of the network

  • input_shape (tuple) – The shape of a single input to the network

  • θ_fid (float(n_params,)) – The value of the fiducial parameter values used to generate inputs

  • model (tuple, len=2) – Tuple containing functions to initialise neural network fn(rng: int(2), input_shape: tuple) -> tuple, list and the neural network as a function of network parameters and inputs fn(w: list, d: float([None], input_shape)) -> float([None], n_summaries). (Essentibly stax-like, see jax.experimental.stax))

  • optimiser (tuple, len=3) – Tuple containing functions to generate the optimiser state fn(x0: list) -> :obj:state, to update the state from a list of gradients fn(i: int, g: list, state: :obj:state) -> :obj:state and to extract network parameters from the state fn(state: :obj:state) -> list. (See jax.experimental.optimizers)

  • key_or_state (int(2) or :obj:state) – Either a stateless random number generator or the state object of an preinitialised optimiser

  • simulator (fn, optional requirement) – (SimulatorIMNN(), AggregatedSimulatorIMNN()) A function that generates a single simulation from a random number generator and a tuple (or array) of parameter values at which to generate the simulations. For the purposes of use in LFI/ABC afterwards it is also useful for the simulator to be able to broadcast to a batch of simulations on the zeroth axis fn(int(2,), float([None], n_params)) -> float([None], input_shape)

  • fiducial (float or list, optional requirement) –

    The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for fitting)

    • (float(n_s, input_shape))GradientIMNN(), NumericalGradientIMNN(), AggregatedGradientIMNN(), AggregatedNumericalGradientIMNN()

    • (list of numpy iterators)DatasetNumericalGradientIMNN()

  • derivative (float or list, optional requirement) –

    The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for fitting)

    • (float(n_d, input_shape, n_params))GradientIMNN(), AggregatedGradientIMNN()

    • (float(n_d, 2, n_params, input_shape))NumericalGradientIMNN(), AggregatedNumericalGradientIMNN()

    • (list of numpy iterators)DatasetNumericalGradientIMNN()

  • main (list of numpy iterators, optional requirement) – (DatasetGradientIMNN()) The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for fitting). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • remaining (list of numpy iterators, optional requirement) – (DatasetGradientIMNN()) The n_s - n_d simulations generated at the fiducial model parameter values used for calculating the covariance ofnetwork outputs with a derivative counterpart (for fitting). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • δθ (float(n_params,), optional requirement) – (NumericalGradientIMNN(), AggregatedNumericalGradientIMNN(), DatasetNumericalGradientIMNN()) Size of perturbation to model parameters for the numerical derivative

  • validation_fiducial (float or list, optional requirement) –

    The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs (for validation)

    • (float(n_s, input_shape))GradientIMNN(), NumericalGradientIMNN(), AggregatedGradientIMNN(), AggregatedNumericalGradientIMNN()

    • (list of numpy iterators)DatasetNumericalGradientIMNN()

  • validation_derivative (float or list, optional requirement) –

    The simulations generated at parameter values perturbed from the fiducial used to calculate the numerical derivative of network outputs with respect to model parameters (for validation)

    • (float(n_d, input_shape, n_params))GradientIMNN(), AggregatedGradientIMNN()

    • (float(n_d, 2, n_params, input_shape))NumericalGradientIMNN(), AggregatedNumericalGradientIMNN()

    • (list of numpy iterators)DatasetNumericalGradientIMNN()

  • validation_main (list of numpy iterators, optional requirement) – (DatasetGradientIMNN()) The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for validation). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • validation_remaining (list of numpy iterators, optional requirement) – (DatasetGradientIMNN()) The n_s - n_d simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs with a derivative counterpart (for validation). These are served n_per_device at a time as a numpy iterator from a TensorFlow dataset.

  • host (jax.device, optional requirement) – (AggregatedSimulatorIMNN(), AggregatedGradientIMNN(), AggregatedNumericalGradientIMNN(), DatasetGradientIMNN(), DatasetNumericalGradientIMNN()) The main device where the Fisher calculation is performed (something like jax.devices("cpu")[0])

  • devices (list, optional requirement) – (AggregatedSimulatorIMNN(), AggregatedGradientIMNN(), AggregatedNumericalGradientIMNN(), DatasetGradientIMNN(), DatasetNumericalGradientIMNN()) A list of the available jax devices (from jax.devices())

  • n_per_device (int, optional requirement) – (AggregatedSimulatorIMNN(), AggregatedGradientIMNN(), AggregatedNumericalGradientIMNN(), DatasetGradientIMNN(), DatasetNumericalGradientIMNN()) Number of simulations to handle at once, this should be as large as possible without letting the memory overflow for the best performance

  • prefetch (tf.data.AUTOTUNE or int or None, optional, default=None) – How many simulation to prefetch in the tensorflow dataset (could be used in AggregatedGradientIMNN() and AggregatedNumericalGradientIMNN())

  • cache (bool, optional, default=None) – Whether to cache simulations in the tensorflow datasets (could be used in AggregatedGradientIMNN() and AggregatedNumericalGradientIMNN())