* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
216 lines
9.1 KiB
Python
216 lines
9.1 KiB
Python
import logging
|
|
from datetime import datetime
|
|
import numpy as np
|
|
from typing import Optional
|
|
from .batching_scheduler import BatchingScheduler
|
|
import torch
|
|
|
|
logger = logging.getLogger("BatchGrouper")
|
|
|
|
|
|
class BatchGrouper:
|
|
"""
|
|
Will cluster all `total_items` into `max_groups` clusters based on distances in
|
|
`sorted_dists`. Each group has `mini_batch_size` elements and these elements are just integers in
|
|
range 0...total_items.
|
|
|
|
Distances between these items are expected to be scaled to 0...1 in a way that distances for two items in the
|
|
same class are higher if closer to 1, while distances between elements of different classes are higher if closer
|
|
to 0.
|
|
|
|
The logic is picking the highest value distance and assigning both items to the same cluster/group if possible.
|
|
This might include merging 2 clusters.
|
|
There are a few threshold to limit the computational cost. If the scaled distance between a pair is below
|
|
`dist_threshold`, or more than `assign_threshold` percent of items have been assigned to the groups, we stop and
|
|
assign the remanining items to the groups that have space left.
|
|
"""
|
|
# Counters
|
|
next_group = 0
|
|
items_assigned = 0
|
|
|
|
# Thresholds
|
|
dist_threshold = 0.5
|
|
assign_threshold = 0.80
|
|
|
|
def __init__(self, sorted_dists, total_items, mini_batch_size=32, dist_threshold=0.5, assign_threshold=0.8) -> None:
|
|
self.sorted_dists = sorted_dists
|
|
self.total_items = total_items
|
|
self.mini_batch_size = mini_batch_size
|
|
self.max_groups = int(total_items / mini_batch_size)
|
|
self.groups = {}
|
|
self.item_to_group = {}
|
|
self.items_assigned = 0
|
|
self.next_group = 0
|
|
self.dist_threshold = dist_threshold
|
|
self.assign_threshold = assign_threshold
|
|
|
|
def cluster_items(self):
|
|
"""Main function of this class. Does the clustering explained in class docstring.
|
|
|
|
:raises e: _description_
|
|
:return _type_: _description_
|
|
"""
|
|
for i in range(self.sorted_dists.shape[-1]): # and some other conditions are unmet
|
|
a, b, dist = self.sorted_dists[:, i]
|
|
a, b = int(a), int(b)
|
|
if dist < self.dist_threshold or self.items_assigned > self.total_items * self.assign_threshold:
|
|
logger.info(f"Breaking with dist: {dist}, and {self.items_assigned} items assigned")
|
|
break
|
|
if a not in self.item_to_group and b not in self.item_to_group:
|
|
g = self.create_or_get_group()
|
|
self.assign_group(a, g)
|
|
self.assign_group(b, g)
|
|
elif a not in self.item_to_group:
|
|
if not self.group_is_full(self.item_to_group[b]):
|
|
self.assign_group(a, self.item_to_group[b])
|
|
elif b not in self.item_to_group:
|
|
if not self.group_is_full(self.item_to_group[a]):
|
|
self.assign_group(b, self.item_to_group[a])
|
|
else:
|
|
grp_a = self.item_to_group[a]
|
|
grp_b = self.item_to_group[b]
|
|
self.merge_groups(grp_a, grp_b)
|
|
self.assign_remaining_items()
|
|
return list(np.concatenate(list(self.groups.values())).flat)
|
|
|
|
def assign_group(self, item, group):
|
|
"""Assigns `item` to group `group`
|
|
"""
|
|
self.item_to_group[item] = group
|
|
self.groups[group].append(item)
|
|
self.items_assigned += 1
|
|
|
|
def create_or_get_group(self):
|
|
"""Creates a new group if current group count is less than max_groups.
|
|
Otherwise returns first group with space left.
|
|
|
|
:return int: The group id
|
|
"""
|
|
if self.next_group < self.max_groups:
|
|
group = self.next_group
|
|
self.groups[group] = []
|
|
self.next_group += 1
|
|
else:
|
|
for i in range(self.next_group):
|
|
if len(self.groups[i]) <= self.mini_batch_size - 2:
|
|
group = i
|
|
break # out of the for loop
|
|
return group
|
|
|
|
def group_is_full(self, group):
|
|
return len(self.groups[group]) == self.mini_batch_size
|
|
|
|
def can_merge_groups(self, grp_a, grp_b):
|
|
return grp_a != grp_b and (len(self.groups[grp_a]) + len(self.groups[grp_b]) < self.mini_batch_size)
|
|
|
|
def merge_groups(self, grp_a, grp_b):
|
|
"""Will merge two groups together, if possible. Otherwise does nothing.
|
|
"""
|
|
if grp_a > grp_b:
|
|
grp_a, grp_b = grp_b, grp_a
|
|
if self.can_merge_groups(grp_a, grp_b):
|
|
logger.debug(f"MERGE {grp_a} with {grp_b}: {len(self.groups[grp_a])} {len(self.groups[grp_b])}")
|
|
for b in self.groups[grp_b]:
|
|
self.item_to_group[b] = grp_a
|
|
self.groups[grp_a].extend(self.groups[grp_b])
|
|
self.groups[grp_b] = []
|
|
self.replace_group(grp_b)
|
|
|
|
def replace_group(self, group):
|
|
"""Replace a group with the last one in the list
|
|
|
|
:param int group: Group to replace
|
|
"""
|
|
grp_to_change = self.next_group - 1
|
|
if grp_to_change != group:
|
|
for item in self.groups[grp_to_change]:
|
|
self.item_to_group[item] = group
|
|
self.groups[group] = self.groups[grp_to_change]
|
|
del self.groups[grp_to_change]
|
|
self.next_group -= 1
|
|
|
|
def assign_remaining_items(self):
|
|
""" Assign remaining items into groups
|
|
"""
|
|
grp_pointer = 0
|
|
i = 0
|
|
logger.info(f"Assigning rest of items: {self.items_assigned} of {self.total_items}")
|
|
while i < self.total_items:
|
|
if i not in self.item_to_group:
|
|
if grp_pointer not in self.groups:
|
|
# This would happen if a group is still empty at this stage
|
|
assert grp_pointer < self.max_groups
|
|
new_group = self.create_or_get_group()
|
|
assert new_group == grp_pointer
|
|
if len(self.groups[grp_pointer]) < self.mini_batch_size:
|
|
self.assign_group(i, grp_pointer)
|
|
i += 1
|
|
else:
|
|
grp_pointer += 1
|
|
else:
|
|
i += 1
|
|
|
|
|
|
def get_dist_tuple_list(dist_matrix):
|
|
batch_size = dist_matrix.size()[0]
|
|
indices = torch.tril_indices(batch_size, batch_size, offset=-1)
|
|
values = dist_matrix[indices[0], indices[1]].cpu()
|
|
return torch.cat([indices, values.unsqueeze(0)], dim=0)
|
|
|
|
|
|
def get_scaled_distances(embeddings, labels, device, same_label_factor=1):
|
|
"""Returns distance matrix between all embeddings scaled to the 0-1 range where 0 is good and 1 is bad.
|
|
This means that small distances for embeddings of the same class will be close to 0 while small distances for
|
|
embeddings of different classes will be close to 1
|
|
|
|
:param _type_ embeddings: Embeddings of batch items
|
|
:param _type_ labels: Labels associated to the embeddings
|
|
:param _type_ device: Device to run on (cuda or cpu)
|
|
:param int same_label_factor: Multiplies the weight of same-class distances allowing to give more or less importance
|
|
to these compared to distinct-class distances, defaults to 1 (which means equal weight)
|
|
:return torch.Tensor: Scaled distance matrix
|
|
"""
|
|
# Get pairwise distance matrix
|
|
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
|
|
# Get list of tuples with emb_A, emb_B, dist ordered by greater for same label and smaller for diff label
|
|
# shape: (batch_size, batch_size)
|
|
labels = labels.to(device)
|
|
labels_equal = (labels.unsqueeze(0) == labels.unsqueeze(1)).squeeze()
|
|
labels_distinct = torch.logical_not(labels_equal)
|
|
pos_dist = distance_matrix * labels_equal
|
|
neg_dist = distance_matrix * labels_distinct
|
|
|
|
# Use some scaling to bring both to a range of 0-1
|
|
pos_max = pos_dist.max()
|
|
neg_max = neg_dist.max()
|
|
# Closer to 1 is harder
|
|
pos_dist = pos_dist / pos_max * same_label_factor
|
|
neg_dist = 1 * labels_distinct - (neg_dist / neg_max)
|
|
return pos_dist + neg_dist
|
|
|
|
|
|
def sort_batches(inputs, labels, masks, embeddings, device, mini_batch_size=32,
|
|
scheduler: Optional[BatchingScheduler] = None):
|
|
start = datetime.now()
|
|
|
|
same_label_factor = scheduler.get_scaling_same_label_factor() if scheduler else 1
|
|
scaled_dist = get_scaled_distances(embeddings, labels, device, same_label_factor)
|
|
# Get vector of (row, column, dist)
|
|
dist_list = get_dist_tuple_list(scaled_dist)
|
|
|
|
dist_list = dist_list.cpu().detach().numpy()
|
|
# Sort distances descending by last row
|
|
sorted_dists = dist_list[:, dist_list[-1, :].argsort()[::-1]]
|
|
|
|
# Loop through list assigning both items to same group
|
|
dist_threshold = scheduler.get_dist_threshold() if scheduler else 0.5
|
|
grouper = BatchGrouper(sorted_dists, total_items=labels.size()[0], mini_batch_size=mini_batch_size,
|
|
dist_threshold=dist_threshold)
|
|
indices = torch.tensor(grouper.cluster_items()).type(torch.IntTensor)
|
|
final_inputs = torch.index_select(inputs, dim=0, index=indices)
|
|
final_labels = torch.index_select(labels, dim=0, index=indices)
|
|
final_masks = torch.index_select(masks, dim=0, index=indices)
|
|
|
|
logger.info(f"Batch sorting took: {datetime.now() - start}")
|
|
return final_inputs, final_labels, final_masks
|