* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
106 lines
4.5 KiB
Python
106 lines
4.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks
|
|
|
|
# Adapted from https://qdrant.tech/articles/triplet-loss/
|
|
|
|
|
|
class BatchAllTripletLoss(nn.Module):
|
|
"""Uses all valid triplets to compute Triplet loss
|
|
Args:
|
|
margin: Margin value in the Triplet Loss equation
|
|
"""
|
|
|
|
def __init__(self, device, margin=1., filter_easy_triplets=True):
|
|
super().__init__()
|
|
self.margin = margin
|
|
self.device = device
|
|
self.filter_easy_triplets = filter_easy_triplets
|
|
|
|
def get_triplet_mask(self, labels):
|
|
"""compute a mask for valid triplets
|
|
Args:
|
|
labels: Batch of integer labels. shape: (batch_size,)
|
|
Returns:
|
|
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
|
|
A triplet is valid if:
|
|
`labels[i] == labels[j] and labels[i] != labels[k]`
|
|
and `i`, `j`, `k` are different.
|
|
"""
|
|
# step 1 - get a mask for distinct indices
|
|
|
|
# shape: (batch_size, batch_size)
|
|
indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
|
|
indices_not_equal = torch.logical_not(indices_equal)
|
|
# shape: (batch_size, batch_size, 1)
|
|
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
|
# shape: (batch_size, 1, batch_size)
|
|
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
|
# shape: (1, batch_size, batch_size)
|
|
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
|
# Shape: (batch_size, batch_size, batch_size)
|
|
distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
|
|
|
|
# step 2 - get a mask for valid anchor-positive-negative triplets
|
|
|
|
# shape: (batch_size, batch_size)
|
|
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
|
# shape: (batch_size, batch_size, 1)
|
|
i_equal_j = labels_equal.unsqueeze(2)
|
|
# shape: (batch_size, 1, batch_size)
|
|
i_equal_k = labels_equal.unsqueeze(1)
|
|
# shape: (batch_size, batch_size, batch_size)
|
|
valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))
|
|
|
|
# step 3 - combine two masks
|
|
mask = torch.logical_and(distinct_indices, valid_indices)
|
|
|
|
return mask
|
|
|
|
def forward(self, embeddings, labels, filter_easy_triplets=True):
|
|
"""computes loss value.
|
|
Args:
|
|
embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
|
|
labels: Batch of integer labels associated with embeddings. shape: (batch_size,)
|
|
Returns:
|
|
Scalar loss value.
|
|
"""
|
|
# step 1 - get distance matrix
|
|
# shape: (batch_size, batch_size)
|
|
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
|
|
|
|
# step 2 - compute loss values for all triplets by applying broadcasting to distance matrix
|
|
|
|
# shape: (batch_size, batch_size, 1)
|
|
anchor_positive_dists = distance_matrix.unsqueeze(2)
|
|
# shape: (batch_size, 1, batch_size)
|
|
anchor_negative_dists = distance_matrix.unsqueeze(1)
|
|
# get loss values for all possible n^3 triplets
|
|
# shape: (batch_size, batch_size, batch_size)
|
|
triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin
|
|
|
|
# step 3 - filter out invalid or easy triplets by setting their loss values to 0
|
|
|
|
# shape: (batch_size, batch_size, batch_size)
|
|
mask = self.get_triplet_mask(labels)
|
|
valid_triplets = mask.sum()
|
|
triplet_loss *= mask.to(self.device)
|
|
# easy triplets have negative loss values
|
|
triplet_loss = F.relu(triplet_loss)
|
|
|
|
if self.filter_easy_triplets:
|
|
# step 4 - compute scalar loss value by averaging positive losses
|
|
num_positive_losses = (triplet_loss > eps).float().sum()
|
|
# We want to factor in how many triplets were used compared to batch_size (used_triplets * 3 / batch_size)
|
|
# The effect of this should be similar to LR decay but penalizing batches with fewer hard triplets
|
|
percent_used_factor = min(1.0, num_positive_losses * 3 / labels.size()[0])
|
|
|
|
triplet_loss = triplet_loss.sum() / (num_positive_losses + eps) * percent_used_factor
|
|
return triplet_loss, valid_triplets, int(num_positive_losses)
|
|
else:
|
|
triplet_loss = triplet_loss.sum() / (valid_triplets + eps)
|
|
return triplet_loss, valid_triplets, valid_triplets
|