Source code for besskge.batch_sampler

# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

"""
Classes for sampling batches of positive and negative triples for each processing device,
according to the BESS distribution scheme.
"""

import warnings
from abc import ABC, abstractmethod
from typing import Dict, List, Union, cast

import einops
import numpy as np
import poptorch
import torch
from numpy.typing import NDArray

from besskge.negative_sampler import ShardedNegativeSampler
from besskge.sharding import PartitionedTripleSet


[docs]class ShardedBatchSampler(torch.utils.data.Dataset[Dict[str, torch.Tensor]], ABC): """ Base class for sharded batch sampler. """ def __init__( self, partitioned_triple_set: PartitionedTripleSet, negative_sampler: ShardedNegativeSampler, shard_bs: int, batches_per_step: int, seed: int, hrt_freq_weighting: bool = False, weight_smoothing: float = 0.0, duplicate_batch: bool = False, return_triple_idx: bool = False, ): """ Initialize sharded batch sampler. :param partitioned_triple_set: The pre-processed collection of triples. :param negative_sampler: The sampler for negative entities. :param shard_bs: The micro-batch size. This is the number of positive triples processed on each shard. :param batches_per_step: The number of batches to sample at each call. :param seed: The RNG seed. :param hrt_freq_weighting: If True, uses frequency-based triple weighting. Default: False. :param weight_smoothing: Weight-smoothing parameter for frequency-based triple weighting. Default: 0.0. :param duplicate_batch: If True, the batch sampled from each triple partition has two identical halves. This is to be used with "ht" corruption scheme at inference time. Default: False. :param return_triple_idx: If True, return the indices (wrt partitioned_triple_set.triples) of the triples in the batch. Default: False. """ self.n_shard = partitioned_triple_set.sharding.n_shard self.triples = partitioned_triple_set.triples self.dummy = partitioned_triple_set.dummy self.triple_counts = partitioned_triple_set.triple_counts self.triple_offsets = partitioned_triple_set.triple_offsets self.triple_partition_mode = partitioned_triple_set.partition_mode self.negative_sampler = negative_sampler self.shard_bs = shard_bs self.batches_per_step = batches_per_step self.duplicate_batch = duplicate_batch if self.triple_partition_mode == "ht_shardpair": # The micro-batch on device N is formed of n_shard blocks, # corresponding to triple partitions (h_shard, t_shard) # with h_shard = N and t_shard = 0, ..., n_shard-1. self.positive_per_partition = int(np.ceil(self.shard_bs / self.n_shard)) else: self.positive_per_partition = self.shard_bs if self.duplicate_batch: self.positive_per_partition //= 2 if self.negative_sampler.corruption_scheme == "ht": # Each partition is split into two halves, so we need positive_per_partition # to be even. self.positive_per_partition = (self.positive_per_partition // 2) * 2 # Total number of triples sampled from each partition at each call self.partition_sample_size = self.batches_per_step * self.positive_per_partition self.hrt_freq_weighting = hrt_freq_weighting self.return_triple_idx = return_triple_idx self.seed = seed self.rng = np.random.default_rng(self.seed) if self.hrt_freq_weighting: if self.dummy != "none": warnings.warn( "hrt frequency weights are being computed on dummy entities" ) _, hr_idx, hr_count = np.unique( self.triples[..., 0] + partitioned_triple_set.sharding.n_entity * self.triples[..., 1], return_counts=True, return_inverse=True, ) _, rt_idx, rt_count = np.unique( self.triples[..., 2] + partitioned_triple_set.sharding.n_entity * self.triples[..., 1], return_counts=True, return_inverse=True, ) self.hrt_weights = np.sqrt( 1.0 / (hr_count[hr_idx] + rt_count[rt_idx] + weight_smoothing) ) def __len__(self) -> int: """ Returns the length of the batch sampler. The length of the batch sampler is based on the length of the largest triple partition. At each call, :attr:`ShardedBatchSampler.partition_sample_size` triples for each partition are returned. :return: The length of the batch sampler. """ return ( int(np.ceil(self.triple_counts.max() / self.partition_sample_size)) * self.partition_sample_size ) def __getitem__(self, idx: List[int]) -> Dict[str, torch.Tensor]: """ Sample batch. :param idx: The batch index. :return: Indices of head, relation, tail and negative entities in the batch, and associated weights and masks. """ sample_triple_dict = self.sample_triples(idx) if self.duplicate_batch: sample_triple_dict = { k: einops.repeat( v, "step shard ... triple -> step shard ... (2 triple)", ) for k, v in sample_triple_dict.items() } sample_idx = cast(NDArray[np.int64], sample_triple_dict.pop("sample_idx")) head, relation, tail = einops.rearrange( self.triples[sample_idx], "... hrt -> hrt ...", ) if self.triple_partition_mode == "ht_shardpair": # Prepare tail indices for AllToAll exchange tail = einops.rearrange( tail, "step shard_h shard_t triple -> step shard_t shard_h triple" ) batch_dict = { "head": head.astype(np.int32), "relation": relation.astype(np.int32), "tail": tail.astype(np.int32), **sample_triple_dict, } sample_negative_dict = self.negative_sampler(sample_idx) if "negative_entities" in sample_negative_dict.keys(): negative_entities = sample_negative_dict.pop("negative_entities") batch_dict.update(negative=negative_entities.astype(np.int32)) batch_dict.update(**sample_negative_dict) if self.dummy in ["head", "tail"]: batch_dict.pop(self.dummy) if self.hrt_freq_weighting: triple_weight = einops.rearrange( self.hrt_weights[sample_idx], "step shard ... triple -> step shard (... triple)", ) triple_weight /= np.sum(triple_weight, axis=-1, keepdims=True) triple_weight *= self.shard_bs batch_dict.update(triple_weight=triple_weight.astype(np.float32)) if self.return_triple_idx: batch_dict.update(triple_idx=sample_idx) return {k: torch.from_numpy(v) for k, v in batch_dict.items()}
[docs] @abstractmethod def sample_triples( self, idx: List[int] ) -> Dict[str, Union[NDArray[np.int64], NDArray[np.bool_]]]: """ Sample positive triples in the batch. :param idx: The batch index. :return: Per-partition indices of positive triples, and other relevant data. """ raise NotImplementedError
[docs] def get_dataloader_sampler( self, shuffle: bool ) -> torch.utils.data.Sampler[List[int]]: """ Returns the dataloader sampler. Instantiate the appropriate :class:`torch.data.Sampler` class for the :class:`torch.utils.data.DataLoader` class to be used with the sharded batch sampler. :param shuffle: Shuffle triples at each new epoch. :return: The dataloader sampler. """ sampler = ( torch.utils.data.RandomSampler(self) if shuffle else torch.utils.data.SequentialSampler(self) ) return torch.utils.data.BatchSampler( sampler, batch_size=self.partition_sample_size, drop_last=False )
[docs] def get_dataloader( self, options: poptorch.Options, shuffle: bool = True, num_workers: int = 0, persistent_workers: bool = False, buffer_size: int = 16, ) -> poptorch.DataLoader: """ Returns the PopTorch dataloader. Instantiate the appropriate :class:`poptorch.DataLoader` class to iterate over the batch sampler. It uses asynchronous data-loading to minimize CPU-IPU I/O. :param options: `poptorch.Options` used to compile and run the model. :param shuffle: If True, shuffles triples at each new epoch. Default: True. :param num_workers: see :meth:`torch.utils.data.DataLoader.__init__`. Default: 0. :param persistent_workers: see :meth:`torch.utils.data.DataLoader.__init__`. Default: False. :param buffer_size: Size of the ring buffer in shared memory used to preload batches. :return: The PopTorch dataloader. """ return poptorch.DataLoader( options=options, dataset=self, batch_size=None, sampler=self.get_dataloader_sampler(shuffle=shuffle), drop_last=False, num_workers=num_workers, persistent_workers=persistent_workers, worker_init_fn=self.worker_init_fn, mode=poptorch.DataLoaderMode.Async, async_options={ "early_preload": True, "buffer_size": buffer_size, "sharing_strategy": poptorch.SharingStrategy.SharedMemory, }, )
[docs] @staticmethod def worker_init_fn(worker_id: int) -> None: """ Worker initialization function to be passed to :class:`torch.utils.data.DataLoader`. :param worker_id: Worker ID. """ worker_info = torch.utils.data.get_worker_info() if worker_info: dataset_unwrapped = cast(ShardedBatchSampler, worker_info.dataset) worker_seed = dataset_unwrapped.seed + worker_id dataset_unwrapped.rng = np.random.default_rng(worker_seed) dataset_unwrapped.negative_sampler.rng = np.random.default_rng(worker_seed)
[docs]class RigidShardedBatchSampler(ShardedBatchSampler): """ At each call, sample triples with the same specified indices from all triple partitions, repeating triples in shorter ones to pad to the same length. Returns a mask to identify padding triples. """ # docstr-coverage: inherited def __init__( self, partitioned_triple_set: PartitionedTripleSet, negative_sampler: ShardedNegativeSampler, shard_bs: int, batches_per_step: int, seed: int, hrt_freq_weighting: bool = False, weight_smoothing: float = 0.0, duplicate_batch: bool = False, return_triple_idx: bool = False, ) -> None: super(RigidShardedBatchSampler, self).__init__( partitioned_triple_set, negative_sampler, shard_bs, batches_per_step, seed, hrt_freq_weighting, weight_smoothing, duplicate_batch, return_triple_idx, ) padded_partition_length = len(self) expand_axes = (0, 1) if self.triple_partition_mode == "ht_shardpair" else (0,) self.triple_mask = ( np.expand_dims(np.arange(padded_partition_length), axis=expand_axes) < self.triple_counts[..., None] ) # shape (n_shard, [n_shard,] padded_partition_length) triple_padded_idx = ( np.expand_dims(np.arange(padded_partition_length), axis=expand_axes) % self.triple_counts[..., None] ) + self.triple_offsets[..., None] # Index safeguard for when the last partition is empty self.triple_padded_idx = np.minimum( triple_padded_idx, self.triples.shape[0] - 1, ) # shape (n_shard, [n_shard,] padded_partition_length) # docstr-coverage: inherited
[docs] def sample_triples( self, idx: List[int] ) -> Dict[str, Union[NDArray[np.int64], NDArray[np.bool_]]]: sample_idx = einops.rearrange( self.triple_padded_idx[..., idx], "shard ... (step triple) -> step shard ... triple", step=self.batches_per_step, ) batch_mask = einops.rearrange( self.triple_mask[..., idx], "shard ... (step triple) -> step shard ... triple", step=self.batches_per_step, ) return dict(sample_idx=sample_idx, triple_mask=batch_mask)
[docs]class RandomShardedBatchSampler(ShardedBatchSampler): """ Sample random indices (with replacement) from all triple partitions. No padding of triple partitions applied. """ # docstr-coverage: inherited
[docs] def sample_triples( self, idx: List[int] ) -> Dict[str, Union[NDArray[np.int64], NDArray[np.bool_]]]: sample_size = ( ( self.batches_per_step, self.n_shard, self.n_shard, self.positive_per_partition, ) if self.triple_partition_mode == "ht_shardpair" else ( self.batches_per_step, self.n_shard, self.positive_per_partition, ) ) sample_idx = np.expand_dims( self.triple_offsets, axis=(0, -1) ) + self.rng.integers( 1 << 63, size=sample_size, ) % np.expand_dims( self.triple_counts, axis=(0, -1) ) return dict(sample_idx=sample_idx)
def __len__(self) -> int: return int(np.ceil(self.triple_counts.max() / self.partition_sample_size)) # docstr-coverage: inherited
[docs] def get_dataloader_sampler( self, shuffle: bool = True ) -> torch.utils.data.Sampler[List[int]]: sampler = torch.utils.data.SequentialSampler(self) return torch.utils.data.BatchSampler(sampler, batch_size=1, drop_last=False)