# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""
General purpose utilities.
"""
import torch
[docs]def gather_indices(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""
IPU-friendly gather function like :func:`torch.take_along_dim`
for 2-dimensional tensors (indices along dim=1).
:param x: shape: (a, e)
:param index: shape: (b, k)
:return: shape: (b, k)
For all rows of :code:`x`, take the `k` elements on the row with the
indices specified by the corresponding row of :code:`index`.
If :code:`b == 1`, the same indices are gathered from all rows of
:code:`x`; if :code:`a == 1`, all rows in :code:`index` gather from
:code:`x[0]`; otherwise :code:`a == b` is required.
"""
bs, sq = x.shape
_, mask_size = index.shape
index_flattened = (
index
+ torch.arange(bs, dtype=torch.int32, device=index.device)
.mul(torch.tensor(sq, dtype=torch.int32, device=index.device))
.unsqueeze(1)
).view(-1)
x = torch.index_select(x.view(-1), 0, index_flattened)
return x.view(-1, mask_size)
[docs]def get_entity_filter(
triples: torch.Tensor, filter_triples: torch.Tensor, filter_mode: str
) -> torch.Tensor:
"""
Compare two sets of triples: for each triple (h,r,t) in the first set, find
the entities `e` such that (e,r,t) (or (h,r,e), depending on `filter_mode`)
appears in the second set of triples.
:param triples: shape (x, 3)
The set of triples to construct filters for.
:param filter_triples: shape (y, 3)
The set of triples determining the head/tail entities to filter.
:param filter_mode:
Set to "h" to look for entities appearing as heads of the same (r,t) pair,
or to "t" to look for entities appearing as tails of the same (h,r) pair.
:return: shape (z, 2)
The sparse filters. Each row is given by a tuple (i, j), with i the index
of the triple in `triples` to which the filter applies to and j the global
ID of the entity to filter.
"""
if filter_mode == "t":
ent_col = 0
elif filter_mode == "h":
ent_col = 2
else:
raise ValueError("`filter_mode` needs to be either 'h' or 't'")
relation_filter = (filter_triples[:, 1]) == triples[:, 1].view(-1, 1)
entity_filter = (filter_triples[:, ent_col]) == triples[:, ent_col].view(-1, 1)
filter = (entity_filter & relation_filter).nonzero(as_tuple=False)
filter[:, 1] = filter_triples[:, 2 - ent_col].view(1, -1)[:, filter[:, 1]]
return filter
[docs]def complex_multiplication(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
"""
Batched complex multiplication.
:param v1: shape: (a, 2*e)
:code:`v1[:,:e]` real part, :code:`v1[:,e:]` imaginary part.
:param v2: shape: (a, 2*e)
:code:`v2[:,:e]` real part, :code:`v2[:,e:]` imaginary part.
:return: shape: (a, 2*e)
Row-wise complex multiplication.
"""
cutpoint = v1.shape[-1] // 2
v1_re, v1_im = torch.split(v1, cutpoint, dim=-1)
v2_re, v2_im = torch.split(v2, cutpoint, dim=-1)
return torch.concat(
[v1_re * v2_re - v1_im * v2_im, v1_re * v2_im + v1_im * v2_re], dim=-1
)
[docs]def complex_rotation(v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
r"""
Batched rotation by unitary tensors.
:param v: shape: (a, 2*e)
Complex tensor to rotate:
:code:`v[:,:e]` real part, :code:`v[:,e:]` imaginary part.
:param r: shape: (a, e)
Rotate :code:`v[k]` by :math:`e^{i \pi r[k]}`
:return: shape: (a, 2*e)
Row-wise rotated tensors.
"""
# Always compute sin and cos in fp16, as faster on IPU
if r.dtype == torch.float32 and r.device.type == "ipu":
r_cos = torch.cos(r.to(dtype=torch.float16)).to(dtype=torch.float32)
r_sin = torch.sin(r.to(dtype=torch.float16)).to(dtype=torch.float32)
else:
r_cos = torch.cos(r)
r_sin = torch.sin(r)
r_complex = torch.concat([r_cos, r_sin], dim=-1)
return complex_multiplication(v, r_complex)