3.6.6. unit_scaling.transforms.simulate_fp8
- unit_scaling.transforms.simulate_fp8(module: M) M [source]
[Experimental] Given a module, uses TorchDynamo to return a new module which simulates the effect of running matmuls in FP8. As is standard in the literature (Noune et al., 2022; Micikevicius et al., 2022), we use the FP8 E4 format in the forwards pass, and FP8 E5 in the backward pass.
Specifically, before each
torch.nn.functional.linear()
andtorch.nn.functional.scaled_dot_product_attention()
call, a quantisation op is inserted which simulates the effect of using FP8. This op reduces the range of values to that of the format, and (stochastically) rounds values to only those representable by the format.The same is true for the backward pass. Models which use modules that contain these functions internally (such as
torch.nn.Linear
) will be inspected by TorchDynamo and have the correct quantisation ops inserted.If the equivalent unit-scaled functions from
unit_scaling.functional
are used in the module, these too will be quantised.Simulation of formats is run in FP32. Users should not expect speedups from using this method. The purpose is to simulate the numerical effects of running matmuls in FP8.
- Parameters:
module (nn.Module) – the module to be quantised
- Returns:
a new module which when used, will run with matmul inputs in FP8.
- Return type:
nn.Module