# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""
Utilities for initializing and managing entity/relation embedding tables.
"""
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from besskge.sharding import Sharding
[docs]def init_xavier_norm(embedding_table: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
    """
    Initialize embeddings according to Xavier normal scheme, with
    `fan_in = 0`, `fan_out=row_size`.
    :param embedding_table:
        Tensor of embedding parameters to initialize.
    :param gain:
        Scaling factor for standard deviation. Default: 1.0.
    :return:
        Initialized tensor.
    """
    return torch.nn.init.normal_(
        embedding_table, std=gain * np.sqrt(2.0 / embedding_table.shape[-1])
    ) 
[docs]def init_KGE_normal(
    embedding_table: torch.Tensor,
    std: float = 1.0,
    divide_by_embedding_size: bool = True,
) -> torch.Tensor:
    """
    Initialize embeddings according to normal distribution with mean 0.
    :param embedding_table:
        Tensor of embedding parameters to initialize.
    :param std:
        Standard deviation. Default: 1.0.
    :param divide_by_embedding_size:
        Rescale standard deviation by `1/row_size`. Default: True.
    :return:
        Initialized tensor.
    """
    if divide_by_embedding_size:
        std /= embedding_table.shape[-1]
    return torch.nn.init.normal_(embedding_table, std=std) 
[docs]def initialize_entity_embedding(
    sharding: Sharding,
    initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]],
    row_size: Optional[List[int]] = None,
) -> torch.nn.Parameter:
    """
    Initialize entity embedding table.
    :param sharding:
        Entity sharding.
    :param initializer:
        Embedding table or list of initializing functions. If providing
        an embedding table, this can either be sharded
        (shape: [n_shard, max_entity_per_shard, row_size])
        or unsharded [shape: (n_entity, row_size]).
        If providing list of initializers, this needs to be of same length
        as :attr:`row_size`.
    :param row_size:
        Number of parameters for each entity.
        This needs to be a list, with the lengths of the different embedding tensors
        to allocate for each entity. Each embedding tensor, once allocated, is
        initialized with the corresponding entry of :attr:`initializer`.
        Can be omitted if passing an embedding table as :attr:`initializer`.
    :return: shape: (n_shard, max_ent_per_shard, row_size)
        Entity embedding table.
    """
    if isinstance(initializer, torch.Tensor):
        if initializer.dim() == 3:
            if initializer.size()[:2] != torch.Size(
                [sharding.n_shard, sharding.max_entity_per_shard]
            ):
                raise ValueError(
                    "Shape of sharded table provided for initialization"
                    " is not compatible with sharding"
                )
            entity_embedding = initializer.to(torch.float32)
        elif initializer.dim() == 2:
            if initializer.shape[0] != sharding.n_entity:
                raise ValueError(
                    "Number of rows of table provided for initialization"
                    " different from number of entities."
                )
            initializer_sharded = initializer[
                torch.from_numpy(
                    np.minimum(sharding.shard_and_idx_to_entity, sharding.n_entity - 1)
                )
            ]
            entity_embedding = initializer_sharded.to(torch.float32)
        else:
            raise ValueError("Table for initialization needs to be 2- or 3-dimensional")
        if row_size:
            assert (
                sum(row_size) == entity_embedding.shape[-1]
            ), "Initialization tensor and row_size provided are incompatible"
    else:
        if not row_size:
            raise ValueError(
                "If not providing an embedding table, row_size needs to be specified"
            )
        if len(initializer) != len(row_size):
            raise ValueError(
                "Different number of embedding splits and initializers provided"
            )
        entity_embedding = torch.empty(
            (sharding.n_shard, sharding.max_entity_per_shard, 0),
            dtype=torch.float32,
        )
        for slice_size, init in zip(row_size, initializer):
            table_slice = init(
                torch.empty(
                    size=(
                        sharding.n_shard,
                        sharding.max_entity_per_shard,
                        slice_size,
                    ),
                    dtype=torch.float32,
                )
            )
            entity_embedding = torch.concat([entity_embedding, table_slice], dim=-1)
    return torch.nn.Parameter(entity_embedding) 
[docs]def initialize_relation_embedding(
    n_relation_type: int,
    inverse_relations: bool,
    initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]],
    row_size: Optional[List[int]] = None,
) -> torch.nn.Parameter:
    """
    Initialize relation embedding table.
    :param n_relation_type:
        Number of relation types.
    :param inverse_relations:
        If True, learn embeddings for inverse relations, in addition to direct ones.
        Needs to be set to `True` when inverse triples are added to the dataset.
        Given a relation with ID `i`, its inverse is the one with
        ID `i+n_relation_type`.
    :param initializer:
         Embedding table or list of initializing functions.
         If providing list of initializers, this needs to be of same length
         as :attr:`row_size`.
    :param row_size:
        Number of parameters for each relation type.
        This needs to be a list, with the lengths of the different embedding tensors
        to allocate for each relation. Each embedding tensor, once allocated, is
        initialized with the corresponding entry of :attr:`initializer`.
        Can be omitted if passing an embedding table as :attr:`initializer`.
    :return:
        Relation embedding table.
    """
    if isinstance(initializer, torch.Tensor):
        if initializer.dim() != 2:
            raise ValueError("Table for initialization needs to be 2-dimensional")
        relation_embedding = initializer.to(torch.float32)
        if row_size:
            assert (
                sum(row_size) == relation_embedding.shape[-1]
            ), "Initialization tensor and row_size provided are incompatible"
    else:
        if not row_size:
            raise ValueError(
                "If not providing an embedding table, row_size needs to be specified"
            )
        if len(initializer) != len(row_size):
            raise ValueError(
                "Different number of embedding splits and initializers provided"
            )
        n_rows = 2 * n_relation_type if inverse_relations else n_relation_type
        relation_embedding = torch.empty(
            (n_rows, 0),
            dtype=torch.float32,
        )
        for slice_size, init in zip(row_size, initializer):
            table_slice = init(
                torch.empty(
                    size=(
                        n_rows,
                        slice_size,
                    ),
                    dtype=torch.float32,
                )
            )
            relation_embedding = torch.concat([relation_embedding, table_slice], dim=-1)
    return torch.nn.Parameter(relation_embedding) 
[docs]def refactor_embedding_sharding(
    entity_embedding: torch.nn.Parameter,
    old_sharding: Sharding,
    new_sharding: Sharding,
) -> torch.nn.Parameter:
    """
    Refactor sharded entity embedding table to pass from
    one entity sharding to a different one.
    :param entity_embedding: shape: (n_shard_old, max_ent_per_shard_old, row_size)
        Entity embedding table sharded according to `old_sharding`.
    :param old_sharding:
        The current entity sharding.
    :param new_sharding:
        The new entity sharding.
    :return: shape: (n_shard_new, max_ent_per_shard_new, row_size)
        The refactored entity embedding table, sharded according
        to `new_sharding`.
    """
    embedding_table = entity_embedding.detach()
    unsharded_table = embedding_table[
        old_sharding.entity_to_shard, old_sharding.entity_to_idx
    ]
    return initialize_entity_embedding(
        initializer=unsharded_table, sharding=new_sharding
    )