besskge.batch_sampler.ShardedBatchSampler
- class besskge.batch_sampler.ShardedBatchSampler(partitioned_triple_set, negative_sampler, shard_bs, batches_per_step, seed, hrt_freq_weighting=False, weight_smoothing=0.0, duplicate_batch=False, return_triple_idx=False)[source]
 Base class for sharded batch sampler.
Initialize sharded batch sampler.
- Parameters:
 partitioned_triple_set (
PartitionedTripleSet) – The pre-processed collection of triples.negative_sampler (
ShardedNegativeSampler) – The sampler for negative entities.shard_bs (
int) – The micro-batch size. This is the number of positive triples processed on each shard.batches_per_step (
int) – The number of batches to sample at each call.seed (
int) – The RNG seed.hrt_freq_weighting (
bool) – If True, uses frequency-based triple weighting. Default: False.weight_smoothing (
float) – Weight-smoothing parameter for frequency-based triple weighting. Default: 0.0.duplicate_batch (
bool) – If True, the batch sampled from each triple partition has two identical halves. This is to be used with “ht” corruption scheme at inference time. Default: False.return_triple_idx (
bool) – If True, return the indices (wrt partitioned_triple_set.triples) of the triples in the batch. Default: False.
- get_dataloader(options, shuffle=True, num_workers=0, persistent_workers=False, buffer_size=16)[source]
 Returns the PopTorch dataloader.
Instantiate the appropriate
poptorch.DataLoaderclass to iterate over the batch sampler. It uses asynchronous data-loading to minimize CPU-IPU I/O.- Parameters:
 options (
Options) – poptorch.Options used to compile and run the model.shuffle (
bool) – If True, shuffles triples at each new epoch. Default: True.num_workers (
int) – seetorch.utils.data.DataLoader.__init__(). Default: 0.persistent_workers (
bool) – seetorch.utils.data.DataLoader.__init__(). Default: False.buffer_size (
int) – Size of the ring buffer in shared memory used to preload batches.
- Return type:
 DataLoader- Returns:
 The PopTorch dataloader.
- get_dataloader_sampler(shuffle)[source]
 Returns the dataloader sampler.
Instantiate the appropriate
torch.data.Samplerclass for thetorch.utils.data.DataLoaderclass to be used with the sharded batch sampler.
- static worker_init_fn(worker_id)[source]
 Worker initialization function to be passed to
torch.utils.data.DataLoader.