besskge.utils.gather_indices
- besskge.utils.gather_indices(x, index)[source]
IPU-friendly gather function like
torch.take_along_dim()for 2-dimensional tensors (indices along dim=1).- Parameters:
- Return type:
- Returns:
shape: (b, k) For all rows of
x, take the k elements on the row with the indices specified by the corresponding row ofindex. Ifb == 1, the same indices are gathered from all rows ofx; ifa == 1, all rows inindexgather fromx[0]; otherwisea == bis required.