Initial codebase (#1)
* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
This commit is contained in:
6
.flake8
Normal file
6
.flake8
Normal file
@@ -0,0 +1,6 @@
|
||||
[flake8]
|
||||
max-line-length = 130
|
||||
per-file-ignores =
|
||||
__init__.py: F401
|
||||
exclude =
|
||||
.git,__pycache__,
|
||||
13
Dockerfile
Normal file
13
Dockerfile
Normal file
@@ -0,0 +1,13 @@
|
||||
FROM pytorch/pytorch
|
||||
|
||||
WORKDIR /app
|
||||
COPY ./requirements.txt /app/
|
||||
|
||||
RUN pip install -r requirements.txt
|
||||
RUN apt-get -y update
|
||||
RUN apt-get -y install git
|
||||
RUN apt-get install ffmpeg libsm6 libxext6 -y
|
||||
|
||||
COPY . /app/
|
||||
RUN git config --global --add safe.directory /app
|
||||
CMD ./train.sh
|
||||
137
README.md
Normal file
137
README.md
Normal file
@@ -0,0 +1,137 @@
|
||||
|
||||
<img src="assets/banner.png" width=100%/>
|
||||
|
||||
# SPOTER Embeddings
|
||||
|
||||
This repository contains code for the Spoter embedding model.
|
||||
<!-- explained in this [blog post](link...). -->
|
||||
The model is heavily based on [Spoter] which was presented in
|
||||
[Sign Pose-Based Transformer for Word-Level Sign Language Recognition](https://openaccess.thecvf.com/content/WACV2022W/HADCV/html/Bohacek_Sign_Pose-Based_Transformer_for_Word-Level_Sign_Language_Recognition_WACVW_2022_paper.html) with one of the main modifications being
|
||||
that this is an embedding model instead of a classification model.
|
||||
This allows for several zero-shot tasks on unseen Sign Language datasets from around the world.
|
||||
<!-- More details about this are shown in the blog post mentioned above. -->
|
||||
|
||||
## Modifications on [SPOTER](https://github.com/matyasbohacek/spoter)
|
||||
Here is a list of the main modifications made on Spoter code and model architecture:
|
||||
|
||||
* The output layer is a linear layer but trained using triplet loss instead of CrossEntropyLoss. The output of the model
|
||||
is therefore an embedding vector that can be used for several downstream tasks.
|
||||
* We started using the keypoints dataset published by Spoter but later created new datasets using BlazePose from Mediapipe (as it is done in [Spoter 2](https://arxiv.org/abs/2210.00893)). This improves results considerably.
|
||||
* We select batches in a way that they contain several hard triplets and then compute the loss on all hard triplets found in each batch.
|
||||
* Some code refactoring to acomodate new classes we implemented.
|
||||
* Minor code fix when using rotate augmentation to avoid exceptions.
|
||||
|
||||
<!-- Include GIFs for Spoter and Spoter embeddings. This could be linked from the blog post -->
|
||||
|
||||
|
||||
## Results
|
||||
|
||||

|
||||
|
||||
We used the silhouette score to measure how well the clusters are defined during the training step.
|
||||
Silhouette score will be high (close to 1) when all clusters of different classes are well separated from each other, and it will be low (close to -1) for the opposite.
|
||||
Our best model reached 0.7 on the train set and 0.1 on validation.
|
||||
|
||||
### Classification accuracy
|
||||
While the model was not trained with classification specifically in mind, it can still be used for that purpose.
|
||||
Here we show top-1 and top-5 classifications which are calculated by taking the 1 (or 5) nearest vector of different classes, to the target vector.
|
||||
|
||||
To estimate the accuracy for LSA, we take a “train” set as given and then classify the holdout set based on the closest vectors from the “train” set.
|
||||
This is done using the model trained on WLASL100 dataset only, to show how our model has zero-shot capabilities.
|
||||
|
||||

|
||||
|
||||
<!-- Also link the product blog here -->
|
||||
|
||||
|
||||
## Get Started
|
||||
|
||||
The recommended way of running code from this repo is by using **Docker**.
|
||||
|
||||
Clone this repository and run:
|
||||
```
|
||||
docker build -t spoter_embeddings .
|
||||
docker run --rm -it --entrypoint=bash --gpus=all -v $PWD:/app spoter_embeddings
|
||||
```
|
||||
|
||||
> Running without specifying the `entrypoint` will train the model with the hyperparameters specified in `train.sh`
|
||||
|
||||
If you prefer running in a **virtual environment** instead, then first install dependencies:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
> We tested this using Python 3.7.13. Other versions may work.
|
||||
|
||||
To train the model, run `train.sh` in Docker or your virtual env.
|
||||
|
||||
The hyperparameters with their descriptions can be found in the [train.py](link...) file.
|
||||
|
||||
|
||||
## Data
|
||||
|
||||
Same as with SPOTER, this model works on top of sequences of signers' skeletal data extracted from videos.
|
||||
This means that the input data has a much lower dimension compared to using videos directly, and therefore the model is
|
||||
quicker and lighter, while you can choose any SOTA body pose model to preprocess video.
|
||||
This makes our model lightweight and able to run in real-time (for example, it takes around 40ms to process a 4-second
|
||||
25 FPS video inside a web browser using onnxruntime)
|
||||
|
||||

|
||||
|
||||
For ready to use datasets refer to the [Spoter] repository.
|
||||
|
||||
For best results, we recommend building your own dataset by downloading a Sign language video dataset such as [WLASL] and then using the `extract_mediapipe_landmarks.py` and `create_wlasl_landmarks_dataset.py` scripts to create a body keypoints datasets that can be used to train the Spoter embeddings model.
|
||||
|
||||
You can run these scripts as follows:
|
||||
```bash
|
||||
# This will extract landmarks from the downloaded videos
|
||||
python3 preprocessing.py extract -videos <path_to_video_folder> --output-landmarks <path_to_landmarks_folder>
|
||||
|
||||
# This will create a dataset (csv file) with the first 100 classes, splitting 20% of it to the test set, and 80% for train
|
||||
python3 preprocessing.py create -videos <path_to_video_folder> -lmks <path_to_landmarks_folder> --dataset-folder=<output_folder> --create-new-split -ts=0.2
|
||||
```
|
||||
|
||||
|
||||
## Example notebooks
|
||||
There are two Jupyter notebooks included in the `notebooks` folder.
|
||||
* embeddings_evaluation.ipynb: This notebook shows how to evaluate a model
|
||||
* visualize_embeddings.ipynb: Model embeddings visualization, optionally with embedded input video
|
||||
|
||||
|
||||
## Tracking experiments with ClearML
|
||||
The code supports tracking experiments, datasets, and models in a ClearML server.
|
||||
If you want to do this make sure to pass the following arguments to train.py:
|
||||
|
||||
```
|
||||
--dataset_loader=clearml
|
||||
--tracker=clearml
|
||||
```
|
||||
|
||||
Also make sure to correctly configure your clearml.conf file.
|
||||
If using Docker, you can map it into Docker adding these volumes when running `docker run`:
|
||||
|
||||
```
|
||||
-v $HOME/clearml.conf:/root/clearml.conf -v $HOME/.clearml:/root/.clearml
|
||||
```
|
||||
|
||||
## Model conversion
|
||||
|
||||
Follow these steps to convert your model to ONNX, TF or TFlite:
|
||||
* Install the additional dependencies listed in `conversion_requirements.txt`. This is best done inside the Docker container.
|
||||
* Run `python convert.py -c <PATH_TO_PYTORCH_CHECKPOINT>`. Add `-tf` if you want to export TensorFlow and TFlite models too.
|
||||
* The output models should be generated in a folder named `converted_models`.
|
||||
|
||||
> You can test your model's performance in a web browser. Check out the README in the [web](/web/) folder.
|
||||
|
||||
|
||||
## License
|
||||
|
||||
The **code** is published under the [Apache License 2.0](./LICENSE) which allows for both academic and commercial use if
|
||||
relevant License and copyright notice is included, our work is cited and all changes are stated.
|
||||
|
||||
The license for the [WLASL](https://arxiv.org/pdf/1910.11006.pdf) and [LSA64](https://core.ac.uk/download/pdf/76495887.pdf) datasets used for experiments is, however, the [Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)](https://creativecommons.org/licenses/by-nc/4.0/) license which allows only for non-commercial usage.
|
||||
|
||||
|
||||
[Spoter]: (https://github.com/matyasbohacek/spoter)
|
||||
[WLASL]: (https://dxli94.github.io/WLASL/)
|
||||
0
__init__.py
Normal file
0
__init__.py
Normal file
BIN
assets/accuracy.png
Normal file
BIN
assets/accuracy.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
BIN
assets/banner.png
Normal file
BIN
assets/banner.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
BIN
assets/scatter_plot.png
Normal file
BIN
assets/scatter_plot.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 123 KiB |
1
augmentations/__init__.py
Normal file
1
augmentations/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .augment import augment_arm_joint_rotate, augment_rotate, augment_shear
|
||||
228
augmentations/augment.py
Normal file
228
augmentations/augment.py
Normal file
@@ -0,0 +1,228 @@
|
||||
|
||||
import math
|
||||
import logging
|
||||
import cv2
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from normalization.body_normalization import BODY_IDENTIFIERS
|
||||
from normalization.hand_normalization import HAND_IDENTIFIERS
|
||||
|
||||
|
||||
HAND_IDENTIFIERS = [id + "_0" for id in HAND_IDENTIFIERS] + [id + "_1" for id in HAND_IDENTIFIERS]
|
||||
ARM_IDENTIFIERS_ORDER = ["neck", "$side$Shoulder", "$side$Elbow", "$side$Wrist"]
|
||||
|
||||
|
||||
def __random_pass(prob):
|
||||
return random.random() < prob
|
||||
|
||||
|
||||
def __numpy_to_dictionary(data_array: np.ndarray) -> dict:
|
||||
"""
|
||||
Supplementary method converting a NumPy array of body landmark data into dictionaries. The array data must match the
|
||||
order of the BODY_IDENTIFIERS list.
|
||||
"""
|
||||
|
||||
output = {}
|
||||
|
||||
for landmark_index, identifier in enumerate(BODY_IDENTIFIERS):
|
||||
output[identifier] = data_array[:, landmark_index].tolist()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def __dictionary_to_numpy(landmarks_dict: dict) -> np.ndarray:
|
||||
"""
|
||||
Supplementary method converting dictionaries of body landmark data into respective NumPy arrays. The resulting array
|
||||
will match the order of the BODY_IDENTIFIERS list.
|
||||
"""
|
||||
|
||||
output = np.empty(shape=(len(landmarks_dict["leftEar"]), len(BODY_IDENTIFIERS), 2))
|
||||
|
||||
for landmark_index, identifier in enumerate(BODY_IDENTIFIERS):
|
||||
output[:, landmark_index, 0] = np.array(landmarks_dict[identifier])[:, 0]
|
||||
output[:, landmark_index, 1] = np.array(landmarks_dict[identifier])[:, 1]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def __rotate(origin: tuple, point: tuple, angle: float):
|
||||
"""
|
||||
Rotates a point counterclockwise by a given angle around a given origin.
|
||||
|
||||
:param origin: Landmark in the (X, Y) format of the origin from which to count angle of rotation
|
||||
:param point: Landmark in the (X, Y) format to be rotated
|
||||
:param angle: Angle under which the point shall be rotated
|
||||
:return: New landmarks (coordinates)
|
||||
"""
|
||||
|
||||
ox, oy = origin
|
||||
px, py = point
|
||||
|
||||
qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy)
|
||||
qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy)
|
||||
|
||||
return qx, qy
|
||||
|
||||
|
||||
def __preprocess_row_sign(sign: dict) -> (dict, dict):
|
||||
"""
|
||||
Supplementary method splitting the single-dictionary skeletal data into two dictionaries of body and hand landmarks
|
||||
respectively.
|
||||
"""
|
||||
|
||||
sign_eval = sign
|
||||
|
||||
if "nose_X" in sign_eval:
|
||||
body_landmarks = {identifier: [(x, y) for x, y in zip(sign_eval[identifier + "_X"], sign_eval[identifier + "_Y"])]
|
||||
for identifier in BODY_IDENTIFIERS}
|
||||
hand_landmarks = {identifier: [(x, y) for x, y in zip(sign_eval[identifier + "_X"], sign_eval[identifier + "_Y"])]
|
||||
for identifier in HAND_IDENTIFIERS}
|
||||
|
||||
else:
|
||||
body_landmarks = {identifier: sign_eval[identifier] for identifier in BODY_IDENTIFIERS}
|
||||
hand_landmarks = {identifier: sign_eval[identifier] for identifier in HAND_IDENTIFIERS}
|
||||
|
||||
return body_landmarks, hand_landmarks
|
||||
|
||||
|
||||
def __wrap_sign_into_row(body_identifiers: dict, hand_identifiers: dict) -> dict:
|
||||
"""
|
||||
Supplementary method for merging body and hand data into a single dictionary.
|
||||
"""
|
||||
|
||||
return {**body_identifiers, **hand_identifiers}
|
||||
|
||||
|
||||
def augment_rotate(sign: dict, angle_range: tuple) -> dict:
|
||||
"""
|
||||
AUGMENTATION TECHNIQUE. All the joint coordinates in each frame are rotated by a random angle up to 13 degrees with
|
||||
the center of rotation lying in the center of the frame, which is equal to [0.5; 0.5].
|
||||
|
||||
:param sign: Dictionary with sequential skeletal data of the signing person
|
||||
:param angle_range: Tuple containing the angle range (minimal and maximal angle in degrees) to randomly choose the
|
||||
angle by which the landmarks will be rotated from
|
||||
|
||||
:return: Dictionary with augmented (by rotation) sequential skeletal data of the signing person
|
||||
"""
|
||||
|
||||
body_landmarks, hand_landmarks = __preprocess_row_sign(sign)
|
||||
angle = math.radians(random.uniform(*angle_range))
|
||||
|
||||
body_landmarks = {key: [__rotate((0.5, 0.5), frame, angle) for frame in value] for key, value in
|
||||
body_landmarks.items()}
|
||||
hand_landmarks = {key: [__rotate((0.5, 0.5), frame, angle) for frame in value] for key, value in
|
||||
hand_landmarks.items()}
|
||||
|
||||
return __wrap_sign_into_row(body_landmarks, hand_landmarks)
|
||||
|
||||
|
||||
def augment_shear(sign: dict, type: str, squeeze_ratio: tuple) -> dict:
|
||||
"""
|
||||
AUGMENTATION TECHNIQUE.
|
||||
|
||||
- Squeeze. All the frames are squeezed from both horizontal sides. Two different random proportions up to 15% of
|
||||
the original frame's width for both left and right side are cut.
|
||||
|
||||
- Perspective transformation. The joint coordinates are projected onto a new plane with a spatially defined
|
||||
center of projection, which simulates recording the sign video with a slight tilt. Each time, the right or left
|
||||
side, as well as the proportion by which both the width and height will be reduced, are chosen randomly. This
|
||||
proportion is selected from a uniform distribution on the [0; 1) interval. Subsequently, the new plane is
|
||||
delineated by reducing the width at the desired side and the respective vertical edge (height) at both of its
|
||||
adjacent corners.
|
||||
|
||||
:param sign: Dictionary with sequential skeletal data of the signing person
|
||||
:param type: Type of shear augmentation to perform (either 'squeeze' or 'perspective')
|
||||
:param squeeze_ratio: Tuple containing the relative range from what the proportion of the original width will be
|
||||
randomly chosen. These proportions will either be cut from both sides or used to construct the
|
||||
new projection
|
||||
|
||||
:return: Dictionary with augmented (by squeezing or perspective transformation) sequential skeletal data of the
|
||||
signing person
|
||||
"""
|
||||
|
||||
body_landmarks, hand_landmarks = __preprocess_row_sign(sign)
|
||||
|
||||
if type == "squeeze":
|
||||
move_left = random.uniform(*squeeze_ratio)
|
||||
move_right = random.uniform(*squeeze_ratio)
|
||||
|
||||
src = np.array(((0, 1), (1, 1), (0, 0), (1, 0)), dtype=np.float32)
|
||||
dest = np.array(((0 + move_left, 1), (1 - move_right, 1), (0 + move_left, 0), (1 - move_right, 0)),
|
||||
dtype=np.float32)
|
||||
mtx = cv2.getPerspectiveTransform(src, dest)
|
||||
|
||||
elif type == "perspective":
|
||||
|
||||
move_ratio = random.uniform(*squeeze_ratio)
|
||||
src = np.array(((0, 1), (1, 1), (0, 0), (1, 0)), dtype=np.float32)
|
||||
|
||||
if __random_pass(0.5):
|
||||
dest = np.array(((0 + move_ratio, 1 - move_ratio), (1, 1), (0 + move_ratio, 0 + move_ratio), (1, 0)),
|
||||
dtype=np.float32)
|
||||
else:
|
||||
dest = np.array(((0, 1), (1 - move_ratio, 1 - move_ratio), (0, 0), (1 - move_ratio, 0 + move_ratio)),
|
||||
dtype=np.float32)
|
||||
|
||||
mtx = cv2.getPerspectiveTransform(src, dest)
|
||||
|
||||
else:
|
||||
|
||||
logging.error("Unsupported shear type provided.")
|
||||
return {}
|
||||
|
||||
landmarks_array = __dictionary_to_numpy(body_landmarks)
|
||||
augmented_landmarks = cv2.perspectiveTransform(np.array(landmarks_array, dtype=np.float32), mtx)
|
||||
|
||||
augmented_zero_landmark = cv2.perspectiveTransform(np.array([[[0, 0]]], dtype=np.float32), mtx)[0][0]
|
||||
augmented_landmarks = np.stack([np.where(sub == augmented_zero_landmark, [0, 0], sub) for sub in augmented_landmarks])
|
||||
|
||||
body_landmarks = __numpy_to_dictionary(augmented_landmarks)
|
||||
|
||||
return __wrap_sign_into_row(body_landmarks, hand_landmarks)
|
||||
|
||||
|
||||
def augment_arm_joint_rotate(sign: dict, probability: float, angle_range: tuple) -> dict:
|
||||
"""
|
||||
AUGMENTATION TECHNIQUE. The joint coordinates of both arms are passed successively, and the impending landmark is
|
||||
slightly rotated with respect to the current one. The chance of each joint to be rotated is 3:10 and the angle of
|
||||
alternation is a uniform random angle up to +-4 degrees. This simulates slight, negligible variances in each
|
||||
execution of a sign, which do not change its semantic meaning.
|
||||
|
||||
:param sign: Dictionary with sequential skeletal data of the signing person
|
||||
:param probability: Probability of each joint to be rotated (float from the range [0, 1])
|
||||
:param angle_range: Tuple containing the angle range (minimal and maximal angle in degrees) to randomly choose the
|
||||
angle by which the landmarks will be rotated from
|
||||
|
||||
:return: Dictionary with augmented (by arm joint rotation) sequential skeletal data of the signing person
|
||||
"""
|
||||
|
||||
body_landmarks, hand_landmarks = __preprocess_row_sign(sign)
|
||||
|
||||
# Iterate over both directions (both hands)
|
||||
for side in ["left", "right"]:
|
||||
# Iterate gradually over the landmarks on arm
|
||||
for landmark_index, landmark_origin in enumerate(ARM_IDENTIFIERS_ORDER):
|
||||
landmark_origin = landmark_origin.replace("$side$", side)
|
||||
|
||||
# End the process on the current hand if the landmark is not present
|
||||
if landmark_origin not in body_landmarks:
|
||||
break
|
||||
|
||||
# Perform rotation by provided probability
|
||||
if __random_pass(probability):
|
||||
angle = math.radians(random.uniform(*angle_range))
|
||||
|
||||
for to_be_rotated in ARM_IDENTIFIERS_ORDER[landmark_index + 1:]:
|
||||
to_be_rotated = to_be_rotated.replace("$side$", side)
|
||||
|
||||
# Skip if the landmark is not present
|
||||
if to_be_rotated not in body_landmarks:
|
||||
continue
|
||||
|
||||
body_landmarks[to_be_rotated] = [__rotate(body_landmarks[landmark_origin][frame_index], frame,
|
||||
angle)
|
||||
for frame_index, frame in enumerate(body_landmarks[to_be_rotated])]
|
||||
|
||||
return __wrap_sign_into_row(body_landmarks, hand_landmarks)
|
||||
21
conversion_requirements.txt
Normal file
21
conversion_requirements.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
bokeh==2.4.3
|
||||
boto3>=1.9
|
||||
clearml==1.6.4
|
||||
ipywidgets==8.0.4
|
||||
matplotlib==3.5.3
|
||||
mediapipe==0.8.11
|
||||
notebook==6.5.2
|
||||
opencv-python==4.6.0.66
|
||||
pandas==1.1.5
|
||||
pandas==1.1.5
|
||||
plotly==5.11.0
|
||||
scikit-learn==1.0.2
|
||||
torchvision==0.13.0
|
||||
tqdm==4.54.1
|
||||
# ------
|
||||
requests==2.28.1
|
||||
onnx==1.12.0
|
||||
onnx-tf==1.10.0
|
||||
onnxruntime==1.12.1
|
||||
tensorflow
|
||||
tensorflow-probability
|
||||
123
convert.py
Normal file
123
convert.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import onnx
|
||||
import onnxruntime
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Warning: Tensorflow not installed. This is required when exporting to tflite")
|
||||
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
|
||||
def print_final_message(model_path):
|
||||
success_msg = f"\033[92mModel converted at {model_path} \033[0m"
|
||||
try:
|
||||
import requests
|
||||
joke = json.loads(requests.request("GET", "https://api.chucknorris.io/jokes/random?category=dev").text)["value"]
|
||||
print(f"{success_msg}\n\nNow go read a Chuck Norris joke:\n\033[1m{joke}\033[0m")
|
||||
except ImportError:
|
||||
print(success_msg)
|
||||
|
||||
|
||||
def convert_tf_saved_model(onnx_model, output_folder):
|
||||
from onnx_tf.backend import prepare
|
||||
tf_rep = prepare(onnx_model) # prepare tf representation
|
||||
tf_rep.export_graph(output_folder) # export the model
|
||||
|
||||
|
||||
def convert_tf_to_lite(model_dir, output_path):
|
||||
# Convert the model
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(model_dir) # path to the SavedModel directory
|
||||
# This is needed for TF Select ops: Cast, RealDiv
|
||||
converter.target_spec.supported_ops = [
|
||||
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
|
||||
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
|
||||
]
|
||||
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Save the model.
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
|
||||
def validate_tflite_output(model_path, input_data, output_array):
|
||||
interpreter = tf.lite.Interpreter(model_path=model_path)
|
||||
|
||||
output = interpreter.get_output_details()[0] # Model has single output.
|
||||
input = interpreter.get_input_details()[0] # Model has single input.
|
||||
interpreter.resize_tensor_input(input['index'], input_data.shape)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_data = tf.convert_to_tensor(input_data, np.float32)
|
||||
interpreter.set_tensor(input['index'], input_data)
|
||||
interpreter.invoke()
|
||||
|
||||
out = interpreter.get_tensor(output['index'])
|
||||
np.testing.assert_allclose(out, output_array, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def convert(checkpoint_path, export_tensorflow):
|
||||
output_folder = "converted_models"
|
||||
model = torch.load(checkpoint_path, map_location='cpu')
|
||||
model.eval()
|
||||
|
||||
# Input to the model
|
||||
x = torch.randn(1, 10, 54, 2, requires_grad=True)
|
||||
numpy_x = to_numpy(x)
|
||||
|
||||
torch_out = model(x)
|
||||
numpy_out = to_numpy(torch_out)
|
||||
|
||||
model_path = f"{output_folder}/spoter.onnx"
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
|
||||
# Export the model
|
||||
torch.onnx.export(model, # model being run
|
||||
x, # model input (or a tuple for multiple inputs)
|
||||
model_path, # where to save the model (can be a file or file-like object)
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=11, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'],
|
||||
dynamic_axes={'input': [1]}) # the model's output names
|
||||
|
||||
# Validate conversion
|
||||
onnx_model = onnx.load(model_path)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(model_path)
|
||||
|
||||
# compute ONNX Runtime output prediction
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
# compare ONNX Runtime and PyTorch results
|
||||
np.testing.assert_allclose(numpy_out, ort_outs[0], rtol=1e-03, atol=1e-05)
|
||||
|
||||
if export_tensorflow:
|
||||
saved_model_dir = f"{output_folder}/tf_saved"
|
||||
tflite_model_path = f"{output_folder}/spoter.tflite"
|
||||
convert_tf_saved_model(onnx_model, saved_model_dir)
|
||||
convert_tf_to_lite(saved_model_dir, tflite_model_path)
|
||||
validate_tflite_output(tflite_model_path, numpy_x, numpy_out)
|
||||
|
||||
print_final_message(output_folder)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--checkpoint_path', help='Checkpoint Path')
|
||||
parser.add_argument('-tf', '--export_tensorflow', help='Export Tensorflow apart from ONNX', action='store_true')
|
||||
args = parser.parse_args()
|
||||
convert(args.checkpoint_path, args.export_tensorflow)
|
||||
3
datasets/__init__.py
Normal file
3
datasets/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .czech_slr_dataset import CzechSLRDataset
|
||||
from .embedding_dataset import SLREmbeddingDataset
|
||||
from .datasets_utils import collate_fn_triplet_padd, collate_fn_padd
|
||||
8
datasets/clearml_dataset_loader.py
Normal file
8
datasets/clearml_dataset_loader.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from clearml import Dataset
|
||||
from .dataset_loader import DatasetLoader
|
||||
|
||||
|
||||
class ClearMLDatasetLoader(DatasetLoader):
|
||||
|
||||
def get_dataset_folder(self, dataset_project, dataset_name):
|
||||
return Dataset.get(dataset_project=dataset_project, dataset_name=dataset_name).get_local_copy()
|
||||
72
datasets/czech_slr_dataset.py
Normal file
72
datasets/czech_slr_dataset.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils.data as torch_data
|
||||
|
||||
from datasets.datasets_utils import load_dataset, tensor_to_dictionary, dictionary_to_tensor, \
|
||||
random_augmentation
|
||||
from normalization.body_normalization import normalize_single_dict as normalize_single_body_dict
|
||||
from normalization.hand_normalization import normalize_single_dict as normalize_single_hand_dict
|
||||
|
||||
|
||||
class CzechSLRDataset(torch_data.Dataset):
|
||||
"""Advanced object representation of the HPOES dataset for loading hand joints landmarks utilizing the Torch's
|
||||
built-in Dataset properties"""
|
||||
|
||||
data: [np.ndarray]
|
||||
labels: [np.ndarray]
|
||||
|
||||
def __init__(self, dataset_filename: str, num_labels=5, transform=None, augmentations=False,
|
||||
augmentations_prob=0.5, normalize=True):
|
||||
"""
|
||||
Initiates the HPOESDataset with the pre-loaded data from the h5 file.
|
||||
|
||||
:param dataset_filename: Path to the h5 file
|
||||
:param transform: Any data transformation to be applied (default: None)
|
||||
"""
|
||||
|
||||
loaded_data = load_dataset(dataset_filename)
|
||||
data, labels = loaded_data[0], loaded_data[1]
|
||||
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
self.targets = list(labels)
|
||||
self.num_labels = num_labels
|
||||
self.transform = transform
|
||||
|
||||
self.augmentations = augmentations
|
||||
self.augmentations_prob = augmentations_prob
|
||||
self.normalize = normalize
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Allocates, potentially transforms and returns the item at the desired index.
|
||||
|
||||
:param idx: Index of the item
|
||||
:return: Tuple containing both the depth map and the label
|
||||
"""
|
||||
|
||||
depth_map = torch.from_numpy(np.copy(self.data[idx]))
|
||||
# label = torch.Tensor([self.labels[idx] - 1])
|
||||
label = torch.Tensor([self.labels[idx]])
|
||||
|
||||
depth_map = tensor_to_dictionary(depth_map)
|
||||
|
||||
# Apply potential augmentations
|
||||
depth_map = random_augmentation(self.augmentations, self.augmentations_prob, depth_map)
|
||||
|
||||
if self.normalize:
|
||||
depth_map = normalize_single_body_dict(depth_map)
|
||||
depth_map = normalize_single_hand_dict(depth_map)
|
||||
|
||||
depth_map = dictionary_to_tensor(depth_map)
|
||||
|
||||
# Move the landmark position interval to improve performance
|
||||
depth_map = depth_map - 0.5
|
||||
|
||||
if self.transform:
|
||||
depth_map = self.transform(depth_map)
|
||||
|
||||
return depth_map, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
17
datasets/dataset_loader.py
Normal file
17
datasets/dataset_loader.py
Normal file
@@ -0,0 +1,17 @@
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class DatasetLoader():
|
||||
"""Abstract class that serves to load datasets from different sources (local, ClearML, other tracker)
|
||||
"""
|
||||
|
||||
def get_dataset_folder(self, dataset_project, dataset_name):
|
||||
return NotImplementedError()
|
||||
|
||||
|
||||
class LocalDatasetLoader(DatasetLoader):
|
||||
|
||||
def get_dataset_folder(self, dataset_project, dataset_name):
|
||||
base_folder = os.environ.get("BASE_DATA_FOLDER", "data")
|
||||
return os.path.join(base_folder, dataset_name)
|
||||
133
datasets/datasets_utils.py
Normal file
133
datasets/datasets_utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import pandas as pd
|
||||
import ast
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from random import randrange
|
||||
|
||||
from augmentations import augment_arm_joint_rotate, augment_rotate, augment_shear
|
||||
from normalization.body_normalization import BODY_IDENTIFIERS
|
||||
from augmentations.augment import HAND_IDENTIFIERS
|
||||
|
||||
|
||||
def load_dataset(file_location: str):
|
||||
|
||||
# Load the datset csv file
|
||||
df = pd.read_csv(file_location, encoding="utf-8")
|
||||
df.columns = [item.replace("_left_", "_0_").replace("_right_", "_1_") for item in list(df.columns)]
|
||||
|
||||
# TEMP
|
||||
labels = df["labels"].to_list()
|
||||
|
||||
data = []
|
||||
|
||||
for row_index, row in df.iterrows():
|
||||
current_row = np.empty(shape=(len(ast.literal_eval(row["leftEar_X"])),
|
||||
len(BODY_IDENTIFIERS + HAND_IDENTIFIERS),
|
||||
2)
|
||||
)
|
||||
for index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS):
|
||||
current_row[:, index, 0] = ast.literal_eval(row[identifier + "_X"])
|
||||
current_row[:, index, 1] = ast.literal_eval(row[identifier + "_Y"])
|
||||
|
||||
data.append(current_row)
|
||||
|
||||
return data, labels
|
||||
|
||||
|
||||
def tensor_to_dictionary(landmarks_tensor: torch.Tensor) -> dict:
|
||||
|
||||
data_array = landmarks_tensor.numpy()
|
||||
output = {}
|
||||
|
||||
for landmark_index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS):
|
||||
output[identifier] = data_array[:, landmark_index]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def dictionary_to_tensor(landmarks_dict: dict) -> torch.Tensor:
|
||||
|
||||
output = np.empty(shape=(len(landmarks_dict["leftEar"]), len(BODY_IDENTIFIERS + HAND_IDENTIFIERS), 2))
|
||||
|
||||
for landmark_index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS):
|
||||
output[:, landmark_index, 0] = [frame[0] for frame in landmarks_dict[identifier]]
|
||||
output[:, landmark_index, 1] = [frame[1] for frame in landmarks_dict[identifier]]
|
||||
|
||||
return torch.from_numpy(output)
|
||||
|
||||
|
||||
def random_augmentation(augmentations, augmentations_prob, depth_map):
|
||||
if augmentations and random.random() < augmentations_prob:
|
||||
selected_aug = randrange(4)
|
||||
if selected_aug == 0:
|
||||
depth_map = augment_arm_joint_rotate(depth_map, 0.3, (-4, 4))
|
||||
elif selected_aug == 1:
|
||||
depth_map = augment_shear(depth_map, "perspective", (0, 0.1))
|
||||
elif selected_aug == 2:
|
||||
depth_map = augment_shear(depth_map, "squeeze", (0, 0.15))
|
||||
elif selected_aug == 3:
|
||||
depth_map = augment_rotate(depth_map, (-13, 13))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
def collate_fn_triplet_padd(batch):
|
||||
'''
|
||||
Padds batch of variable length
|
||||
|
||||
note: it converts things ToTensor manually here since the ToTensor transform
|
||||
assume it takes in images rather than arbitrary tensors.
|
||||
'''
|
||||
# batch: list of length batch_size, each element contains ouput of dataset
|
||||
# MASKING
|
||||
anchor_lengths = [element[0].shape[0] for element in batch]
|
||||
max_anchor_l = max(anchor_lengths)
|
||||
positive_lengths = [element[1].shape[0] for element in batch]
|
||||
max_positive_l = max(positive_lengths)
|
||||
negative_lengths = [element[2].shape[0] for element in batch]
|
||||
max_negative_l = max(negative_lengths)
|
||||
|
||||
anchor_mask = [[False] * anchor_lengths[n] + [True] * (max_anchor_l - anchor_lengths[n])
|
||||
for n in range(len(batch))]
|
||||
positive_mask = [[False] * positive_lengths[n] + [True] * (max_positive_l - positive_lengths[n])
|
||||
for n in range(len(batch))]
|
||||
negative_mask = [[False] * negative_lengths[n] + [True] * (max_negative_l - negative_lengths[n])
|
||||
for n in range(len(batch))]
|
||||
|
||||
# PADDING
|
||||
anchor_batch = [element[0] for element in batch]
|
||||
positive_batch = [element[1] for element in batch]
|
||||
negative_batch = [element[2] for element in batch]
|
||||
|
||||
anchor_batch = pad_sequence(anchor_batch, batch_first=True)
|
||||
positive_batch = pad_sequence(positive_batch, batch_first=True)
|
||||
negative_batch = pad_sequence(negative_batch, batch_first=True)
|
||||
|
||||
return anchor_batch, positive_batch, negative_batch, \
|
||||
torch.Tensor(anchor_mask), torch.Tensor(positive_mask), torch.Tensor(negative_mask)
|
||||
|
||||
|
||||
def collate_fn_padd(batch):
|
||||
'''
|
||||
Padds batch of variable length
|
||||
|
||||
note: it converts things ToTensor manually here since the ToTensor transform
|
||||
assume it takes in images rather than arbitrary tensors.
|
||||
'''
|
||||
# batch: list of length batch_size, each element contains ouput of dataset
|
||||
# MASKING
|
||||
anchor_lengths = [element[0].shape[0] for element in batch]
|
||||
max_anchor_l = max(anchor_lengths)
|
||||
|
||||
anchor_mask = [[False] * anchor_lengths[n] + [True] * (max_anchor_l - anchor_lengths[n])
|
||||
for n in range(len(batch))]
|
||||
|
||||
# PADDING
|
||||
anchor_batch = [element[0] for element in batch]
|
||||
anchor_batch = pad_sequence(anchor_batch, batch_first=True)
|
||||
|
||||
labels = torch.Tensor([element[1] for element in batch])
|
||||
|
||||
return anchor_batch, labels, torch.Tensor(anchor_mask)
|
||||
103
datasets/embedding_dataset.py
Normal file
103
datasets/embedding_dataset.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import torch.utils.data as torch_data
|
||||
from random import sample
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
from datasets.datasets_utils import load_dataset, tensor_to_dictionary, dictionary_to_tensor, \
|
||||
random_augmentation
|
||||
from normalization.body_normalization import normalize_single_dict as normalize_single_body_dict
|
||||
from normalization.hand_normalization import normalize_single_dict as normalize_single_hand_dict
|
||||
|
||||
|
||||
class SLREmbeddingDataset(torch_data.Dataset):
|
||||
"""Advanced object representation of the WLASL dataset for loading triplet used in triplet loss utilizing the
|
||||
Torch's built-in Dataset properties"""
|
||||
|
||||
data: List[np.ndarray]
|
||||
labels: List[np.ndarray]
|
||||
|
||||
def __init__(self, dataset_filename: str, triplet=True, transform=None, augmentations=False,
|
||||
augmentations_prob=0.5, normalize=True):
|
||||
"""
|
||||
Initiates the HPOESDataset with the pre-loaded data from the h5 file.
|
||||
|
||||
:param dataset_filename: Path to the h5 file
|
||||
:param transform: Any data transformation to be applied (default: None)
|
||||
"""
|
||||
|
||||
loaded_data = load_dataset(dataset_filename)
|
||||
data, labels = loaded_data[0], loaded_data[1]
|
||||
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
self.targets = list(labels)
|
||||
self.transform = transform
|
||||
self.triplet = triplet
|
||||
self.augmentations = augmentations
|
||||
self.augmentations_prob = augmentations_prob
|
||||
self.normalize = normalize
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Allocates, potentially transforms and returns the item at the desired index.
|
||||
|
||||
:param idx: Index of the item
|
||||
:return: Tuple containing both the depth map and the label
|
||||
"""
|
||||
depth_map_a = torch.from_numpy(np.copy(self.data[idx]))
|
||||
label = torch.Tensor([self.labels[idx]])
|
||||
|
||||
depth_map_a = tensor_to_dictionary(depth_map_a)
|
||||
|
||||
if self.triplet:
|
||||
positive_indexes = list(np.where(np.array(self.labels) == self.labels[idx])[0])
|
||||
positive_index_sample = sample(positive_indexes, 2)
|
||||
positive_index = positive_index_sample[0] if positive_index_sample[0] != idx else positive_index_sample[1]
|
||||
negative_indexes = list(np.where(np.array(self.labels) != self.labels[idx])[0])
|
||||
negative_index = sample(negative_indexes, 1)[0]
|
||||
# TODO: implement hard triplets
|
||||
|
||||
depth_map_p = torch.from_numpy(np.copy(self.data[positive_index]))
|
||||
depth_map_n = torch.from_numpy(np.copy(self.data[negative_index]))
|
||||
|
||||
depth_map_p = tensor_to_dictionary(depth_map_p)
|
||||
depth_map_n = tensor_to_dictionary(depth_map_n)
|
||||
|
||||
# TODO: Add Data augmentation to positive and negative ?
|
||||
|
||||
# Apply potential augmentations
|
||||
depth_map_a = random_augmentation(self.augmentations, self.augmentations_prob, depth_map_a)
|
||||
|
||||
if self.normalize:
|
||||
depth_map_a = normalize_single_body_dict(depth_map_a)
|
||||
depth_map_a = normalize_single_hand_dict(depth_map_a)
|
||||
if self.triplet:
|
||||
depth_map_p = normalize_single_body_dict(depth_map_p)
|
||||
depth_map_p = normalize_single_hand_dict(depth_map_p)
|
||||
depth_map_n = normalize_single_body_dict(depth_map_n)
|
||||
depth_map_n = normalize_single_hand_dict(depth_map_n)
|
||||
|
||||
depth_map_a = dictionary_to_tensor(depth_map_a)
|
||||
# Move the landmark position interval to improve performance
|
||||
depth_map_a = depth_map_a - 0.5
|
||||
|
||||
if self.triplet:
|
||||
depth_map_p = dictionary_to_tensor(depth_map_p)
|
||||
depth_map_p = depth_map_p - 0.5
|
||||
depth_map_n = dictionary_to_tensor(depth_map_n)
|
||||
depth_map_n = depth_map_n - 0.5
|
||||
|
||||
if self.transform:
|
||||
depth_map_a = self.transform(depth_map_a)
|
||||
if self.triplet:
|
||||
depth_map_p = self.transform(depth_map_p)
|
||||
depth_map_n = self.transform(depth_map_n)
|
||||
|
||||
if self.triplet:
|
||||
return depth_map_a, depth_map_p, depth_map_n
|
||||
|
||||
return depth_map_a, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
4
models/__init__.py
Normal file
4
models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .spoter_model import SPOTER
|
||||
from .spoter_embedding_model import SPOTER_EMBEDDINGS
|
||||
from .utils import train_epoch, evaluate, evaluate_top_k, train_epoch_embedding, train_epoch_embedding_online, \
|
||||
evaluate_embedding, embeddings_scatter_plot, embeddings_scatter_plot_splits
|
||||
41
models/spoter_embedding_model.py
Normal file
41
models/spoter_embedding_model.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.spoter_model import _get_clones, SPOTERTransformerDecoderLayer
|
||||
|
||||
|
||||
class SPOTER_EMBEDDINGS(nn.Module):
|
||||
"""
|
||||
Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence
|
||||
of skeletal data.
|
||||
"""
|
||||
|
||||
def __init__(self, features, hidden_dim=108, nhead=9, num_encoder_layers=6, num_decoder_layers=6,
|
||||
norm_emb=False, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.pos_encoding = nn.Parameter(torch.rand(1, 1, hidden_dim)) # init positional encoding
|
||||
self.class_query = nn.Parameter(torch.rand(1, 1, hidden_dim))
|
||||
self.transformer = nn.Transformer(hidden_dim, nhead, num_encoder_layers, num_decoder_layers, dropout=dropout)
|
||||
self.linear_embed = nn.Linear(hidden_dim, features)
|
||||
|
||||
# Deactivate the initial attention decoder mechanism
|
||||
custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048,
|
||||
dropout, "relu")
|
||||
self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers)
|
||||
self.norm_emb = norm_emb
|
||||
|
||||
def forward(self, inputs, src_masks=None):
|
||||
|
||||
h = torch.transpose(inputs.flatten(start_dim=2), 1, 0).float()
|
||||
h = self.transformer(
|
||||
self.pos_encoding.repeat(1, h.shape[1], 1) + h,
|
||||
self.class_query.repeat(1, h.shape[1], 1),
|
||||
src_key_padding_mask=src_masks
|
||||
).transpose(0, 1)
|
||||
embedding = self.linear_embed(h)
|
||||
|
||||
if self.norm_emb:
|
||||
embedding = nn.functional.normalize(embedding, dim=2)
|
||||
|
||||
return embedding
|
||||
66
models/spoter_model.py
Normal file
66
models/spoter_model.py
Normal file
@@ -0,0 +1,66 @@
|
||||
|
||||
import copy
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _get_clones(mod, n):
|
||||
return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)])
|
||||
|
||||
|
||||
class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer):
|
||||
"""
|
||||
Edited TransformerDecoderLayer implementation omitting the redundant self-attention operation as opposed to the
|
||||
standard implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
|
||||
super(SPOTERTransformerDecoderLayer, self).__init__(d_model, nhead, dim_feedforward, dropout, activation)
|
||||
|
||||
del self.self_attn
|
||||
|
||||
def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None,
|
||||
memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
tgt = tgt + self.dropout1(tgt)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
|
||||
return tgt
|
||||
|
||||
|
||||
class SPOTER(nn.Module):
|
||||
"""
|
||||
Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence
|
||||
of skeletal data.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, hidden_dim=55):
|
||||
super().__init__()
|
||||
|
||||
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
|
||||
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
|
||||
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
|
||||
self.transformer = nn.Transformer(hidden_dim, 9, 6, 6)
|
||||
self.linear_class = nn.Linear(hidden_dim, num_classes)
|
||||
|
||||
# Deactivate the initial attention decoder mechanism
|
||||
custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048,
|
||||
0.1, "relu")
|
||||
self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
|
||||
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
||||
res = self.linear_class(h)
|
||||
|
||||
return res
|
||||
280
models/utils.py
Normal file
280
models/utils.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import silhouette_score
|
||||
from sklearn.manifold import TSNE
|
||||
from training.batch_sorter import sort_batches
|
||||
from utils import get_logger
|
||||
|
||||
|
||||
def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None):
|
||||
|
||||
pred_correct, pred_all = 0, 0
|
||||
running_loss = 0.0
|
||||
model.train(True)
|
||||
for i, data in enumerate(dataloader):
|
||||
inputs, labels = data
|
||||
inputs = inputs.squeeze(0).to(device)
|
||||
labels = labels.to(device, dtype=torch.long)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs).expand(1, -1, -1)
|
||||
loss = criterion(outputs[0], labels[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss.item()
|
||||
|
||||
# Statistics
|
||||
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]):
|
||||
pred_correct += 1
|
||||
pred_all += 1
|
||||
|
||||
epoch_loss = running_loss / len(dataloader)
|
||||
model.train(False)
|
||||
if scheduler:
|
||||
scheduler.step(epoch_loss)
|
||||
|
||||
return epoch_loss, pred_correct, pred_all, (pred_correct / pred_all)
|
||||
|
||||
|
||||
def train_epoch_embedding(model, epoch_iters, train_loader, val_loader, criterion, optimizer, device, scheduler=None):
|
||||
|
||||
running_loss = []
|
||||
model.train(True)
|
||||
for i, (anchor, positive, negative, a_mask, p_mask, n_mask) in enumerate(train_loader):
|
||||
optimizer.zero_grad()
|
||||
|
||||
anchor_emb = model(anchor.to(device), a_mask.to(device))
|
||||
positive_emb = model(positive.to(device), p_mask.to(device))
|
||||
negative_emb = model(negative.to(device), n_mask.to(device))
|
||||
|
||||
loss = criterion(anchor_emb.to(device), positive_emb.to(device), negative_emb.to(device))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss.append(loss.item())
|
||||
|
||||
if i == epoch_iters:
|
||||
break
|
||||
|
||||
epoch_loss = np.mean(running_loss)
|
||||
|
||||
# VALIDATION
|
||||
model.train(False)
|
||||
val_silhouette_coef = evaluate_embedding(model, val_loader, device)
|
||||
|
||||
if scheduler:
|
||||
scheduler.step(val_silhouette_coef)
|
||||
|
||||
return epoch_loss, val_silhouette_coef
|
||||
|
||||
|
||||
def train_epoch_embedding_online(model, epoch_iters, train_loader, val_loader, criterion, optimizer, device,
|
||||
scheduler=None, enable_batch_sorting=False, mini_batch_size=None,
|
||||
pre_batch_mining_count=1, batching_scheduler=None):
|
||||
|
||||
running_loss = []
|
||||
iter_used_triplets = []
|
||||
iter_valid_triplets = []
|
||||
iter_pct_used = []
|
||||
model.train(True)
|
||||
mini_batch = mini_batch_size or train_loader.batch_size
|
||||
for i, (inputs, labels, masks) in enumerate(train_loader):
|
||||
labels_size = labels.size()[0]
|
||||
batch_loop_count = int(labels_size / mini_batch)
|
||||
if batch_loop_count == 0:
|
||||
continue
|
||||
# Second condition is added so that we only run batch sorting if we have a full batch
|
||||
if enable_batch_sorting:
|
||||
if labels_size < train_loader.batch_size:
|
||||
trim_count = labels_size % mini_batch
|
||||
inputs = inputs[:-trim_count]
|
||||
labels = labels[:-trim_count]
|
||||
masks = masks[:-trim_count]
|
||||
embeddings = None
|
||||
with torch.no_grad():
|
||||
for j in range(batch_loop_count):
|
||||
batch_embed = compute_batched_embeddings(model, device, inputs, masks, mini_batch, j)
|
||||
if embeddings is None:
|
||||
embeddings = batch_embed
|
||||
else:
|
||||
embeddings = torch.cat([embeddings, batch_embed], dim=0)
|
||||
inputs, labels, masks = sort_batches(inputs, labels, masks, embeddings, device,
|
||||
mini_batch_size=mini_batch_size, scheduler=batching_scheduler)
|
||||
del embeddings
|
||||
del batch_embed
|
||||
mining_loop_count = pre_batch_mining_count
|
||||
else:
|
||||
mining_loop_count = 1
|
||||
for k in range(mining_loop_count):
|
||||
for j in range(batch_loop_count):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
batch_labels = labels[mini_batch * j:mini_batch * (j + 1)]
|
||||
if batch_labels.size()[0] == 0:
|
||||
break
|
||||
embeddings = compute_batched_embeddings(model, device, inputs, masks, mini_batch, j)
|
||||
loss, valid_triplets, used_triplets = criterion(embeddings, batch_labels)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss.append(loss.item())
|
||||
if valid_triplets > 0:
|
||||
iter_used_triplets.append(used_triplets)
|
||||
iter_valid_triplets.append(valid_triplets)
|
||||
iter_pct_used.append((used_triplets * 100) / valid_triplets)
|
||||
|
||||
if epoch_iters > 0 and i * batch_loop_count * pre_batch_mining_count >= epoch_iters:
|
||||
print("Breaking out because of epoch_iters filter")
|
||||
break
|
||||
|
||||
epoch_loss = np.mean(running_loss)
|
||||
mean_used_triplets = np.mean(iter_used_triplets)
|
||||
triplets_stats = {
|
||||
"valid_triplets": np.mean(iter_valid_triplets),
|
||||
"used_triplets": mean_used_triplets,
|
||||
"pct_used": np.mean(iter_pct_used)
|
||||
}
|
||||
|
||||
if batching_scheduler:
|
||||
batching_scheduler.step(mean_used_triplets)
|
||||
|
||||
# VALIDATION
|
||||
model.train(False)
|
||||
with torch.no_grad():
|
||||
val_silhouette_coef = evaluate_embedding(model, val_loader, device)
|
||||
|
||||
if scheduler:
|
||||
scheduler.step(val_silhouette_coef)
|
||||
|
||||
return epoch_loss, val_silhouette_coef, triplets_stats
|
||||
|
||||
|
||||
def compute_batched_embeddings(model, device, inputs, masks, mini_batch, iteration):
|
||||
batch_inputs = inputs[mini_batch * iteration:mini_batch * (iteration + 1)]
|
||||
batch_masks = masks[mini_batch * iteration:mini_batch * (iteration + 1)]
|
||||
|
||||
return model(batch_inputs.to(device), batch_masks.to(device)).squeeze(1)
|
||||
|
||||
|
||||
def evaluate(model, dataloader, device, print_stats=False):
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
pred_correct, pred_all = 0, 0
|
||||
stats = {i: [0, 0] for i in range(101)}
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
inputs, labels = data
|
||||
inputs = inputs.squeeze(0).to(device)
|
||||
labels = labels.to(device, dtype=torch.long)
|
||||
|
||||
outputs = model(inputs).expand(1, -1, -1)
|
||||
|
||||
# Statistics
|
||||
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]):
|
||||
stats[int(labels[0][0])][0] += 1
|
||||
pred_correct += 1
|
||||
|
||||
stats[int(labels[0][0])][1] += 1
|
||||
pred_all += 1
|
||||
|
||||
if print_stats:
|
||||
stats = {key: value[0] / value[1] for key, value in stats.items() if value[1] != 0}
|
||||
print("Label accuracies statistics:")
|
||||
print(str(stats) + "\n")
|
||||
logger.info("Label accuracies statistics:")
|
||||
logger.info(str(stats) + "\n")
|
||||
|
||||
return pred_correct, pred_all, (pred_correct / pred_all)
|
||||
|
||||
|
||||
def evaluate_embedding(model, dataloader, device):
|
||||
val_embeddings = []
|
||||
labels_emb = []
|
||||
|
||||
for i, (inputs, labels, masks) in enumerate(dataloader):
|
||||
inputs = inputs.to(device)
|
||||
masks = masks.to(device)
|
||||
|
||||
outputs = model(inputs, masks)
|
||||
for n in range(outputs.shape[0]):
|
||||
val_embeddings.append(outputs[n, 0].cpu().detach().numpy())
|
||||
labels_emb.append(labels.detach().numpy()[n])
|
||||
|
||||
silhouette_coefficient = silhouette_score(
|
||||
X=np.array(val_embeddings),
|
||||
labels=np.array(labels_emb).reshape(len(labels_emb))
|
||||
)
|
||||
|
||||
return silhouette_coefficient
|
||||
|
||||
|
||||
def embeddings_scatter_plot(model, dataloader, device, id_to_label, perplexity=40, n_iter=1000):
|
||||
|
||||
val_embeddings = []
|
||||
labels_emb = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (inputs, labels, masks) in enumerate(dataloader):
|
||||
inputs = inputs.to(device)
|
||||
masks = masks.to(device)
|
||||
|
||||
outputs = model(inputs, masks)
|
||||
for n in range(outputs.shape[0]):
|
||||
val_embeddings.append(outputs[n, 0].cpu().detach().numpy())
|
||||
labels_emb.append(id_to_label[int(labels.detach().numpy()[n])])
|
||||
|
||||
tsne = TSNE(n_components=2, verbose=0, perplexity=perplexity, n_iter=n_iter)
|
||||
tsne_results = tsne.fit_transform(np.array(val_embeddings))
|
||||
|
||||
return tsne_results, labels_emb
|
||||
|
||||
|
||||
def embeddings_scatter_plot_splits(model, dataloaders, device, id_to_label, perplexity=40, n_iter=1000):
|
||||
|
||||
labels_split = {}
|
||||
embeddings_split = {}
|
||||
splits = list(dataloaders.keys())
|
||||
with torch.no_grad():
|
||||
for split, dataloader in dataloaders.items():
|
||||
labels_str = []
|
||||
embeddings = []
|
||||
for i, (inputs, labels, masks) in enumerate(dataloader):
|
||||
inputs = inputs.to(device)
|
||||
masks = masks.to(device)
|
||||
|
||||
outputs = model(inputs, masks)
|
||||
for n in range(outputs.shape[0]):
|
||||
embeddings.append(outputs[n, 0].cpu().detach().numpy())
|
||||
labels_str.append(id_to_label[int(labels.detach().numpy()[n])])
|
||||
labels_split[split] = labels_str
|
||||
embeddings_split[split] = embeddings
|
||||
|
||||
tsne = TSNE(n_components=2, verbose=0, perplexity=perplexity, n_iter=n_iter)
|
||||
all_embeddings = np.vstack([embeddings_split[split] for split in splits])
|
||||
tsne_results = tsne.fit_transform(all_embeddings)
|
||||
tsne_results_dict = {}
|
||||
curr_index = 0
|
||||
for split in splits:
|
||||
len_embeddings = len(embeddings_split[split])
|
||||
tsne_results_dict[split] = tsne_results[curr_index: curr_index + len_embeddings]
|
||||
curr_index += len_embeddings
|
||||
|
||||
return tsne_results_dict, labels_split
|
||||
|
||||
|
||||
def evaluate_top_k(model, dataloader, device, k=5):
|
||||
|
||||
pred_correct, pred_all = 0, 0
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
inputs, labels = data
|
||||
inputs = inputs.squeeze(0).to(device)
|
||||
labels = labels.to(device, dtype=torch.long)
|
||||
|
||||
outputs = model(inputs).expand(1, -1, -1)
|
||||
|
||||
if int(labels[0][0]) in torch.topk(outputs, k).indices.tolist()[0][0]:
|
||||
pred_correct += 1
|
||||
|
||||
pred_all += 1
|
||||
|
||||
return pred_correct, pred_all, (pred_correct / pred_all)
|
||||
92
normalization/blazepose_mapping.py
Normal file
92
normalization/blazepose_mapping.py
Normal file
@@ -0,0 +1,92 @@
|
||||
|
||||
_BODY_KEYPOINT_MAPPING = {
|
||||
"nose": "nose",
|
||||
"left_eye": "leftEye",
|
||||
"right_eye": "rightEye",
|
||||
"left_ear": "leftEar",
|
||||
"right_ear": "rightEar",
|
||||
"left_shoulder": "leftShoulder",
|
||||
"right_shoulder": "rightShoulder",
|
||||
"left_elbow": "leftElbow",
|
||||
"right_elbow": "rightElbow",
|
||||
"left_wrist": "leftWrist",
|
||||
"right_wrist": "rightWrist"
|
||||
}
|
||||
|
||||
_HAND_KEYPOINT_MAPPING = {
|
||||
"wrist": "wrist",
|
||||
"index_finger_tip": "indexTip",
|
||||
"index_finger_dip": "indexDIP",
|
||||
"index_finger_pip": "indexPIP",
|
||||
"index_finger_mcp": "indexMCP",
|
||||
"middle_finger_tip": "middleTip",
|
||||
"middle_finger_dip": "middleDIP",
|
||||
"middle_finger_pip": "middlePIP",
|
||||
"middle_finger_mcp": "middleMCP",
|
||||
"ring_finger_tip": "ringTip",
|
||||
"ring_finger_dip": "ringDIP",
|
||||
"ring_finger_pip": "ringPIP",
|
||||
"ring_finger_mcp": "ringMCP",
|
||||
"pinky_tip": "littleTip",
|
||||
"pinky_dip": "littleDIP",
|
||||
"pinky_pip": "littlePIP",
|
||||
"pinky_mcp": "littleMCP",
|
||||
"thumb_tip": "thumbTip",
|
||||
"thumb_ip": "thumbIP",
|
||||
"thumb_mcp": "thumbMP",
|
||||
"thumb_cmc": "thumbCMC"
|
||||
}
|
||||
|
||||
|
||||
def map_blazepose_keypoint(column):
|
||||
# Remove _x, _y suffixes
|
||||
suffix = column[-2:].upper()
|
||||
column = column[:-2]
|
||||
|
||||
if column.startswith("left_hand_"):
|
||||
hand = "left"
|
||||
finger_name = column[10:]
|
||||
elif column.startswith("right_hand_"):
|
||||
hand = "right"
|
||||
finger_name = column[11:]
|
||||
else:
|
||||
if column not in _BODY_KEYPOINT_MAPPING:
|
||||
return None
|
||||
mapped = _BODY_KEYPOINT_MAPPING[column]
|
||||
return mapped + suffix
|
||||
|
||||
if finger_name not in _HAND_KEYPOINT_MAPPING:
|
||||
return None
|
||||
mapped = _HAND_KEYPOINT_MAPPING[finger_name]
|
||||
return f"{mapped}_{hand}{suffix}"
|
||||
|
||||
|
||||
def map_blazepose_df(df):
|
||||
to_drop = []
|
||||
renamings = {}
|
||||
for column in df.columns:
|
||||
mapped_column = map_blazepose_keypoint(column)
|
||||
if mapped_column:
|
||||
renamings[column] = mapped_column
|
||||
else:
|
||||
to_drop.append(column)
|
||||
df = df.rename(columns=renamings)
|
||||
|
||||
for index, row in df.iterrows():
|
||||
|
||||
sequence_size = len(row["leftEar_Y"])
|
||||
lsx = row["leftShoulder_X"]
|
||||
rsx = row["rightShoulder_X"]
|
||||
lsy = row["leftShoulder_Y"]
|
||||
rsy = row["rightShoulder_Y"]
|
||||
neck_x = []
|
||||
neck_y = []
|
||||
# Treat each element of the sequence (analyzed frame) individually
|
||||
for sequence_index in range(sequence_size):
|
||||
neck_x.append((float(lsx[sequence_index]) + float(rsx[sequence_index])) / 2)
|
||||
neck_y.append((float(lsy[sequence_index]) + float(rsy[sequence_index])) / 2)
|
||||
df.loc[index, "neck_X"] = str(neck_x)
|
||||
df.loc[index, "neck_Y"] = str(neck_y)
|
||||
|
||||
df.drop(columns=to_drop, inplace=True)
|
||||
return df
|
||||
241
normalization/body_normalization.py
Normal file
241
normalization/body_normalization.py
Normal file
@@ -0,0 +1,241 @@
|
||||
|
||||
from typing import Tuple
|
||||
import pandas as pd
|
||||
from utils import get_logger
|
||||
|
||||
|
||||
BODY_IDENTIFIERS = [
|
||||
"nose",
|
||||
"neck",
|
||||
"rightEye",
|
||||
"leftEye",
|
||||
"rightEar",
|
||||
"leftEar",
|
||||
"rightShoulder",
|
||||
"leftShoulder",
|
||||
"rightElbow",
|
||||
"leftElbow",
|
||||
"rightWrist",
|
||||
"leftWrist"
|
||||
]
|
||||
|
||||
|
||||
def normalize_body_full(df: pd.DataFrame) -> Tuple[pd.DataFrame, list]:
|
||||
"""
|
||||
Normalizes the body position data using the Bohacek-normalization algorithm.
|
||||
|
||||
:param df: pd.DataFrame to be normalized
|
||||
:return: pd.DataFrame with normalized values for body pose
|
||||
"""
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# TODO: Fix division by zero
|
||||
|
||||
normalized_df = pd.DataFrame(columns=df.columns)
|
||||
invalid_row_indexes = []
|
||||
body_landmarks = {"X": [], "Y": []}
|
||||
|
||||
# Construct the relevant identifiers
|
||||
for identifier in BODY_IDENTIFIERS:
|
||||
body_landmarks["X"].append(identifier + "_X")
|
||||
body_landmarks["Y"].append(identifier + "_Y")
|
||||
|
||||
# Iterate over all of the records in the dataset
|
||||
for index, row in df.iterrows():
|
||||
|
||||
sequence_size = len(row["leftEar_Y"])
|
||||
valid_sequence = True
|
||||
original_row = row
|
||||
|
||||
last_starting_point, last_ending_point = None, None
|
||||
|
||||
# Treat each element of the sequence (analyzed frame) individually
|
||||
for sequence_index in range(sequence_size):
|
||||
|
||||
# Prevent from even starting the analysis if some necessary elements are not present
|
||||
if (row["leftShoulder_X"][sequence_index] == 0 or row["rightShoulder_X"][sequence_index] == 0) and \
|
||||
(row["neck_X"][sequence_index] == 0 or row["nose_X"][sequence_index] == 0):
|
||||
if not last_starting_point:
|
||||
valid_sequence = False
|
||||
continue
|
||||
|
||||
else:
|
||||
starting_point, ending_point = last_starting_point, last_ending_point
|
||||
|
||||
else:
|
||||
|
||||
# NOTE:
|
||||
#
|
||||
# While in the paper, it is written that the head metric is calculated by halving the shoulder distance,
|
||||
# this is meant for the distance between the very ends of one's shoulder, as literature studying body
|
||||
# metrics and ratios generally states. The Vision Pose Estimation API, however, seems to be predicting
|
||||
# rather the center of one's shoulder. Based on our experiments and manual reviews of the data,
|
||||
# employing
|
||||
# this as just the plain shoulder distance seems to be more corresponding to the desired metric.
|
||||
#
|
||||
# Please, review this if using other third-party pose estimation libraries.
|
||||
|
||||
if row["leftShoulder_X"][sequence_index] != 0 and row["rightShoulder_X"][sequence_index] != 0:
|
||||
left_shoulder = (row["leftShoulder_X"][sequence_index], row["leftShoulder_Y"][sequence_index])
|
||||
right_shoulder = (row["rightShoulder_X"][sequence_index], row["rightShoulder_Y"][sequence_index])
|
||||
shoulder_distance = ((((left_shoulder[0] - right_shoulder[0]) ** 2) + (
|
||||
(left_shoulder[1] - right_shoulder[1]) ** 2)) ** 0.5)
|
||||
head_metric = shoulder_distance
|
||||
else:
|
||||
neck = (row["neck_X"][sequence_index], row["neck_Y"][sequence_index])
|
||||
nose = (row["nose_X"][sequence_index], row["nose_Y"][sequence_index])
|
||||
neck_nose_distance = ((((neck[0] - nose[0]) ** 2) + ((neck[1] - nose[1]) ** 2)) ** 0.5)
|
||||
head_metric = neck_nose_distance
|
||||
|
||||
# Set the starting and ending point of the normalization bounding box
|
||||
starting_point = [row["neck_X"][sequence_index] - 3 * head_metric,
|
||||
row["leftEye_Y"][sequence_index] + (head_metric / 2)]
|
||||
ending_point = [row["neck_X"][sequence_index] + 3 * head_metric, starting_point[1] - 6 * head_metric]
|
||||
|
||||
last_starting_point, last_ending_point = starting_point, ending_point
|
||||
|
||||
# Ensure that all of the bounding-box-defining coordinates are not out of the picture
|
||||
if starting_point[0] < 0:
|
||||
starting_point[0] = 0
|
||||
if starting_point[1] < 0:
|
||||
starting_point[1] = 0
|
||||
if ending_point[0] < 0:
|
||||
ending_point[0] = 0
|
||||
if ending_point[1] < 0:
|
||||
ending_point[1] = 0
|
||||
|
||||
# Normalize individual landmarks and save the results
|
||||
for identifier in BODY_IDENTIFIERS:
|
||||
key = identifier + "_"
|
||||
|
||||
# Prevent from trying to normalize incorrectly captured points
|
||||
if row[key + "X"][sequence_index] == 0:
|
||||
continue
|
||||
|
||||
normalized_x = (row[key + "X"][sequence_index] - starting_point[0]) / (ending_point[0] -
|
||||
starting_point[0])
|
||||
normalized_y = (row[key + "Y"][sequence_index] - ending_point[1]) / (starting_point[1] -
|
||||
ending_point[1])
|
||||
|
||||
row[key + "X"][sequence_index] = normalized_x
|
||||
row[key + "Y"][sequence_index] = normalized_y
|
||||
|
||||
if valid_sequence:
|
||||
normalized_df = normalized_df.append(row, ignore_index=True)
|
||||
else:
|
||||
logger.warning(" BODY LANDMARKS: One video instance could not be normalized.")
|
||||
normalized_df = normalized_df.append(original_row, ignore_index=True)
|
||||
invalid_row_indexes.append(index)
|
||||
|
||||
logger.info("The normalization of body is finished.")
|
||||
logger.info("\t-> Original size:", df.shape[0])
|
||||
logger.info("\t-> Normalized size:", normalized_df.shape[0])
|
||||
logger.info("\t-> Problematic videos:", len(invalid_row_indexes))
|
||||
|
||||
return normalized_df, invalid_row_indexes
|
||||
|
||||
|
||||
def normalize_single_dict(row: dict):
|
||||
"""
|
||||
Normalizes the skeletal data for a given sequence of frames with signer's body pose data. The normalization follows
|
||||
the definition from our paper.
|
||||
|
||||
:param row: Dictionary containing key-value pairs with joint identifiers and corresponding lists (sequences) of
|
||||
that particular joints coordinates
|
||||
:return: Dictionary with normalized skeletal data (following the same schema as input data)
|
||||
"""
|
||||
|
||||
sequence_size = len(row["leftEar"])
|
||||
valid_sequence = True
|
||||
original_row = row
|
||||
logger = get_logger(__name__)
|
||||
|
||||
last_starting_point, last_ending_point = None, None
|
||||
|
||||
# Treat each element of the sequence (analyzed frame) individually
|
||||
for sequence_index in range(sequence_size):
|
||||
left_shoulder = (row["leftShoulder"][sequence_index][0], row["leftShoulder"][sequence_index][1])
|
||||
right_shoulder = (row["rightShoulder"][sequence_index][0], row["rightShoulder"][sequence_index][1])
|
||||
neck = (row["neck"][sequence_index][0], row["neck"][sequence_index][1])
|
||||
nose = (row["nose"][sequence_index][0], row["nose"][sequence_index][1])
|
||||
# Prevent from even starting the analysis if some necessary elements are not present
|
||||
if (left_shoulder[0] == 0 or right_shoulder[0] == 0
|
||||
or (left_shoulder[0] == right_shoulder[0] and left_shoulder[1] == right_shoulder[1])) and (
|
||||
neck[0] == 0 or nose[0] == 0 or (neck[0] == nose[0] and neck[1] == nose[1])):
|
||||
if not last_starting_point:
|
||||
valid_sequence = False
|
||||
continue
|
||||
|
||||
else:
|
||||
starting_point, ending_point = last_starting_point, last_ending_point
|
||||
|
||||
else:
|
||||
|
||||
# NOTE:
|
||||
#
|
||||
# While in the paper, it is written that the head metric is calculated by halving the shoulder distance,
|
||||
# this is meant for the distance between the very ends of one's shoulder, as literature studying body
|
||||
# metrics and ratios generally states. The Vision Pose Estimation API, however, seems to be predicting
|
||||
# rather the center of one's shoulder. Based on our experiments and manual reviews of the data, employing
|
||||
# this as just the plain shoulder distance seems to be more corresponding to the desired metric.
|
||||
#
|
||||
# Please, review this if using other third-party pose estimation libraries.
|
||||
|
||||
if left_shoulder[0] != 0 and right_shoulder[0] != 0 and \
|
||||
(left_shoulder[0] != right_shoulder[0] or left_shoulder[1] != right_shoulder[1]):
|
||||
shoulder_distance = ((((left_shoulder[0] - right_shoulder[0]) ** 2) + (
|
||||
(left_shoulder[1] - right_shoulder[1]) ** 2)) ** 0.5)
|
||||
head_metric = shoulder_distance
|
||||
else:
|
||||
neck_nose_distance = ((((neck[0] - nose[0]) ** 2) + ((neck[1] - nose[1]) ** 2)) ** 0.5)
|
||||
head_metric = neck_nose_distance
|
||||
|
||||
# Set the starting and ending point of the normalization bounding box
|
||||
# starting_point = [row["neck"][sequence_index][0] - 3 * head_metric,
|
||||
# row["leftEye"][sequence_index][1] + (head_metric / 2)]
|
||||
starting_point = [row["neck"][sequence_index][0] - 3 * head_metric,
|
||||
row["leftEye"][sequence_index][1] + head_metric]
|
||||
ending_point = [row["neck"][sequence_index][0] + 3 * head_metric, starting_point[1] - 6 * head_metric]
|
||||
|
||||
last_starting_point, last_ending_point = starting_point, ending_point
|
||||
|
||||
# Ensure that all of the bounding-box-defining coordinates are not out of the picture
|
||||
if starting_point[0] < 0:
|
||||
starting_point[0] = 0
|
||||
if starting_point[1] < 0:
|
||||
starting_point[1] = 0
|
||||
if ending_point[0] < 0:
|
||||
ending_point[0] = 0
|
||||
if ending_point[1] < 0:
|
||||
ending_point[1] = 0
|
||||
|
||||
# Normalize individual landmarks and save the results
|
||||
for identifier in BODY_IDENTIFIERS:
|
||||
key = identifier
|
||||
|
||||
# Prevent from trying to normalize incorrectly captured points
|
||||
if row[key][sequence_index][0] == 0:
|
||||
continue
|
||||
|
||||
if (ending_point[0] - starting_point[0]) == 0 or (starting_point[1] - ending_point[1]) == 0:
|
||||
logger.warning("Problematic normalization")
|
||||
valid_sequence = False
|
||||
break
|
||||
|
||||
normalized_x = (row[key][sequence_index][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
|
||||
normalized_y = (row[key][sequence_index][1] - ending_point[1]) / (starting_point[1] - ending_point[1])
|
||||
|
||||
row[key][sequence_index] = list(row[key][sequence_index])
|
||||
|
||||
row[key][sequence_index][0] = normalized_x
|
||||
row[key][sequence_index][1] = normalized_y
|
||||
|
||||
if valid_sequence:
|
||||
return row
|
||||
|
||||
else:
|
||||
return original_row
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
195
normalization/hand_normalization.py
Normal file
195
normalization/hand_normalization.py
Normal file
@@ -0,0 +1,195 @@
|
||||
|
||||
import pandas as pd
|
||||
from utils import get_logger
|
||||
|
||||
|
||||
HAND_IDENTIFIERS = [
|
||||
"wrist",
|
||||
"indexTip",
|
||||
"indexDIP",
|
||||
"indexPIP",
|
||||
"indexMCP",
|
||||
"middleTip",
|
||||
"middleDIP",
|
||||
"middlePIP",
|
||||
"middleMCP",
|
||||
"ringTip",
|
||||
"ringDIP",
|
||||
"ringPIP",
|
||||
"ringMCP",
|
||||
"littleTip",
|
||||
"littleDIP",
|
||||
"littlePIP",
|
||||
"littleMCP",
|
||||
"thumbTip",
|
||||
"thumbIP",
|
||||
"thumbMP",
|
||||
"thumbCMC"
|
||||
]
|
||||
|
||||
|
||||
def normalize_hands_full(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Normalizes the hands position data using the Bohacek-normalization algorithm.
|
||||
|
||||
:param df: pd.DataFrame to be normalized
|
||||
:return: pd.DataFrame with normalized values for hand pose
|
||||
"""
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# TODO: Fix division by zero
|
||||
df.columns = [item.replace("_left_", "_0_").replace("_right_", "_1_") for item in list(df.columns)]
|
||||
|
||||
normalized_df = pd.DataFrame(columns=df.columns)
|
||||
|
||||
hand_landmarks = {"X": {0: [], 1: []}, "Y": {0: [], 1: []}}
|
||||
|
||||
# Determine how many hands are present in the dataset
|
||||
range_hand_size = 1
|
||||
if "wrist_1_X" in df.columns:
|
||||
range_hand_size = 2
|
||||
|
||||
# Construct the relevant identifiers
|
||||
for identifier in HAND_IDENTIFIERS:
|
||||
for hand_index in range(range_hand_size):
|
||||
hand_landmarks["X"][hand_index].append(identifier + "_" + str(hand_index) + "_X")
|
||||
hand_landmarks["Y"][hand_index].append(identifier + "_" + str(hand_index) + "_Y")
|
||||
|
||||
# Iterate over all of the records in the dataset
|
||||
for index, row in df.iterrows():
|
||||
# Treat each hand individually
|
||||
for hand_index in range(range_hand_size):
|
||||
|
||||
sequence_size = len(row["wrist_" + str(hand_index) + "_X"])
|
||||
|
||||
# Treat each element of the sequence (analyzed frame) individually
|
||||
for sequence_index in range(sequence_size):
|
||||
|
||||
# Retrieve all of the X and Y values of the current frame
|
||||
landmarks_x_values = [row[key][sequence_index]
|
||||
for key in hand_landmarks["X"][hand_index] if row[key][sequence_index] != 0]
|
||||
landmarks_y_values = [row[key][sequence_index]
|
||||
for key in hand_landmarks["Y"][hand_index] if row[key][sequence_index] != 0]
|
||||
|
||||
# Prevent from even starting the analysis if some necessary elements are not present
|
||||
if not landmarks_x_values or not landmarks_y_values:
|
||||
logger.warning(
|
||||
" HAND LANDMARKS: One frame could not be normalized as there is no data present. Record: " +
|
||||
str(index) +
|
||||
", Frame: " + str(sequence_index))
|
||||
continue
|
||||
|
||||
# Calculate the deltas
|
||||
width, height = max(landmarks_x_values) - min(landmarks_x_values), max(landmarks_y_values) - min(
|
||||
landmarks_y_values)
|
||||
if width > height:
|
||||
delta_x = 0.1 * width
|
||||
delta_y = delta_x + ((width - height) / 2)
|
||||
else:
|
||||
delta_y = 0.1 * height
|
||||
delta_x = delta_y + ((height - width) / 2)
|
||||
|
||||
# Set the starting and ending point of the normalization bounding box
|
||||
starting_point = (min(landmarks_x_values) - delta_x, min(landmarks_y_values) - delta_y)
|
||||
ending_point = (max(landmarks_x_values) + delta_x, max(landmarks_y_values) + delta_y)
|
||||
|
||||
# Normalize individual landmarks and save the results
|
||||
for identifier in HAND_IDENTIFIERS:
|
||||
key = identifier + "_" + str(hand_index) + "_"
|
||||
|
||||
# Prevent from trying to normalize incorrectly captured points
|
||||
if row[key + "X"][sequence_index] == 0 or (ending_point[0] - starting_point[0]) == 0 or \
|
||||
(starting_point[1] - ending_point[1]) == 0:
|
||||
continue
|
||||
|
||||
normalized_x = (row[key + "X"][sequence_index] - starting_point[0]) / (ending_point[0] -
|
||||
starting_point[0])
|
||||
normalized_y = (row[key + "Y"][sequence_index] - ending_point[1]) / (starting_point[1] -
|
||||
ending_point[1])
|
||||
|
||||
row[key + "X"][sequence_index] = normalized_x
|
||||
row[key + "Y"][sequence_index] = normalized_y
|
||||
|
||||
normalized_df = normalized_df.append(row, ignore_index=True)
|
||||
|
||||
return normalized_df
|
||||
|
||||
|
||||
def normalize_single_dict(row: dict):
|
||||
"""
|
||||
Normalizes the skeletal data for a given sequence of frames with signer's hand pose data. The normalization follows
|
||||
the definition from our paper.
|
||||
|
||||
:param row: Dictionary containing key-value pairs with joint identifiers and corresponding lists (sequences) of
|
||||
that particular joints coordinates
|
||||
:return: Dictionary with normalized skeletal data (following the same schema as input data)
|
||||
"""
|
||||
|
||||
hand_landmarks = {0: [], 1: []}
|
||||
|
||||
# Determine how many hands are present in the dataset
|
||||
range_hand_size = 1
|
||||
if "wrist_1" in row.keys():
|
||||
range_hand_size = 2
|
||||
|
||||
# Construct the relevant identifiers
|
||||
for identifier in HAND_IDENTIFIERS:
|
||||
for hand_index in range(range_hand_size):
|
||||
hand_landmarks[hand_index].append(identifier + "_" + str(hand_index))
|
||||
|
||||
# Treat each hand individually
|
||||
for hand_index in range(range_hand_size):
|
||||
|
||||
sequence_size = len(row["wrist_" + str(hand_index)])
|
||||
|
||||
# Treat each element of the sequence (analyzed frame) individually
|
||||
for sequence_index in range(sequence_size):
|
||||
|
||||
# Retrieve all of the X and Y values of the current frame
|
||||
landmarks_x_values = [row[key][sequence_index][0] for key in hand_landmarks[hand_index] if
|
||||
row[key][sequence_index][0] != 0]
|
||||
landmarks_y_values = [row[key][sequence_index][1] for key in hand_landmarks[hand_index] if
|
||||
row[key][sequence_index][1] != 0]
|
||||
|
||||
# Prevent from even starting the analysis if some necessary elements are not present
|
||||
if not landmarks_x_values or not landmarks_y_values:
|
||||
continue
|
||||
|
||||
# Calculate the deltas
|
||||
width, height = max(landmarks_x_values) - min(landmarks_x_values), max(landmarks_y_values) - min(
|
||||
landmarks_y_values)
|
||||
if width > height:
|
||||
delta_x = 0.1 * width
|
||||
delta_y = delta_x + ((width - height) / 2)
|
||||
else:
|
||||
delta_y = 0.1 * height
|
||||
delta_x = delta_y + ((height - width) / 2)
|
||||
|
||||
# Set the starting and ending point of the normalization bounding box
|
||||
starting_point = (min(landmarks_x_values) - delta_x, min(landmarks_y_values) - delta_y)
|
||||
ending_point = (max(landmarks_x_values) + delta_x, max(landmarks_y_values) + delta_y)
|
||||
|
||||
# Normalize individual landmarks and save the results
|
||||
for identifier in HAND_IDENTIFIERS:
|
||||
key = identifier + "_" + str(hand_index)
|
||||
|
||||
# Prevent from trying to normalize incorrectly captured points
|
||||
if row[key][sequence_index][0] == 0 or (ending_point[0] - starting_point[0]) == 0 or (
|
||||
starting_point[1] - ending_point[1]) == 0:
|
||||
continue
|
||||
|
||||
normalized_x = (row[key][sequence_index][0] - starting_point[0]) / (ending_point[0] -
|
||||
starting_point[0])
|
||||
normalized_y = (row[key][sequence_index][1] - starting_point[1]) / (ending_point[1] -
|
||||
starting_point[1])
|
||||
|
||||
row[key][sequence_index] = list(row[key][sequence_index])
|
||||
|
||||
row[key][sequence_index][0] = normalized_x
|
||||
row[key][sequence_index][1] = normalized_y
|
||||
|
||||
return row
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
47
normalization/main.py
Normal file
47
normalization/main.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import ast
|
||||
import pandas as pd
|
||||
|
||||
from normalization.hand_normalization import normalize_hands_full
|
||||
from normalization.body_normalization import normalize_body_full
|
||||
|
||||
DATASET_PATH = './data'
|
||||
# Load the dataset
|
||||
df = pd.read_csv(os.path.join(DATASET_PATH, "WLASL_test_15fps.csv"), encoding="utf-8")
|
||||
|
||||
# Retrieve metadata
|
||||
video_size_heights = df["video_size_height"].to_list()
|
||||
video_size_widths = df["video_size_width"].to_list()
|
||||
|
||||
# Delete redundant (non-related) properties
|
||||
del df["video_size_height"]
|
||||
del df["video_size_width"]
|
||||
|
||||
# Temporarily remove other relevant metadata
|
||||
labels = df["labels"].to_list()
|
||||
video_fps = df["video_fps"].to_list()
|
||||
del df["labels"]
|
||||
del df["video_fps"]
|
||||
|
||||
# Convert the strings into lists
|
||||
|
||||
|
||||
def convert(x): return ast.literal_eval(str(x))
|
||||
|
||||
|
||||
for column in df.columns:
|
||||
df[column] = df[column].apply(convert)
|
||||
|
||||
# Perform the normalizations
|
||||
df = normalize_hands_full(df)
|
||||
df, invalid_row_indexes = normalize_body_full(df)
|
||||
|
||||
# Clear lists of items from deleted rows
|
||||
# labels = [t for i, t in enumerate(labels) if i not in invalid_row_indexes]
|
||||
# video_fps = [t for i, t in enumerate(video_fps) if i not in invalid_row_indexes]
|
||||
|
||||
# Return the metadata back to the dataset
|
||||
df["labels"] = labels
|
||||
df["video_fps"] = video_fps
|
||||
|
||||
df.to_csv(os.path.join(DATASET_PATH, "WLASL_test_15fps_normalized.csv"), encoding="utf-8", index=False)
|
||||
411
notebooks/embeddings_evaluation.ipynb
Normal file
411
notebooks/embeddings_evaluation.ipynb
Normal file
@@ -0,0 +1,411 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c20f7fd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ada032d0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"import os.path as op\n",
|
||||
"import pandas as pd\n",
|
||||
"import json\n",
|
||||
"import base64"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "05682e73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sys.path.append(op.abspath('..'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fede7684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ce531994",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import Counter\n",
|
||||
"from itertools import chain\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import multiprocessing\n",
|
||||
"from scipy.spatial import distance_matrix\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f4a2d672",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"\n",
|
||||
"from datasets import SLREmbeddingDataset, collate_fn_padd\n",
|
||||
"from datasets.dataset_loader import LocalDatasetLoader\n",
|
||||
"from models import embeddings_scatter_plot_splits\n",
|
||||
"from models import SPOTER_EMBEDDINGS"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "af8fbe32",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model and dataset loading"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1d9db764",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"seed = 43\n",
|
||||
"random.seed(seed)\n",
|
||||
"np.random.seed(seed)\n",
|
||||
"os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
|
||||
"torch.manual_seed(seed)\n",
|
||||
"torch.cuda.manual_seed(seed)\n",
|
||||
"torch.cuda.manual_seed_all(seed)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"torch.use_deterministic_algorithms(True) \n",
|
||||
"generator = torch.Generator()\n",
|
||||
"generator.manual_seed(seed)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "71224139",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BASE_DATA_FOLDER = '../data/'\n",
|
||||
"os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n",
|
||||
"device = torch.device(\"cpu\")\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = torch.device(\"cuda\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "013d3774",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# LOAD MODEL FROM CLEARML\n",
|
||||
"# from clearml import InputModel\n",
|
||||
"# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n",
|
||||
"# checkpoint = torch.load(model.get_weights())\n",
|
||||
"\n",
|
||||
"## Set your path to checkoint here\n",
|
||||
"CHECKPOINT_PATH = \"../checkpoints/checkpoint_embed_992.pth\"\n",
|
||||
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
||||
"\n",
|
||||
"model = SPOTER_EMBEDDINGS(\n",
|
||||
" features=checkpoint[\"config_args\"].vector_length,\n",
|
||||
" hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n",
|
||||
" norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n",
|
||||
").to(device)\n",
|
||||
"\n",
|
||||
"model.load_state_dict(checkpoint[\"state_dict\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ba6b58f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
||||
"if SL_DATASET == 'wlasl':\n",
|
||||
" dataset_name = \"wlasl_mapped_mediapipe_only_landmarks_25fps\"\n",
|
||||
" num_classes = 100\n",
|
||||
" split_dataset_path = \"WLASL100_{}_25fps.csv\"\n",
|
||||
"else:\n",
|
||||
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
||||
" num_classes = 64\n",
|
||||
" split_dataset_path = \"LSA64_{}.csv\"\n",
|
||||
" \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5643a72c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_dataset_loader(loader_name=None):\n",
|
||||
" if loader_name == 'CLEARML':\n",
|
||||
" from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n",
|
||||
" return ClearMLDatasetLoader()\n",
|
||||
" else:\n",
|
||||
" return LocalDatasetLoader()\n",
|
||||
"\n",
|
||||
"dataset_loader = get_dataset_loader()\n",
|
||||
"dataset_project = \"Sign Language Recognition\"\n",
|
||||
"batch_size = 1\n",
|
||||
"dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "04a62088",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def seed_worker(worker_id):\n",
|
||||
" worker_seed = torch.initial_seed() % 2**32\n",
|
||||
" np.random.seed(worker_seed)\n",
|
||||
" random.seed(worker_seed)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "79c837c1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataloaders = {}\n",
|
||||
"splits = ['train', 'val']\n",
|
||||
"dfs = {}\n",
|
||||
"for split in splits:\n",
|
||||
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
|
||||
" split_set = SLREmbeddingDataset(split_set_path, triplet=False, augmentations=False)\n",
|
||||
" data_loader = DataLoader(\n",
|
||||
" split_set,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" shuffle=False,\n",
|
||||
" collate_fn=collate_fn_padd,\n",
|
||||
" pin_memory=torch.cuda.is_available(),\n",
|
||||
" num_workers=multiprocessing.cpu_count(),\n",
|
||||
" worker_init_fn=seed_worker,\n",
|
||||
" generator=generator,\n",
|
||||
" )\n",
|
||||
" dataloaders[split] = data_loader\n",
|
||||
" dfs[split] = pd.read_csv(split_set_path)\n",
|
||||
"\n",
|
||||
"with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n",
|
||||
" id_to_label = json.load(fid)\n",
|
||||
"id_to_label = {int(key): value for key, value in id_to_label.items()}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b5bda73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"labels_split = {}\n",
|
||||
"embeddings_split = {}\n",
|
||||
"splits = list(dataloaders.keys())\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for split, dataloader in dataloaders.items():\n",
|
||||
" labels_str = []\n",
|
||||
" embeddings = []\n",
|
||||
" k = 0\n",
|
||||
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
|
||||
" k += 1\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" masks = masks.to(device)\n",
|
||||
" outputs = model(inputs, masks)\n",
|
||||
" for n in range(outputs.shape[0]):\n",
|
||||
" embeddings.append(outputs[n, 0].cpu().detach().numpy())\n",
|
||||
" embeddings_split[split] = embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0efa0871",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"len(embeddings_split['train']), len(dfs['train'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ab83c6e2",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for split in splits:\n",
|
||||
" df = dfs[split]\n",
|
||||
" df['embeddings'] = embeddings_split[split]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2951638d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute metrics\n",
|
||||
"Here computing top1 and top5 metrics either by using only a class centroid or by using the whole dataset to classify vectors.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7399b8ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for use_centroids, str_use_centroids in zip([True, False],\n",
|
||||
" ['Using centroids only', 'Using all embeddings']):\n",
|
||||
"\n",
|
||||
" df_val = dfs['val']\n",
|
||||
" df_train = dfs['train']\n",
|
||||
" if use_centroids:\n",
|
||||
" df_train = dfs['train'].groupby('labels')['embeddings'].apply(np.mean).reset_index()\n",
|
||||
" x_train = np.vstack(df_train['embeddings'])\n",
|
||||
" x_val = np.vstack(df_val['embeddings'])\n",
|
||||
"\n",
|
||||
" d_mat = distance_matrix(x_val, x_train, p=2)\n",
|
||||
"\n",
|
||||
" top5_embs = 0\n",
|
||||
" top5_classes = 0\n",
|
||||
" knn = 0\n",
|
||||
" top1 = 0\n",
|
||||
"\n",
|
||||
" len_val_dataset = len(df_val)\n",
|
||||
" good_samples = []\n",
|
||||
"\n",
|
||||
" for i in range(d_mat.shape[0]):\n",
|
||||
" true_label = df_val.loc[i, 'labels']\n",
|
||||
" labels = df_train['labels'].values\n",
|
||||
" argsort = np.argsort(d_mat[i])\n",
|
||||
" sorted_labels = labels[argsort]\n",
|
||||
" if sorted_labels[0] == true_label:\n",
|
||||
" top1 += 1\n",
|
||||
" if use_centroids:\n",
|
||||
" good_samples.append(df_val.loc[i, 'video_id'])\n",
|
||||
" else:\n",
|
||||
" good_samples.append((df_val.loc[i, 'video_id'],\n",
|
||||
" df_train.loc[argsort[0], 'video_id'],\n",
|
||||
" i,\n",
|
||||
" argsort[0]))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
|
||||
" knn += 1\n",
|
||||
" if true_label in sorted_labels[:5]:\n",
|
||||
" top5_embs += 1\n",
|
||||
" if true_label in list(dict.fromkeys(sorted_labels))[:5]:\n",
|
||||
" top5_classes += 1\n",
|
||||
" else:\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" print(str_use_centroids)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" print(f'Top-1 accuracy: {100 * top1 / len_val_dataset : 0.2f} %')\n",
|
||||
" if not use_centroids:\n",
|
||||
" print(f'5-nn accuracy: {100 * knn / len_val_dataset : 0.2f} % (Picks the class that appears most often in the 5 closest embeddings)')\n",
|
||||
" print(f'Top-5 embeddings class match: {100 * top5_embs / len_val_dataset: 0.2f} % (Picks any class in the 5 closest embeddings)')\n",
|
||||
" if not use_centroids:\n",
|
||||
" print(f'Top-5 unique class match: {100 * top5_classes / len_val_dataset: 0.2f} % (Picks the 5 closest distinct classes)')\n",
|
||||
" print('\\n' + '#'*32 + '\\n')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d2aaac6c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Show some examples (only for WLASL)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b9d1d309",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.display import Video"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fd2a0cd8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for row in df_train[df_train.label_name == 'thursday'][:3].itertuples():\n",
|
||||
" display(Video(op.join(BASE_DATA_FOLDER, f'wlasl/videos/{row.video_id}.mp4'), embed=True))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"jupytext": {
|
||||
"cell_metadata_filter": "-all",
|
||||
"main_language": "python",
|
||||
"notebook_metadata_filter": "-all"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
491
notebooks/visualize_embeddings.ipynb
Normal file
491
notebooks/visualize_embeddings.ipynb
Normal file
File diff suppressed because one or more lines are too long
21
preprocessing.py
Normal file
21
preprocessing.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from argparse import ArgumentParser
|
||||
from preprocessing.create_wlasl_landmarks_dataset import parse_create_args, create
|
||||
from preprocessing.extract_mediapipe_landmarks import parse_extract_args, extract
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_parser = ArgumentParser()
|
||||
subparser = main_parser.add_subparsers(dest="action")
|
||||
create_subparser = subparser.add_parser("create")
|
||||
extract_subparser = subparser.add_parser("extract")
|
||||
parse_create_args(create_subparser)
|
||||
parse_extract_args(extract_subparser)
|
||||
|
||||
args = main_parser.parse_args()
|
||||
|
||||
if args.action == "create":
|
||||
create(args)
|
||||
elif args.action == "extract":
|
||||
extract(args)
|
||||
else:
|
||||
ValueError("action command must be either 'create' or 'extract'")
|
||||
0
preprocessing/__init__.py
Normal file
0
preprocessing/__init__.py
Normal file
155
preprocessing/create_wlasl_landmarks_dataset.py
Normal file
155
preprocessing/create_wlasl_landmarks_dataset.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import os
|
||||
import os.path as op
|
||||
import json
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from utils import get_logger
|
||||
from tqdm.auto import tqdm
|
||||
from sklearn.model_selection import train_test_split
|
||||
from normalization.blazepose_mapping import map_blazepose_df
|
||||
|
||||
BASE_DATA_FOLDER = 'data/'
|
||||
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
mp_drawing_styles = mp.solutions.drawing_styles
|
||||
mp_hands = mp.solutions.hands
|
||||
mp_holistic = mp.solutions.holistic
|
||||
pose_landmarks = mp_holistic.PoseLandmark
|
||||
hand_landmarks = mp_holistic.HandLandmark
|
||||
|
||||
|
||||
def get_landmarks_names():
|
||||
'''
|
||||
Returns landmark names for mediapipe holistic model
|
||||
'''
|
||||
pose_lmks = ','.join([f'{lmk.name.lower()}_x,{lmk.name.lower()}_y' for lmk in pose_landmarks])
|
||||
left_hand_lmks = ','.join([f'left_hand_{lmk.name.lower()}_x,left_hand_{lmk.name.lower()}_y'
|
||||
for lmk in hand_landmarks])
|
||||
right_hand_lmks = ','.join([f'right_hand_{lmk.name.lower()}_x,right_hand_{lmk.name.lower()}_y'
|
||||
for lmk in hand_landmarks])
|
||||
lmks_names = f'{pose_lmks},{left_hand_lmks},{right_hand_lmks}'
|
||||
return lmks_names
|
||||
|
||||
|
||||
def convert_to_str(arr, precision=6):
|
||||
if isinstance(arr, np.ndarray):
|
||||
values = []
|
||||
for val in arr:
|
||||
if val == 0:
|
||||
values.append('0')
|
||||
else:
|
||||
values.append(f'{val:.{precision}f}')
|
||||
return f"[{','.join(values)}]"
|
||||
else:
|
||||
return str(arr)
|
||||
|
||||
|
||||
def parse_create_args(parser):
|
||||
parser.add_argument('--landmarks-dataset', '-lmks', required=True,
|
||||
help='Path to folder with landmarks npy files. \
|
||||
You need to run `extract_mediapipe_landmarks.py` script first')
|
||||
parser.add_argument('--dataset-folder', '-df', default='data/wlasl',
|
||||
help='Path to folder where original `WLASL_v0.3.json` and `id_to_label.json` are stored. \
|
||||
Note that final CSV files will be saved in this folder too.')
|
||||
parser.add_argument('--videos-folder', '-videos', default=None,
|
||||
help='Path to folder with videos. If None, then no information of videos (fps, length, \
|
||||
width and height) will be stored in final csv file')
|
||||
parser.add_argument('--num-classes', '-nc', default=100, type=int, help='Number of classes to use in WLASL dataset')
|
||||
parser.add_argument('--create-new-split', action='store_true')
|
||||
parser.add_argument('--test-size', '-ts', default=0.25, type=float,
|
||||
help='Test split percentage size. Only required if --create-new-split is set')
|
||||
|
||||
|
||||
# python3 preprocessing.py --landmarks-dataset=data/landmarks -videos data/wlasl/videos
|
||||
def create(args):
|
||||
logger = get_logger(__name__)
|
||||
|
||||
landmarks_dataset = args.landmarks_dataset
|
||||
videos_folder = args.videos_folder
|
||||
dataset_folder = args.dataset_folder
|
||||
num_classes = args.num_classes
|
||||
test_size = args.test_size
|
||||
|
||||
os.makedirs(dataset_folder, exist_ok=True)
|
||||
|
||||
shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/id_to_label.json'), dataset_folder)
|
||||
shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/WLASL_v0.3.json'), dataset_folder)
|
||||
|
||||
wlasl_json_fn = op.join(dataset_folder, 'WLASL_v0.3.json')
|
||||
|
||||
with open(wlasl_json_fn) as fid:
|
||||
data = json.load(fid)
|
||||
|
||||
video_data = []
|
||||
for label_id, datum in enumerate(tqdm(data[:num_classes])):
|
||||
instances = []
|
||||
for instance in datum['instances']:
|
||||
instances.append(instance)
|
||||
video_id = instance['video_id']
|
||||
print(video_id)
|
||||
video_dict = {'video_id': video_id,
|
||||
'label_name': datum['gloss'],
|
||||
'labels': label_id,
|
||||
'split': instance['split']}
|
||||
if videos_folder is not None:
|
||||
cap = cv2.VideoCapture(op.join(videos_folder, f'{video_id}.mp4'))
|
||||
if not cap.isOpened():
|
||||
logger.warning(f'Video {video_id}.mp4 not found')
|
||||
continue
|
||||
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
length = cap.get(cv2.CAP_PROP_FRAME_COUNT) / float(cap.get(cv2.CAP_PROP_FPS))
|
||||
video_info = {'video_width': width,
|
||||
'video_height': height,
|
||||
'fps': fps,
|
||||
'length': length}
|
||||
video_dict.update(video_info)
|
||||
video_data.append(video_dict)
|
||||
df_video = pd.DataFrame(video_data)
|
||||
video_ids = df_video['video_id'].unique()
|
||||
lmks_data = []
|
||||
lmks_names = get_landmarks_names().split(',')
|
||||
for video_id in video_ids:
|
||||
lmk_fn = op.join(landmarks_dataset, f'{video_id}.npy')
|
||||
if not op.exists(lmk_fn):
|
||||
logger.warning(f'{lmk_fn} file not found. Skipping')
|
||||
continue
|
||||
lmk = np.load(lmk_fn).T
|
||||
lmks_dict = {'video_id': video_id}
|
||||
for lmk_, name in zip(lmk, lmks_names):
|
||||
lmks_dict[name] = lmk_
|
||||
lmks_data.append(lmks_dict)
|
||||
|
||||
df_lmks = pd.DataFrame(lmks_data)
|
||||
print(df_lmks)
|
||||
df = pd.merge(df_video, df_lmks)
|
||||
print(df)
|
||||
aux_columns = ['split', 'video_id', 'labels', 'label_name']
|
||||
if videos_folder is not None:
|
||||
aux_columns += ['video_width', 'video_height', 'fps', 'length']
|
||||
df_aux = df[aux_columns]
|
||||
df = map_blazepose_df(df)
|
||||
df = pd.concat([df, df_aux], axis=1)
|
||||
if args.create_new_split:
|
||||
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
|
||||
else:
|
||||
print(df['split'].unique())
|
||||
df_train = df[(df['split'] == 'train') | (df['split'] == 'val')]
|
||||
df_test = df[df['split'] == 'test']
|
||||
|
||||
print(f'Num classes: {num_classes}')
|
||||
print(df_train['labels'].value_counts())
|
||||
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
|
||||
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
|
||||
the --create-new-split flag'
|
||||
for split, df_split in zip(['train', 'val'],
|
||||
[df_train, df_test]):
|
||||
fn_out = op.join(dataset_folder, f'WLASL{num_classes}_{split}.csv')
|
||||
(df_split.reset_index(drop=True)
|
||||
.applymap(convert_to_str)
|
||||
.to_csv(fn_out, index=False))
|
||||
154
preprocessing/extract_mediapipe_landmarks.py
Normal file
154
preprocessing/extract_mediapipe_landmarks.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
import os.path as op
|
||||
from itertools import chain
|
||||
from collections import namedtuple
|
||||
import glob
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import mediapipe as mp
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
# Import drawing_utils and drawing_styles.
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
mp_drawing_styles = mp.solutions.drawing_styles
|
||||
mp_holistic = mp.solutions.holistic
|
||||
mp_pose = mp.solutions.pose
|
||||
|
||||
LEN_LANDMARKS_POSE = len(mp_holistic.PoseLandmark)
|
||||
LEN_LANDMARKS_HAND = len(mp_holistic.HandLandmark)
|
||||
TOTAL_LANDMARKS = LEN_LANDMARKS_POSE + 2 * LEN_LANDMARKS_HAND
|
||||
|
||||
Landmark = namedtuple("Landmark", ["x", "y"])
|
||||
|
||||
|
||||
class LandmarksResults:
|
||||
"""
|
||||
Wrapper for landmarks results. When not available it fills with 0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
results,
|
||||
num_landmarks_pose=LEN_LANDMARKS_POSE,
|
||||
num_landmarks_hand=LEN_LANDMARKS_HAND,
|
||||
):
|
||||
self.results = results
|
||||
self.num_landmarks_pose = num_landmarks_pose
|
||||
self.num_landmarks_hand = num_landmarks_hand
|
||||
|
||||
@property
|
||||
def pose_landmarks(self):
|
||||
if self.results.pose_landmarks is None:
|
||||
return [Landmark(0, 0)] * self.num_landmarks_pose
|
||||
else:
|
||||
return self.results.pose_landmarks.landmark
|
||||
|
||||
@property
|
||||
def left_hand_landmarks(self):
|
||||
if self.results.left_hand_landmarks is None:
|
||||
return [Landmark(0, 0)] * self.num_landmarks_hand
|
||||
else:
|
||||
return self.results.left_hand_landmarks.landmark
|
||||
|
||||
@property
|
||||
def right_hand_landmarks(self):
|
||||
if self.results.right_hand_landmarks is None:
|
||||
return [Landmark(0, 0)] * self.num_landmarks_hand
|
||||
else:
|
||||
return self.results.right_hand_landmarks.landmark
|
||||
|
||||
|
||||
def get_landmarks(image_orig, holistic, debug=False):
|
||||
"""
|
||||
Runs landmarks detection for single image
|
||||
Returns: list of landmarks
|
||||
"""
|
||||
# Convert the BGR image to RGB before processing.
|
||||
image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
|
||||
results = LandmarksResults(holistic.process(image))
|
||||
if debug:
|
||||
lmks_pose = []
|
||||
for lmk in results.pose_landmarks:
|
||||
lmks_pose.append(lmk.x)
|
||||
lmks_pose.append(lmk.y)
|
||||
assert len(lmks_pose) == LEN_LANDMARKS_POSE
|
||||
|
||||
lmks_left_hand = []
|
||||
|
||||
for lmk in results.left_hand_landmarks:
|
||||
lmks_left_hand.append(lmk.x)
|
||||
lmks_left_hand.append(lmk.y)
|
||||
|
||||
assert (
|
||||
len(lmks_left_hand) == 2 * LEN_LANDMARKS_HAND
|
||||
), f"{len(lmks_left_hand)} != {2 * LEN_LANDMARKS_HAND}"
|
||||
|
||||
lmks_right_hand = []
|
||||
|
||||
for lmk in results.right_hand_landmarks:
|
||||
lmks_right_hand.append(lmk.x)
|
||||
lmks_right_hand.append(lmk.y),
|
||||
|
||||
assert (
|
||||
len(lmks_right_hand) == 2 * LEN_LANDMARKS_HAND
|
||||
), f"{len(lmks_right_hand)} != {2 * LEN_LANDMARKS_HAND}"
|
||||
landmarks = []
|
||||
for lmk in chain(
|
||||
results.pose_landmarks,
|
||||
results.left_hand_landmarks,
|
||||
results.right_hand_landmarks,
|
||||
):
|
||||
landmarks.append(lmk.x)
|
||||
landmarks.append(lmk.y)
|
||||
assert (
|
||||
len(landmarks) == TOTAL_LANDMARKS * 2
|
||||
), f"{len(landmarks)} != {TOTAL_LANDMARKS * 2}"
|
||||
return landmarks
|
||||
|
||||
|
||||
def parse_extract_args(parser):
|
||||
parser.add_argument(
|
||||
"--videos-folder",
|
||||
"-videos",
|
||||
help="Path of folder with videos to extract landmarks from",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-landmarks",
|
||||
"-lmks",
|
||||
help="Path of output folder where landmarks npy files will be saved",
|
||||
required=True,
|
||||
)
|
||||
|
||||
|
||||
# python3 preprocessing.py -videos=data/wlasl/videos_25fps/ -lmks=data/landmarks
|
||||
def extract(args):
|
||||
landmarks_output = args.output_landmarks
|
||||
videos_folder = args.videos_folder
|
||||
os.makedirs(landmarks_output, exist_ok=True)
|
||||
for fn_video in tqdm(sorted(glob.glob(op.join(videos_folder, "*mp4")))):
|
||||
cap = cv2.VideoCapture(fn_video)
|
||||
ret, image_orig = cap.read()
|
||||
height, width = image_orig.shape[:2]
|
||||
landmarks_video = []
|
||||
with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar:
|
||||
with mp_holistic.Holistic(
|
||||
static_image_mode=False,
|
||||
min_detection_confidence=0.5,
|
||||
model_complexity=2,
|
||||
) as holistic:
|
||||
while ret:
|
||||
try:
|
||||
landmarks = get_landmarks(image_orig, holistic)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
landmarks = get_landmarks(image_orig, holistic, debug=True)
|
||||
ret, image_orig = cap.read()
|
||||
landmarks_video.append(landmarks)
|
||||
pbar.update(1)
|
||||
landmarks_video = np.vstack(landmarks_video)
|
||||
np.save(
|
||||
op.join(landmarks_output, op.basename(fn_video).split(".")[0]),
|
||||
landmarks_video,
|
||||
)
|
||||
14
requirements.txt
Normal file
14
requirements.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
bokeh==2.4.3
|
||||
boto3>=1.9
|
||||
clearml==1.6.4
|
||||
ipywidgets==8.0.4
|
||||
matplotlib==3.5.3
|
||||
mediapipe==0.8.11
|
||||
notebook==6.5.2
|
||||
opencv-python==4.6.0.66
|
||||
pandas==1.1.5
|
||||
pandas==1.1.5
|
||||
plotly==5.11.0
|
||||
scikit-learn==1.0.2
|
||||
torchvision==0.13.0
|
||||
tqdm==4.54.1
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
104
tests/test_batch_sorter.py
Normal file
104
tests/test_batch_sorter.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import unittest
|
||||
# from traceback_with_variables import activate_by_import #noqa
|
||||
import torch
|
||||
from training.batch_sorter import BatchGrouper, sort_batches, get_scaled_distances, get_dist_tuple_list
|
||||
|
||||
|
||||
class TestBatchSorting(unittest.TestCase):
|
||||
|
||||
def get_sorted_dists(self):
|
||||
device = get_device()
|
||||
embeddings = torch.rand(32*8, 8).to(device)
|
||||
labels = torch.rand(32*8, 1)
|
||||
scaled_dist = get_scaled_distances(embeddings, labels, device)
|
||||
# Get vector of (row, column, dist)
|
||||
dist_list = get_dist_tuple_list(scaled_dist)
|
||||
|
||||
A = dist_list.cpu().detach().numpy()
|
||||
return A[:, A[-1, :].argsort()[::-1]]
|
||||
|
||||
def setUp(self) -> None:
|
||||
dists = self.get_sorted_dists()
|
||||
self.grouper = BatchGrouper(sorted_dists=dists, total_items=32*8, mini_batch_size=32)
|
||||
return super().setUp()
|
||||
|
||||
def test_assigns_and_merges(self):
|
||||
group0 = self.grouper.create_or_get_group()
|
||||
self.grouper.assign_group(1, group0)
|
||||
self.grouper.assign_group(2, group0)
|
||||
|
||||
group1 = self.grouper.create_or_get_group()
|
||||
self.grouper.assign_group(3, group1)
|
||||
self.grouper.assign_group(4, group1)
|
||||
|
||||
# Merge groups
|
||||
self.grouper.merge_groups(group0, group1)
|
||||
self.assertEqual(len(self.grouper.groups[group0]), 4)
|
||||
self.assertFalse(group1 in self.grouper.groups)
|
||||
self.assertEqual(self.grouper.item_to_group[3], group0)
|
||||
self.assertEqual(self.grouper.item_to_group[4], group0)
|
||||
|
||||
def test_full_groups(self):
|
||||
group0 = self.grouper.create_or_get_group()
|
||||
for i in range(30):
|
||||
self.grouper.assign_group(i, group0)
|
||||
|
||||
self.assertFalse(self.grouper.group_is_full(group0))
|
||||
initial_group_len = len(self.grouper.groups[group0])
|
||||
|
||||
group1 = self.grouper.create_or_get_group()
|
||||
for i in range(30, 33):
|
||||
self.grouper.assign_group(i, group1)
|
||||
|
||||
self.grouper.merge_groups(group0, group1)
|
||||
# Assert no merge done
|
||||
self.assertEqual(len(self.grouper.groups[group0]), initial_group_len)
|
||||
self.assertTrue(group1 in self.grouper.groups)
|
||||
self.assertEqual(self.grouper.item_to_group[31], group1)
|
||||
self.assertEqual(self.grouper.item_to_group[32], group1)
|
||||
|
||||
def test_replace_groups(self):
|
||||
group0 = self.grouper.create_or_get_group()
|
||||
for i in range(20):
|
||||
self.grouper.assign_group(i, group0)
|
||||
|
||||
group1 = self.grouper.create_or_get_group()
|
||||
for i in range(20, 23):
|
||||
self.grouper.assign_group(i, group1)
|
||||
|
||||
group2 = self.grouper.create_or_get_group()
|
||||
for i in range(23, 30):
|
||||
self.grouper.assign_group(i, group2)
|
||||
|
||||
self.grouper.merge_groups(group1, group0)
|
||||
self.assertEqual(len(self.grouper.groups[group0]), 23)
|
||||
self.assertTrue(group1 in self.grouper.groups)
|
||||
self.assertFalse(group2 in self.grouper.groups)
|
||||
self.assertEqual(len(self.grouper.groups[group1]), 7)
|
||||
|
||||
|
||||
def get_device():
|
||||
device = torch.device("cpu")
|
||||
return device
|
||||
|
||||
|
||||
def test_get_scaled_distances():
|
||||
device = get_device()
|
||||
emb = torch.rand(4, 3)
|
||||
labels = torch.tensor([0, 1, 2, 2])
|
||||
distances = get_scaled_distances(emb, labels, device)
|
||||
assert torch.all(distances >= 0)
|
||||
assert torch.all(distances <= 1)
|
||||
|
||||
|
||||
def test_batch_sorter_indices():
|
||||
device = get_device()
|
||||
inputs = torch.rand(32*16, 1000)
|
||||
labels = torch.rand(32*16, 1)
|
||||
masks = torch.rand(32*16, 100)
|
||||
embeddings = torch.rand(32*16, 32).to(device)
|
||||
|
||||
i_out, l_out, m_out = sort_batches(inputs, labels, masks, embeddings, device)
|
||||
first_match_index = torch.all(inputs == i_out[0], dim=1).nonzero(as_tuple=True)[0][0]
|
||||
assert torch.all(labels[first_match_index] == l_out[0])
|
||||
assert torch.all(masks[first_match_index] == m_out[0])
|
||||
0
tracking/__init__.py
Normal file
0
tracking/__init__.py
Normal file
21
tracking/clearml_tracker.py
Normal file
21
tracking/clearml_tracker.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from clearml import Task, Logger
|
||||
from .tracker import Tracker
|
||||
|
||||
|
||||
class ClearMLTracker(Tracker):
|
||||
|
||||
def __init__(self, project_name=None, experiment_name=None):
|
||||
self.task = Task.current_task() or Task.init(project_name=project_name, task_name=experiment_name)
|
||||
|
||||
def execute_remotely(self, queue_name):
|
||||
self.task.execute_remotely(queue_name=queue_name)
|
||||
|
||||
def log_scalar_metric(self, metric, series, iteration, value):
|
||||
Logger.current_logger().report_scalar(metric, series, iteration=iteration, value=value)
|
||||
|
||||
def log_chart(self, title, series, iteration, figure):
|
||||
Logger.current_logger().report_plotly(title=title, series=series, iteration=iteration, figure=figure)
|
||||
|
||||
def finish_run(self):
|
||||
self.task.mark_completed()
|
||||
self.task.close()
|
||||
28
tracking/tracker.py
Normal file
28
tracking/tracker.py
Normal file
@@ -0,0 +1,28 @@
|
||||
|
||||
class Tracker:
|
||||
|
||||
def __init__(self, project_name, experiment_name):
|
||||
super().__init__()
|
||||
|
||||
def execute_remotely(self, queue_name):
|
||||
pass
|
||||
|
||||
def track_config(self, configs):
|
||||
# Used to track configuration parameters of an experiment run
|
||||
pass
|
||||
|
||||
def track_artifacts(self, filepath):
|
||||
# Used to track artifacts like model weights
|
||||
pass
|
||||
|
||||
def log_scalar_metric(self, metric, series, iteration, value):
|
||||
pass
|
||||
|
||||
def log_chart(self, title, series, iteration, figure):
|
||||
pass
|
||||
|
||||
def finish_run(self):
|
||||
pass
|
||||
|
||||
def get_callback(self):
|
||||
pass
|
||||
287
train.py
Normal file
287
train.py
Normal file
@@ -0,0 +1,287 @@
|
||||
|
||||
from datetime import datetime
|
||||
import os
|
||||
import os.path as op
|
||||
import argparse
|
||||
import json
|
||||
from datasets.dataset_loader import LocalDatasetLoader
|
||||
from tracking.tracker import Tracker
|
||||
import torch
|
||||
import multiprocessing
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
# import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from pathlib import Path
|
||||
import copy
|
||||
|
||||
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
|
||||
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
|
||||
train_epoch_embedding_online, evaluate_embedding
|
||||
from training.online_batch_mining import BatchAllTripletLoss
|
||||
from training.batching_scheduler import BatchingScheduler
|
||||
from training.gaussian_noise import GaussianNoise
|
||||
from training.train_utils import train_setup, create_embedding_scatter_plots
|
||||
from training.train_arguments import get_default_args
|
||||
from utils import get_logger
|
||||
try:
|
||||
# Needed for argparse patching in case clearml is used
|
||||
import clearml # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
PROJECT_NAME = "spoter"
|
||||
CLEARML = "clearml"
|
||||
|
||||
|
||||
def is_pre_batch_sorting_enabled(args):
|
||||
return args.start_mining_hard is not None and args.start_mining_hard > 0
|
||||
|
||||
|
||||
def get_tracker(tracker_name, project, experiment_name):
|
||||
if tracker_name == CLEARML:
|
||||
from tracking.clearml_tracker import ClearMLTracker
|
||||
return ClearMLTracker(project_name=project, experiment_name=experiment_name)
|
||||
else:
|
||||
return Tracker(project_name=project, experiment_name=experiment_name)
|
||||
|
||||
|
||||
def get_dataset_loader(loader_name):
|
||||
if loader_name == CLEARML:
|
||||
from datasets.clearml_dataset_loader import ClearMLDatasetLoader
|
||||
return ClearMLDatasetLoader()
|
||||
else:
|
||||
return LocalDatasetLoader()
|
||||
|
||||
|
||||
def build_data_loader(dataset, batch_size, shuffle, collate_fn, generator):
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn,
|
||||
generator=generator, pin_memory=torch.cuda.is_available(), num_workers=multiprocessing.cpu_count())
|
||||
|
||||
|
||||
def train(args, tracker: Tracker):
|
||||
tracker.execute_remotely(queue_name="default")
|
||||
# Initialize all the random seeds
|
||||
gen = train_setup(args.seed, args.experiment_name)
|
||||
os.environ['EXPERIMENT_NAME'] = args.experiment_name
|
||||
logger = get_logger(args.experiment_name)
|
||||
|
||||
# Set device to CUDA only if applicable
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Construct the model
|
||||
if not args.classification_model:
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=args.vector_length,
|
||||
hidden_dim=args.hidden_dim,
|
||||
norm_emb=args.normalize_embeddings,
|
||||
dropout=args.dropout
|
||||
)
|
||||
model_type = 'embed'
|
||||
if args.hard_triplet_mining == "None":
|
||||
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
|
||||
elif args.hard_triplet_mining == "in_batch":
|
||||
cel_criterion = BatchAllTripletLoss(
|
||||
device=device,
|
||||
margin=args.triplet_loss_margin,
|
||||
filter_easy_triplets=bool(args.filter_easy_triplets)
|
||||
)
|
||||
else:
|
||||
slrt_model = SPOTER(num_classes=args.num_classes, hidden_dim=args.hidden_dim)
|
||||
model_type = 'classif'
|
||||
cel_criterion = nn.CrossEntropyLoss()
|
||||
slrt_model.to(device)
|
||||
|
||||
if args.optimizer == "SGD":
|
||||
optimizer = optim.SGD(slrt_model.parameters(), lr=args.lr)
|
||||
elif args.optimizer == "ADAM":
|
||||
optimizer = optim.Adam(slrt_model.parameters(), lr=args.lr)
|
||||
|
||||
if args.scheduler_factor > 0:
|
||||
mode = 'min' if args.classification_model else 'max'
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode=mode,
|
||||
factor=args.scheduler_factor,
|
||||
patience=args.scheduler_patience
|
||||
)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
if args.hard_mining_scheduler_triplets_threshold > 0:
|
||||
batching_scheduler = BatchingScheduler(triplets_threshold=args.hard_mining_scheduler_triplets_threshold)
|
||||
else:
|
||||
batching_scheduler = None
|
||||
|
||||
# Ensure that the path for checkpointing and for images both exist
|
||||
Path("out-checkpoints/" + args.experiment_name + "/").mkdir(parents=True, exist_ok=True)
|
||||
Path("out-img/").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Training set
|
||||
transform = transforms.Compose([GaussianNoise(args.gaussian_mean, args.gaussian_std)])
|
||||
dataset_loader = get_dataset_loader(args.dataset_loader)
|
||||
dataset_folder = dataset_loader.get_dataset_folder(args.dataset_project, args.dataset_name)
|
||||
training_set_path = op.join(dataset_folder, args.training_set_path)
|
||||
|
||||
with open(op.join(dataset_folder, 'id_to_label.json')) as fid:
|
||||
id_to_label = json.load(fid)
|
||||
id_to_label = {int(key): value for key, value in id_to_label.items()}
|
||||
|
||||
if not args.classification_model:
|
||||
batch_size = args.batch_size
|
||||
val_batch_size = args.batch_size
|
||||
if args.hard_triplet_mining == "None":
|
||||
train_set = SLREmbeddingDataset(training_set_path, triplet=True, transform=transform, augmentations=True,
|
||||
augmentations_prob=args.augmentations_prob)
|
||||
collate_fn_train = collate_fn_triplet_padd
|
||||
elif args.hard_triplet_mining == "in_batch":
|
||||
train_set = SLREmbeddingDataset(training_set_path, triplet=False, transform=transform, augmentations=True,
|
||||
augmentations_prob=args.augmentations_prob)
|
||||
collate_fn_train = collate_fn_padd
|
||||
if is_pre_batch_sorting_enabled(args):
|
||||
batch_size *= args.hard_mining_pre_batch_multipler
|
||||
train_val_set = SLREmbeddingDataset(training_set_path, triplet=False)
|
||||
# Train dataloader for validation
|
||||
train_val_loader = build_data_loader(train_val_set, val_batch_size, False, collate_fn_padd, gen)
|
||||
else:
|
||||
train_set = CzechSLRDataset(training_set_path, transform=transform, augmentations=True)
|
||||
batch_size = 1
|
||||
val_batch_size = 1
|
||||
collate_fn_train = None
|
||||
|
||||
train_loader = build_data_loader(train_set, batch_size, True, collate_fn_train, gen)
|
||||
|
||||
# Validation set
|
||||
validation_set_path = op.join(dataset_folder, args.validation_set_path)
|
||||
|
||||
if args.classification_model:
|
||||
val_set = CzechSLRDataset(validation_set_path)
|
||||
collate_fn_val = None
|
||||
else:
|
||||
val_set = SLREmbeddingDataset(validation_set_path, triplet=False)
|
||||
collate_fn_val = collate_fn_padd
|
||||
|
||||
val_loader = build_data_loader(val_set, val_batch_size, False, collate_fn_val, gen)
|
||||
|
||||
# MARK: TRAINING
|
||||
train_acc, val_acc = 0, 0
|
||||
losses, train_accs, val_accs = [], [], []
|
||||
lr_progress = []
|
||||
top_val_acc = -999
|
||||
top_model_saved = True
|
||||
|
||||
logger.info("Starting " + args.experiment_name + "...\n\n")
|
||||
|
||||
if is_pre_batch_sorting_enabled(args):
|
||||
mini_batch_size = int(batch_size / args.hard_mining_pre_batch_multipler)
|
||||
else:
|
||||
mini_batch_size = None
|
||||
enable_batch_sorting = False
|
||||
pre_batch_mining_count = 1
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
start_time = datetime.now()
|
||||
if not args.classification_model:
|
||||
train_kwargs = {"model": slrt_model,
|
||||
"epoch_iters": args.epoch_iters,
|
||||
"train_loader": train_loader,
|
||||
"val_loader": val_loader,
|
||||
"criterion": cel_criterion,
|
||||
"optimizer": optimizer,
|
||||
"device": device,
|
||||
"scheduler": scheduler if epoch >= args.scheduler_warmup else None,
|
||||
}
|
||||
if args.hard_triplet_mining == "None":
|
||||
train_loss, val_silhouette_coef = train_epoch_embedding(**train_kwargs)
|
||||
elif args.hard_triplet_mining == "in_batch":
|
||||
if epoch == args.start_mining_hard:
|
||||
enable_batch_sorting = True
|
||||
pre_batch_mining_count = args.hard_mining_pre_batch_mining_count
|
||||
train_kwargs.update(dict(enable_batch_sorting=enable_batch_sorting,
|
||||
mini_batch_size=mini_batch_size,
|
||||
pre_batch_mining_count=pre_batch_mining_count,
|
||||
batching_scheduler=batching_scheduler if enable_batch_sorting else None))
|
||||
|
||||
train_loss, val_silhouette_coef, triplets_stats = train_epoch_embedding_online(**train_kwargs)
|
||||
|
||||
tracker.log_scalar_metric("triplets", "valid_triplets", epoch, triplets_stats["valid_triplets"])
|
||||
tracker.log_scalar_metric("triplets", "used_triplets", epoch, triplets_stats["used_triplets"])
|
||||
tracker.log_scalar_metric("triplets_pct", "pct_used", epoch, triplets_stats["pct_used"])
|
||||
tracker.log_scalar_metric("train_loss", "loss", epoch, train_loss)
|
||||
losses.append(train_loss)
|
||||
|
||||
# calculate acc on train dataset
|
||||
silhouette_coefficient_train = evaluate_embedding(slrt_model, train_val_loader, device)
|
||||
|
||||
tracker.log_scalar_metric("silhouette_coefficient", "train", epoch, silhouette_coefficient_train)
|
||||
train_accs.append(silhouette_coefficient_train)
|
||||
|
||||
val_accs.append(val_silhouette_coef)
|
||||
tracker.log_scalar_metric("silhouette_coefficient", "val", epoch, val_silhouette_coef)
|
||||
|
||||
else:
|
||||
train_loss, _, _, train_acc = train_epoch(slrt_model, train_loader, cel_criterion, optimizer, device)
|
||||
tracker.log_scalar_metric("train_loss", "loss", epoch, train_loss)
|
||||
tracker.log_scalar_metric("acc", "train", epoch, train_acc)
|
||||
losses.append(train_loss)
|
||||
train_accs.append(train_acc)
|
||||
|
||||
_, _, val_acc = evaluate(slrt_model, val_loader, device)
|
||||
val_accs.append(val_acc)
|
||||
tracker.log_scalar_metric("acc", "val", epoch, val_acc)
|
||||
|
||||
logger.info(f"Epoch time: {datetime.now() - start_time}")
|
||||
logger.info("[" + str(epoch) + "] TRAIN loss: " + str(train_loss) + " acc: " + str(train_accs[-1]))
|
||||
logger.info("[" + str(epoch) + "] VALIDATION acc: " + str(val_accs[-1]))
|
||||
|
||||
lr_progress.append(optimizer.param_groups[0]["lr"])
|
||||
tracker.log_scalar_metric("lr", "lr", epoch, lr_progress[-1])
|
||||
|
||||
if val_accs[-1] > top_val_acc:
|
||||
top_val_acc = val_accs[-1]
|
||||
top_model_name = "checkpoint_" + model_type + "_" + str(epoch) + ".pth"
|
||||
top_model_dict = {
|
||||
"name": top_model_name,
|
||||
"epoch": epoch,
|
||||
"val_acc": val_accs[-1],
|
||||
"config_args": args,
|
||||
"state_dict": copy.deepcopy(slrt_model.state_dict()),
|
||||
}
|
||||
top_model_saved = False
|
||||
|
||||
# Save checkpoint if it is the best on validation and delete previous checkpoints
|
||||
if args.save_checkpoints_every > 0 and epoch % args.save_checkpoints_every == 0 and not top_model_saved:
|
||||
torch.save(
|
||||
top_model_dict,
|
||||
"out-checkpoints/" + args.experiment_name + "/" + top_model_name
|
||||
)
|
||||
top_model_saved = True
|
||||
logger.info("Saved new best checkpoint: " + top_model_name)
|
||||
|
||||
# save top model if checkpoints are disabled
|
||||
if not top_model_saved:
|
||||
torch.save(
|
||||
top_model_dict,
|
||||
"out-checkpoints/" + args.experiment_name + "/" + top_model_name
|
||||
)
|
||||
logger.info("Saved new best checkpoint: " + top_model_name)
|
||||
|
||||
# Log scatter plots
|
||||
if not args.classification_model and args.hard_triplet_mining == "in_batch":
|
||||
logger.info("Generating Scatter Plot.")
|
||||
best_model = slrt_model
|
||||
best_model.load_state_dict(top_model_dict["state_dict"])
|
||||
create_embedding_scatter_plots(tracker, best_model, train_loader, val_loader, device, id_to_label, epoch,
|
||||
top_model_name)
|
||||
logger.info("The experiment is finished.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("", parents=[get_default_args()], add_help=False)
|
||||
args = parser.parse_args()
|
||||
tracker = get_tracker(args.tracker, PROJECT_NAME, args.experiment_name)
|
||||
train(args, tracker)
|
||||
tracker.finish_run()
|
||||
24
train.sh
Executable file
24
train.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/bin/sh
|
||||
python -m train \
|
||||
--save_checkpoints_every -1 \
|
||||
--experiment_name "augment_rotate_75_x8" \
|
||||
--epochs 10 \
|
||||
--optimizer "SGD" \
|
||||
--lr 0.001 \
|
||||
--batch_size 32 \
|
||||
--dataset_name "wlasl" \
|
||||
--training_set_path "WLASL100_train.csv" \
|
||||
--validation_set_path "WLASL100_test.csv" \
|
||||
--vector_length 32 \
|
||||
--epoch_iters -1 \
|
||||
--scheduler_factor 0 \
|
||||
--hard_triplet_mining "in_batch" \
|
||||
--filter_easy_triplets \
|
||||
--triplet_loss_margin 1 \
|
||||
--dropout 0.2 \
|
||||
--start_mining_hard=200 \
|
||||
--hard_mining_pre_batch_multipler=16 \
|
||||
--hard_mining_pre_batch_mining_count=5 \
|
||||
--augmentations_prob=0.75 \
|
||||
--hard_mining_scheduler_triplets_threshold=0 \
|
||||
# --normalize_embeddings \
|
||||
0
training/__init__.py
Normal file
0
training/__init__.py
Normal file
215
training/batch_sorter.py
Normal file
215
training/batch_sorter.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from .batching_scheduler import BatchingScheduler
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger("BatchGrouper")
|
||||
|
||||
|
||||
class BatchGrouper:
|
||||
"""
|
||||
Will cluster all `total_items` into `max_groups` clusters based on distances in
|
||||
`sorted_dists`. Each group has `mini_batch_size` elements and these elements are just integers in
|
||||
range 0...total_items.
|
||||
|
||||
Distances between these items are expected to be scaled to 0...1 in a way that distances for two items in the
|
||||
same class are higher if closer to 1, while distances between elements of different classes are higher if closer
|
||||
to 0.
|
||||
|
||||
The logic is picking the highest value distance and assigning both items to the same cluster/group if possible.
|
||||
This might include merging 2 clusters.
|
||||
There are a few threshold to limit the computational cost. If the scaled distance between a pair is below
|
||||
`dist_threshold`, or more than `assign_threshold` percent of items have been assigned to the groups, we stop and
|
||||
assign the remanining items to the groups that have space left.
|
||||
"""
|
||||
# Counters
|
||||
next_group = 0
|
||||
items_assigned = 0
|
||||
|
||||
# Thresholds
|
||||
dist_threshold = 0.5
|
||||
assign_threshold = 0.80
|
||||
|
||||
def __init__(self, sorted_dists, total_items, mini_batch_size=32, dist_threshold=0.5, assign_threshold=0.8) -> None:
|
||||
self.sorted_dists = sorted_dists
|
||||
self.total_items = total_items
|
||||
self.mini_batch_size = mini_batch_size
|
||||
self.max_groups = int(total_items / mini_batch_size)
|
||||
self.groups = {}
|
||||
self.item_to_group = {}
|
||||
self.items_assigned = 0
|
||||
self.next_group = 0
|
||||
self.dist_threshold = dist_threshold
|
||||
self.assign_threshold = assign_threshold
|
||||
|
||||
def cluster_items(self):
|
||||
"""Main function of this class. Does the clustering explained in class docstring.
|
||||
|
||||
:raises e: _description_
|
||||
:return _type_: _description_
|
||||
"""
|
||||
for i in range(self.sorted_dists.shape[-1]): # and some other conditions are unmet
|
||||
a, b, dist = self.sorted_dists[:, i]
|
||||
a, b = int(a), int(b)
|
||||
if dist < self.dist_threshold or self.items_assigned > self.total_items * self.assign_threshold:
|
||||
logger.info(f"Breaking with dist: {dist}, and {self.items_assigned} items assigned")
|
||||
break
|
||||
if a not in self.item_to_group and b not in self.item_to_group:
|
||||
g = self.create_or_get_group()
|
||||
self.assign_group(a, g)
|
||||
self.assign_group(b, g)
|
||||
elif a not in self.item_to_group:
|
||||
if not self.group_is_full(self.item_to_group[b]):
|
||||
self.assign_group(a, self.item_to_group[b])
|
||||
elif b not in self.item_to_group:
|
||||
if not self.group_is_full(self.item_to_group[a]):
|
||||
self.assign_group(b, self.item_to_group[a])
|
||||
else:
|
||||
grp_a = self.item_to_group[a]
|
||||
grp_b = self.item_to_group[b]
|
||||
self.merge_groups(grp_a, grp_b)
|
||||
self.assign_remaining_items()
|
||||
return list(np.concatenate(list(self.groups.values())).flat)
|
||||
|
||||
def assign_group(self, item, group):
|
||||
"""Assigns `item` to group `group`
|
||||
"""
|
||||
self.item_to_group[item] = group
|
||||
self.groups[group].append(item)
|
||||
self.items_assigned += 1
|
||||
|
||||
def create_or_get_group(self):
|
||||
"""Creates a new group if current group count is less than max_groups.
|
||||
Otherwise returns first group with space left.
|
||||
|
||||
:return int: The group id
|
||||
"""
|
||||
if self.next_group < self.max_groups:
|
||||
group = self.next_group
|
||||
self.groups[group] = []
|
||||
self.next_group += 1
|
||||
else:
|
||||
for i in range(self.next_group):
|
||||
if len(self.groups[i]) <= self.mini_batch_size - 2:
|
||||
group = i
|
||||
break # out of the for loop
|
||||
return group
|
||||
|
||||
def group_is_full(self, group):
|
||||
return len(self.groups[group]) == self.mini_batch_size
|
||||
|
||||
def can_merge_groups(self, grp_a, grp_b):
|
||||
return grp_a != grp_b and (len(self.groups[grp_a]) + len(self.groups[grp_b]) < self.mini_batch_size)
|
||||
|
||||
def merge_groups(self, grp_a, grp_b):
|
||||
"""Will merge two groups together, if possible. Otherwise does nothing.
|
||||
"""
|
||||
if grp_a > grp_b:
|
||||
grp_a, grp_b = grp_b, grp_a
|
||||
if self.can_merge_groups(grp_a, grp_b):
|
||||
logger.debug(f"MERGE {grp_a} with {grp_b}: {len(self.groups[grp_a])} {len(self.groups[grp_b])}")
|
||||
for b in self.groups[grp_b]:
|
||||
self.item_to_group[b] = grp_a
|
||||
self.groups[grp_a].extend(self.groups[grp_b])
|
||||
self.groups[grp_b] = []
|
||||
self.replace_group(grp_b)
|
||||
|
||||
def replace_group(self, group):
|
||||
"""Replace a group with the last one in the list
|
||||
|
||||
:param int group: Group to replace
|
||||
"""
|
||||
grp_to_change = self.next_group - 1
|
||||
if grp_to_change != group:
|
||||
for item in self.groups[grp_to_change]:
|
||||
self.item_to_group[item] = group
|
||||
self.groups[group] = self.groups[grp_to_change]
|
||||
del self.groups[grp_to_change]
|
||||
self.next_group -= 1
|
||||
|
||||
def assign_remaining_items(self):
|
||||
""" Assign remaining items into groups
|
||||
"""
|
||||
grp_pointer = 0
|
||||
i = 0
|
||||
logger.info(f"Assigning rest of items: {self.items_assigned} of {self.total_items}")
|
||||
while i < self.total_items:
|
||||
if i not in self.item_to_group:
|
||||
if grp_pointer not in self.groups:
|
||||
# This would happen if a group is still empty at this stage
|
||||
assert grp_pointer < self.max_groups
|
||||
new_group = self.create_or_get_group()
|
||||
assert new_group == grp_pointer
|
||||
if len(self.groups[grp_pointer]) < self.mini_batch_size:
|
||||
self.assign_group(i, grp_pointer)
|
||||
i += 1
|
||||
else:
|
||||
grp_pointer += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
|
||||
def get_dist_tuple_list(dist_matrix):
|
||||
batch_size = dist_matrix.size()[0]
|
||||
indices = torch.tril_indices(batch_size, batch_size, offset=-1)
|
||||
values = dist_matrix[indices[0], indices[1]].cpu()
|
||||
return torch.cat([indices, values.unsqueeze(0)], dim=0)
|
||||
|
||||
|
||||
def get_scaled_distances(embeddings, labels, device, same_label_factor=1):
|
||||
"""Returns distance matrix between all embeddings scaled to the 0-1 range where 0 is good and 1 is bad.
|
||||
This means that small distances for embeddings of the same class will be close to 0 while small distances for
|
||||
embeddings of different classes will be close to 1
|
||||
|
||||
:param _type_ embeddings: Embeddings of batch items
|
||||
:param _type_ labels: Labels associated to the embeddings
|
||||
:param _type_ device: Device to run on (cuda or cpu)
|
||||
:param int same_label_factor: Multiplies the weight of same-class distances allowing to give more or less importance
|
||||
to these compared to distinct-class distances, defaults to 1 (which means equal weight)
|
||||
:return torch.Tensor: Scaled distance matrix
|
||||
"""
|
||||
# Get pairwise distance matrix
|
||||
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
|
||||
# Get list of tuples with emb_A, emb_B, dist ordered by greater for same label and smaller for diff label
|
||||
# shape: (batch_size, batch_size)
|
||||
labels = labels.to(device)
|
||||
labels_equal = (labels.unsqueeze(0) == labels.unsqueeze(1)).squeeze()
|
||||
labels_distinct = torch.logical_not(labels_equal)
|
||||
pos_dist = distance_matrix * labels_equal
|
||||
neg_dist = distance_matrix * labels_distinct
|
||||
|
||||
# Use some scaling to bring both to a range of 0-1
|
||||
pos_max = pos_dist.max()
|
||||
neg_max = neg_dist.max()
|
||||
# Closer to 1 is harder
|
||||
pos_dist = pos_dist / pos_max * same_label_factor
|
||||
neg_dist = 1 * labels_distinct - (neg_dist / neg_max)
|
||||
return pos_dist + neg_dist
|
||||
|
||||
|
||||
def sort_batches(inputs, labels, masks, embeddings, device, mini_batch_size=32,
|
||||
scheduler: Optional[BatchingScheduler] = None):
|
||||
start = datetime.now()
|
||||
|
||||
same_label_factor = scheduler.get_scaling_same_label_factor() if scheduler else 1
|
||||
scaled_dist = get_scaled_distances(embeddings, labels, device, same_label_factor)
|
||||
# Get vector of (row, column, dist)
|
||||
dist_list = get_dist_tuple_list(scaled_dist)
|
||||
|
||||
dist_list = dist_list.cpu().detach().numpy()
|
||||
# Sort distances descending by last row
|
||||
sorted_dists = dist_list[:, dist_list[-1, :].argsort()[::-1]]
|
||||
|
||||
# Loop through list assigning both items to same group
|
||||
dist_threshold = scheduler.get_dist_threshold() if scheduler else 0.5
|
||||
grouper = BatchGrouper(sorted_dists, total_items=labels.size()[0], mini_batch_size=mini_batch_size,
|
||||
dist_threshold=dist_threshold)
|
||||
indices = torch.tensor(grouper.cluster_items()).type(torch.IntTensor)
|
||||
final_inputs = torch.index_select(inputs, dim=0, index=indices)
|
||||
final_labels = torch.index_select(labels, dim=0, index=indices)
|
||||
final_masks = torch.index_select(masks, dim=0, index=indices)
|
||||
|
||||
logger.info(f"Batch sorting took: {datetime.now() - start}")
|
||||
return final_inputs, final_labels, final_masks
|
||||
62
training/batching_scheduler.py
Normal file
62
training/batching_scheduler.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BatchingScheduler():
|
||||
""" This class acts as scheduler for the batching algorithm
|
||||
"""
|
||||
|
||||
def __init__(self, decay_factor=0.8, min_threshold=0.2, triplets_threshold=10, cooldown=10) -> None:
|
||||
# internal vars
|
||||
self._step_count = 0
|
||||
self._dist_threshold = 0.5
|
||||
self._last_used_triplets = deque([], 5)
|
||||
self._scaling_same_label_factor = 1
|
||||
self._last_update_step = -10
|
||||
|
||||
# Parameters
|
||||
self.decay_factor = decay_factor
|
||||
self.min_threshold = min_threshold
|
||||
self.triplets_threshold = triplets_threshold
|
||||
self.cooldown = cooldown
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a :class:`dict`.
|
||||
"""
|
||||
return {key: value for key, value in self.__dict__.items()}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the schedulers state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): scheduler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def step(self, used_triplets):
|
||||
self._step_count += 1
|
||||
self._last_used_triplets.append(used_triplets)
|
||||
if (np.mean(self._last_used_triplets) < self.triplets_threshold and
|
||||
self._last_update_step + self.cooldown <= self._step_count):
|
||||
if self._dist_threshold > self.min_threshold:
|
||||
print(f"Updating dist_threshold at {self._step_count} ({np.mean(self._last_used_triplets)})")
|
||||
self.update_dist_threshold()
|
||||
if self._scaling_same_label_factor > 0.6:
|
||||
print(f"Updating scale factor at {self._step_count} ({np.mean(self._last_used_triplets)})")
|
||||
self.update_scale_factor()
|
||||
self._last_update_step = self._step_count
|
||||
|
||||
def update_scale_factor(self):
|
||||
self._scaling_same_label_factor = max(self._scaling_same_label_factor * 0.9, 0.6)
|
||||
print(f"Updating scaling factor to {self._scaling_same_label_factor}")
|
||||
|
||||
def update_dist_threshold(self):
|
||||
self._dist_threshold = max(self.min_threshold, self._dist_threshold * self.decay_factor)
|
||||
print(f"Updated dist_threshold to {self._dist_threshold}")
|
||||
|
||||
def get_dist_threshold(self) -> float:
|
||||
return self._dist_threshold
|
||||
|
||||
def get_scaling_same_label_factor(self) -> float:
|
||||
return self._scaling_same_label_factor
|
||||
18
training/gaussian_noise.py
Normal file
18
training/gaussian_noise.py
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class GaussianNoise(object):
|
||||
def __init__(self, mean=0., std=1.):
|
||||
self.std = std
|
||||
self.mean = mean
|
||||
|
||||
def __call__(self, tensor):
|
||||
return tensor + torch.randn(tensor.size()) * self.std + self.mean
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
105
training/online_batch_mining.py
Normal file
105
training/online_batch_mining.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks
|
||||
|
||||
# Adapted from https://qdrant.tech/articles/triplet-loss/
|
||||
|
||||
|
||||
class BatchAllTripletLoss(nn.Module):
|
||||
"""Uses all valid triplets to compute Triplet loss
|
||||
Args:
|
||||
margin: Margin value in the Triplet Loss equation
|
||||
"""
|
||||
|
||||
def __init__(self, device, margin=1., filter_easy_triplets=True):
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
self.device = device
|
||||
self.filter_easy_triplets = filter_easy_triplets
|
||||
|
||||
def get_triplet_mask(self, labels):
|
||||
"""compute a mask for valid triplets
|
||||
Args:
|
||||
labels: Batch of integer labels. shape: (batch_size,)
|
||||
Returns:
|
||||
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
|
||||
A triplet is valid if:
|
||||
`labels[i] == labels[j] and labels[i] != labels[k]`
|
||||
and `i`, `j`, `k` are different.
|
||||
"""
|
||||
# step 1 - get a mask for distinct indices
|
||||
|
||||
# shape: (batch_size, batch_size)
|
||||
indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
|
||||
indices_not_equal = torch.logical_not(indices_equal)
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||
# shape: (1, batch_size, batch_size)
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||
# Shape: (batch_size, batch_size, batch_size)
|
||||
distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
|
||||
|
||||
# step 2 - get a mask for valid anchor-positive-negative triplets
|
||||
|
||||
# shape: (batch_size, batch_size)
|
||||
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
i_equal_j = labels_equal.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
i_equal_k = labels_equal.unsqueeze(1)
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))
|
||||
|
||||
# step 3 - combine two masks
|
||||
mask = torch.logical_and(distinct_indices, valid_indices)
|
||||
|
||||
return mask
|
||||
|
||||
def forward(self, embeddings, labels, filter_easy_triplets=True):
|
||||
"""computes loss value.
|
||||
Args:
|
||||
embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
|
||||
labels: Batch of integer labels associated with embeddings. shape: (batch_size,)
|
||||
Returns:
|
||||
Scalar loss value.
|
||||
"""
|
||||
# step 1 - get distance matrix
|
||||
# shape: (batch_size, batch_size)
|
||||
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
|
||||
|
||||
# step 2 - compute loss values for all triplets by applying broadcasting to distance matrix
|
||||
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
anchor_positive_dists = distance_matrix.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
anchor_negative_dists = distance_matrix.unsqueeze(1)
|
||||
# get loss values for all possible n^3 triplets
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin
|
||||
|
||||
# step 3 - filter out invalid or easy triplets by setting their loss values to 0
|
||||
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
mask = self.get_triplet_mask(labels)
|
||||
valid_triplets = mask.sum()
|
||||
triplet_loss *= mask.to(self.device)
|
||||
# easy triplets have negative loss values
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
if self.filter_easy_triplets:
|
||||
# step 4 - compute scalar loss value by averaging positive losses
|
||||
num_positive_losses = (triplet_loss > eps).float().sum()
|
||||
# We want to factor in how many triplets were used compared to batch_size (used_triplets * 3 / batch_size)
|
||||
# The effect of this should be similar to LR decay but penalizing batches with fewer hard triplets
|
||||
percent_used_factor = min(1.0, num_positive_losses * 3 / labels.size()[0])
|
||||
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_losses + eps) * percent_used_factor
|
||||
return triplet_loss, valid_triplets, int(num_positive_losses)
|
||||
else:
|
||||
triplet_loss = triplet_loss.sum() / (valid_triplets + eps)
|
||||
return triplet_loss, valid_triplets, valid_triplets
|
||||
84
training/train_arguments.py
Normal file
84
training/train_arguments.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def get_default_args():
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
|
||||
parser.add_argument("--experiment_name", type=str, default="lsa_64_spoter",
|
||||
help="Name of the experiment after which the logs and plots will be named")
|
||||
parser.add_argument("--num_classes", type=int, default=100, help="Number of classes to be recognized by the model")
|
||||
parser.add_argument("--hidden_dim", type=int, default=108,
|
||||
help="Hidden dimension of the underlying Transformer model")
|
||||
parser.add_argument("--seed", type=int, default=379,
|
||||
help="Seed with which to initialize all the random components of the training")
|
||||
|
||||
# Embeddings
|
||||
parser.add_argument("--classification_model", action='store_true', default=False,
|
||||
help="Select SPOTER model to train, pass only for original classification model")
|
||||
parser.add_argument("--vector_length", type=int, default=32,
|
||||
help="Number of features used in the embedding vector")
|
||||
parser.add_argument("--epoch_iters", type=int, default=-1,
|
||||
help="Iterations per epoch while training embeddings. Will loop through dataset once if -1")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch Size during training and validation")
|
||||
parser.add_argument("--hard_triplet_mining", type=str, default=None,
|
||||
help="Strategy to select hard triplets, options [None, in_batch]")
|
||||
parser.add_argument("--triplet_loss_margin", type=float, default=1,
|
||||
help="Margin used in triplet loss margin (See documentation)")
|
||||
parser.add_argument("--normalize_embeddings", action='store_true', default=False,
|
||||
help="Normalize model output to keep vector length to one")
|
||||
parser.add_argument("--filter_easy_triplets", action='store_true', default=False,
|
||||
help="Filter easy triplets in online in batch triplets")
|
||||
|
||||
# Data
|
||||
parser.add_argument("--dataset_name", type=str, default="", help="Dataset name")
|
||||
parser.add_argument("--dataset_project", type=str, default="Sign Language Recognition", help="Dataset project name")
|
||||
parser.add_argument("--training_set_path", type=str, default="",
|
||||
help="Path to the training dataset CSV file (relative to root dataset)")
|
||||
parser.add_argument("--validation_set_path", type=str, default="", help="Path to the validation dataset CSV file")
|
||||
parser.add_argument("--dataset_loader", type=str, default="local",
|
||||
help="Dataset loader to use, options: [clearml, local]")
|
||||
|
||||
# Training hyperparameters
|
||||
parser.add_argument("--epochs", type=int, default=1300, help="Number of epochs to train the model for")
|
||||
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the model training")
|
||||
parser.add_argument("--dropout", type=float, default=0.1,
|
||||
help="Dropout used in transformer layer")
|
||||
parser.add_argument("--augmentations_prob", type=float, default=0.5, help="How often to use data augmentation")
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--save_checkpoints_every", type=int, default=-1,
|
||||
help="Determines every how many epochs the weight checkpoints are saved. If -1 only best model \
|
||||
after final epoch")
|
||||
|
||||
# Optimizer
|
||||
parser.add_argument("--optimizer", type=str, default="SGD",
|
||||
help="Optimizer used during training, options: [SGD, ADAM]")
|
||||
|
||||
# Tracker
|
||||
parser.add_argument("--tracker", type=str, default="none",
|
||||
help="Experiment tracker to use, options: [clearml, none]")
|
||||
|
||||
# Scheduler
|
||||
parser.add_argument("--scheduler_factor", type=float, default=0,
|
||||
help="Factor for the ReduceLROnPlateau scheduler")
|
||||
parser.add_argument("--scheduler_patience", type=int, default=10,
|
||||
help="Patience for the ReduceLROnPlateau scheduler")
|
||||
parser.add_argument("--scheduler_warmup", type=int, default=400,
|
||||
help="Warmup epochs before scheduler starts")
|
||||
|
||||
# Gaussian noise normalization
|
||||
parser.add_argument("--gaussian_mean", type=float, default=0, help="Mean parameter for Gaussian noise layer")
|
||||
parser.add_argument("--gaussian_std", type=float, default=0.001,
|
||||
help="Standard deviation parameter for Gaussian noise layer")
|
||||
|
||||
# Batch Sorting
|
||||
parser.add_argument("--start_mining_hard", type=int, default=None, help="On which epoch to start hard mining")
|
||||
parser.add_argument("--hard_mining_pre_batch_multipler", type=int, default=16,
|
||||
help="How many batches should be computed at once")
|
||||
parser.add_argument("--hard_mining_pre_batch_mining_count", type=int, default=5,
|
||||
help="How many times to loop through a list of computed batches")
|
||||
parser.add_argument("--hard_mining_scheduler_triplets_threshold", type=float, default=0,
|
||||
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
|
||||
distance threshold of the batch sorter")
|
||||
|
||||
return parser
|
||||
71
training/train_utils.py
Normal file
71
training/train_utils.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import torch
|
||||
|
||||
from models import embeddings_scatter_plot, embeddings_scatter_plot_splits
|
||||
|
||||
|
||||
def train_setup(seed, experiment_name):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
g = torch.Generator()
|
||||
g.manual_seed(seed)
|
||||
return g
|
||||
|
||||
|
||||
def create_embedding_scatter_plots(tracker, model, train_loader, val_loader, device, id_to_label, epoch, model_name):
|
||||
tsne_results, labels = embeddings_scatter_plot(model, train_loader, device, id_to_label, perplexity=40, n_iter=1000)
|
||||
|
||||
df = pd.DataFrame({'x': tsne_results[:, 0],
|
||||
'y': tsne_results[:, 1],
|
||||
'label': labels})
|
||||
fig = px.scatter(df, y="y", x="x", color="label")
|
||||
|
||||
tracker.log_chart(
|
||||
title="Training Scatter Plot with Best Model: " + model_name,
|
||||
series="Scatter Plot",
|
||||
iteration=epoch,
|
||||
figure=fig
|
||||
)
|
||||
|
||||
tsne_results, labels = embeddings_scatter_plot(model, val_loader, device, id_to_label, perplexity=40, n_iter=1000)
|
||||
|
||||
df = pd.DataFrame({'x': tsne_results[:, 0],
|
||||
'y': tsne_results[:, 1],
|
||||
'label': labels})
|
||||
fig = px.scatter(df, y="y", x="x", color="label")
|
||||
|
||||
tracker.log_chart(
|
||||
title="Validation Scatter Plot with Best Model: " + model_name,
|
||||
series="Scatter Plot",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
|
||||
dataloaders = {'train': train_loader,
|
||||
'val': val_loader}
|
||||
splits = list(dataloaders.keys())
|
||||
tsne_results_splits, labels_splits = embeddings_scatter_plot_splits(model, dataloaders,
|
||||
device, id_to_label, perplexity=40, n_iter=1000)
|
||||
tsne_results = np.vstack([tsne_results_splits[split] for split in splits])
|
||||
labels = np.concatenate([labels_splits[split] for split in splits])
|
||||
split = np.concatenate([[split]*len(labels_splits[split]) for split in splits])
|
||||
df = pd.DataFrame({'x': tsne_results[:, 0],
|
||||
'y': tsne_results[:, 1],
|
||||
'label': labels,
|
||||
'split': split})
|
||||
fig = px.scatter(df, y="y", x="x", color="label", symbol='split')
|
||||
tracker.log_chart(
|
||||
title="Scatter Plot of train and val with Best Model: " + model_name,
|
||||
series="Scatter Plot",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
40
utils.py
Normal file
40
utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
custom_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: grey + custom_format + reset,
|
||||
logging.INFO: grey + custom_format + reset,
|
||||
logging.WARNING: yellow + custom_format + reset,
|
||||
logging.ERROR: red + custom_format + reset,
|
||||
logging.CRITICAL: bold_red + custom_format + reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
# create console handler with a higher log level
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
ch.setFormatter(CustomFormatter())
|
||||
file_handler = logging.FileHandler(os.getenv('EXPERIMENT_NAME', 'run') + ".log")
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(CustomFormatter())
|
||||
logger.addHandler(ch)
|
||||
logger.addHandler(file_handler)
|
||||
return logger
|
||||
8
web/README.md
Normal file
8
web/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# SPOTER Web
|
||||
|
||||
To test Spoter model in the web, follow these steps:
|
||||
* Convert your latest Pytorch model to Onnx by running `python convert.py`. This is best done inside the Docker container. You will need to install additional dependencies for the conversions (see commented lines in requirements.txt)
|
||||
* The ONNX should be generated in the `web` folder, otherwise copy it there.
|
||||
* run `npx light-server -s . -p 8080` in the `web` folder. (`npx` comes with `npm`)
|
||||
|
||||
Enjoy!
|
||||
61
web/index.html
Normal file
61
web/index.html
Normal file
@@ -0,0 +1,61 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<header>
|
||||
<title>ONNX Runtime JavaScript examples: Quick Start - Web (using script tag)</title>
|
||||
</header>
|
||||
<body>
|
||||
<button id="start-test">Start Test</button>
|
||||
<p id="output"></p>
|
||||
<!-- import ONNXRuntime Web from CDN -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
|
||||
<script>
|
||||
async function setupButtons() {
|
||||
let test_button = document.querySelector("#start-test");
|
||||
|
||||
test_button.addEventListener('click', async function() {
|
||||
main();
|
||||
});
|
||||
}
|
||||
// use an async context to call onnxruntime functions.
|
||||
async function main() {
|
||||
try {
|
||||
// create a new session and load the specific model.
|
||||
//
|
||||
// the model in this example contains a single MatMul node
|
||||
// it has 2 inputs: 'a'(float32, 3x4) and 'b'(float32, 4x3)
|
||||
// it has 1 output: 'c'(float32, 3x3)
|
||||
const session = await ort.InferenceSession.create('./spoter.onnx');
|
||||
|
||||
// Number of frames
|
||||
const N = 100
|
||||
|
||||
// prepare inputs. a tensor need its corresponding TypedArray as data
|
||||
const dataA = new Float32Array(108 * N);
|
||||
dataA.fill(0.4);
|
||||
// const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]);
|
||||
const tensorA = new ort.Tensor('float32', dataA, [1, N, 54, 2]);
|
||||
console.log(tensorA);
|
||||
|
||||
// prepare feeds. use model input names as keys.
|
||||
const feeds = { input: tensorA };
|
||||
|
||||
// feed inputs and run
|
||||
startTime = new Date();
|
||||
const results = await session.run(feeds);
|
||||
// read from results
|
||||
const dataC = results.output.data;
|
||||
endTime = new Date();
|
||||
let output = document.querySelector("#output");
|
||||
|
||||
var timeDiff = endTime - startTime; //in ms
|
||||
output.innerText = `Data of result tensor 'output':\n ${dataC}` + "\nInference took " + timeDiff + " ms";
|
||||
|
||||
} catch (e) {
|
||||
let output = document.querySelector("#output");
|
||||
output.innerText = `failed to inference ONNX model: ${e}.`;
|
||||
}
|
||||
}
|
||||
setupButtons();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user