besskge.batch_sampler.ShardedBatchSampler
- class besskge.batch_sampler.ShardedBatchSampler(partitioned_triple_set, negative_sampler, shard_bs, batches_per_step, seed, hrt_freq_weighting=False, weight_smoothing=0.0, duplicate_batch=False, return_triple_idx=False)[source]
Base class for sharded batch sampler.
Initialize sharded batch sampler.
- Parameters:
partitioned_triple_set (
PartitionedTripleSet
) – The pre-processed collection of triples.negative_sampler (
ShardedNegativeSampler
) – The sampler for negative entities.shard_bs (
int
) – The micro-batch size. This is the number of positive triples processed on each shard.batches_per_step (
int
) – The number of batches to sample at each call.seed (
int
) – The RNG seed.hrt_freq_weighting (
bool
) – If True, uses frequency-based triple weighting. Default: False.weight_smoothing (
float
) – Weight-smoothing parameter for frequency-based triple weighting. Default: 0.0.duplicate_batch (
bool
) – If True, the batch sampled from each triple partition has two identical halves. This is to be used with “ht” corruption scheme at inference time. Default: False.return_triple_idx (
bool
) – If True, return the indices (wrt partitioned_triple_set.triples) of the triples in the batch. Default: False.
- get_dataloader(options, shuffle=True, num_workers=0, persistent_workers=False, buffer_size=16)[source]
Returns the PopTorch dataloader.
Instantiate the appropriate
poptorch.DataLoader
class to iterate over the batch sampler. It uses asynchronous data-loading to minimize CPU-IPU I/O.- Parameters:
options (
Options
) – poptorch.Options used to compile and run the model.shuffle (
bool
) – If True, shuffles triples at each new epoch. Default: True.num_workers (
int
) – seetorch.utils.data.DataLoader.__init__()
. Default: 0.persistent_workers (
bool
) – seetorch.utils.data.DataLoader.__init__()
. Default: False.buffer_size (
int
) – Size of the ring buffer in shared memory used to preload batches.
- Return type:
DataLoader
- Returns:
The PopTorch dataloader.
- get_dataloader_sampler(shuffle)[source]
Returns the dataloader sampler.
Instantiate the appropriate
torch.data.Sampler
class for thetorch.utils.data.DataLoader
class to be used with the sharded batch sampler.
- static worker_init_fn(worker_id)[source]
Worker initialization function to be passed to
torch.utils.data.DataLoader
.