BESS overview
When distributing the workload over \(n\) workers (=IPUs), BESS randomly splits the entity embedding table into \(n\) shards of equal size, each of which is stored in a worker’s memory. The embedding table for relation types, on the other hand, is replicated across workers, as it is usually much smaller.
The entity sharding induces a partitioning of the triples in the dataset, according to the shard-pair of the head entity and the tail entity. At execution time (for both training and inference), batches are constructed by sampling triples uniformly from each of the \(n^2\) shard-pairs. Negative entities, used to corrupt the head or tail of a triple in order to construct negative samples, are also sampled in a balanced way to ensure a variety that is beneficial to the final embedding quality.
This batching scheme allows us to balance workload and communication across workers. First, each worker needs to gather the same number of embeddings from its on-chip memory, both for positive and negative samples. These include the embeddings needed by the worker itself, and the embeddings needed by its peers.
The batch in Figure 2 can then be reconstructed by sharing the embeddings of positive tails and negative entities between workers through a balanced AllToAll collective operator. Head embeddings remain in place, as each triple block is then scored on the worker where the head embedding is stored.
The distribution scheme presented above is implemented in besskge.bess.EmbeddingMovingBessKGE
.
While communication is always balanced, exchanging negative embeddings between workers can turn out to be expensive
when using many negative samples per triple, or when the embedding dimension is large.
In these cases, using besskge.bess.ScoreMovingBessKGE
can increase overall throughput.
This alternative distribution scheme works in the same way as besskge.bess.EmbeddingMovingBessKGE
for
the sharding of entities and partitioning of triples, as well as for the way embeddings for positive triples are
shared through AllToAll collectives and scored. The difference lies in how negative scores are computed: instead of
sending negative embeddings to the query’s worker, all queries are replicated on each device through an AllGather
collective, scored against the (partial) set of negatives stored on the device and then the scores are
sent to the correct worker via a new, balanced AllToAll.
This allows us to communicate negative scores instead of negative embeddings, which is cheaper, although it
requires additional collective communications between devices.