besskge.loss.BaseLossFunction

class besskge.loss.BaseLossFunction(*args, **kwargs)[source]

Base class for a loss function.

Losses are always computed in FP32.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

abstract forward(positive_score, negative_score, triple_weight)[source]

Compute batch loss.

Parameters:
  • positive_score (Tensor) – shape: (batch_size,) Scores of positive triples.

  • negative_score (Tensor) – shape: (batch_size, n_negative) Scores of negative triples.

  • triple_weight (Tensor) – shape: (batch_size,) or () Weights of positive triples.

Return type:

Tensor

Returns:

The batch loss.

get_negative_weights(negative_score)[source]

Construct weights of negative samples, based on their score.

Parameters:

negative_score (Tensor) – : (batch_size, n_negative) Scores of negative samples.

Return type:

Tensor

Returns:

shape: (batch_size, n_negative) if BaseLossFunction.negative_adversarial_sampling else () Weights of negative samples.

loss_scale: Tensor

Loss scaling factor, might be needed when using FP16 weights

negative_adversarial_sampling: bool

Use self-adversarial weighting of negative samples.

negative_adversarial_scale: Tensor

Reciprocal temperature of self-adversarial weighting