besskge.utils.gather_indices

besskge.utils.gather_indices(x, index)[source]

IPU-friendly gather function like torch.take_along_dim() for 2-dimensional tensors (indices along dim=1).

Parameters:
  • x (Tensor) – shape: (a, e)

  • index (Tensor) – shape: (b, k)

Return type:

Tensor

Returns:

shape: (b, k) For all rows of x, take the k elements on the row with the indices specified by the corresponding row of index. If b == 1, the same indices are gathered from all rows of x; if a == 1, all rows in index gather from x[0]; otherwise a == b is required.