Source code for besskge.sharding

# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

"""
Sharding of embedding tables and triple sets for distributed execution.
"""

import dataclasses
import warnings
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
from numpy.typing import NDArray

from besskge.dataset import KGDataset


[docs]@dataclasses.dataclass class Sharding: """ A mapping of entities to shards (and back again). """ #: Number of shards n_shard: int #: Entity shard by global ID; #: int32[n_entity] entity_to_shard: NDArray[np.int32] #: Entity local ID on shard by global ID; #: int32[n_entity] entity_to_idx: NDArray[np.int32] #: Entity global ID by (shard, local_ID); #: int32[n_shard, max_entity_per_shard] shard_and_idx_to_entity: NDArray[np.int32] #: Number of true entities (excluding padding) in each shard; #: int64[n_shard] shard_counts: NDArray[np.int64] #: Number of entities of each type on each shard; #: int64[n_shard, n_types] entity_type_counts: Optional[NDArray[np.int64]] #: Offsets for entities of same type on each shared #: (entities remain clustered by type also locally); #: int64[n_shard, n_types] entity_type_offsets: Optional[NDArray[np.int64]] @property def n_entity(self) -> int: """ Number of entities in the knowledge graph. """ return len(self.entity_to_shard) @property def max_entity_per_shard(self) -> int: """ Number of entities in a shard, after applying padding. """ return self.shard_and_idx_to_entity.shape[1]
[docs] @classmethod def create( cls, n_entity: int, n_shard: int, seed: int, type_offsets: Optional[NDArray[np.int64]] = None, ) -> "Sharding": """ Construct a random, balanced sharding of entities. :param n_entity: Number of entities in the knowledge graph. :param n_shard: Number of shards. :param seed: Seed for random sharding. :param type_offsets: shape: (n_types,) Global offsets of entity types. Default: None. :return: Random sharding of n_entity entities in n_shard shards. """ rng = np.random.default_rng(seed) max_entity_per_shard = int(np.ceil(n_entity / n_shard)) # Keep global entity ID sorted on each shard, to preserve type-based clustering shard_and_idx_to_entity = np.sort( rng.permutation(n_shard * max_entity_per_shard).reshape( n_shard, max_entity_per_shard ), axis=1, ) entity_to_shard, entity_to_idx = np.divmod( np.argsort(shard_and_idx_to_entity.flatten())[:n_entity], max_entity_per_shard, ) shard_deduction = np.sum( shard_and_idx_to_entity[:, -n_shard:] >= n_entity, axis=-1 ) # Number of actual entities in each shard shard_counts = max_entity_per_shard - shard_deduction entity_type_counts: Optional[NDArray[np.int64]] entity_type_offsets: Optional[NDArray[np.int64]] if type_offsets is not None: type_id_per_shard = ( np.digitize(shard_and_idx_to_entity, bins=type_offsets) + len(type_offsets) * np.arange(n_shard)[:, None] - 1 ) # Per-shard entity type counts and offsets entity_type_counts = np.bincount( type_id_per_shard.flatten(), minlength=len(type_offsets) * n_shard ).reshape(n_shard, -1) entity_type_offsets = np.c_[ [0] * n_shard, np.cumsum(entity_type_counts, axis=1)[:, :-1] ] entity_type_counts[:, -1] -= shard_deduction else: entity_type_counts = entity_type_offsets = None return cls( n_shard=n_shard, entity_to_shard=entity_to_shard, entity_to_idx=entity_to_idx, shard_and_idx_to_entity=shard_and_idx_to_entity, shard_counts=shard_counts, entity_type_counts=entity_type_counts, entity_type_offsets=entity_type_offsets, )
[docs] def save(self, out_file: Path) -> None: """ Save sharding to .npz file. :param out_file: Path to output file. """ np.savez(out_file, **dataclasses.asdict(self))
[docs] @classmethod def load(cls, path: Path) -> "Sharding": """ Load a :class:`Sharding` object saved with :func:`Sharding.save`. :param path: Path to saved :class:`Sharding` object. :return: The saved :class:`Sharding` object. """ data = dict(np.load(path)) return cls(n_shard=int(data.pop("n_shard")), **data)
[docs]@dataclasses.dataclass class PartitionedTripleSet: """ A partitioned collection of triples. If :code:`partition_mode = 'h_shard'` each triple is assigned to one of `n_shard` partitions based on the shard where the head entity is stored. Similarly, if :code:`partition_mode = 't_shard'`, each triple is assigned to one of `n_shard` partitions based on the shard where the tail entity is stored. If :code:`partition_mode = 'ht_shardpair'`, each triple is assigned to one of `n_shard^2` partitions based on the shard-pair `(shard_h, shard_t)`. Shard-pairs are ordered as: `(0,0), (0,1), ..., (0, n_shard-1), (1,0), ..., (n_shard-1, n_shard-1)`. """ #: Sharding of entities sharding: Sharding #: Whether the collection contains inverse triples (t,r_inv,h) #: for each regular triple (h,r,t) inverse_triples: bool #: Partitioning criterion for triples; #: "h_shard", "t_shard", "ht_shardpair" partition_mode: str #: If set is constructed from (h,r,?) (resp. (?,r,t)) queries, #: dummy tails (resp. heads) are added to make pairs into triples. #: "head", "tail", "none" dummy: Optional[str] #: h/r/t IDs for triples ordered by partition. #: Local IDs for heads (resp. tails) and global IDs #: for tails (resp. heads) if partition_mode = "h_shard" (resp. "t_shard"); #: local IDs for heads and tails if partition_mode = "ht_shardpair" #: int32[n_triple, {h,r,t}] triples: NDArray[np.int32] #: Number of triples in each partition; #: int64[n_shard] or int64[n_shard, n_shard] triple_counts: NDArray[np.int64] #: Delimiting indices of ordered partitions; #: int64[n_shard] or int64[n_shard, n_shard] triple_offsets: NDArray[np.int64] #: Sorting indices to order triples by partition; #: int64[n_triple] triple_sort_idx: NDArray[np.int64] #: Entity type IDs of triple head/tail; #: int32[n_triple, {h_type, t_type}] types: Optional[NDArray[np.int32]] #: Global IDs of (possibly triple-specific) negative heads; #: int32[n_triple or 1, n_neg_heads] neg_heads: Optional[NDArray[np.int32]] #: Global IDs of (possibly triple-specific) negative heads; #: int32[n_triple or 1, n_neg_tails] neg_tails: Optional[NDArray[np.int32]] @classmethod def partition_triples( cls, triples: NDArray[np.int32], sharding: Sharding, partition_mode: str, ) -> Tuple[ NDArray[np.int32], NDArray[np.int64], NDArray[np.int64], NDArray[np.int64] ]: n_shard = sharding.n_shard offsets: NDArray[np.int64] if partition_mode in ["h_shard", "t_shard"]: column_id = 0 if partition_mode == "h_shard" else -1 partition_id = sharding.entity_to_shard[triples[:, column_id]] counts = np.bincount(partition_id, minlength=n_shard) offsets = np.concatenate([np.array([0]), np.cumsum(counts)[:-1]]) elif partition_mode == "ht_shardpair": shard_h, shard_t = sharding.entity_to_shard[triples[:, [0, 2]].T] partition_id = shard_h * n_shard + shard_t counts = np.bincount(partition_id, minlength=n_shard * n_shard).reshape( n_shard, n_shard ) offsets = np.concatenate([np.array([0]), np.cumsum(counts)[:-1]]).reshape( n_shard, n_shard ) else: raise ValueError( f"Partition mode {partition_mode} not supported" " for triple partitioning" ) sort_idx = np.argsort(partition_id) sorted_triples: NDArray[np.int32] sorted_triples = triples[sort_idx] if partition_mode in ["h_shard", "ht_shardpair"]: sorted_triples[:, 0] = sharding.entity_to_idx[sorted_triples[:, 0]] if partition_mode in ["t_shard", "ht_shardpair"]: sorted_triples[:, -1] = sharding.entity_to_idx[sorted_triples[:, -1]] return sorted_triples, counts, offsets, sort_idx
[docs] @classmethod def create_from_dataset( cls, dataset: KGDataset, part: str, sharding: Sharding, partition_mode: str = "ht_shardpair", add_inverse_triples: bool = False, ) -> "PartitionedTripleSet": """ Create a partitioned triple set from a :class:`KGDataset` part. :param dataset: Knowledge graph dataset. :param part: The dataset part to shard. :param sharding: The entity sharding to use. :param partition_mode: The triple partition mode. Can be "h_shard", "t_shard", "ht_shardpair". :return: Partitioned set of triples. """ triples = dataset.triples[part] n_triples = triples.shape[0] if add_inverse_triples: inverse_triples = np.copy(triples[:, ::-1]) inverse_triples[:, 1] += dataset.n_relation_type triples = np.concatenate([triples, inverse_triples], axis=0) ( sorted_triples, counts, offsets, sort_idx, ) = PartitionedTripleSet.partition_triples(triples, sharding, partition_mode) ht_types = dataset.ht_types if ht_types and part in ht_types.keys(): types = ht_types[part] if add_inverse_triples: types = np.concatenate([types, types[:, ::-1]], axis=0) types = types[sort_idx] else: types = None if add_inverse_triples: if (dataset.neg_heads and part in dataset.neg_heads.keys()) and ( dataset.neg_tails and part in dataset.neg_tails.keys() ): n_negatives = dataset.neg_heads[part].shape[-1] neg_heads_broad = np.broadcast_to( dataset.neg_heads[part], (n_triples, n_negatives) ) neg_tails_broad = np.broadcast_to( dataset.neg_tails[part], (n_triples, n_negatives) ) neg_heads_extended = np.concatenate( [neg_heads_broad, neg_tails_broad], axis=0 ) neg_tails_extended = np.concatenate( [neg_tails_broad, neg_heads_broad], axis=0 ) elif (dataset.neg_heads and part in dataset.neg_heads.keys()) != ( dataset.neg_tails and part in dataset.neg_tails.keys() ): raise ValueError( "To use inverse triples, either both or" " neither of negative heads and tails need to" f" be defined for the {part} part of the dataset" ) if dataset.neg_heads and part in dataset.neg_heads.keys(): if add_inverse_triples: neg_heads = neg_heads_extended else: neg_heads = dataset.neg_heads[part] neg_heads = neg_heads.reshape(-1, neg_heads.shape[-1]) if neg_heads.shape[0] != 1: neg_heads = neg_heads[sort_idx] else: neg_heads = None if dataset.neg_tails and part in dataset.neg_tails.keys(): if add_inverse_triples: neg_tails = neg_tails_extended else: neg_tails = dataset.neg_tails[part] neg_tails = neg_tails.reshape(-1, neg_tails.shape[-1]) if neg_tails.shape[0] != 1: neg_tails = neg_tails[sort_idx] else: neg_tails = None return cls( sharding=sharding, inverse_triples=add_inverse_triples, partition_mode=partition_mode, dummy="none", triples=sorted_triples, triple_counts=counts, triple_offsets=offsets, triple_sort_idx=sort_idx, types=types, neg_heads=neg_heads, neg_tails=neg_tails, )
[docs] @classmethod def create_from_queries( cls, dataset: KGDataset, sharding: Sharding, queries: NDArray[np.int32], query_mode: str, ground_truth: Optional[NDArray[np.int32]] = None, negative: Optional[NDArray[np.int32]] = None, negative_type: Optional[str] = None, ) -> "PartitionedTripleSet": """ Create a partitioned triple set from a set of (h,r,?) or (?,r,t) queries. Pairs are completed to triples by adding dummy entities. :param dataset: Knowledge graph dataset. :param sharding: The entity sharding to use. :param queries: shape: (n_query, 2) The set of (h, r) or (r, t) queries. Global IDs for entities/relations. :param query_mode: "hr" for (h,r,?) queries, "rt" for (?,r,t) queries. :param ground_truth: shape: (n_query,) If known, the global ID of the ground truth tail/head. :param negative: shape: (N, n_negative) Global IDs of negative entities to score against each query. This can be query-specific (N=n_query) or the same for all queries (N=1). Default: None (namely score each query against all entities in the graph). :param negative_type: Score each query only against entities of a specific type. Default: None (namely score each query against entities of any type). :return: Partitioned set of queries (with dummy h/t completion). """ n_query = queries.shape[0] # Dummy entities to complete queries (=pairs) to triples if negative_type: if ( not dataset.type_offsets or negative_type not in dataset.type_offsets.keys() ): raise ValueError( f"{negative_type} is not the label of" " a type of entity in the KGDataset" ) ds_type_offsets = dataset.type_offsets type_range_dict = { k: (a, b - 1) for ((k, a), b) in zip( dataset.type_offsets.items(), [*list(dataset.type_offsets.values())[1:], dataset.n_entity], ) } type_range = type_range_dict[negative_type] if negative is not None: # Check that all negatives provided are of requested type if np.any(negative < type_range[0]) or np.any( negative >= type_range[1] ): warnings.warn( "The negative entities provided are not all" " of the specified negative_type" ) if ground_truth is not None: fill_column = ground_truth.reshape(n_query, 1) elif negative_type: fill_column = np.full(fill_value=type_range[0], shape=(n_query, 1)) else: fill_column = np.full(fill_value=0, shape=(n_query, 1)) if negative is not None: negative = negative.reshape(-1, negative.shape[-1]) elif negative_type: negative = np.expand_dims(np.arange(type_range[0], type_range[1]), axis=0) else: negative = np.expand_dims(np.arange(sharding.n_entity), axis=0) if query_mode == "hr": triples = np.concatenate([queries, fill_column], axis=-1) partition_mode = "h_shard" dummy = "tail" if ground_truth is None else None neg_heads = None neg_tails = negative elif query_mode == "rt": triples = np.concatenate([fill_column, queries], axis=-1) partition_mode = "t_shard" dummy = "head" if ground_truth is None else None neg_heads = negative neg_tails = None else: raise ValueError(f"Query mode {query_mode} not supported") ( sorted_triples, counts, offsets, sort_idx, ) = PartitionedTripleSet.partition_triples(triples, sharding, partition_mode) if negative_type: types = ( np.digitize( sorted_triples[:, [0, 2]], np.fromiter(ds_type_offsets.values(), dtype=np.int32), ) - 1 ) else: types = None if neg_heads is not None and neg_heads.shape[0] != 1: neg_heads = neg_heads[sort_idx] if neg_tails is not None and neg_tails.shape[0] != 1: neg_tails = neg_tails[sort_idx] return cls( sharding=sharding, inverse_triples=False, partition_mode=partition_mode, dummy=dummy, triples=sorted_triples, triple_counts=counts, triple_offsets=offsets, triple_sort_idx=sort_idx, types=types, neg_heads=neg_heads, neg_tails=neg_tails, )