3.6.5. unit_scaling.transforms.simulate_format

unit_scaling.transforms.simulate_format(module: M, fwd_format: FPFormat, bwd_format: FPFormat) M[source]

[Experimental] Given a module, uses TorchDynamo to return a new module which simulates the effect of using the supplied formats for matmuls.

Specifically, before each torch.nn.functional.linear() and torch.nn.functional.scaled_dot_product_attention() call, a quantisation op is inserted which simulates the effect of using the supplied fwd_format. This op reduces the range of values to that of the given format, and (stochastically) rounds values to only those representable by the format.

The same is true for the backward pass, where an op is inserted to quantise to the bwd_format. Models which use modules that contain these functions internally (such as torch.nn.Linear) will be inspected by TorchDynamo and have the correct quantisation ops inserted.

If the equivalent unit-scaled functions from unit_scaling.functional are used in the module, these too will be quantised.

Simulation of formats is run in FP32. Users should not expect speedups from using this method. The purpose is to simulate the numerical effects of running matmuls in various formats.

Parameters:
  • module (nn.Module) – the module to be quantised

  • fwd_format (FPFormat) – the quantisation format to be used in the forward pass (activations and weights)

  • bwd_format (FPFormat) – the quantisation format to be used in the backward pass (gradients of activations and weights)

Returns:

a new module which when used, will run using the simulated formats.

Return type:

nn.Module