3.6.5. unit_scaling.transforms.simulate_format
- unit_scaling.transforms.simulate_format(module: M, fwd_format: FPFormat, bwd_format: FPFormat) M [source]
[Experimental] Given a module, uses TorchDynamo to return a new module which simulates the effect of using the supplied formats for matmuls.
Specifically, before each
torch.nn.functional.linear()
andtorch.nn.functional.scaled_dot_product_attention()
call, a quantisation op is inserted which simulates the effect of using the supplied fwd_format. This op reduces the range of values to that of the given format, and (stochastically) rounds values to only those representable by the format.The same is true for the backward pass, where an op is inserted to quantise to the bwd_format. Models which use modules that contain these functions internally (such as
torch.nn.Linear
) will be inspected by TorchDynamo and have the correct quantisation ops inserted.If the equivalent unit-scaled functions from
unit_scaling.functional
are used in the module, these too will be quantised.Simulation of formats is run in FP32. Users should not expect speedups from using this method. The purpose is to simulate the numerical effects of running matmuls in various formats.
- Parameters:
- Returns:
a new module which when used, will run using the simulated formats.
- Return type:
nn.Module