besskge.loss.BaseLossFunction
- class besskge.loss.BaseLossFunction(*args, **kwargs)[source]
Base class for a loss function.
Losses are always computed in FP32.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- get_negative_weights(negative_score)[source]
Construct weights of negative samples, based on their score.
- Parameters:
negative_score (
Tensor
) – : (batch_size, n_negative) Scores of negative samples.- Return type:
- Returns:
shape: (batch_size, n_negative) if
BaseLossFunction.negative_adversarial_sampling
else () Weights of negative samples.