Source code for pytext.utils.loss_utils

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy
import torch
from pytext.utils.cuda_utils import FloatTensor


[docs]def range_to_anchors_and_delta(precision_range, num_anchors): """Calculates anchor points from precision range. Args: precision_range: an interval (a, b), where 0.0 <= a <= b <= 1.0 num_anchors: int, number of equally spaced anchor points. Returns: precision_values: A `Tensor` of [num_anchors] equally spaced values in the interval precision_range. delta: The spacing between the values in precision_values. Raises: ValueError: If precision_range is invalid. """ # Validate precision_range. if len(precision_range) != 2: raise ValueError( "length of precision_range (%d) must be 2" % len(precision_range) ) if not 0 <= precision_range[0] <= precision_range[1] <= 1: raise ValueError( "precision values must follow 0 <= %f <= %f <= 1" % (precision_range[0], precision_range[1]) ) # Sets precision_values uniformly between min_precision and max_precision. precision_values = numpy.linspace( start=precision_range[0], stop=precision_range[1], num=num_anchors + 1 )[1:] delta = (precision_range[1] - precision_range[0]) / num_anchors return FloatTensor(precision_values), delta
[docs]def build_class_priors( labels, class_priors=None, weights=None, positive_pseudocount=1.0, negative_pseudocount=1.0, ): """build class priors, if necessary. For each class, the class priors are estimated as (P + sum_i w_i y_i) / (P + N + sum_i w_i), where y_i is the ith label, w_i is the ith weight, P is a pseudo-count of positive labels, and N is a pseudo-count of negative labels. Args: labels: A `Tensor` with shape [batch_size, num_classes]. Entries should be in [0, 1]. class_priors: None, or a floating point `Tensor` of shape [C] containing the prior probability of each class (i.e. the fraction of the training data consisting of positive examples). If None, the class priors are computed from `targets` with a moving average. weights: `Tensor` of shape broadcastable to labels, [N, 1] or [N, C], where `N = batch_size`, C = num_classes` positive_pseudocount: Number of positive labels used to initialize the class priors. negative_pseudocount: Number of negative labels used to initialize the class priors. Returns: class_priors: A Tensor of shape [num_classes] consisting of the weighted class priors, after updating with moving average ops if created. """ if class_priors is not None: return class_priors N, C = labels.size() weighted_label_counts = (weights * labels).sum(0) weight_sum = weights.sum(0) class_priors = torch.div( weighted_label_counts + positive_pseudocount, weight_sum + positive_pseudocount + negative_pseudocount, ) return class_priors
[docs]def weighted_hinge_loss(labels, logits, positive_weights=1.0, negative_weights=1.0): """ Args: labels: one-hot representation `Tensor` of shape broadcastable to logits logits: A `Tensor` of shape [N, C] or [N, C, K] positive_weights: Scalar or Tensor negative_weights: same shape as positive_weights Returns: 3D Tensor of shape [N, C, K], where K is length of positive weights or 2D Tensor of shape [N, C] """ positive_weights_is_tensor = torch.is_tensor(positive_weights) negative_weights_is_tensor = torch.is_tensor(negative_weights) # Validate positive_weights and negative_weights if positive_weights_is_tensor ^ negative_weights_is_tensor: raise ValueError( "positive_weights and negative_weights must be same shape Tensor " "or both be scalars. But positive_weight_is_tensor: %r, while " "negative_weight_is_tensor: %r" % (positive_weights_is_tensor, negative_weights_is_tensor) ) if positive_weights_is_tensor and ( positive_weights.size() != negative_weights.size() ): raise ValueError( "shape of positive_weights and negative_weights " "must be the same! " "shape of positive_weights is {0}, " "but shape of negative_weights is {1}" % (positive_weights.size(), negative_weights.size()) ) # positive_term: Tensor [N, C] or [N, C, K] positive_term = (1 - logits).clamp(min=0) * labels negative_term = (1 + logits).clamp(min=0) * (1 - labels) if positive_weights_is_tensor and positive_term.dim() == 2: return ( positive_term.unsqueeze(-1) * positive_weights + negative_term.unsqueeze(-1) * negative_weights ) else: return positive_term * positive_weights + negative_term * negative_weights
[docs]def true_positives_lower_bound(labels, logits, weights): """ true_positives_lower_bound defined in paper: "Scalable Learning of Non-Decomposable Objectives" Args: labels: A `Tensor` of shape broadcastable to logits. logits: A `Tensor` of shape [N, C] or [N, C, K]. If the third dimension is present, the lower bound is computed on each slice [:, :, k] independently. weights: Per-example loss coefficients, with shape [N, 1] or [N, C] Returns: A `Tensor` of shape [C] or [C, K]. """ # A `Tensor` of shape [N, C] or [N, C, K] loss_on_positives = weighted_hinge_loss(labels, logits, negative_weights=0.0) weighted_loss_on_positives = ( weights.unsqueeze(-1) * (labels - loss_on_positives) if loss_on_positives.dim() > weights.dim() else weights * (labels - loss_on_positives) ) return weighted_loss_on_positives.sum(0)
[docs]def false_postives_upper_bound(labels, logits, weights): """ false_positives_upper_bound defined in paper: "Scalable Learning of Non-Decomposable Objectives" Args: labels: A `Tensor` of shape broadcastable to logits. logits: A `Tensor` of shape [N, C] or [N, C, K]. If the third dimension is present, the lower bound is computed on each slice [:, :, k] independently. weights: Per-example loss coefficients, with shape broadcast-compatible with that of `labels`. i.e. [N, 1] or [N, C] Returns: A `Tensor` of shape [C] or [C, K]. """ loss_on_negatives = weighted_hinge_loss(labels, logits, positive_weights=0) weighted_loss_on_negatives = ( weights.unsqueeze(-1) * loss_on_negatives if loss_on_negatives.dim() > weights.dim() else weights * loss_on_negatives ) return weighted_loss_on_negatives.sum(0)
[docs]class LagrangeMultiplier(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return input.clamp(min=0)
[docs] @staticmethod def backward(ctx, grad_output): return grad_output.neg()
[docs]def lagrange_multiplier(x): return LagrangeMultiplier.apply(x)