* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
105 lines
3.7 KiB
Python
105 lines
3.7 KiB
Python
import unittest
|
|
# from traceback_with_variables import activate_by_import #noqa
|
|
import torch
|
|
from training.batch_sorter import BatchGrouper, sort_batches, get_scaled_distances, get_dist_tuple_list
|
|
|
|
|
|
class TestBatchSorting(unittest.TestCase):
|
|
|
|
def get_sorted_dists(self):
|
|
device = get_device()
|
|
embeddings = torch.rand(32*8, 8).to(device)
|
|
labels = torch.rand(32*8, 1)
|
|
scaled_dist = get_scaled_distances(embeddings, labels, device)
|
|
# Get vector of (row, column, dist)
|
|
dist_list = get_dist_tuple_list(scaled_dist)
|
|
|
|
A = dist_list.cpu().detach().numpy()
|
|
return A[:, A[-1, :].argsort()[::-1]]
|
|
|
|
def setUp(self) -> None:
|
|
dists = self.get_sorted_dists()
|
|
self.grouper = BatchGrouper(sorted_dists=dists, total_items=32*8, mini_batch_size=32)
|
|
return super().setUp()
|
|
|
|
def test_assigns_and_merges(self):
|
|
group0 = self.grouper.create_or_get_group()
|
|
self.grouper.assign_group(1, group0)
|
|
self.grouper.assign_group(2, group0)
|
|
|
|
group1 = self.grouper.create_or_get_group()
|
|
self.grouper.assign_group(3, group1)
|
|
self.grouper.assign_group(4, group1)
|
|
|
|
# Merge groups
|
|
self.grouper.merge_groups(group0, group1)
|
|
self.assertEqual(len(self.grouper.groups[group0]), 4)
|
|
self.assertFalse(group1 in self.grouper.groups)
|
|
self.assertEqual(self.grouper.item_to_group[3], group0)
|
|
self.assertEqual(self.grouper.item_to_group[4], group0)
|
|
|
|
def test_full_groups(self):
|
|
group0 = self.grouper.create_or_get_group()
|
|
for i in range(30):
|
|
self.grouper.assign_group(i, group0)
|
|
|
|
self.assertFalse(self.grouper.group_is_full(group0))
|
|
initial_group_len = len(self.grouper.groups[group0])
|
|
|
|
group1 = self.grouper.create_or_get_group()
|
|
for i in range(30, 33):
|
|
self.grouper.assign_group(i, group1)
|
|
|
|
self.grouper.merge_groups(group0, group1)
|
|
# Assert no merge done
|
|
self.assertEqual(len(self.grouper.groups[group0]), initial_group_len)
|
|
self.assertTrue(group1 in self.grouper.groups)
|
|
self.assertEqual(self.grouper.item_to_group[31], group1)
|
|
self.assertEqual(self.grouper.item_to_group[32], group1)
|
|
|
|
def test_replace_groups(self):
|
|
group0 = self.grouper.create_or_get_group()
|
|
for i in range(20):
|
|
self.grouper.assign_group(i, group0)
|
|
|
|
group1 = self.grouper.create_or_get_group()
|
|
for i in range(20, 23):
|
|
self.grouper.assign_group(i, group1)
|
|
|
|
group2 = self.grouper.create_or_get_group()
|
|
for i in range(23, 30):
|
|
self.grouper.assign_group(i, group2)
|
|
|
|
self.grouper.merge_groups(group1, group0)
|
|
self.assertEqual(len(self.grouper.groups[group0]), 23)
|
|
self.assertTrue(group1 in self.grouper.groups)
|
|
self.assertFalse(group2 in self.grouper.groups)
|
|
self.assertEqual(len(self.grouper.groups[group1]), 7)
|
|
|
|
|
|
def get_device():
|
|
device = torch.device("cpu")
|
|
return device
|
|
|
|
|
|
def test_get_scaled_distances():
|
|
device = get_device()
|
|
emb = torch.rand(4, 3)
|
|
labels = torch.tensor([0, 1, 2, 2])
|
|
distances = get_scaled_distances(emb, labels, device)
|
|
assert torch.all(distances >= 0)
|
|
assert torch.all(distances <= 1)
|
|
|
|
|
|
def test_batch_sorter_indices():
|
|
device = get_device()
|
|
inputs = torch.rand(32*16, 1000)
|
|
labels = torch.rand(32*16, 1)
|
|
masks = torch.rand(32*16, 100)
|
|
embeddings = torch.rand(32*16, 32).to(device)
|
|
|
|
i_out, l_out, m_out = sort_batches(inputs, labels, masks, embeddings, device)
|
|
first_match_index = torch.all(inputs == i_out[0], dim=1).nonzero(as_tuple=True)[0][0]
|
|
assert torch.all(labels[first_match_index] == l_out[0])
|
|
assert torch.all(masks[first_match_index] == m_out[0])
|