Source code for besskge.bess

# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

"""
PyTorch modules implementing the BESS distribution scheme :cite:p:`BESS`
for KGE training and inference on multiple IPUs.
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, Union, cast

import numpy as np
import poptorch
import torch
from poptorch_experimental_addons.collectives import (
    all_gather_cross_replica as all_gather,
)
from poptorch_experimental_addons.collectives import (
    all_to_all_single_cross_replica as all_to_all,
)

from besskge.loss import BaseLossFunction
from besskge.metric import Evaluation
from besskge.negative_sampler import (
    PlaceholderNegativeSampler,
    ShardedNegativeSampler,
    TripleBasedShardedNegativeSampler,
)
from besskge.scoring import BaseScoreFunction
from besskge.utils import gather_indices

BAD_NEGATIVE_SCORE = -50000.0


[docs]class BessKGE(torch.nn.Module, ABC): """ Base class for distributed training and inference of KGE models, using the distribution framework BESS :cite:p:`BESS`. To be used in combination with a batch sampler based on a "ht_shardpair"-partitioned triple set. """ def __init__( self, negative_sampler: ShardedNegativeSampler, score_fn: BaseScoreFunction, loss_fn: Optional[BaseLossFunction] = None, evaluation: Optional[Evaluation] = None, return_scores: bool = False, augment_negative: bool = False, ) -> None: """ Initialize BESS-KGE module. :param negative_sampler: Sampler of negative entities. :param score_fn: Scoring function. :param loss_fn: Loss function, required when training. Default: None. :param evaluation: Evaluation module, for computing metrics on device. Default: None. :param return_scores: If True, return positive and negative scores of batches to the host. Default: False. :param augment_negative: If True, augment sampled negative entities with the head/tails (according to the corruption scheme) of other positive triples in the micro-batch. Default: False. """ super().__init__() self.sharding = score_fn.sharding self.negative_sampler = negative_sampler self.score_fn = score_fn self.loss_fn = loss_fn self.evaluation = evaluation self.return_scores = return_scores self.augment_negative = augment_negative if not (loss_fn or evaluation or return_scores): raise ValueError( "Nothing to return. At least one of loss_fn," " evaluation or return_scores needs to be != None" ) if self.augment_negative: assert ( score_fn.negative_sample_sharing ), "Negative augmentation requires negative sample sharing" assert not isinstance( self, ScoreMovingBessKGE ), "ScoreMovingBessKGE does not support negative augmentation" if negative_sampler.flat_negative_format: assert ( score_fn.negative_sample_sharing ), "Using flat negative format requires negative sample sharing" elif score_fn.negative_sample_sharing and isinstance( self.negative_sampler, TripleBasedShardedNegativeSampler ): raise ValueError( "Negative sample sharing cannot be used" " with non-flat triple-specific negatives" ) self.entity_embedding = self.score_fn.entity_embedding self.entity_embedding_size: int = self.score_fn.entity_embedding.shape[-1] @property def n_embedding_parameters(self) -> int: """ Returns the number of trainable parameters in the embedding tables """ return ( self.score_fn.entity_embedding.numel() + self.score_fn.relation_embedding.numel() )
[docs] def forward( self, head: torch.Tensor, relation: torch.Tensor, tail: torch.Tensor, negative: torch.Tensor, triple_mask: Optional[torch.Tensor] = None, triple_weight: Optional[torch.Tensor] = None, negative_mask: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """ The forward step. Comprises of four phases: 1) Gather relevant embeddings from local memory; 2) Share embeddings with other devices through collective operators; 3) Score positive and negative triples; 4) Compute loss/metrics. Each device scores `n_shard * positive_per_partition` positive triples. :param head: shape: (1, n_shard, positive_per_partition) Head indices. :param relation: shape: (1, n_shard, positive_per_partition) Relation indices. :param tail: shape: (1, n_shard, positive_per_partition) Tail indices. :param triple_mask: shape: (1, n_shard, positive_per_partition) Mask to filter the triples in the micro-batch before computing metrics. :param negative: shape: (1, n_shard, B, padded_negative) Indices of negative entities, with `B = 1, 2 or n_shard * positive_per_partition`. :param triple_weight: shape: (1, n_shard * positive_per_partition,) or (1,) Weights of positive triples. :param negative_mask: shape: (1, B, n_shard, padded_negative) Mask to identify padding negatives, to discard when computing metrics. :return: Micro-batch loss, scores and metrics. """ if triple_weight is None: triple_weight = torch.tensor( [1.0], dtype=torch.float32, requires_grad=False, device="ipu", ) head, relation, tail, negative, triple_weight = ( head.squeeze(0), relation.squeeze(0), tail.squeeze(0), negative.squeeze(0), triple_weight.squeeze(0), ) positive_score, negative_score = self.score_batch( head, relation, tail, negative ) if negative_mask is not None: negative_mask = negative_mask.squeeze(0).flatten(start_dim=-2) # shape (B, n_shard * padded_neg_length) if ( self.negative_sampler.flat_negative_format and self.negative_sampler.corruption_scheme == "ht" ): cutpoint = relation.shape[1] // 2 mask_h, mask_t = torch.split(negative_mask, 1, dim=0) negative_mask = torch.concat( [ mask_h.expand(relation.shape[0], cutpoint, -1), mask_t.expand( relation.shape[0], relation.shape[1] - cutpoint, -1 ), ], dim=1, ).flatten(end_dim=1) if self.augment_negative: step = ( 1 if self.negative_sampler.flat_negative_format else 1 + negative.shape[0] * negative.shape[2] ) aug_mask = ( torch.arange( negative_score.shape[1], device=negative_score.device, dtype=torch.int, )[None, :] == step * torch.arange( negative_score.shape[0], device=negative_score.device, dtype=torch.int, )[:, None] ) if self.negative_sampler.corruption_scheme == "ht": aug_mask = ( aug_mask[: aug_mask.shape[0] // 2, :] .reshape(relation.shape[0], relation.shape[1] // 2, -1) .repeat(1, 2, 1) .flatten(end_dim=1) ) if negative_mask is not None: aug_mask[:, -negative_mask.shape[1] :] = ~negative_mask # Discard score of true head/tail from negative scores negative_score += ( torch.tensor( BAD_NEGATIVE_SCORE, dtype=negative_score.dtype, device=negative_score.device, ) * aug_mask ) elif negative_mask is not None: # Kill scores of padding negatives negative_score += torch.tensor( BAD_NEGATIVE_SCORE, dtype=negative_score.dtype, device=negative_score.device, ) * (~negative_mask) out_dict: Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] = dict() if self.return_scores: out_dict.update( positive_score=positive_score, negative_score=negative_score ) if self.loss_fn: # Losses are always computed in FP32 loss = self.loss_fn( positive_score.float(), negative_score.float(), triple_weight, ) out_dict.update(loss=poptorch.identity_loss(loss, reduction="none")) if self.evaluation: if triple_mask is not None: triple_mask = triple_mask.flatten() with torch.no_grad(): batch_rank = self.evaluation.ranks_from_scores( positive_score, negative_score ) if self.evaluation.return_ranks: out_dict.update(ranks=batch_rank) out_dict.update( metrics=self.evaluation.stacked_metrics_from_ranks( batch_rank, triple_mask ) ) return out_dict
[docs] @abstractmethod def score_batch( self, head: torch.Tensor, relation: torch.Tensor, tail: torch.Tensor, negative: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute positive and negative scores for the micro-batch. :param head: see :meth:`BessKGE.forward` :param relation: see :meth:`BessKGE.forward` :param tail: see :meth:`BessKGE.forward` :param negative: see :meth:`BessKGE.forward` :return: Positive (shape: (n_shard * positive_per_partition,)) and negative (shape: (n_shard * positive_per_partition, n_negative)) scores of the micro-batch. """ raise NotImplementedError
[docs]class EmbeddingMovingBessKGE(BessKGE): """ Compute negative scores on the shard where the positive triples are scored (namely the head shard). This requires moving the embedding of negative entities between shards, which can be done with a single AllToAll collective. Each triple is scored against a total number of entities equal to `n_negative * n_shard` if negative sample sharing is disabled, or to `n_negative * n_shard * B` otherwise (see :meth:`BessKGE.forward`) for "h", "t" corruption scheme, `n_negative * n_shard * (B > 2 ? B // 2 : 1)` for "ht". """ # docstr-coverage: inherited
[docs] def score_batch( self, head: torch.Tensor, relation: torch.Tensor, tail: torch.Tensor, negative: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # Gather embeddings n_shard = relation.shape[0] negative_flat = negative.flatten(start_dim=1) gather_idx = torch.concat([head, tail, negative_flat], dim=1) head_embedding, tail_and_negative_embedding = torch.split( self.entity_embedding[gather_idx], [head.shape[1], tail.shape[1] + negative_flat.shape[1]], dim=1, ) # Share negative and tail embeddings if self.negative_sampler.local_sampling: tail_embedding, negative_embedding = torch.split( tail_and_negative_embedding, [tail.shape[1], negative_flat.shape[1]], dim=1, ) tail_embedding = all_to_all(tail_embedding, n_shard) else: tail_and_negative_embedding = all_to_all( tail_and_negative_embedding, n_shard ) tail_embedding, negative_embedding = torch.split( tail_and_negative_embedding, [tail.shape[1], negative_flat.shape[1]], dim=1, ) negative_embedding = ( negative_embedding.reshape(*negative.shape, self.entity_embedding_size) .transpose(0, 1) .flatten(start_dim=1, end_dim=2) ) positive_score = self.score_fn.score_triple( head_embedding.flatten(end_dim=1), relation.flatten(end_dim=1), tail_embedding.flatten(end_dim=1), ) if self.negative_sampler.corruption_scheme == "h": if self.augment_negative: negative_embedding = torch.concat( [ head_embedding.view( negative_embedding.shape[0], -1, self.entity_embedding_size ), negative_embedding, ], dim=1, ) negative_score = self.score_fn.score_heads( negative_embedding, relation.flatten(end_dim=1), tail_embedding.flatten(end_dim=1), ) elif self.negative_sampler.corruption_scheme == "t": if self.augment_negative: negative_embedding = torch.concat( [ tail_embedding.view( negative_embedding.shape[0], -1, self.entity_embedding_size ), negative_embedding, ], dim=1, ) negative_score = self.score_fn.score_tails( head_embedding.flatten(end_dim=1), relation.flatten(end_dim=1), negative_embedding, ) elif self.negative_sampler.corruption_scheme == "ht": cut_point = relation.shape[1] // 2 relation_half1, relation_half2 = torch.split( relation, cut_point, dim=1, ) head_half1, head_half2 = torch.split( head_embedding, cut_point, dim=1, ) tail_half1, tail_half2 = torch.split( tail_embedding, cut_point, dim=1, ) if self.negative_sampler.flat_negative_format: negative_heads, negative_tails = torch.split( negative_embedding, 1, dim=0 ) else: negative_embedding = negative_embedding.reshape( *relation.shape[:2], -1, self.entity_embedding_size ) negative_heads, negative_tails = torch.split( negative_embedding, cut_point, dim=1 ) negative_heads = negative_heads.flatten(end_dim=1) negative_tails = negative_tails.flatten(end_dim=1) if self.augment_negative: negative_heads = torch.concat( [ head_half1.view( negative_heads.shape[0], -1, self.entity_embedding_size ), negative_heads, ], dim=1, ) negative_tails = torch.concat( [ tail_half2.view( negative_tails.shape[0], -1, self.entity_embedding_size ), negative_tails, ], dim=1, ) negative_score_heads = self.score_fn.score_heads( negative_heads, relation_half1.flatten(end_dim=1), tail_half1.flatten(end_dim=1), ) negative_score_tails = self.score_fn.score_tails( head_half2.flatten(end_dim=1), relation_half2.flatten(end_dim=1), negative_tails, ) negative_score = torch.concat( [ negative_score_heads.reshape(*relation_half1.shape[:2], -1), negative_score_tails.reshape(*relation_half2.shape[:2], -1), ], dim=1, ).flatten(end_dim=1) return positive_score, negative_score
[docs]class ScoreMovingBessKGE(BessKGE): """ Compute negative scores on the shard where the negative entities are stored. This avoids moving embeddings between shards (convenient when the number of negative entities is very large, for example when scoring queries against all entities in the knowledge graph, or when using a large embedding size). AllGather collectives are required to replicate queries on all devices, so that they can be scored against the local negative entities. An AllToAll collective is then used to send the scores back to the correct device. For the number of negative samples scored for each triple, see the corresponding value documented in :class:`EmbeddingMovingBessKGE` and, if using negative sample sharing, multiply that by `n_shard`. Does not support local sampling or negative augmentation. """ # docstr-coverage: inherited
[docs] def score_batch( self, head: torch.Tensor, relation: torch.Tensor, tail: torch.Tensor, negative: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: n_shard = self.sharding.n_shard # Gather embeddings # relation_embedding = self.score_fn.relation_embedding[relation] negative_flat = negative.flatten(start_dim=1) gather_idx = torch.concat([head, tail, negative_flat], dim=1) head_embedding, tail_embedding, negative_embedding = torch.split( self.entity_embedding[gather_idx], [head.shape[1], tail.shape[1], negative_flat.shape[1]], dim=1, ) negative_embedding = negative_embedding.reshape( *negative.shape, self.entity_embedding_size ) if ( isinstance(self.negative_sampler, TripleBasedShardedNegativeSampler) and self.negative_sampler.flat_negative_format ): # Negatives are replicated along dimension 0, for local scoring only # one copy is needed negative_embedding = negative_embedding[0].unsqueeze(0) relation_all = all_gather(relation, n_shard) if self.negative_sampler.corruption_scheme == "h": tail_embedding_all = all_gather(tail_embedding, n_shard).transpose(0, 1) negative_score = self.score_fn.score_heads( negative_embedding.flatten(end_dim=1), relation_all.flatten(end_dim=2), tail_embedding_all.flatten(end_dim=2), ) elif self.negative_sampler.corruption_scheme == "t": head_embedding_all = all_gather(head_embedding, n_shard) negative_score = self.score_fn.score_tails( head_embedding_all.flatten(end_dim=2), relation_all.flatten(end_dim=2), negative_embedding.flatten(end_dim=1), ) elif self.negative_sampler.corruption_scheme == "ht": cut_point = relation.shape[1] // 2 relation_half1, relation_half2 = torch.split( relation_all, cut_point, dim=2, ) tail_embedding_all = all_gather( tail_embedding[:, :cut_point, :], n_shard ).transpose(0, 1) head_embedding_all = all_gather(head_embedding[:, cut_point:, :], n_shard) if self.negative_sampler.flat_negative_format: negative_heads, negative_tails = torch.split( negative_embedding, 1, dim=1 ) negative_heads = negative_heads.flatten(end_dim=1) negative_tails = negative_tails.flatten(end_dim=1) else: negative_embedding = negative_embedding.reshape( self.sharding.n_shard, *relation.shape[:2], -1, self.entity_embedding_size ) negative_heads, negative_tails = torch.split( negative_embedding, cut_point, dim=2 ) negative_heads = negative_heads.flatten(end_dim=2) negative_tails = negative_tails.flatten(end_dim=2) negative_score_heads = self.score_fn.score_heads( negative_heads, relation_half1.flatten(end_dim=2), tail_embedding_all.flatten(end_dim=2), ) negative_score_tails = self.score_fn.score_tails( head_embedding_all.flatten(end_dim=2), relation_half2.flatten(end_dim=2), negative_tails, ) negative_score = torch.concat( [ negative_score_heads.reshape(*relation_half1.shape[:3], -1), negative_score_tails.reshape(*relation_half2.shape[:3], -1), ], dim=2, ).flatten(end_dim=2) # Send negative scores back to corresponding triple processing device negative_score = ( all_to_all( negative_score.reshape( n_shard, relation.shape[0] * relation.shape[1], -1 ), n_shard, ) .transpose(0, 1) .flatten(start_dim=1) ) # Recover micro-batch tail embeddings (#TODO: avoidable?) tail_embedding = all_to_all(tail_embedding, n_shard) positive_score = self.score_fn.score_triple( head_embedding.flatten(end_dim=1), relation.flatten(end_dim=1), tail_embedding.flatten(end_dim=1), ) return positive_score, negative_score
[docs]class TopKQueryBessKGE(torch.nn.Module): """ Distributed scoring of (h, r, ?) or (?, r, t) queries (against all entities in the knowledge graph, or a query-specific set) returning the top-k most likely completions, based on the BESS :cite:p:`BESS` inference scheme. To be used in combination with a batch sampler based on a "h_shard"/"t_shard"-partitioned triple set. If the correct tail/head is known, this can be passed as an input in order to compute metrics on the final predictions. This class is recommended over :class:`BessKGE` when the number of negatives is large, for example when one wants to score queries against all entities in the knowledge graph, as it uses a sliding window over the negative sample size via an on-device for-loop. Only to be used for inference. """ def __init__( self, k: int, candidate_sampler: Union[ TripleBasedShardedNegativeSampler, PlaceholderNegativeSampler ], score_fn: BaseScoreFunction, evaluation: Optional[Evaluation] = None, return_scores: bool = False, window_size: int = 100, ) -> None: """ Initialize TopK BESS-KGE module. :param k: For each query return the top-k most likely predictions. :param candidate_sampler: Sampler of candidate entities to score against queries. Use :class:`besskge.negative_sampler.PlaceholderNegativeSampler` to score queries against all entities in the knowledge graph, avoiding unnecessary loading of negative entities on device. :param score_fn: Scoring function. :param evaluation: Evaluation module, for computing metrics on device. Default: None. :param return_scores: If True, return scores of the top-k best completions. Default: False. :param window_size: Size of the sliding window, namely the number of negative entities scored against each query at each step of the on-device for-loop. Should be decreased with large batch sizes, to avoid an OOM error. Default: 100. """ super().__init__() self.sharding = score_fn.sharding self.negative_sampler = candidate_sampler self.score_fn = score_fn self.evaluation = evaluation self.return_scores = return_scores self.k = k self.window_size = window_size if self.negative_sampler.flat_negative_format: assert ( score_fn.negative_sample_sharing ), "Using flat negative format requires negative sample sharing" elif score_fn.negative_sample_sharing: raise ValueError( "Negative sample sharing cannot be used" " with non-flat triple-specific negatives" ) if self.negative_sampler.corruption_scheme not in ["h", "t"]: raise ValueError("TopKQueryBessKGE only support 'h', 't' corruption scheme") if isinstance(self.negative_sampler, TripleBasedShardedNegativeSampler): assert self.negative_sampler.mask_on_gather, ( "TopKQueryBessKGE requires setting mask_on_gather=True" " in the candidate_sampler" ) self.entity_embedding = self.score_fn.entity_embedding self.entity_embedding_size: int = self.entity_embedding.shape[-1]
[docs] def forward( self, relation: torch.Tensor, head: Optional[torch.Tensor] = None, tail: Optional[torch.Tensor] = None, negative: Optional[torch.Tensor] = None, triple_mask: Optional[torch.Tensor] = None, negative_mask: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """ Forward step. Similarly to :class:`ScoreMovingBessKGE`, candidates are scored on the device where they are gathered, then scores for the same query against candidates in different shards are collected together via an AllToAll. At each iteration of the for loop, only the top-k best query responses and respective scores are kept to be used in the next iteration, while the rest are discarded. :param relation: shape: (1, shard_bs,) Relation indices. :param head: shape: (1, shard_bs,) Head indices, if known. Default: None. :param tail: shape: (1, shard_bs,) Tail indices, if known. Default: None. :param negative: shape: (1, n_shard, B, padded_negative) Candidates to score against the queries. It can be the same set for all queries (B=1), or specific for each query in the batch (B=shard_bs). If None, score each query against all entities in the knowledge graph. Default: None. :param triple_mask: shape: (1, shard_bs,) Mask to filter the triples in the micro-batch before computing metrics. Default: None. :param negative_mask: shape: (1, n_shard, B, padded_negative) If candidates are provided, mask to discard padding negatives when computing best completions. Requires the use of :code:`mask_on_gather=True` in the candidate sampler (see :class:`besskge.negative_sampler.TripleBasedShardedNegativeSampler`). Default: None. """ relation = relation.squeeze(0) if head is not None: head = head.squeeze(0) if tail is not None: tail = tail.squeeze(0) candidate: torch.Tensor if negative is None: candidate = torch.arange( self.sharding.max_entity_per_shard, dtype=torch.int32, device=relation.device, ) else: assert negative_mask is not None candidate = negative.squeeze(0) negative_mask = negative_mask.squeeze(0) if self.negative_sampler.flat_negative_format: candidate = candidate[0] negative_mask = negative_mask[0] negative_mask = negative_mask.reshape(-1, negative_mask.shape[-1]) candidate = candidate.reshape(-1, candidate.shape[-1]) # shape (1 or total_bs, n_negative_per_shard) n_shard = self.sharding.n_shard shard_bs = relation.shape[0] n_best = self.k + 1 relation_all = all_gather(relation, n_shard) if self.negative_sampler.corruption_scheme == "h": tail_embedding = self.entity_embedding[tail] tail_embedding_all = all_gather(tail_embedding, n_shard) elif self.negative_sampler.corruption_scheme == "t": head_embedding = self.entity_embedding[head] head_embedding_all = all_gather(head_embedding, n_shard) def loop_body( curr_score: torch.Tensor, curr_idx: torch.Tensor, slide_idx: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mask = slide_idx < candidate.shape[-1] slide_idx = torch.where( mask, slide_idx, torch.tensor( [candidate.shape[-1] - 1], dtype=torch.int32, device=mask.device ), ) if negative_mask is not None: mask = torch.logical_and(mask, gather_indices(negative_mask, slide_idx)) neg_ent_idx = gather_indices( candidate, slide_idx ) # shape (1 or n_sh * shard_bs, ws) negative_embedding = self.entity_embedding[neg_ent_idx] if self.negative_sampler.corruption_scheme == "h": negative_score = self.score_fn.score_heads( negative_embedding, relation_all.flatten(end_dim=1), tail_embedding_all.flatten(end_dim=1), ) elif self.negative_sampler.corruption_scheme == "t": negative_score = self.score_fn.score_tails( head_embedding_all.flatten(end_dim=1), relation_all.flatten(end_dim=1), negative_embedding, ) negative_score += torch.tensor( BAD_NEGATIVE_SCORE, dtype=negative_score.dtype, device=mask.device ) * (~mask).to( dtype=negative_score.dtype ) # shape (n_shard * shard_bs, ws) top_k_scores = torch.topk( torch.concat([negative_score, curr_score], dim=1), k=n_best, dim=1, ) indices_broad = neg_ent_idx.broadcast_to(*negative_score.shape) indices = torch.concat([indices_broad, curr_idx], dim=1) curr_idx = gather_indices(indices, top_k_scores.indices) return ( cast(torch.Tensor, top_k_scores.values), # mypy check curr_idx, slide_idx + torch.tensor( self.window_size, dtype=torch.int32, device=slide_idx.device ), ) n_rep = int(np.ceil(candidate.shape[-1] / self.window_size)) best_curr_score = torch.full( fill_value=BAD_NEGATIVE_SCORE, size=(n_shard * shard_bs, n_best), requires_grad=False, dtype=self.score_fn.entity_embedding.dtype, device=candidate.device, ) best_curr_idx = torch.full( fill_value=self.sharding.max_entity_per_shard, size=(n_shard * shard_bs, n_best), requires_grad=False, dtype=torch.int32, device=candidate.device, ) slide_idx = ( torch.arange(self.window_size, dtype=torch.int32, device=relation.device) .to(torch.int32) .reshape(1, -1) ) best_curr_score, best_curr_idx, _ = poptorch.for_loop( n_rep, loop_body, [ best_curr_score, best_curr_idx, slide_idx, ], ) # shape (total_bs, n_best) # Send back queries to original shard best_score = all_to_all( best_curr_score.reshape(n_shard, shard_bs, n_best), n_shard, ) best_idx = all_to_all( best_curr_idx.reshape(n_shard, shard_bs, n_best), n_shard, ) # Discard padding shard entities best_score += torch.tensor( BAD_NEGATIVE_SCORE, dtype=best_score.dtype, device=best_idx.device ) * ( best_idx >= torch.from_numpy(self.sharding.shard_counts)[:, None, None].to( dtype=torch.int32, device=best_idx.device ) ) # Best global indices best_global_idx = ( gather_indices( torch.from_numpy( self.sharding.shard_and_idx_to_entity, ).to(dtype=torch.int32, device=best_idx.device), best_idx.reshape(self.sharding.n_shard, -1), ) .reshape(*best_idx.shape) .transpose(0, 1) .flatten(start_dim=1) ) # Final topk among best k from all shards topk_final = torch.topk( best_score.transpose(0, 1).flatten(start_dim=1), k=self.k, dim=1 ) out_dict: Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] topk_global_id = gather_indices(best_global_idx, topk_final.indices) out_dict = dict(topk_global_id=topk_global_id) if self.return_scores: out_dict.update(topk_scores=topk_final.values) if self.evaluation: if triple_mask is not None: triple_mask = triple_mask.flatten() with torch.no_grad(): ground_truth = ( tail if self.negative_sampler.corruption_scheme == "t" else head ) assert ( ground_truth is not None ), "Evaluation requires providing ground truth entities" batch_rank = self.evaluation.ranks_from_indices( ground_truth, topk_global_id ) if self.evaluation.return_ranks: out_dict.update(ranks=batch_rank) out_dict.update( metrics=self.evaluation.stacked_metrics_from_ranks( batch_rank, triple_mask ) ) return out_dict
[docs]class AllScoresBESS(torch.nn.Module): """ Distributed scoring of (h, r, ?) or (?, r, t) queries against the entities in the knowledge graph, returning all scores to host in blocks, based on the BESS :cite:p:`BESS` inference scheme. To be used in combination with a batch sampler based on a "h_shard"/"t_shard"-partitioned triple set. Since each iteration on IPU computes only part of the scores (based on the size of the sliding window), metrics should be computed on host after aggregating data (see :class:`besskge.pipeline.AllScoresPipeline`). Only to be used for inference. """ def __init__( self, candidate_sampler: PlaceholderNegativeSampler, score_fn: BaseScoreFunction, window_size: int = 1000, ) -> None: """ Initialize AllScores BESS-KGE module. :param candidate_sampler: :class:`besskge.negative_sampler.PlaceholderNegativeSampler` class, specifying corruption scheme. :param score_fn: Scoring function. :param window_size: Size of the sliding window, namely the number of negative entities scored against each query at each step on IPU and returned to host. Should be decreased with large batch sizes, to avoid an OOM error. Default: 1000. """ super().__init__() self.sharding = score_fn.sharding self.score_fn = score_fn self.negative_sampler = candidate_sampler self.window_size = window_size if not score_fn.negative_sample_sharing: raise ValueError("AllScoresBESS requires using negative sample sharing") if self.negative_sampler.corruption_scheme not in ["h", "t"]: raise ValueError("AllScoresBESS only support 'h', 't' corruption scheme") if not isinstance(self.negative_sampler, PlaceholderNegativeSampler): raise ValueError( "AllScoresBESS requires a `PlaceholderNegativeSampler`" " candidate_sampler" ) self.entity_embedding = self.score_fn.entity_embedding self.entity_embedding_size: int = self.entity_embedding.shape[-1] self.candidate = torch.arange(self.window_size, dtype=torch.int32) self.n_step = int( np.ceil(self.sharding.max_entity_per_shard / self.window_size) )
[docs] def forward( self, step: torch.Tensor, relation: torch.Tensor, head: Optional[torch.Tensor] = None, tail: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward step. Similarly to :class:`ScoreMovingBessKGE`, candidates are scored on the device where they are gathered, then scores for the same query against candidates in different shards are collected together via an AllToAll. :param step: The index of the block (of size self.window_size) of entities on each IPU to score against queries. :param relation: shape: (1, shard_bs,) Relation indices. :param head: shape: (1, shard_bs,) Head indices, if known. Default: None. :param tail: shape: (1, shard_bs,) Tail indices, if known. Default: None. :return: The scores for the completions. """ relation = relation.squeeze(0) if head is not None: head = head.squeeze(0) if tail is not None: tail = tail.squeeze(0) n_shard = self.sharding.n_shard shard_bs = relation.shape[0] relation_all = all_gather(relation, n_shard) if self.negative_sampler.corruption_scheme == "h": tail_embedding = self.entity_embedding[tail] tail_embedding_all = all_gather(tail_embedding, n_shard) elif self.negative_sampler.corruption_scheme == "t": head_embedding = self.entity_embedding[head] head_embedding_all = all_gather(head_embedding, n_shard) # Local indices of the entities to score against queries ent_slice = torch.minimum( step * self.window_size + torch.arange(self.window_size, device=relation.device), torch.tensor(self.sharding.max_entity_per_shard - 1), ) negative_embedding = self.entity_embedding[ent_slice] if self.negative_sampler.corruption_scheme == "h": scores = self.score_fn.score_heads( negative_embedding, relation_all.flatten(end_dim=1), tail_embedding_all.flatten(end_dim=1), ) elif self.negative_sampler.corruption_scheme == "t": scores = self.score_fn.score_tails( head_embedding_all.flatten(end_dim=1), relation_all.flatten(end_dim=1), negative_embedding, ) # Send back queries to original shard scores = ( all_to_all( scores.reshape(n_shard, shard_bs, self.window_size), n_shard, ) .transpose(0, 1) .flatten(start_dim=1) ) # shape (bs, n_shard * ws) return scores