* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
import os
|
|
import random
|
|
import numpy as np
|
|
import pandas as pd
|
|
import plotly.express as px
|
|
import torch
|
|
|
|
from models import embeddings_scatter_plot, embeddings_scatter_plot_splits
|
|
|
|
|
|
def train_setup(seed, experiment_name):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
g = torch.Generator()
|
|
g.manual_seed(seed)
|
|
return g
|
|
|
|
|
|
def create_embedding_scatter_plots(tracker, model, train_loader, val_loader, device, id_to_label, epoch, model_name):
|
|
tsne_results, labels = embeddings_scatter_plot(model, train_loader, device, id_to_label, perplexity=40, n_iter=1000)
|
|
|
|
df = pd.DataFrame({'x': tsne_results[:, 0],
|
|
'y': tsne_results[:, 1],
|
|
'label': labels})
|
|
fig = px.scatter(df, y="y", x="x", color="label")
|
|
|
|
tracker.log_chart(
|
|
title="Training Scatter Plot with Best Model: " + model_name,
|
|
series="Scatter Plot",
|
|
iteration=epoch,
|
|
figure=fig
|
|
)
|
|
|
|
tsne_results, labels = embeddings_scatter_plot(model, val_loader, device, id_to_label, perplexity=40, n_iter=1000)
|
|
|
|
df = pd.DataFrame({'x': tsne_results[:, 0],
|
|
'y': tsne_results[:, 1],
|
|
'label': labels})
|
|
fig = px.scatter(df, y="y", x="x", color="label")
|
|
|
|
tracker.log_chart(
|
|
title="Validation Scatter Plot with Best Model: " + model_name,
|
|
series="Scatter Plot",
|
|
iteration=epoch,
|
|
figure=fig,
|
|
)
|
|
|
|
dataloaders = {'train': train_loader,
|
|
'val': val_loader}
|
|
splits = list(dataloaders.keys())
|
|
tsne_results_splits, labels_splits = embeddings_scatter_plot_splits(model, dataloaders,
|
|
device, id_to_label, perplexity=40, n_iter=1000)
|
|
tsne_results = np.vstack([tsne_results_splits[split] for split in splits])
|
|
labels = np.concatenate([labels_splits[split] for split in splits])
|
|
split = np.concatenate([[split]*len(labels_splits[split]) for split in splits])
|
|
df = pd.DataFrame({'x': tsne_results[:, 0],
|
|
'y': tsne_results[:, 1],
|
|
'label': labels,
|
|
'split': split})
|
|
fig = px.scatter(df, y="y", x="x", color="label", symbol='split')
|
|
tracker.log_chart(
|
|
title="Scatter Plot of train and val with Best Model: " + model_name,
|
|
series="Scatter Plot",
|
|
iteration=epoch,
|
|
figure=fig,
|
|
)
|