Source code for besskge.metric

# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

"""
Utilities for computing metrics to evaluate the predictions of KGE models.
"""

import re
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional, cast

import torch


[docs]class BaseMetric(ABC): @abstractmethod def __call__(self, prediction_rank: torch.Tensor) -> torch.Tensor: """ Returns metric values for the elements in the batch. :param prediction_rank: shape: (batch_size,) The rank of ground truth among ordered predictions. :return: Metric values for the element in the batch. """ raise NotImplementedError
[docs]class ReciprocalRank(BaseMetric): """ Reciprocal rank (for example to compute MRR). Returns the reciprocal rank of ground truth among predictions. """ def __init__(self) -> None: """ :param reduction: see :class:`BaseMetric` """ # docstr-coverage: inherited def __call__(self, prediction_rank: torch.Tensor) -> torch.Tensor: return torch.reciprocal(prediction_rank)
[docs]class HitsAtK(BaseMetric): """ Hits@K metric. Returns the count of triples where the ground truth is among the K most-likely predicted entities. """ def __init__( self, k: int, ) -> None: """ :param k: Maximum acceptable rank. """ self.K = k # docstr-coverage: inherited def __call__(self, prediction_rank: torch.Tensor) -> torch.Tensor: return (prediction_rank <= self.K).to(torch.float)
#: Metric str -> callable METRICS_DICT = {"mrr": ReciprocalRank, "hits@k": HitsAtK}
[docs]class Evaluation: """ A module for computing link prediction metrics. """ def __init__( self, metric_list: List[str], mode: str = "average", worst_rank_infty: bool = False, reduction: str = "none", return_ranks: bool = False, ) -> None: """ Initialize evaluation module. :param metric_list: List of metrics to compute. Currently supports "mrr" and "hits@K". :param mode: Mode used for metrics. Can be "optimistic", "pessimistic" or "average". Default: "average". :param worst_rank_infty: If True, assign a prediction rank of infinity as the worst possible rank. If False, assign a prediction rank of `n_negative + 1` as the worst possible rank. Default: False. :param reduction: Method to use to reduce metrics along the batch dimension. Currently supports "none" (no reduction) and "sum". :param return_ranks: If True, returns prediction ranks alongside metrics. """ if mode not in ["pessimistic", "optimistic", "average"]: raise ValueError(f"Mode {mode} not supported for evaluation") self.mode = mode self.return_ranks = return_ranks if reduction == "none": self.reduction = lambda x: x elif reduction == "sum": self.reduction = lambda x: torch.sum(x, dim=0) else: raise ValueError(f"Reduction {reduction} not supported for evaluation") self.worst_rank_infty = worst_rank_infty self.metrics: Dict[str, Callable[[torch.Tensor], torch.Tensor]] hits_k_filter = [re.search(r"hits@(\d+)", a) for a in metric_list] self.metrics = { a[0]: METRICS_DICT["hits@k"](k=int(a[1])) for a in hits_k_filter if a } self.metrics.update( { m_name: METRICS_DICT[m_name]() for m_name in list(set(metric_list) - set(self.metrics.keys())) } )
[docs] def ranks_from_scores( self, pos_score: torch.Tensor, candidate_score: torch.Tensor ) -> torch.Tensor: """ Compute the prediction rank from the score of the positive triple (ground truth) and the scores of triples corrupted with the candidate entities. :param pos_score: shape: (batch_size,) Scores of positive triples. :param candidate_score: shape: (batch_size, n_candidate) Scores of candidate triples. :return: The rank of the positive score among the ordered scores of the candidate triples. """ batch_size, n_negative = candidate_score.shape pos_score = pos_score.reshape(-1, 1) if pos_score.shape[0] != batch_size: raise ValueError( "`pos_score` and `candidate_score` need to have same size at dimension 0" ) pos_score.nan_to_num_(-torch.inf) if self.mode == "optimistic": n_better = torch.sum(candidate_score > pos_score, dim=-1).to(torch.float32) if self.worst_rank_infty: mask = n_better == n_negative elif self.mode == "pessimistic": n_better = torch.sum(candidate_score >= pos_score, dim=-1).to(torch.float32) if self.worst_rank_infty: mask = n_better == n_negative elif self.mode == "average": n_better_opt = torch.sum(candidate_score > pos_score, dim=-1).to( torch.float32 ) n_better_pess = torch.sum(candidate_score >= pos_score, dim=-1).to( torch.float32 ) n_better = torch.tensor( 0.5, dtype=torch.float32, device=n_better_opt.device ) * (n_better_opt + n_better_pess) if self.worst_rank_infty: mask = torch.logical_or( n_better_opt == n_negative, n_better_pess == n_negative ) batch_rank = ( torch.tensor(1.0, dtype=torch.float32, device=n_better.device) + n_better ) if self.worst_rank_infty: batch_rank[mask] = torch.inf return batch_rank
[docs] def ranks_from_indices( self, ground_truth: torch.Tensor, candidate_indices: torch.Tensor ) -> torch.Tensor: """ Compute the prediction rank from the ground truth ID and ORDERED candidate IDs. :param ground_truth: shape: (batch_size,) Indices of ground truth entities for each query. :param candidate_indices: shape: (batch_size, n_candidates) Indices of top `n_candidates` predicted entities, ordered by decreasing likelihood. The indices on each row are assumed to be distinct. :return: The rank of the ground truth among the predictions. """ batch_size, n_negative = candidate_indices.shape ground_truth = ground_truth.reshape(-1, 1) if ground_truth.shape[0] != batch_size: raise ValueError( "`pos_score` and `candidate_score` need to have the same size" " for dimension 0" ) worst_rank = torch.inf if self.worst_rank_infty else float(n_negative + 1) ranks = torch.where( ground_truth == candidate_indices, torch.arange( 1, n_negative + 1, dtype=torch.float32, device=ground_truth.device ), worst_rank, ) batch_rank = ranks.min(dim=-1)[0] return cast(torch.Tensor, batch_rank)
[docs] def dict_metrics_from_ranks( self, batch_rank: torch.Tensor, triple_mask: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: """ Compute the required metrics starting from the prediction ranks of the elements in the batch. :param batch_rank: shape: (batch_size,) Prediction rank for each element in the batch. :param triple_mask: shape: (batch_size,) Boolean mask. If provided, all metrics for the elements where :code:`~triple_mask` are set to 0.0. :return: The dictionary of (reduced) batch metrics. """ output = {} for m_name, m_fn in self.metrics.items(): batch_metric = m_fn(batch_rank) if triple_mask is not None: batch_metric = torch.where( triple_mask, batch_metric, torch.tensor( [0.0], dtype=batch_metric.dtype, device=batch_metric.device ), ) output[m_name] = self.reduction(batch_metric) return output
[docs] def stacked_metrics_from_ranks( self, batch_rank: torch.Tensor, triple_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Like :meth:`Evaluation.dict_metrics_from_ranks`, but the outputs for different metrics are returned stacked in a single tensor, according to the ordering of :attr:`Evaluation.metrics`. :param batch_rank: shape: (batch_size,) see :meth:`Evaluation.dict_metrics_from_ranks`. :param triple_mask: shape: (batch_size,) see :meth:`Evaluation.dict_metrics_from_ranks`. :return: shape: (1, n_metrics, batch_size) if :attr:`reduction` = "none", else (1, n_metrics) The stacked (reduced) metrics for the batch. """ return torch.stack( list(self.dict_metrics_from_ranks(batch_rank, triple_mask).values()) ).unsqueeze(0)