besskge.bess.BessKGE
- class besskge.bess.BessKGE(negative_sampler, score_fn, loss_fn=None, evaluation=None, return_scores=False, augment_negative=False)[source]
Base class for distributed training and inference of KGE models, using the distribution framework BESS [CJM+22]. To be used in combination with a batch sampler based on a “ht_shardpair”-partitioned triple set.
Initialize BESS-KGE module.
- Parameters:
negative_sampler (
ShardedNegativeSampler
) – Sampler of negative entities.score_fn (
BaseScoreFunction
) – Scoring function.loss_fn (
Optional
[BaseLossFunction
]) – Loss function, required when training. Default: None.evaluation (
Optional
[Evaluation
]) – Evaluation module, for computing metrics on device. Default: None.return_scores (
bool
) – If True, return positive and negative scores of batches to the host. Default: False.augment_negative (
bool
) – If True, augment sampled negative entities with the head/tails (according to the corruption scheme) of other positive triples in the micro-batch. Default: False.
- forward(head, relation, tail, negative, triple_mask=None, triple_weight=None, negative_mask=None)[source]
The forward step.
Comprises of four phases:
Gather relevant embeddings from local memory;
Share embeddings with other devices through collective operators;
Score positive and negative triples;
Compute loss/metrics.
Each device scores n_shard * positive_per_partition positive triples.
- Parameters:
head (
Tensor
) – shape: (1, n_shard, positive_per_partition) Head indices.relation (
Tensor
) – shape: (1, n_shard, positive_per_partition) Relation indices.tail (
Tensor
) – shape: (1, n_shard, positive_per_partition) Tail indices.triple_mask (
Optional
[Tensor
]) – shape: (1, n_shard, positive_per_partition) Mask to filter the triples in the micro-batch before computing metrics.negative (
Tensor
) – shape: (1, n_shard, B, padded_negative) Indices of negative entities, with B = 1, 2 or n_shard * positive_per_partition.triple_weight (
Optional
[Tensor
]) – shape: (1, n_shard * positive_per_partition,) or (1,) Weights of positive triples.negative_mask (
Optional
[Tensor
]) – shape: (1, B, n_shard, padded_negative) Mask to identify padding negatives, to discard when computing metrics.
- Return type:
- Returns:
Micro-batch loss, scores and metrics.
- property n_embedding_parameters: int
Returns the number of trainable parameters in the embedding tables
- abstract score_batch(head, relation, tail, negative)[source]
Compute positive and negative scores for the micro-batch.
- Parameters:
head (
Tensor
) – seeBessKGE.forward()
relation (
Tensor
) – seeBessKGE.forward()
tail (
Tensor
) – seeBessKGE.forward()
negative (
Tensor
) – seeBessKGE.forward()
- Return type:
- Returns:
Positive (shape: (n_shard * positive_per_partition,)) and negative (shape: (n_shard * positive_per_partition, n_negative)) scores of the micro-batch.