besskge.utils.gather_indices
- besskge.utils.gather_indices(x, index)[source]
IPU-friendly gather function like
torch.take_along_dim()
for 2-dimensional tensors (indices along dim=1).- Parameters:
- Return type:
- Returns:
shape: (b, k) For all rows of
x
, take the k elements on the row with the indices specified by the corresponding row ofindex
. Ifb == 1
, the same indices are gathered from all rows ofx
; ifa == 1
, all rows inindex
gather fromx[0]
; otherwisea == b
is required.