Module poptorch_experimental_addons.sharded
A collection of functions to support sharded matrix multiplications under a variety of different sharded tensor constraints.
Only 1-D tensor sharding is currently supported in poptorch.
Note: functions names use convention {in1}{in2}{out}_sharded_matmul where {…} can evaluate to: - rep: replicated - col: column-sharded - row: row-sharded
Expand source code
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""
A collection of functions to support sharded matrix multiplications under a variety
of different sharded tensor constraints.
Only 1-D tensor sharding is currently supported in poptorch.
Note: functions names use convention {in1}{in2}{out}_sharded_matmul where {...} can
evaluate to:
- rep: replicated
- col: column-sharded
- row: row-sharded
"""
from typing import Any
import einops
import torch
from .collectives import all_gather_cross_replica, all_reduce_cross_replica_sum
def rowcolrow_sharded_matmul(
X: torch.Tensor, Y: torch.Tensor, replication_factor: int, num_chunks: int = 1
) -> Any:
"""
Matrix multiplication for row-sharded x column-sharded -> row-sharded tensors
Gathers the right multiplicand across IPU program replicas
"""
X_local_outer_dim, _ = X.shape
_, Y_local_outer_dim = Y.shape
result = torch.zeros((X_local_outer_dim, Y_local_outer_dim * replication_factor))
X = einops.rearrange(X, "m (k c) -> m k c", k=num_chunks)
Y = einops.rearrange(Y, "(k c) n -> c k n", k=num_chunks)
for i in range(num_chunks):
index = torch.tensor([i])
Yg = all_gather_cross_replica(
torch.index_select(Y, dim=1, index=index).squeeze(), replication_factor
)
Yg = einops.rearrange(Yg, "r c n -> c (r n)")
Xp = torch.index_select(X, dim=1, index=index).squeeze()
result += Xp @ Yg
return result
def repcolcol_sharded_matmul(
X: torch.Tensor, Y: torch.Tensor, replication_factor: int
) -> Any:
"""
Matrix multiplication for replicated x column-sharded -> column-sharded tensors
"""
X = all_reduce_cross_replica_sum(X, replication_factor, insert_in_grad_graph=True)
return X @ Y
def colrowrep_sharded_matmul(
X: torch.Tensor, Y: torch.Tensor, replication_factor: int
) -> Any:
"""Matrix multiplication for row-sharded x column-sharded -> replicated tensors"""
out = X @ Y
return all_reduce_cross_replica_sum(out, replication_factor)
Functions
def colrowrep_sharded_matmul(X: torch.Tensor, Y: torch.Tensor, replication_factor: int) ‑> Any
-
Matrix multiplication for row-sharded x column-sharded -> replicated tensors
Expand source code
def colrowrep_sharded_matmul( X: torch.Tensor, Y: torch.Tensor, replication_factor: int ) -> Any: """Matrix multiplication for row-sharded x column-sharded -> replicated tensors""" out = X @ Y return all_reduce_cross_replica_sum(out, replication_factor)
def repcolcol_sharded_matmul(X: torch.Tensor, Y: torch.Tensor, replication_factor: int) ‑> Any
-
Matrix multiplication for replicated x column-sharded -> column-sharded tensors
Expand source code
def repcolcol_sharded_matmul( X: torch.Tensor, Y: torch.Tensor, replication_factor: int ) -> Any: """ Matrix multiplication for replicated x column-sharded -> column-sharded tensors """ X = all_reduce_cross_replica_sum(X, replication_factor, insert_in_grad_graph=True) return X @ Y
def rowcolrow_sharded_matmul(X: torch.Tensor, Y: torch.Tensor, replication_factor: int, num_chunks: int = 1) ‑> Any
-
Matrix multiplication for row-sharded x column-sharded -> row-sharded tensors
Gathers the right multiplicand across IPU program replicas
Expand source code
def rowcolrow_sharded_matmul( X: torch.Tensor, Y: torch.Tensor, replication_factor: int, num_chunks: int = 1 ) -> Any: """ Matrix multiplication for row-sharded x column-sharded -> row-sharded tensors Gathers the right multiplicand across IPU program replicas """ X_local_outer_dim, _ = X.shape _, Y_local_outer_dim = Y.shape result = torch.zeros((X_local_outer_dim, Y_local_outer_dim * replication_factor)) X = einops.rearrange(X, "m (k c) -> m k c", k=num_chunks) Y = einops.rearrange(Y, "(k c) n -> c k n", k=num_chunks) for i in range(num_chunks): index = torch.tensor([i]) Yg = all_gather_cross_replica( torch.index_select(Y, dim=1, index=index).squeeze(), replication_factor ) Yg = einops.rearrange(Yg, "r c n -> c (r n)") Xp = torch.index_select(X, dim=1, index=index).squeeze() result += Xp @ Yg return result