3.1.22.10. unit_scaling.functional.matmul

unit_scaling.functional.matmul(left: Tensor, right: Tensor, constraint: str | None = 'to_output_scale') Tensor[source]

A unit-scaled matrix product of two tensors.

Matrix product of two tensors.

The behavior depends on the dimensionality of the tensors as follows:

  • If both tensors are 1-dimensional, the dot product (scalar) is returned.

  • If both arguments are 2-dimensional, the matrix-matrix product is returned.

  • If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.

  • If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.

  • If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if input is a \((j \times 1 \times n \times n)\) tensor and other is a \((k \times n \times n)\) tensor, out will be a \((j \times k \times n \times n)\) tensor.

    Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if input is a \((j \times 1 \times n \times m)\) tensor and other is a \((k \times m \times p)\) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. out will be a \((j \times k \times n \times p)\) tensor.

This operation has support for arguments with sparse layouts. In particular the matrix-matrix (both arguments 2-dimensional) supports sparse arguments with the same restrictions as torch.mm()

Warning

Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, or may not have autograd support. If you notice missing functionality please open a feature request.

This operator supports TensorFloat32.

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Note

The 1-dimensional dot product version of this function does not support an out parameter.

Parameters:
  • input (Tensor) – the first tensor to be multiplied

  • other (Tensor) – the second tensor to be multiplied

  • constraint (Optional[str]?) – The name of the constraint function to be applied to the outputs & input gradients. In this case, the constraint name must be one of: [None, ‘gmean’, ‘hmean’, ‘amean’, ‘to_output_scale’, ‘to_left_grad_scale’, to_right_grad_scale] (see unit_scaling.constraints for details on these constraint functions). Defaults to gmean.