besskge.batch_sampler.RigidShardedBatchSampler

class besskge.batch_sampler.RigidShardedBatchSampler(partitioned_triple_set, negative_sampler, shard_bs, batches_per_step, seed, hrt_freq_weighting=False, weight_smoothing=0.0, duplicate_batch=False, return_triple_idx=False)[source]

At each call, sample triples with the same specified indices from all triple partitions, repeating triples in shorter ones to pad to the same length. Returns a mask to identify padding triples.

Initialize sharded batch sampler.

Parameters:
  • partitioned_triple_set (PartitionedTripleSet) – The pre-processed collection of triples.

  • negative_sampler (ShardedNegativeSampler) – The sampler for negative entities.

  • shard_bs (int) – The micro-batch size. This is the number of positive triples processed on each shard.

  • batches_per_step (int) – The number of batches to sample at each call.

  • seed (int) – The RNG seed.

  • hrt_freq_weighting (bool) – If True, uses frequency-based triple weighting. Default: False.

  • weight_smoothing (float) – Weight-smoothing parameter for frequency-based triple weighting. Default: 0.0.

  • duplicate_batch (bool) – If True, the batch sampled from each triple partition has two identical halves. This is to be used with “ht” corruption scheme at inference time. Default: False.

  • return_triple_idx (bool) – If True, return the indices (wrt partitioned_triple_set.triples) of the triples in the batch. Default: False.

get_dataloader(options, shuffle=True, num_workers=0, persistent_workers=False, buffer_size=16)

Returns the PopTorch dataloader.

Instantiate the appropriate poptorch.DataLoader class to iterate over the batch sampler. It uses asynchronous data-loading to minimize CPU-IPU I/O.

Parameters:
  • options (Options) – poptorch.Options used to compile and run the model.

  • shuffle (bool) – If True, shuffles triples at each new epoch. Default: True.

  • num_workers (int) – see torch.utils.data.DataLoader.__init__(). Default: 0.

  • persistent_workers (bool) – see torch.utils.data.DataLoader.__init__(). Default: False.

  • buffer_size (int) – Size of the ring buffer in shared memory used to preload batches.

Return type:

DataLoader

Returns:

The PopTorch dataloader.

get_dataloader_sampler(shuffle)

Returns the dataloader sampler.

Instantiate the appropriate torch.data.Sampler class for the torch.utils.data.DataLoader class to be used with the sharded batch sampler.

Parameters:

shuffle (bool) – Shuffle triples at each new epoch.

Return type:

Sampler[List[int]]

Returns:

The dataloader sampler.

sample_triples(idx)[source]

Sample positive triples in the batch.

Parameters:

idx (List[int]) – The batch index.

Return type:

Dict[str, Union[ndarray[Any, dtype[int64]], ndarray[Any, dtype[bool_]]]]

Returns:

Per-partition indices of positive triples, and other relevant data.

static worker_init_fn(worker_id)

Worker initialization function to be passed to torch.utils.data.DataLoader.

Parameters:

worker_id (int) – Worker ID.

Return type:

None