besskge.loss.MarginBasedLossFunction
- class besskge.loss.MarginBasedLossFunction(margin, negative_adversarial_sampling, negative_adversarial_scale=1.0, loss_scale=1.0)[source]
Base class for margin-based loss functions.
Initialize margin-based loss function.
- Parameters:
margin (
float
) – The margin to be used in the loss computation.negative_adversarial_sampling (
bool
) – seeBaseLossFunction
negative_adversarial_scale (
float
) – seeBaseLossFunction
loss_scale (
float
) – seeBaseLossFunction
- abstract forward(positive_score, negative_score, triple_weight)
Compute batch loss.
- get_negative_weights(negative_score)
Construct weights of negative samples, based on their score.
- Parameters:
negative_score (
Tensor
) – : (batch_size, n_negative) Scores of negative samples.- Return type:
- Returns:
shape: (batch_size, n_negative) if
BaseLossFunction.negative_adversarial_sampling
else () Weights of negative samples.