Files
spoterembedding/training/online_batch_mining.py
Mathias Claassen 81bbf66aab Initial codebase (#1)
* Add project code

* Logger improvements

* Improvements to web demo code

* added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py

* Fix rotation augmentation

* fixed error in docstring, and removed unnecessary replace -1 -> 0

* Readme updates

* Share base notebooks

* Add notebooks and unify for different datasets

* requirements update

* fixes

* Make evaluate more deterministic

* Allow training with clearml

* refactor preprocessing and apply linter

* Minor fixes

* Minor notebook tweaks

* Readme updates

* Fix PR comments

* Remove unneeded code

* Add banner to Readme

---------

Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
2023-03-03 10:07:54 -03:00

106 lines
4.5 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks
# Adapted from https://qdrant.tech/articles/triplet-loss/
class BatchAllTripletLoss(nn.Module):
"""Uses all valid triplets to compute Triplet loss
Args:
margin: Margin value in the Triplet Loss equation
"""
def __init__(self, device, margin=1., filter_easy_triplets=True):
super().__init__()
self.margin = margin
self.device = device
self.filter_easy_triplets = filter_easy_triplets
def get_triplet_mask(self, labels):
"""compute a mask for valid triplets
Args:
labels: Batch of integer labels. shape: (batch_size,)
Returns:
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
A triplet is valid if:
`labels[i] == labels[j] and labels[i] != labels[k]`
and `i`, `j`, `k` are different.
"""
# step 1 - get a mask for distinct indices
# shape: (batch_size, batch_size)
indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
indices_not_equal = torch.logical_not(indices_equal)
# shape: (batch_size, batch_size, 1)
i_not_equal_j = indices_not_equal.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
i_not_equal_k = indices_not_equal.unsqueeze(1)
# shape: (1, batch_size, batch_size)
j_not_equal_k = indices_not_equal.unsqueeze(0)
# Shape: (batch_size, batch_size, batch_size)
distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
# step 2 - get a mask for valid anchor-positive-negative triplets
# shape: (batch_size, batch_size)
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
# shape: (batch_size, batch_size, 1)
i_equal_j = labels_equal.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
i_equal_k = labels_equal.unsqueeze(1)
# shape: (batch_size, batch_size, batch_size)
valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))
# step 3 - combine two masks
mask = torch.logical_and(distinct_indices, valid_indices)
return mask
def forward(self, embeddings, labels, filter_easy_triplets=True):
"""computes loss value.
Args:
embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
labels: Batch of integer labels associated with embeddings. shape: (batch_size,)
Returns:
Scalar loss value.
"""
# step 1 - get distance matrix
# shape: (batch_size, batch_size)
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
# step 2 - compute loss values for all triplets by applying broadcasting to distance matrix
# shape: (batch_size, batch_size, 1)
anchor_positive_dists = distance_matrix.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
anchor_negative_dists = distance_matrix.unsqueeze(1)
# get loss values for all possible n^3 triplets
# shape: (batch_size, batch_size, batch_size)
triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin
# step 3 - filter out invalid or easy triplets by setting their loss values to 0
# shape: (batch_size, batch_size, batch_size)
mask = self.get_triplet_mask(labels)
valid_triplets = mask.sum()
triplet_loss *= mask.to(self.device)
# easy triplets have negative loss values
triplet_loss = F.relu(triplet_loss)
if self.filter_easy_triplets:
# step 4 - compute scalar loss value by averaging positive losses
num_positive_losses = (triplet_loss > eps).float().sum()
# We want to factor in how many triplets were used compared to batch_size (used_triplets * 3 / batch_size)
# The effect of this should be similar to LR decay but penalizing batches with fewer hard triplets
percent_used_factor = min(1.0, num_positive_losses * 3 / labels.size()[0])
triplet_loss = triplet_loss.sum() / (num_positive_losses + eps) * percent_used_factor
return triplet_loss, valid_triplets, int(num_positive_losses)
else:
triplet_loss = triplet_loss.sum() / (valid_triplets + eps)
return triplet_loss, valid_triplets, valid_triplets