# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""
General purpose utilities.
"""
import torch
[docs]def gather_indices(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    """
    IPU-friendly gather function like :func:`torch.take_along_dim`
    for 2-dimensional tensors (indices along dim=1).
    :param x: shape: (a, e)
    :param index: shape: (b, k)
    :return: shape: (b, k)
        For all rows of :code:`x`, take the `k` elements on the row with the
        indices specified by the corresponding row of :code:`index`.
        If :code:`b == 1`, the same indices are gathered from all rows of
        :code:`x`; if :code:`a == 1`, all rows in :code:`index` gather from
        :code:`x[0]`; otherwise :code:`a == b` is required.
    """
    bs, sq = x.shape
    _, mask_size = index.shape
    index_flattened = (
        index
        + torch.arange(bs, dtype=torch.int32, device=index.device)
        .mul(torch.tensor(sq, dtype=torch.int32, device=index.device))
        .unsqueeze(1)
    ).view(-1)
    x = torch.index_select(x.view(-1), 0, index_flattened)
    return x.view(-1, mask_size) 
[docs]def get_entity_filter(
    triples: torch.Tensor, filter_triples: torch.Tensor, filter_mode: str
) -> torch.Tensor:
    """
    Compare two sets of triples: for each triple (h,r,t) in the first set, find
    the entities `e` such that (e,r,t) (or (h,r,e), depending on `filter_mode`)
    appears in the second set of triples.
    :param triples: shape (x, 3)
        The set of triples to construct filters for.
    :param filter_triples: shape (y, 3)
        The set of triples determining the head/tail entities to filter.
    :param filter_mode:
        Set to "h" to look for entities appearing as heads of the same (r,t) pair,
        or to "t" to look for entities appearing as tails of the same (h,r) pair.
    :return: shape (z, 2)
        The sparse filters. Each row is given by a tuple (i, j), with i the index
        of the triple in `triples` to which the filter applies to and j the global
        ID of the entity to filter.
    """
    if filter_mode == "t":
        ent_col = 0
    elif filter_mode == "h":
        ent_col = 2
    else:
        raise ValueError("`filter_mode` needs to be either 'h' or 't'")
    relation_filter = (filter_triples[:, 1]) == triples[:, 1].view(-1, 1)
    entity_filter = (filter_triples[:, ent_col]) == triples[:, ent_col].view(-1, 1)
    filter = (entity_filter & relation_filter).nonzero(as_tuple=False)
    filter[:, 1] = filter_triples[:, 2 - ent_col].view(1, -1)[:, filter[:, 1]]
    return filter 
[docs]def complex_multiplication(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
    """
    Batched complex multiplication.
    :param v1: shape: (a, 2*e)
        :code:`v1[:,:e]` real part, :code:`v1[:,e:]` imaginary part.
    :param v2: shape: (a, 2*e)
        :code:`v2[:,:e]` real part, :code:`v2[:,e:]` imaginary part.
    :return: shape: (a, 2*e)
        Row-wise complex multiplication.
    """
    cutpoint = v1.shape[-1] // 2
    v1_re, v1_im = torch.split(v1, cutpoint, dim=-1)
    v2_re, v2_im = torch.split(v2, cutpoint, dim=-1)
    return torch.concat(
        [v1_re * v2_re - v1_im * v2_im, v1_re * v2_im + v1_im * v2_re], dim=-1
    ) 
[docs]def complex_rotation(v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    r"""
    Batched rotation by unitary tensors.
    :param v: shape: (a, 2*e)
        Complex tensor to rotate:
        :code:`v[:,:e]` real part, :code:`v[:,e:]` imaginary part.
    :param r: shape: (a, e)
        Rotate :code:`v[k]` by :math:`e^{i \pi r[k]}`
    :return: shape: (a, 2*e)
        Row-wise rotated tensors.
    """
    # Always compute sin and cos in fp16, as faster on IPU
    if r.dtype == torch.float32 and r.device.type == "ipu":
        r_cos = torch.cos(r.to(dtype=torch.float16)).to(dtype=torch.float32)
        r_sin = torch.sin(r.to(dtype=torch.float16)).to(dtype=torch.float32)
    else:
        r_cos = torch.cos(r)
        r_sin = torch.sin(r)
    r_complex = torch.concat([r_cos, r_sin], dim=-1)
    return complex_multiplication(v, r_complex)