besskge.scoring.PairRE
- class besskge.scoring.PairRE(negative_sample_sharing, scoring_norm, sharding, n_relation_type, embedding_size, entity_initializer=[<function init_KGE_uniform>], relation_initializer=[<function init_KGE_uniform>], normalize_entities=True, inverse_relations=False)[source]
PairRE scoring function [CHWC21].
Initialize PairRE model.
- Parameters:
negative_sample_sharing (
bool
) – seeDistanceBasedScoreFunction.__init__()
scoring_norm (
int
) – seeDistanceBasedScoreFunction.__init__()
sharding (
Sharding
) – Entity sharding.n_relation_type (
int
) – Number of relation types in the knowledge graph.embedding_size (
int
) – Size of entity and relation embeddings.entity_initializer (
Union
[Tensor
,List
[Callable
[...
,Tensor
]]]) – Initialization function or table for entity embeddings.relation_initializer (
Union
[Tensor
,List
[Callable
[...
,Tensor
]]]) – Initialization function or table for relation embeddings.normalize_entities (
bool
) – If True, L2-normalize head and tail entity embeddings before projecting, as in [CHWC21]. Default: True.inverse_relations (
bool
) – If True, learn embeddings for inverse relations. Default: False.
- broadcasted_distance(v1, v2)
Broadcasted distances of queries against sets of entities.
For each query and candidate, the computes the p-distance of the embeddings.
- Parameters:
- Return type:
- Returns:
shape: (batch_size, B * n_neg) if
BaseScoreFunction.negative_sample_sharing
else (batch_size, n_neg)
- forward(head_emb, relation_id, tail_emb)
- reduce_embedding(v)
p-norm reduction along embedding dimension.
- score_heads(head_emb, relation_id, tail_emb)[source]
Score sets of head entities against fixed (r,t) queries.
- Parameters:
- Return type:
- Returns:
shape: (batch_size, B * n_heads) if
BaseScoreFunction.negative_sample_sharing
else (batch_size, n_heads). Scores of broadcasted triples.
- score_tails(head_emb, relation_id, tail_emb)[source]
Score sets of tail entities against fixed (h,r) queries.
- Parameters:
- Return type:
- Returns:
shape: (batch_size, B * n_tails) if
BaseScoreFunction.negative_sample_sharing
else (batch_size, n_tails) Scores of broadcasted triples.