Files
spoterembedding/training/train_utils.py
Mathias Claassen 81bbf66aab Initial codebase (#1)
* Add project code

* Logger improvements

* Improvements to web demo code

* added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py

* Fix rotation augmentation

* fixed error in docstring, and removed unnecessary replace -1 -> 0

* Readme updates

* Share base notebooks

* Add notebooks and unify for different datasets

* requirements update

* fixes

* Make evaluate more deterministic

* Allow training with clearml

* refactor preprocessing and apply linter

* Minor fixes

* Minor notebook tweaks

* Readme updates

* Fix PR comments

* Remove unneeded code

* Add banner to Readme

---------

Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
2023-03-03 10:07:54 -03:00

72 lines
2.6 KiB
Python

import os
import random
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from models import embeddings_scatter_plot, embeddings_scatter_plot_splits
def train_setup(seed, experiment_name):
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
g = torch.Generator()
g.manual_seed(seed)
return g
def create_embedding_scatter_plots(tracker, model, train_loader, val_loader, device, id_to_label, epoch, model_name):
tsne_results, labels = embeddings_scatter_plot(model, train_loader, device, id_to_label, perplexity=40, n_iter=1000)
df = pd.DataFrame({'x': tsne_results[:, 0],
'y': tsne_results[:, 1],
'label': labels})
fig = px.scatter(df, y="y", x="x", color="label")
tracker.log_chart(
title="Training Scatter Plot with Best Model: " + model_name,
series="Scatter Plot",
iteration=epoch,
figure=fig
)
tsne_results, labels = embeddings_scatter_plot(model, val_loader, device, id_to_label, perplexity=40, n_iter=1000)
df = pd.DataFrame({'x': tsne_results[:, 0],
'y': tsne_results[:, 1],
'label': labels})
fig = px.scatter(df, y="y", x="x", color="label")
tracker.log_chart(
title="Validation Scatter Plot with Best Model: " + model_name,
series="Scatter Plot",
iteration=epoch,
figure=fig,
)
dataloaders = {'train': train_loader,
'val': val_loader}
splits = list(dataloaders.keys())
tsne_results_splits, labels_splits = embeddings_scatter_plot_splits(model, dataloaders,
device, id_to_label, perplexity=40, n_iter=1000)
tsne_results = np.vstack([tsne_results_splits[split] for split in splits])
labels = np.concatenate([labels_splits[split] for split in splits])
split = np.concatenate([[split]*len(labels_splits[split]) for split in splits])
df = pd.DataFrame({'x': tsne_results[:, 0],
'y': tsne_results[:, 1],
'label': labels,
'split': split})
fig = px.scatter(df, y="y", x="x", color="label", symbol='split')
tracker.log_chart(
title="Scatter Plot of train and val with Best Model: " + model_name,
series="Scatter Plot",
iteration=epoch,
figure=fig,
)