Source code for unit_scaling.formats
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""Classes for simulating (non-standard) number formats."""
from dataclasses import dataclass
from typing import Tuple, cast
import torch
from torch import Tensor
from ._internal_utils import generate__all__
Shape = Tuple[int, ...]
[docs]
@dataclass
class FPFormat:
"""Generic representation of a floating-point number format."""
exponent_bits: int
mantissa_bits: int
rounding: str = "stochastic" # "stochastic|nearest"
srbits: int = 0 # Number of bits for stochastic rounding, zero => use all bits
def __post_init__(self) -> None:
assert self.exponent_bits >= 2, "FPFormat requires at least 2 exponent bits"
assert (
self.srbits == 0 or self.rounding == "stochastic"
), "Nonzero srbits for non-stochastic rounding"
if self.srbits == 0 and self.rounding == "stochastic":
self.srbits = 23 - self.mantissa_bits
@property
def bits(self) -> int:
"""The number of bits used by the format."""
return 1 + self.exponent_bits + self.mantissa_bits
def __str__(self) -> str: # pragma: no cover
return (
f"E{self.exponent_bits}M{self.mantissa_bits}-"
+ dict(stochastic="SR", nearest="RN")[self.rounding]
)
@property
def max_absolute_value(self) -> float:
"""The maximum absolute value representable by the format."""
max_exponent = 2 ** (self.exponent_bits - 1) - 1
return cast(float, 2**max_exponent * (2 - 2**-self.mantissa_bits))
@property
def min_absolute_normal(self) -> float:
"""The minimum absolute normal value representable by the format."""
min_exponent = 1 - 2 ** (self.exponent_bits - 1)
return cast(float, 2**min_exponent)
@property
def min_absolute_subnormal(self) -> float:
"""The minimum absolute subnormal value representable by the format."""
return self.min_absolute_normal * 2.0**-self.mantissa_bits
[docs]
def quantise(self, x: Tensor) -> Tensor:
"""Non-differentiably quantise the given tensor in this format."""
absmax = self.max_absolute_value
downscale = 2.0 ** (127 - 2 ** (self.exponent_bits - 1))
mask = torch.tensor(2 ** (23 - self.mantissa_bits) - 1, device=x.device)
if self.rounding == "stochastic":
srbitsbar = 23 - self.mantissa_bits - self.srbits
offset = (
torch.randint(
0, 2**self.srbits, x.shape, dtype=torch.int32, device=x.device
)
<< srbitsbar
)
# Correct for bias. We can do this only for srbits < 23-mantissa_bits,
# but it is only likely to matter when srbits is small.
if srbitsbar > 0:
offset += 1 << (srbitsbar - 1)
elif self.rounding == "nearest":
offset = mask // 2
else: # pragma: no cover
raise ValueError(
f'Unexpected FPFormat(rounding="{self.rounding}"),'
' expected "stochastic" or "nearest"'
)
q = x.to(torch.float32)
q = torch.clip(x, -absmax, absmax)
q /= downscale
q = ((q.view(torch.int32) + offset) & ~mask).view(torch.float32)
q *= downscale
return q.to(x.dtype)
[docs]
def quantise_fwd(self, x: Tensor) -> Tensor:
"""Quantise the given tensor in the forward pass only."""
class QuantiseForward(torch.autograd.Function):
@staticmethod
def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor) -> Tensor:
return self.quantise(x)
@staticmethod
def backward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor
) -> Tensor:
return grad_y
return QuantiseForward.apply(x) # type: ignore
[docs]
def quantise_bwd(self, x: Tensor) -> Tensor:
"""Quantise the given tensor in the backward pass only."""
class QuantiseBackward(torch.autograd.Function):
@staticmethod
def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor) -> Tensor:
return x
@staticmethod
def backward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor
) -> Tensor:
return self.quantise(grad_y)
return QuantiseBackward.apply(x) # type: ignore
[docs]
def format_to_tuple(format: FPFormat) -> Tuple[int, int]:
"""Convert the format into a tuple of `(exponent_bits, mantissa_bits)`"""
return (format.exponent_bits, format.mantissa_bits)
[docs]
def tuple_to_format(t: Tuple[int, int]) -> FPFormat:
"""Given a tuple of `(exponent_bits, mantissa_bits)` returns the corresponding
:class:`FPFormat`"""
return FPFormat(*t)
__all__ = generate__all__(__name__)