import numpy as np import torch from sklearn.metrics import silhouette_score from sklearn.manifold import TSNE from training.batch_sorter import sort_batches from utils import get_logger def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None): pred_correct, pred_all = 0, 0 running_loss = 0.0 model.train(True) for i, data in enumerate(dataloader): inputs, labels = data inputs = inputs.squeeze(0).to(device) labels = labels.to(device, dtype=torch.long) optimizer.zero_grad() outputs = model(inputs).expand(1, -1, -1) loss = criterion(outputs[0], labels[0]) loss.backward() optimizer.step() running_loss += loss.item() # Statistics if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]): pred_correct += 1 pred_all += 1 epoch_loss = running_loss / len(dataloader) model.train(False) if scheduler: scheduler.step(epoch_loss) return epoch_loss, pred_correct, pred_all, (pred_correct / pred_all) def train_epoch_embedding(model, epoch_iters, train_loader, val_loader, criterion, optimizer, device, scheduler=None): running_loss = [] model.train(True) for i, (anchor, positive, negative, a_mask, p_mask, n_mask) in enumerate(train_loader): optimizer.zero_grad() anchor_emb = model(anchor.to(device), a_mask.to(device)) positive_emb = model(positive.to(device), p_mask.to(device)) negative_emb = model(negative.to(device), n_mask.to(device)) loss = criterion(anchor_emb.to(device), positive_emb.to(device), negative_emb.to(device)) loss.backward() optimizer.step() running_loss.append(loss.item()) if i == epoch_iters: break epoch_loss = np.mean(running_loss) # VALIDATION model.train(False) val_silhouette_coef = evaluate_embedding(model, val_loader, device) if scheduler: scheduler.step(val_silhouette_coef) return epoch_loss, val_silhouette_coef def train_epoch_embedding_online(model, epoch_iters, train_loader, val_loader, criterion, optimizer, device, scheduler=None, enable_batch_sorting=False, mini_batch_size=None, pre_batch_mining_count=1, batching_scheduler=None): running_loss = [] iter_used_triplets = [] iter_valid_triplets = [] iter_pct_used = [] model.train(True) mini_batch = mini_batch_size or train_loader.batch_size for i, (inputs, labels, masks) in enumerate(train_loader): labels_size = labels.size()[0] batch_loop_count = int(labels_size / mini_batch) if batch_loop_count == 0: continue # Second condition is added so that we only run batch sorting if we have a full batch if enable_batch_sorting: if labels_size < train_loader.batch_size: trim_count = labels_size % mini_batch if trim_count > 0: inputs = inputs[:-trim_count] labels = labels[:-trim_count] masks = masks[:-trim_count] embeddings = None with torch.no_grad(): for j in range(batch_loop_count): batch_embed = compute_batched_embeddings(model, device, inputs, masks, mini_batch, j) if embeddings is None: embeddings = batch_embed else: embeddings = torch.cat([embeddings, batch_embed], dim=0) inputs, labels, masks = sort_batches(inputs, labels, masks, embeddings, device, mini_batch_size=mini_batch_size, scheduler=batching_scheduler) del embeddings del batch_embed mining_loop_count = pre_batch_mining_count else: mining_loop_count = 1 for k in range(mining_loop_count): for j in range(batch_loop_count): optimizer.zero_grad(set_to_none=True) batch_labels = labels[mini_batch * j:mini_batch * (j + 1)] if batch_labels.size()[0] == 0: break embeddings = compute_batched_embeddings(model, device, inputs, masks, mini_batch, j) loss, valid_triplets, used_triplets = criterion(embeddings, batch_labels) loss.backward() optimizer.step() running_loss.append(loss.item()) if valid_triplets > 0: iter_used_triplets.append(used_triplets) iter_valid_triplets.append(valid_triplets) iter_pct_used.append((used_triplets * 100) / valid_triplets) if epoch_iters > 0 and i * batch_loop_count * pre_batch_mining_count >= epoch_iters: print("Breaking out because of epoch_iters filter") break epoch_loss = np.mean(running_loss) mean_used_triplets = np.mean(iter_used_triplets) triplets_stats = { "valid_triplets": np.mean(iter_valid_triplets), "used_triplets": mean_used_triplets, "pct_used": np.mean(iter_pct_used) } if batching_scheduler: batching_scheduler.step(mean_used_triplets) # VALIDATION model.train(False) with torch.no_grad(): val_silhouette_coef = evaluate_embedding(model, val_loader, device) if scheduler: scheduler.step(val_silhouette_coef) return epoch_loss, val_silhouette_coef, triplets_stats def compute_batched_embeddings(model, device, inputs, masks, mini_batch, iteration): batch_inputs = inputs[mini_batch * iteration:mini_batch * (iteration + 1)] batch_masks = masks[mini_batch * iteration:mini_batch * (iteration + 1)] return model(batch_inputs.to(device), batch_masks.to(device)).squeeze(1) def evaluate(model, dataloader, device, print_stats=False): logger = get_logger(__name__) pred_correct, pred_all = 0, 0 stats = {i: [0, 0] for i in range(251)} for i, data in enumerate(dataloader): inputs, labels = data inputs = inputs.squeeze(0).to(device) labels = labels.to(device, dtype=torch.long) outputs = model(inputs).expand(1, -1, -1) # Statistics if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]): stats[int(labels[0][0])][0] += 1 pred_correct += 1 stats[int(labels[0][0])][1] += 1 pred_all += 1 if print_stats: stats = {key: value[0] / value[1] for key, value in stats.items() if value[1] != 0} print("Label accuracies statistics:") print(str(stats) + "\n") logger.info("Label accuracies statistics:") logger.info(str(stats) + "\n") return pred_correct, pred_all, (pred_correct / pred_all) def evaluate_embedding(model, dataloader, device): val_embeddings = [] labels_emb = [] for i, (inputs, labels, masks) in enumerate(dataloader): inputs = inputs.to(device) masks = masks.to(device) outputs = model(inputs, masks) for n in range(outputs.shape[0]): val_embeddings.append(outputs[n, 0].cpu().detach().numpy()) labels_emb.append(labels.detach().numpy()[n]) silhouette_coefficient = silhouette_score( X=np.array(val_embeddings), labels=np.array(labels_emb).reshape(len(labels_emb)) ) return silhouette_coefficient def embeddings_scatter_plot(model, dataloader, device, id_to_label, perplexity=40, n_iter=1000): val_embeddings = [] labels_emb = [] with torch.no_grad(): for i, (inputs, labels, masks) in enumerate(dataloader): inputs = inputs.to(device) masks = masks.to(device) outputs = model(inputs, masks) for n in range(outputs.shape[0]): val_embeddings.append(outputs[n, 0].cpu().detach().numpy()) labels_emb.append(id_to_label[int(labels.detach().numpy()[n])]) tsne = TSNE(n_components=2, verbose=0, perplexity=perplexity, n_iter=n_iter) tsne_results = tsne.fit_transform(np.array(val_embeddings)) return tsne_results, labels_emb def embeddings_scatter_plot_splits(model, dataloaders, device, id_to_label, perplexity=40, n_iter=1000): labels_split = {} embeddings_split = {} splits = list(dataloaders.keys()) with torch.no_grad(): for split, dataloader in dataloaders.items(): labels_str = [] embeddings = [] for i, (inputs, labels, masks) in enumerate(dataloader): inputs = inputs.to(device) masks = masks.to(device) outputs = model(inputs, masks) for n in range(outputs.shape[0]): embeddings.append(outputs[n, 0].cpu().detach().numpy()) labels_str.append(id_to_label[int(labels.detach().numpy()[n])]) labels_split[split] = labels_str embeddings_split[split] = embeddings tsne = TSNE(n_components=2, verbose=0, perplexity=perplexity, n_iter=n_iter) all_embeddings = np.vstack([embeddings_split[split] for split in splits]) tsne_results = tsne.fit_transform(all_embeddings) tsne_results_dict = {} curr_index = 0 for split in splits: len_embeddings = len(embeddings_split[split]) tsne_results_dict[split] = tsne_results[curr_index: curr_index + len_embeddings] curr_index += len_embeddings return tsne_results_dict, labels_split def evaluate_top_k(model, dataloader, device, k=5): pred_correct, pred_all = 0, 0 for i, data in enumerate(dataloader): inputs, labels = data inputs = inputs.squeeze(0).to(device) labels = labels.to(device, dtype=torch.long) outputs = model(inputs).expand(1, -1, -1) if int(labels[0][0]) in torch.topk(outputs, k).indices.tolist()[0][0]: pred_correct += 1 pred_all += 1 return pred_correct, pred_all, (pred_correct / pred_all)