diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..c32d95e --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] +max-line-length = 130 +per-file-ignores = + __init__.py: F401 +exclude = + .git,__pycache__, diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..93e7add --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM pytorch/pytorch + +WORKDIR /app +COPY ./requirements.txt /app/ + +RUN pip install -r requirements.txt +RUN apt-get -y update +RUN apt-get -y install git +RUN apt-get install ffmpeg libsm6 libxext6 -y + +COPY . /app/ +RUN git config --global --add safe.directory /app +CMD ./train.sh diff --git a/README.md b/README.md new file mode 100644 index 0000000..f3e99e6 --- /dev/null +++ b/README.md @@ -0,0 +1,137 @@ + + + +# SPOTER Embeddings + +This repository contains code for the Spoter embedding model. + +The model is heavily based on [Spoter] which was presented in +[Sign Pose-Based Transformer for Word-Level Sign Language Recognition](https://openaccess.thecvf.com/content/WACV2022W/HADCV/html/Bohacek_Sign_Pose-Based_Transformer_for_Word-Level_Sign_Language_Recognition_WACVW_2022_paper.html) with one of the main modifications being +that this is an embedding model instead of a classification model. +This allows for several zero-shot tasks on unseen Sign Language datasets from around the world. + + +## Modifications on [SPOTER](https://github.com/matyasbohacek/spoter) +Here is a list of the main modifications made on Spoter code and model architecture: + +* The output layer is a linear layer but trained using triplet loss instead of CrossEntropyLoss. The output of the model +is therefore an embedding vector that can be used for several downstream tasks. +* We started using the keypoints dataset published by Spoter but later created new datasets using BlazePose from Mediapipe (as it is done in [Spoter 2](https://arxiv.org/abs/2210.00893)). This improves results considerably. +* We select batches in a way that they contain several hard triplets and then compute the loss on all hard triplets found in each batch. +* Some code refactoring to acomodate new classes we implemented. +* Minor code fix when using rotate augmentation to avoid exceptions. + + + + +## Results + +![Scatter plot of dataset embeddings](/assets/scatter_plot.png) + +We used the silhouette score to measure how well the clusters are defined during the training step. +Silhouette score will be high (close to 1) when all clusters of different classes are well separated from each other, and it will be low (close to -1) for the opposite. +Our best model reached 0.7 on the train set and 0.1 on validation. + +### Classification accuracy +While the model was not trained with classification specifically in mind, it can still be used for that purpose. +Here we show top-1 and top-5 classifications which are calculated by taking the 1 (or 5) nearest vector of different classes, to the target vector. + +To estimate the accuracy for LSA, we take a “train” set as given and then classify the holdout set based on the closest vectors from the “train” set. +This is done using the model trained on WLASL100 dataset only, to show how our model has zero-shot capabilities. + +![Accuracy table](/assets/accuracy.png) + + + + +## Get Started + +The recommended way of running code from this repo is by using **Docker**. + +Clone this repository and run: +``` +docker build -t spoter_embeddings . +docker run --rm -it --entrypoint=bash --gpus=all -v $PWD:/app spoter_embeddings +``` + +> Running without specifying the `entrypoint` will train the model with the hyperparameters specified in `train.sh` + +If you prefer running in a **virtual environment** instead, then first install dependencies: + +```shell +pip install -r requirements.txt +``` + +> We tested this using Python 3.7.13. Other versions may work. + +To train the model, run `train.sh` in Docker or your virtual env. + +The hyperparameters with their descriptions can be found in the [train.py](link...) file. + + +## Data + +Same as with SPOTER, this model works on top of sequences of signers' skeletal data extracted from videos. +This means that the input data has a much lower dimension compared to using videos directly, and therefore the model is +quicker and lighter, while you can choose any SOTA body pose model to preprocess video. +This makes our model lightweight and able to run in real-time (for example, it takes around 40ms to process a 4-second +25 FPS video inside a web browser using onnxruntime) + +![Sign Language Dataset Overview](http://spoter.signlanguagerecognition.com/img/datasets_overview.gif) + +For ready to use datasets refer to the [Spoter] repository. + +For best results, we recommend building your own dataset by downloading a Sign language video dataset such as [WLASL] and then using the `extract_mediapipe_landmarks.py` and `create_wlasl_landmarks_dataset.py` scripts to create a body keypoints datasets that can be used to train the Spoter embeddings model. + +You can run these scripts as follows: +```bash +# This will extract landmarks from the downloaded videos +python3 preprocessing.py extract -videos --output-landmarks + +# This will create a dataset (csv file) with the first 100 classes, splitting 20% of it to the test set, and 80% for train +python3 preprocessing.py create -videos -lmks --dataset-folder= --create-new-split -ts=0.2 +``` + + +## Example notebooks +There are two Jupyter notebooks included in the `notebooks` folder. +* embeddings_evaluation.ipynb: This notebook shows how to evaluate a model +* visualize_embeddings.ipynb: Model embeddings visualization, optionally with embedded input video + + +## Tracking experiments with ClearML +The code supports tracking experiments, datasets, and models in a ClearML server. +If you want to do this make sure to pass the following arguments to train.py: + +``` + --dataset_loader=clearml + --tracker=clearml +``` + +Also make sure to correctly configure your clearml.conf file. +If using Docker, you can map it into Docker adding these volumes when running `docker run`: + +``` +-v $HOME/clearml.conf:/root/clearml.conf -v $HOME/.clearml:/root/.clearml +``` + +## Model conversion + +Follow these steps to convert your model to ONNX, TF or TFlite: +* Install the additional dependencies listed in `conversion_requirements.txt`. This is best done inside the Docker container. +* Run `python convert.py -c `. Add `-tf` if you want to export TensorFlow and TFlite models too. +* The output models should be generated in a folder named `converted_models`. + +> You can test your model's performance in a web browser. Check out the README in the [web](/web/) folder. + + +## License + +The **code** is published under the [Apache License 2.0](./LICENSE) which allows for both academic and commercial use if +relevant License and copyright notice is included, our work is cited and all changes are stated. + +The license for the [WLASL](https://arxiv.org/pdf/1910.11006.pdf) and [LSA64](https://core.ac.uk/download/pdf/76495887.pdf) datasets used for experiments is, however, the [Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)](https://creativecommons.org/licenses/by-nc/4.0/) license which allows only for non-commercial usage. + + +[Spoter]: (https://github.com/matyasbohacek/spoter) +[WLASL]: (https://dxli94.github.io/WLASL/) \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/assets/accuracy.png b/assets/accuracy.png new file mode 100644 index 0000000..18095a9 Binary files /dev/null and b/assets/accuracy.png differ diff --git a/assets/banner.png b/assets/banner.png new file mode 100644 index 0000000..641f630 Binary files /dev/null and b/assets/banner.png differ diff --git a/assets/scatter_plot.png b/assets/scatter_plot.png new file mode 100644 index 0000000..9358386 Binary files /dev/null and b/assets/scatter_plot.png differ diff --git a/augmentations/__init__.py b/augmentations/__init__.py new file mode 100644 index 0000000..4130dfc --- /dev/null +++ b/augmentations/__init__.py @@ -0,0 +1 @@ +from .augment import augment_arm_joint_rotate, augment_rotate, augment_shear diff --git a/augmentations/augment.py b/augmentations/augment.py new file mode 100644 index 0000000..98e6a0c --- /dev/null +++ b/augmentations/augment.py @@ -0,0 +1,228 @@ + +import math +import logging +import cv2 +import random + +import numpy as np + +from normalization.body_normalization import BODY_IDENTIFIERS +from normalization.hand_normalization import HAND_IDENTIFIERS + + +HAND_IDENTIFIERS = [id + "_0" for id in HAND_IDENTIFIERS] + [id + "_1" for id in HAND_IDENTIFIERS] +ARM_IDENTIFIERS_ORDER = ["neck", "$side$Shoulder", "$side$Elbow", "$side$Wrist"] + + +def __random_pass(prob): + return random.random() < prob + + +def __numpy_to_dictionary(data_array: np.ndarray) -> dict: + """ + Supplementary method converting a NumPy array of body landmark data into dictionaries. The array data must match the + order of the BODY_IDENTIFIERS list. + """ + + output = {} + + for landmark_index, identifier in enumerate(BODY_IDENTIFIERS): + output[identifier] = data_array[:, landmark_index].tolist() + + return output + + +def __dictionary_to_numpy(landmarks_dict: dict) -> np.ndarray: + """ + Supplementary method converting dictionaries of body landmark data into respective NumPy arrays. The resulting array + will match the order of the BODY_IDENTIFIERS list. + """ + + output = np.empty(shape=(len(landmarks_dict["leftEar"]), len(BODY_IDENTIFIERS), 2)) + + for landmark_index, identifier in enumerate(BODY_IDENTIFIERS): + output[:, landmark_index, 0] = np.array(landmarks_dict[identifier])[:, 0] + output[:, landmark_index, 1] = np.array(landmarks_dict[identifier])[:, 1] + + return output + + +def __rotate(origin: tuple, point: tuple, angle: float): + """ + Rotates a point counterclockwise by a given angle around a given origin. + + :param origin: Landmark in the (X, Y) format of the origin from which to count angle of rotation + :param point: Landmark in the (X, Y) format to be rotated + :param angle: Angle under which the point shall be rotated + :return: New landmarks (coordinates) + """ + + ox, oy = origin + px, py = point + + qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy) + qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy) + + return qx, qy + + +def __preprocess_row_sign(sign: dict) -> (dict, dict): + """ + Supplementary method splitting the single-dictionary skeletal data into two dictionaries of body and hand landmarks + respectively. + """ + + sign_eval = sign + + if "nose_X" in sign_eval: + body_landmarks = {identifier: [(x, y) for x, y in zip(sign_eval[identifier + "_X"], sign_eval[identifier + "_Y"])] + for identifier in BODY_IDENTIFIERS} + hand_landmarks = {identifier: [(x, y) for x, y in zip(sign_eval[identifier + "_X"], sign_eval[identifier + "_Y"])] + for identifier in HAND_IDENTIFIERS} + + else: + body_landmarks = {identifier: sign_eval[identifier] for identifier in BODY_IDENTIFIERS} + hand_landmarks = {identifier: sign_eval[identifier] for identifier in HAND_IDENTIFIERS} + + return body_landmarks, hand_landmarks + + +def __wrap_sign_into_row(body_identifiers: dict, hand_identifiers: dict) -> dict: + """ + Supplementary method for merging body and hand data into a single dictionary. + """ + + return {**body_identifiers, **hand_identifiers} + + +def augment_rotate(sign: dict, angle_range: tuple) -> dict: + """ + AUGMENTATION TECHNIQUE. All the joint coordinates in each frame are rotated by a random angle up to 13 degrees with + the center of rotation lying in the center of the frame, which is equal to [0.5; 0.5]. + + :param sign: Dictionary with sequential skeletal data of the signing person + :param angle_range: Tuple containing the angle range (minimal and maximal angle in degrees) to randomly choose the + angle by which the landmarks will be rotated from + + :return: Dictionary with augmented (by rotation) sequential skeletal data of the signing person + """ + + body_landmarks, hand_landmarks = __preprocess_row_sign(sign) + angle = math.radians(random.uniform(*angle_range)) + + body_landmarks = {key: [__rotate((0.5, 0.5), frame, angle) for frame in value] for key, value in + body_landmarks.items()} + hand_landmarks = {key: [__rotate((0.5, 0.5), frame, angle) for frame in value] for key, value in + hand_landmarks.items()} + + return __wrap_sign_into_row(body_landmarks, hand_landmarks) + + +def augment_shear(sign: dict, type: str, squeeze_ratio: tuple) -> dict: + """ + AUGMENTATION TECHNIQUE. + + - Squeeze. All the frames are squeezed from both horizontal sides. Two different random proportions up to 15% of + the original frame's width for both left and right side are cut. + + - Perspective transformation. The joint coordinates are projected onto a new plane with a spatially defined + center of projection, which simulates recording the sign video with a slight tilt. Each time, the right or left + side, as well as the proportion by which both the width and height will be reduced, are chosen randomly. This + proportion is selected from a uniform distribution on the [0; 1) interval. Subsequently, the new plane is + delineated by reducing the width at the desired side and the respective vertical edge (height) at both of its + adjacent corners. + + :param sign: Dictionary with sequential skeletal data of the signing person + :param type: Type of shear augmentation to perform (either 'squeeze' or 'perspective') + :param squeeze_ratio: Tuple containing the relative range from what the proportion of the original width will be + randomly chosen. These proportions will either be cut from both sides or used to construct the + new projection + + :return: Dictionary with augmented (by squeezing or perspective transformation) sequential skeletal data of the + signing person + """ + + body_landmarks, hand_landmarks = __preprocess_row_sign(sign) + + if type == "squeeze": + move_left = random.uniform(*squeeze_ratio) + move_right = random.uniform(*squeeze_ratio) + + src = np.array(((0, 1), (1, 1), (0, 0), (1, 0)), dtype=np.float32) + dest = np.array(((0 + move_left, 1), (1 - move_right, 1), (0 + move_left, 0), (1 - move_right, 0)), + dtype=np.float32) + mtx = cv2.getPerspectiveTransform(src, dest) + + elif type == "perspective": + + move_ratio = random.uniform(*squeeze_ratio) + src = np.array(((0, 1), (1, 1), (0, 0), (1, 0)), dtype=np.float32) + + if __random_pass(0.5): + dest = np.array(((0 + move_ratio, 1 - move_ratio), (1, 1), (0 + move_ratio, 0 + move_ratio), (1, 0)), + dtype=np.float32) + else: + dest = np.array(((0, 1), (1 - move_ratio, 1 - move_ratio), (0, 0), (1 - move_ratio, 0 + move_ratio)), + dtype=np.float32) + + mtx = cv2.getPerspectiveTransform(src, dest) + + else: + + logging.error("Unsupported shear type provided.") + return {} + + landmarks_array = __dictionary_to_numpy(body_landmarks) + augmented_landmarks = cv2.perspectiveTransform(np.array(landmarks_array, dtype=np.float32), mtx) + + augmented_zero_landmark = cv2.perspectiveTransform(np.array([[[0, 0]]], dtype=np.float32), mtx)[0][0] + augmented_landmarks = np.stack([np.where(sub == augmented_zero_landmark, [0, 0], sub) for sub in augmented_landmarks]) + + body_landmarks = __numpy_to_dictionary(augmented_landmarks) + + return __wrap_sign_into_row(body_landmarks, hand_landmarks) + + +def augment_arm_joint_rotate(sign: dict, probability: float, angle_range: tuple) -> dict: + """ + AUGMENTATION TECHNIQUE. The joint coordinates of both arms are passed successively, and the impending landmark is + slightly rotated with respect to the current one. The chance of each joint to be rotated is 3:10 and the angle of + alternation is a uniform random angle up to +-4 degrees. This simulates slight, negligible variances in each + execution of a sign, which do not change its semantic meaning. + + :param sign: Dictionary with sequential skeletal data of the signing person + :param probability: Probability of each joint to be rotated (float from the range [0, 1]) + :param angle_range: Tuple containing the angle range (minimal and maximal angle in degrees) to randomly choose the + angle by which the landmarks will be rotated from + + :return: Dictionary with augmented (by arm joint rotation) sequential skeletal data of the signing person + """ + + body_landmarks, hand_landmarks = __preprocess_row_sign(sign) + + # Iterate over both directions (both hands) + for side in ["left", "right"]: + # Iterate gradually over the landmarks on arm + for landmark_index, landmark_origin in enumerate(ARM_IDENTIFIERS_ORDER): + landmark_origin = landmark_origin.replace("$side$", side) + + # End the process on the current hand if the landmark is not present + if landmark_origin not in body_landmarks: + break + + # Perform rotation by provided probability + if __random_pass(probability): + angle = math.radians(random.uniform(*angle_range)) + + for to_be_rotated in ARM_IDENTIFIERS_ORDER[landmark_index + 1:]: + to_be_rotated = to_be_rotated.replace("$side$", side) + + # Skip if the landmark is not present + if to_be_rotated not in body_landmarks: + continue + + body_landmarks[to_be_rotated] = [__rotate(body_landmarks[landmark_origin][frame_index], frame, + angle) + for frame_index, frame in enumerate(body_landmarks[to_be_rotated])] + + return __wrap_sign_into_row(body_landmarks, hand_landmarks) diff --git a/conversion_requirements.txt b/conversion_requirements.txt new file mode 100644 index 0000000..4d37abe --- /dev/null +++ b/conversion_requirements.txt @@ -0,0 +1,21 @@ +bokeh==2.4.3 +boto3>=1.9 +clearml==1.6.4 +ipywidgets==8.0.4 +matplotlib==3.5.3 +mediapipe==0.8.11 +notebook==6.5.2 +opencv-python==4.6.0.66 +pandas==1.1.5 +pandas==1.1.5 +plotly==5.11.0 +scikit-learn==1.0.2 +torchvision==0.13.0 +tqdm==4.54.1 +# ------ +requests==2.28.1 +onnx==1.12.0 +onnx-tf==1.10.0 +onnxruntime==1.12.1 +tensorflow +tensorflow-probability diff --git a/convert.py b/convert.py new file mode 100644 index 0000000..c5d3932 --- /dev/null +++ b/convert.py @@ -0,0 +1,123 @@ +import os +import argparse +import json + +import numpy as np +import torch +import onnx +import onnxruntime + +try: + import tensorflow as tf +except ImportError: + print("Warning: Tensorflow not installed. This is required when exporting to tflite") + + +def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + +def print_final_message(model_path): + success_msg = f"\033[92mModel converted at {model_path} \033[0m" + try: + import requests + joke = json.loads(requests.request("GET", "https://api.chucknorris.io/jokes/random?category=dev").text)["value"] + print(f"{success_msg}\n\nNow go read a Chuck Norris joke:\n\033[1m{joke}\033[0m") + except ImportError: + print(success_msg) + + +def convert_tf_saved_model(onnx_model, output_folder): + from onnx_tf.backend import prepare + tf_rep = prepare(onnx_model) # prepare tf representation + tf_rep.export_graph(output_folder) # export the model + + +def convert_tf_to_lite(model_dir, output_path): + # Convert the model + converter = tf.lite.TFLiteConverter.from_saved_model(model_dir) # path to the SavedModel directory + # This is needed for TF Select ops: Cast, RealDiv + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. + tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. + ] + + tflite_model = converter.convert() + + # Save the model. + with open(output_path, 'wb') as f: + f.write(tflite_model) + + +def validate_tflite_output(model_path, input_data, output_array): + interpreter = tf.lite.Interpreter(model_path=model_path) + + output = interpreter.get_output_details()[0] # Model has single output. + input = interpreter.get_input_details()[0] # Model has single input. + interpreter.resize_tensor_input(input['index'], input_data.shape) + interpreter.allocate_tensors() + + input_data = tf.convert_to_tensor(input_data, np.float32) + interpreter.set_tensor(input['index'], input_data) + interpreter.invoke() + + out = interpreter.get_tensor(output['index']) + np.testing.assert_allclose(out, output_array, rtol=1e-03, atol=1e-05) + + +def convert(checkpoint_path, export_tensorflow): + output_folder = "converted_models" + model = torch.load(checkpoint_path, map_location='cpu') + model.eval() + + # Input to the model + x = torch.randn(1, 10, 54, 2, requires_grad=True) + numpy_x = to_numpy(x) + + torch_out = model(x) + numpy_out = to_numpy(torch_out) + + model_path = f"{output_folder}/spoter.onnx" + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + # Export the model + torch.onnx.export(model, # model being run + x, # model input (or a tuple for multiple inputs) + model_path, # where to save the model (can be a file or file-like object) + export_params=True, # store the trained parameter weights inside the model file + opset_version=11, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=['input'], # the model's input names + output_names=['output'], + dynamic_axes={'input': [1]}) # the model's output names + + # Validate conversion + onnx_model = onnx.load(model_path) + onnx.checker.check_model(onnx_model) + + ort_session = onnxruntime.InferenceSession(model_path) + + # compute ONNX Runtime output prediction + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} + ort_outs = ort_session.run(None, ort_inputs) + + # compare ONNX Runtime and PyTorch results + np.testing.assert_allclose(numpy_out, ort_outs[0], rtol=1e-03, atol=1e-05) + + if export_tensorflow: + saved_model_dir = f"{output_folder}/tf_saved" + tflite_model_path = f"{output_folder}/spoter.tflite" + convert_tf_saved_model(onnx_model, saved_model_dir) + convert_tf_to_lite(saved_model_dir, tflite_model_path) + validate_tflite_output(tflite_model_path, numpy_x, numpy_out) + + print_final_message(output_folder) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--checkpoint_path', help='Checkpoint Path') + parser.add_argument('-tf', '--export_tensorflow', help='Export Tensorflow apart from ONNX', action='store_true') + args = parser.parse_args() + convert(args.checkpoint_path, args.export_tensorflow) diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..e92e504 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,3 @@ +from .czech_slr_dataset import CzechSLRDataset +from .embedding_dataset import SLREmbeddingDataset +from .datasets_utils import collate_fn_triplet_padd, collate_fn_padd diff --git a/datasets/clearml_dataset_loader.py b/datasets/clearml_dataset_loader.py new file mode 100644 index 0000000..bf41e0b --- /dev/null +++ b/datasets/clearml_dataset_loader.py @@ -0,0 +1,8 @@ +from clearml import Dataset +from .dataset_loader import DatasetLoader + + +class ClearMLDatasetLoader(DatasetLoader): + + def get_dataset_folder(self, dataset_project, dataset_name): + return Dataset.get(dataset_project=dataset_project, dataset_name=dataset_name).get_local_copy() diff --git a/datasets/czech_slr_dataset.py b/datasets/czech_slr_dataset.py new file mode 100644 index 0000000..39bd97b --- /dev/null +++ b/datasets/czech_slr_dataset.py @@ -0,0 +1,72 @@ +import torch +import numpy as np +import torch.utils.data as torch_data + +from datasets.datasets_utils import load_dataset, tensor_to_dictionary, dictionary_to_tensor, \ + random_augmentation +from normalization.body_normalization import normalize_single_dict as normalize_single_body_dict +from normalization.hand_normalization import normalize_single_dict as normalize_single_hand_dict + + +class CzechSLRDataset(torch_data.Dataset): + """Advanced object representation of the HPOES dataset for loading hand joints landmarks utilizing the Torch's + built-in Dataset properties""" + + data: [np.ndarray] + labels: [np.ndarray] + + def __init__(self, dataset_filename: str, num_labels=5, transform=None, augmentations=False, + augmentations_prob=0.5, normalize=True): + """ + Initiates the HPOESDataset with the pre-loaded data from the h5 file. + + :param dataset_filename: Path to the h5 file + :param transform: Any data transformation to be applied (default: None) + """ + + loaded_data = load_dataset(dataset_filename) + data, labels = loaded_data[0], loaded_data[1] + + self.data = data + self.labels = labels + self.targets = list(labels) + self.num_labels = num_labels + self.transform = transform + + self.augmentations = augmentations + self.augmentations_prob = augmentations_prob + self.normalize = normalize + + def __getitem__(self, idx): + """ + Allocates, potentially transforms and returns the item at the desired index. + + :param idx: Index of the item + :return: Tuple containing both the depth map and the label + """ + + depth_map = torch.from_numpy(np.copy(self.data[idx])) + # label = torch.Tensor([self.labels[idx] - 1]) + label = torch.Tensor([self.labels[idx]]) + + depth_map = tensor_to_dictionary(depth_map) + + # Apply potential augmentations + depth_map = random_augmentation(self.augmentations, self.augmentations_prob, depth_map) + + if self.normalize: + depth_map = normalize_single_body_dict(depth_map) + depth_map = normalize_single_hand_dict(depth_map) + + depth_map = dictionary_to_tensor(depth_map) + + # Move the landmark position interval to improve performance + depth_map = depth_map - 0.5 + + if self.transform: + depth_map = self.transform(depth_map) + + return depth_map, label + + def __len__(self): + return len(self.labels) diff --git a/datasets/dataset_loader.py b/datasets/dataset_loader.py new file mode 100644 index 0000000..dd44082 --- /dev/null +++ b/datasets/dataset_loader.py @@ -0,0 +1,17 @@ + +import os + + +class DatasetLoader(): + """Abstract class that serves to load datasets from different sources (local, ClearML, other tracker) + """ + + def get_dataset_folder(self, dataset_project, dataset_name): + return NotImplementedError() + + +class LocalDatasetLoader(DatasetLoader): + + def get_dataset_folder(self, dataset_project, dataset_name): + base_folder = os.environ.get("BASE_DATA_FOLDER", "data") + return os.path.join(base_folder, dataset_name) diff --git a/datasets/datasets_utils.py b/datasets/datasets_utils.py new file mode 100644 index 0000000..e031805 --- /dev/null +++ b/datasets/datasets_utils.py @@ -0,0 +1,133 @@ +import pandas as pd +import ast +import torch +import random +import numpy as np +from torch.nn.utils.rnn import pad_sequence +from random import randrange + +from augmentations import augment_arm_joint_rotate, augment_rotate, augment_shear +from normalization.body_normalization import BODY_IDENTIFIERS +from augmentations.augment import HAND_IDENTIFIERS + + +def load_dataset(file_location: str): + + # Load the datset csv file + df = pd.read_csv(file_location, encoding="utf-8") + df.columns = [item.replace("_left_", "_0_").replace("_right_", "_1_") for item in list(df.columns)] + + # TEMP + labels = df["labels"].to_list() + + data = [] + + for row_index, row in df.iterrows(): + current_row = np.empty(shape=(len(ast.literal_eval(row["leftEar_X"])), + len(BODY_IDENTIFIERS + HAND_IDENTIFIERS), + 2) + ) + for index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS): + current_row[:, index, 0] = ast.literal_eval(row[identifier + "_X"]) + current_row[:, index, 1] = ast.literal_eval(row[identifier + "_Y"]) + + data.append(current_row) + + return data, labels + + +def tensor_to_dictionary(landmarks_tensor: torch.Tensor) -> dict: + + data_array = landmarks_tensor.numpy() + output = {} + + for landmark_index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS): + output[identifier] = data_array[:, landmark_index] + + return output + + +def dictionary_to_tensor(landmarks_dict: dict) -> torch.Tensor: + + output = np.empty(shape=(len(landmarks_dict["leftEar"]), len(BODY_IDENTIFIERS + HAND_IDENTIFIERS), 2)) + + for landmark_index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS): + output[:, landmark_index, 0] = [frame[0] for frame in landmarks_dict[identifier]] + output[:, landmark_index, 1] = [frame[1] for frame in landmarks_dict[identifier]] + + return torch.from_numpy(output) + + +def random_augmentation(augmentations, augmentations_prob, depth_map): + if augmentations and random.random() < augmentations_prob: + selected_aug = randrange(4) + if selected_aug == 0: + depth_map = augment_arm_joint_rotate(depth_map, 0.3, (-4, 4)) + elif selected_aug == 1: + depth_map = augment_shear(depth_map, "perspective", (0, 0.1)) + elif selected_aug == 2: + depth_map = augment_shear(depth_map, "squeeze", (0, 0.15)) + elif selected_aug == 3: + depth_map = augment_rotate(depth_map, (-13, 13)) + + return depth_map + + +def collate_fn_triplet_padd(batch): + ''' + Padds batch of variable length + + note: it converts things ToTensor manually here since the ToTensor transform + assume it takes in images rather than arbitrary tensors. + ''' + # batch: list of length batch_size, each element contains ouput of dataset + # MASKING + anchor_lengths = [element[0].shape[0] for element in batch] + max_anchor_l = max(anchor_lengths) + positive_lengths = [element[1].shape[0] for element in batch] + max_positive_l = max(positive_lengths) + negative_lengths = [element[2].shape[0] for element in batch] + max_negative_l = max(negative_lengths) + + anchor_mask = [[False] * anchor_lengths[n] + [True] * (max_anchor_l - anchor_lengths[n]) + for n in range(len(batch))] + positive_mask = [[False] * positive_lengths[n] + [True] * (max_positive_l - positive_lengths[n]) + for n in range(len(batch))] + negative_mask = [[False] * negative_lengths[n] + [True] * (max_negative_l - negative_lengths[n]) + for n in range(len(batch))] + + # PADDING + anchor_batch = [element[0] for element in batch] + positive_batch = [element[1] for element in batch] + negative_batch = [element[2] for element in batch] + + anchor_batch = pad_sequence(anchor_batch, batch_first=True) + positive_batch = pad_sequence(positive_batch, batch_first=True) + negative_batch = pad_sequence(negative_batch, batch_first=True) + + return anchor_batch, positive_batch, negative_batch, \ + torch.Tensor(anchor_mask), torch.Tensor(positive_mask), torch.Tensor(negative_mask) + + +def collate_fn_padd(batch): + ''' + Padds batch of variable length + + note: it converts things ToTensor manually here since the ToTensor transform + assume it takes in images rather than arbitrary tensors. + ''' + # batch: list of length batch_size, each element contains ouput of dataset + # MASKING + anchor_lengths = [element[0].shape[0] for element in batch] + max_anchor_l = max(anchor_lengths) + + anchor_mask = [[False] * anchor_lengths[n] + [True] * (max_anchor_l - anchor_lengths[n]) + for n in range(len(batch))] + + # PADDING + anchor_batch = [element[0] for element in batch] + anchor_batch = pad_sequence(anchor_batch, batch_first=True) + + labels = torch.Tensor([element[1] for element in batch]) + + return anchor_batch, labels, torch.Tensor(anchor_mask) diff --git a/datasets/embedding_dataset.py b/datasets/embedding_dataset.py new file mode 100644 index 0000000..a6a093b --- /dev/null +++ b/datasets/embedding_dataset.py @@ -0,0 +1,103 @@ +import torch +import torch.utils.data as torch_data +from random import sample +from typing import List +import numpy as np + +from datasets.datasets_utils import load_dataset, tensor_to_dictionary, dictionary_to_tensor, \ + random_augmentation +from normalization.body_normalization import normalize_single_dict as normalize_single_body_dict +from normalization.hand_normalization import normalize_single_dict as normalize_single_hand_dict + + +class SLREmbeddingDataset(torch_data.Dataset): + """Advanced object representation of the WLASL dataset for loading triplet used in triplet loss utilizing the + Torch's built-in Dataset properties""" + + data: List[np.ndarray] + labels: List[np.ndarray] + + def __init__(self, dataset_filename: str, triplet=True, transform=None, augmentations=False, + augmentations_prob=0.5, normalize=True): + """ + Initiates the HPOESDataset with the pre-loaded data from the h5 file. + + :param dataset_filename: Path to the h5 file + :param transform: Any data transformation to be applied (default: None) + """ + + loaded_data = load_dataset(dataset_filename) + data, labels = loaded_data[0], loaded_data[1] + + self.data = data + self.labels = labels + self.targets = list(labels) + self.transform = transform + self.triplet = triplet + self.augmentations = augmentations + self.augmentations_prob = augmentations_prob + self.normalize = normalize + + def __getitem__(self, idx): + """ + Allocates, potentially transforms and returns the item at the desired index. + + :param idx: Index of the item + :return: Tuple containing both the depth map and the label + """ + depth_map_a = torch.from_numpy(np.copy(self.data[idx])) + label = torch.Tensor([self.labels[idx]]) + + depth_map_a = tensor_to_dictionary(depth_map_a) + + if self.triplet: + positive_indexes = list(np.where(np.array(self.labels) == self.labels[idx])[0]) + positive_index_sample = sample(positive_indexes, 2) + positive_index = positive_index_sample[0] if positive_index_sample[0] != idx else positive_index_sample[1] + negative_indexes = list(np.where(np.array(self.labels) != self.labels[idx])[0]) + negative_index = sample(negative_indexes, 1)[0] + # TODO: implement hard triplets + + depth_map_p = torch.from_numpy(np.copy(self.data[positive_index])) + depth_map_n = torch.from_numpy(np.copy(self.data[negative_index])) + + depth_map_p = tensor_to_dictionary(depth_map_p) + depth_map_n = tensor_to_dictionary(depth_map_n) + + # TODO: Add Data augmentation to positive and negative ? + + # Apply potential augmentations + depth_map_a = random_augmentation(self.augmentations, self.augmentations_prob, depth_map_a) + + if self.normalize: + depth_map_a = normalize_single_body_dict(depth_map_a) + depth_map_a = normalize_single_hand_dict(depth_map_a) + if self.triplet: + depth_map_p = normalize_single_body_dict(depth_map_p) + depth_map_p = normalize_single_hand_dict(depth_map_p) + depth_map_n = normalize_single_body_dict(depth_map_n) + depth_map_n = normalize_single_hand_dict(depth_map_n) + + depth_map_a = dictionary_to_tensor(depth_map_a) + # Move the landmark position interval to improve performance + depth_map_a = depth_map_a - 0.5 + + if self.triplet: + depth_map_p = dictionary_to_tensor(depth_map_p) + depth_map_p = depth_map_p - 0.5 + depth_map_n = dictionary_to_tensor(depth_map_n) + depth_map_n = depth_map_n - 0.5 + + if self.transform: + depth_map_a = self.transform(depth_map_a) + if self.triplet: + depth_map_p = self.transform(depth_map_p) + depth_map_n = self.transform(depth_map_n) + + if self.triplet: + return depth_map_a, depth_map_p, depth_map_n + + return depth_map_a, label + + def __len__(self): + return len(self.labels) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..ef94e51 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,4 @@ +from .spoter_model import SPOTER +from .spoter_embedding_model import SPOTER_EMBEDDINGS +from .utils import train_epoch, evaluate, evaluate_top_k, train_epoch_embedding, train_epoch_embedding_online, \ + evaluate_embedding, embeddings_scatter_plot, embeddings_scatter_plot_splits diff --git a/models/spoter_embedding_model.py b/models/spoter_embedding_model.py new file mode 100644 index 0000000..2d2a944 --- /dev/null +++ b/models/spoter_embedding_model.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn + +from models.spoter_model import _get_clones, SPOTERTransformerDecoderLayer + + +class SPOTER_EMBEDDINGS(nn.Module): + """ + Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence + of skeletal data. + """ + + def __init__(self, features, hidden_dim=108, nhead=9, num_encoder_layers=6, num_decoder_layers=6, + norm_emb=False, dropout=0.1): + super().__init__() + + self.pos_encoding = nn.Parameter(torch.rand(1, 1, hidden_dim)) # init positional encoding + self.class_query = nn.Parameter(torch.rand(1, 1, hidden_dim)) + self.transformer = nn.Transformer(hidden_dim, nhead, num_encoder_layers, num_decoder_layers, dropout=dropout) + self.linear_embed = nn.Linear(hidden_dim, features) + + # Deactivate the initial attention decoder mechanism + custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048, + dropout, "relu") + self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers) + self.norm_emb = norm_emb + + def forward(self, inputs, src_masks=None): + + h = torch.transpose(inputs.flatten(start_dim=2), 1, 0).float() + h = self.transformer( + self.pos_encoding.repeat(1, h.shape[1], 1) + h, + self.class_query.repeat(1, h.shape[1], 1), + src_key_padding_mask=src_masks + ).transpose(0, 1) + embedding = self.linear_embed(h) + + if self.norm_emb: + embedding = nn.functional.normalize(embedding, dim=2) + + return embedding diff --git a/models/spoter_model.py b/models/spoter_model.py new file mode 100644 index 0000000..ed253ca --- /dev/null +++ b/models/spoter_model.py @@ -0,0 +1,66 @@ + +import copy +import torch + +import torch.nn as nn +from typing import Optional + + +def _get_clones(mod, n): + return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)]) + + +class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer): + """ + Edited TransformerDecoderLayer implementation omitting the redundant self-attention operation as opposed to the + standard implementation. + """ + + def __init__(self, d_model, nhead, dim_feedforward, dropout, activation): + super(SPOTERTransformerDecoderLayer, self).__init__(d_model, nhead, dim_feedforward, dropout, activation) + + del self.self_attn + + def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + + tgt = tgt + self.dropout1(tgt) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class SPOTER(nn.Module): + """ + Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence + of skeletal data. + """ + + def __init__(self, num_classes, hidden_dim=55): + super().__init__() + + self.row_embed = nn.Parameter(torch.rand(50, hidden_dim)) + self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0)) + self.class_query = nn.Parameter(torch.rand(1, hidden_dim)) + self.transformer = nn.Transformer(hidden_dim, 9, 6, 6) + self.linear_class = nn.Linear(hidden_dim, num_classes) + + # Deactivate the initial attention decoder mechanism + custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048, + 0.1, "relu") + self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers) + + def forward(self, inputs): + h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float() + h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1) + res = self.linear_class(h) + + return res diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..4979c79 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,280 @@ +import numpy as np +import torch +from sklearn.metrics import silhouette_score +from sklearn.manifold import TSNE +from training.batch_sorter import sort_batches +from utils import get_logger + + +def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None): + + pred_correct, pred_all = 0, 0 + running_loss = 0.0 + model.train(True) + for i, data in enumerate(dataloader): + inputs, labels = data + inputs = inputs.squeeze(0).to(device) + labels = labels.to(device, dtype=torch.long) + + optimizer.zero_grad() + outputs = model(inputs).expand(1, -1, -1) + loss = criterion(outputs[0], labels[0]) + loss.backward() + optimizer.step() + running_loss += loss.item() + + # Statistics + if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]): + pred_correct += 1 + pred_all += 1 + + epoch_loss = running_loss / len(dataloader) + model.train(False) + if scheduler: + scheduler.step(epoch_loss) + + return epoch_loss, pred_correct, pred_all, (pred_correct / pred_all) + + +def train_epoch_embedding(model, epoch_iters, train_loader, val_loader, criterion, optimizer, device, scheduler=None): + + running_loss = [] + model.train(True) + for i, (anchor, positive, negative, a_mask, p_mask, n_mask) in enumerate(train_loader): + optimizer.zero_grad() + + anchor_emb = model(anchor.to(device), a_mask.to(device)) + positive_emb = model(positive.to(device), p_mask.to(device)) + negative_emb = model(negative.to(device), n_mask.to(device)) + + loss = criterion(anchor_emb.to(device), positive_emb.to(device), negative_emb.to(device)) + loss.backward() + optimizer.step() + running_loss.append(loss.item()) + + if i == epoch_iters: + break + + epoch_loss = np.mean(running_loss) + + # VALIDATION + model.train(False) + val_silhouette_coef = evaluate_embedding(model, val_loader, device) + + if scheduler: + scheduler.step(val_silhouette_coef) + + return epoch_loss, val_silhouette_coef + + +def train_epoch_embedding_online(model, epoch_iters, train_loader, val_loader, criterion, optimizer, device, + scheduler=None, enable_batch_sorting=False, mini_batch_size=None, + pre_batch_mining_count=1, batching_scheduler=None): + + running_loss = [] + iter_used_triplets = [] + iter_valid_triplets = [] + iter_pct_used = [] + model.train(True) + mini_batch = mini_batch_size or train_loader.batch_size + for i, (inputs, labels, masks) in enumerate(train_loader): + labels_size = labels.size()[0] + batch_loop_count = int(labels_size / mini_batch) + if batch_loop_count == 0: + continue + # Second condition is added so that we only run batch sorting if we have a full batch + if enable_batch_sorting: + if labels_size < train_loader.batch_size: + trim_count = labels_size % mini_batch + inputs = inputs[:-trim_count] + labels = labels[:-trim_count] + masks = masks[:-trim_count] + embeddings = None + with torch.no_grad(): + for j in range(batch_loop_count): + batch_embed = compute_batched_embeddings(model, device, inputs, masks, mini_batch, j) + if embeddings is None: + embeddings = batch_embed + else: + embeddings = torch.cat([embeddings, batch_embed], dim=0) + inputs, labels, masks = sort_batches(inputs, labels, masks, embeddings, device, + mini_batch_size=mini_batch_size, scheduler=batching_scheduler) + del embeddings + del batch_embed + mining_loop_count = pre_batch_mining_count + else: + mining_loop_count = 1 + for k in range(mining_loop_count): + for j in range(batch_loop_count): + optimizer.zero_grad(set_to_none=True) + batch_labels = labels[mini_batch * j:mini_batch * (j + 1)] + if batch_labels.size()[0] == 0: + break + embeddings = compute_batched_embeddings(model, device, inputs, masks, mini_batch, j) + loss, valid_triplets, used_triplets = criterion(embeddings, batch_labels) + + loss.backward() + optimizer.step() + running_loss.append(loss.item()) + if valid_triplets > 0: + iter_used_triplets.append(used_triplets) + iter_valid_triplets.append(valid_triplets) + iter_pct_used.append((used_triplets * 100) / valid_triplets) + + if epoch_iters > 0 and i * batch_loop_count * pre_batch_mining_count >= epoch_iters: + print("Breaking out because of epoch_iters filter") + break + + epoch_loss = np.mean(running_loss) + mean_used_triplets = np.mean(iter_used_triplets) + triplets_stats = { + "valid_triplets": np.mean(iter_valid_triplets), + "used_triplets": mean_used_triplets, + "pct_used": np.mean(iter_pct_used) + } + + if batching_scheduler: + batching_scheduler.step(mean_used_triplets) + + # VALIDATION + model.train(False) + with torch.no_grad(): + val_silhouette_coef = evaluate_embedding(model, val_loader, device) + + if scheduler: + scheduler.step(val_silhouette_coef) + + return epoch_loss, val_silhouette_coef, triplets_stats + + +def compute_batched_embeddings(model, device, inputs, masks, mini_batch, iteration): + batch_inputs = inputs[mini_batch * iteration:mini_batch * (iteration + 1)] + batch_masks = masks[mini_batch * iteration:mini_batch * (iteration + 1)] + + return model(batch_inputs.to(device), batch_masks.to(device)).squeeze(1) + + +def evaluate(model, dataloader, device, print_stats=False): + + logger = get_logger(__name__) + + pred_correct, pred_all = 0, 0 + stats = {i: [0, 0] for i in range(101)} + + for i, data in enumerate(dataloader): + inputs, labels = data + inputs = inputs.squeeze(0).to(device) + labels = labels.to(device, dtype=torch.long) + + outputs = model(inputs).expand(1, -1, -1) + + # Statistics + if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]): + stats[int(labels[0][0])][0] += 1 + pred_correct += 1 + + stats[int(labels[0][0])][1] += 1 + pred_all += 1 + + if print_stats: + stats = {key: value[0] / value[1] for key, value in stats.items() if value[1] != 0} + print("Label accuracies statistics:") + print(str(stats) + "\n") + logger.info("Label accuracies statistics:") + logger.info(str(stats) + "\n") + + return pred_correct, pred_all, (pred_correct / pred_all) + + +def evaluate_embedding(model, dataloader, device): + val_embeddings = [] + labels_emb = [] + + for i, (inputs, labels, masks) in enumerate(dataloader): + inputs = inputs.to(device) + masks = masks.to(device) + + outputs = model(inputs, masks) + for n in range(outputs.shape[0]): + val_embeddings.append(outputs[n, 0].cpu().detach().numpy()) + labels_emb.append(labels.detach().numpy()[n]) + + silhouette_coefficient = silhouette_score( + X=np.array(val_embeddings), + labels=np.array(labels_emb).reshape(len(labels_emb)) + ) + + return silhouette_coefficient + + +def embeddings_scatter_plot(model, dataloader, device, id_to_label, perplexity=40, n_iter=1000): + + val_embeddings = [] + labels_emb = [] + + with torch.no_grad(): + for i, (inputs, labels, masks) in enumerate(dataloader): + inputs = inputs.to(device) + masks = masks.to(device) + + outputs = model(inputs, masks) + for n in range(outputs.shape[0]): + val_embeddings.append(outputs[n, 0].cpu().detach().numpy()) + labels_emb.append(id_to_label[int(labels.detach().numpy()[n])]) + + tsne = TSNE(n_components=2, verbose=0, perplexity=perplexity, n_iter=n_iter) + tsne_results = tsne.fit_transform(np.array(val_embeddings)) + + return tsne_results, labels_emb + + +def embeddings_scatter_plot_splits(model, dataloaders, device, id_to_label, perplexity=40, n_iter=1000): + + labels_split = {} + embeddings_split = {} + splits = list(dataloaders.keys()) + with torch.no_grad(): + for split, dataloader in dataloaders.items(): + labels_str = [] + embeddings = [] + for i, (inputs, labels, masks) in enumerate(dataloader): + inputs = inputs.to(device) + masks = masks.to(device) + + outputs = model(inputs, masks) + for n in range(outputs.shape[0]): + embeddings.append(outputs[n, 0].cpu().detach().numpy()) + labels_str.append(id_to_label[int(labels.detach().numpy()[n])]) + labels_split[split] = labels_str + embeddings_split[split] = embeddings + + tsne = TSNE(n_components=2, verbose=0, perplexity=perplexity, n_iter=n_iter) + all_embeddings = np.vstack([embeddings_split[split] for split in splits]) + tsne_results = tsne.fit_transform(all_embeddings) + tsne_results_dict = {} + curr_index = 0 + for split in splits: + len_embeddings = len(embeddings_split[split]) + tsne_results_dict[split] = tsne_results[curr_index: curr_index + len_embeddings] + curr_index += len_embeddings + + return tsne_results_dict, labels_split + + +def evaluate_top_k(model, dataloader, device, k=5): + + pred_correct, pred_all = 0, 0 + + for i, data in enumerate(dataloader): + inputs, labels = data + inputs = inputs.squeeze(0).to(device) + labels = labels.to(device, dtype=torch.long) + + outputs = model(inputs).expand(1, -1, -1) + + if int(labels[0][0]) in torch.topk(outputs, k).indices.tolist()[0][0]: + pred_correct += 1 + + pred_all += 1 + + return pred_correct, pred_all, (pred_correct / pred_all) diff --git a/normalization/blazepose_mapping.py b/normalization/blazepose_mapping.py new file mode 100644 index 0000000..90a0672 --- /dev/null +++ b/normalization/blazepose_mapping.py @@ -0,0 +1,92 @@ + +_BODY_KEYPOINT_MAPPING = { + "nose": "nose", + "left_eye": "leftEye", + "right_eye": "rightEye", + "left_ear": "leftEar", + "right_ear": "rightEar", + "left_shoulder": "leftShoulder", + "right_shoulder": "rightShoulder", + "left_elbow": "leftElbow", + "right_elbow": "rightElbow", + "left_wrist": "leftWrist", + "right_wrist": "rightWrist" +} + +_HAND_KEYPOINT_MAPPING = { + "wrist": "wrist", + "index_finger_tip": "indexTip", + "index_finger_dip": "indexDIP", + "index_finger_pip": "indexPIP", + "index_finger_mcp": "indexMCP", + "middle_finger_tip": "middleTip", + "middle_finger_dip": "middleDIP", + "middle_finger_pip": "middlePIP", + "middle_finger_mcp": "middleMCP", + "ring_finger_tip": "ringTip", + "ring_finger_dip": "ringDIP", + "ring_finger_pip": "ringPIP", + "ring_finger_mcp": "ringMCP", + "pinky_tip": "littleTip", + "pinky_dip": "littleDIP", + "pinky_pip": "littlePIP", + "pinky_mcp": "littleMCP", + "thumb_tip": "thumbTip", + "thumb_ip": "thumbIP", + "thumb_mcp": "thumbMP", + "thumb_cmc": "thumbCMC" +} + + +def map_blazepose_keypoint(column): + # Remove _x, _y suffixes + suffix = column[-2:].upper() + column = column[:-2] + + if column.startswith("left_hand_"): + hand = "left" + finger_name = column[10:] + elif column.startswith("right_hand_"): + hand = "right" + finger_name = column[11:] + else: + if column not in _BODY_KEYPOINT_MAPPING: + return None + mapped = _BODY_KEYPOINT_MAPPING[column] + return mapped + suffix + + if finger_name not in _HAND_KEYPOINT_MAPPING: + return None + mapped = _HAND_KEYPOINT_MAPPING[finger_name] + return f"{mapped}_{hand}{suffix}" + + +def map_blazepose_df(df): + to_drop = [] + renamings = {} + for column in df.columns: + mapped_column = map_blazepose_keypoint(column) + if mapped_column: + renamings[column] = mapped_column + else: + to_drop.append(column) + df = df.rename(columns=renamings) + + for index, row in df.iterrows(): + + sequence_size = len(row["leftEar_Y"]) + lsx = row["leftShoulder_X"] + rsx = row["rightShoulder_X"] + lsy = row["leftShoulder_Y"] + rsy = row["rightShoulder_Y"] + neck_x = [] + neck_y = [] + # Treat each element of the sequence (analyzed frame) individually + for sequence_index in range(sequence_size): + neck_x.append((float(lsx[sequence_index]) + float(rsx[sequence_index])) / 2) + neck_y.append((float(lsy[sequence_index]) + float(rsy[sequence_index])) / 2) + df.loc[index, "neck_X"] = str(neck_x) + df.loc[index, "neck_Y"] = str(neck_y) + + df.drop(columns=to_drop, inplace=True) + return df diff --git a/normalization/body_normalization.py b/normalization/body_normalization.py new file mode 100644 index 0000000..015012f --- /dev/null +++ b/normalization/body_normalization.py @@ -0,0 +1,241 @@ + +from typing import Tuple +import pandas as pd +from utils import get_logger + + +BODY_IDENTIFIERS = [ + "nose", + "neck", + "rightEye", + "leftEye", + "rightEar", + "leftEar", + "rightShoulder", + "leftShoulder", + "rightElbow", + "leftElbow", + "rightWrist", + "leftWrist" +] + + +def normalize_body_full(df: pd.DataFrame) -> Tuple[pd.DataFrame, list]: + """ + Normalizes the body position data using the Bohacek-normalization algorithm. + + :param df: pd.DataFrame to be normalized + :return: pd.DataFrame with normalized values for body pose + """ + logger = get_logger(__name__) + + # TODO: Fix division by zero + + normalized_df = pd.DataFrame(columns=df.columns) + invalid_row_indexes = [] + body_landmarks = {"X": [], "Y": []} + + # Construct the relevant identifiers + for identifier in BODY_IDENTIFIERS: + body_landmarks["X"].append(identifier + "_X") + body_landmarks["Y"].append(identifier + "_Y") + + # Iterate over all of the records in the dataset + for index, row in df.iterrows(): + + sequence_size = len(row["leftEar_Y"]) + valid_sequence = True + original_row = row + + last_starting_point, last_ending_point = None, None + + # Treat each element of the sequence (analyzed frame) individually + for sequence_index in range(sequence_size): + + # Prevent from even starting the analysis if some necessary elements are not present + if (row["leftShoulder_X"][sequence_index] == 0 or row["rightShoulder_X"][sequence_index] == 0) and \ + (row["neck_X"][sequence_index] == 0 or row["nose_X"][sequence_index] == 0): + if not last_starting_point: + valid_sequence = False + continue + + else: + starting_point, ending_point = last_starting_point, last_ending_point + + else: + + # NOTE: + # + # While in the paper, it is written that the head metric is calculated by halving the shoulder distance, + # this is meant for the distance between the very ends of one's shoulder, as literature studying body + # metrics and ratios generally states. The Vision Pose Estimation API, however, seems to be predicting + # rather the center of one's shoulder. Based on our experiments and manual reviews of the data, + # employing + # this as just the plain shoulder distance seems to be more corresponding to the desired metric. + # + # Please, review this if using other third-party pose estimation libraries. + + if row["leftShoulder_X"][sequence_index] != 0 and row["rightShoulder_X"][sequence_index] != 0: + left_shoulder = (row["leftShoulder_X"][sequence_index], row["leftShoulder_Y"][sequence_index]) + right_shoulder = (row["rightShoulder_X"][sequence_index], row["rightShoulder_Y"][sequence_index]) + shoulder_distance = ((((left_shoulder[0] - right_shoulder[0]) ** 2) + ( + (left_shoulder[1] - right_shoulder[1]) ** 2)) ** 0.5) + head_metric = shoulder_distance + else: + neck = (row["neck_X"][sequence_index], row["neck_Y"][sequence_index]) + nose = (row["nose_X"][sequence_index], row["nose_Y"][sequence_index]) + neck_nose_distance = ((((neck[0] - nose[0]) ** 2) + ((neck[1] - nose[1]) ** 2)) ** 0.5) + head_metric = neck_nose_distance + + # Set the starting and ending point of the normalization bounding box + starting_point = [row["neck_X"][sequence_index] - 3 * head_metric, + row["leftEye_Y"][sequence_index] + (head_metric / 2)] + ending_point = [row["neck_X"][sequence_index] + 3 * head_metric, starting_point[1] - 6 * head_metric] + + last_starting_point, last_ending_point = starting_point, ending_point + + # Ensure that all of the bounding-box-defining coordinates are not out of the picture + if starting_point[0] < 0: + starting_point[0] = 0 + if starting_point[1] < 0: + starting_point[1] = 0 + if ending_point[0] < 0: + ending_point[0] = 0 + if ending_point[1] < 0: + ending_point[1] = 0 + + # Normalize individual landmarks and save the results + for identifier in BODY_IDENTIFIERS: + key = identifier + "_" + + # Prevent from trying to normalize incorrectly captured points + if row[key + "X"][sequence_index] == 0: + continue + + normalized_x = (row[key + "X"][sequence_index] - starting_point[0]) / (ending_point[0] - + starting_point[0]) + normalized_y = (row[key + "Y"][sequence_index] - ending_point[1]) / (starting_point[1] - + ending_point[1]) + + row[key + "X"][sequence_index] = normalized_x + row[key + "Y"][sequence_index] = normalized_y + + if valid_sequence: + normalized_df = normalized_df.append(row, ignore_index=True) + else: + logger.warning(" BODY LANDMARKS: One video instance could not be normalized.") + normalized_df = normalized_df.append(original_row, ignore_index=True) + invalid_row_indexes.append(index) + + logger.info("The normalization of body is finished.") + logger.info("\t-> Original size:", df.shape[0]) + logger.info("\t-> Normalized size:", normalized_df.shape[0]) + logger.info("\t-> Problematic videos:", len(invalid_row_indexes)) + + return normalized_df, invalid_row_indexes + + +def normalize_single_dict(row: dict): + """ + Normalizes the skeletal data for a given sequence of frames with signer's body pose data. The normalization follows + the definition from our paper. + + :param row: Dictionary containing key-value pairs with joint identifiers and corresponding lists (sequences) of + that particular joints coordinates + :return: Dictionary with normalized skeletal data (following the same schema as input data) + """ + + sequence_size = len(row["leftEar"]) + valid_sequence = True + original_row = row + logger = get_logger(__name__) + + last_starting_point, last_ending_point = None, None + + # Treat each element of the sequence (analyzed frame) individually + for sequence_index in range(sequence_size): + left_shoulder = (row["leftShoulder"][sequence_index][0], row["leftShoulder"][sequence_index][1]) + right_shoulder = (row["rightShoulder"][sequence_index][0], row["rightShoulder"][sequence_index][1]) + neck = (row["neck"][sequence_index][0], row["neck"][sequence_index][1]) + nose = (row["nose"][sequence_index][0], row["nose"][sequence_index][1]) + # Prevent from even starting the analysis if some necessary elements are not present + if (left_shoulder[0] == 0 or right_shoulder[0] == 0 + or (left_shoulder[0] == right_shoulder[0] and left_shoulder[1] == right_shoulder[1])) and ( + neck[0] == 0 or nose[0] == 0 or (neck[0] == nose[0] and neck[1] == nose[1])): + if not last_starting_point: + valid_sequence = False + continue + + else: + starting_point, ending_point = last_starting_point, last_ending_point + + else: + + # NOTE: + # + # While in the paper, it is written that the head metric is calculated by halving the shoulder distance, + # this is meant for the distance between the very ends of one's shoulder, as literature studying body + # metrics and ratios generally states. The Vision Pose Estimation API, however, seems to be predicting + # rather the center of one's shoulder. Based on our experiments and manual reviews of the data, employing + # this as just the plain shoulder distance seems to be more corresponding to the desired metric. + # + # Please, review this if using other third-party pose estimation libraries. + + if left_shoulder[0] != 0 and right_shoulder[0] != 0 and \ + (left_shoulder[0] != right_shoulder[0] or left_shoulder[1] != right_shoulder[1]): + shoulder_distance = ((((left_shoulder[0] - right_shoulder[0]) ** 2) + ( + (left_shoulder[1] - right_shoulder[1]) ** 2)) ** 0.5) + head_metric = shoulder_distance + else: + neck_nose_distance = ((((neck[0] - nose[0]) ** 2) + ((neck[1] - nose[1]) ** 2)) ** 0.5) + head_metric = neck_nose_distance + + # Set the starting and ending point of the normalization bounding box + # starting_point = [row["neck"][sequence_index][0] - 3 * head_metric, + # row["leftEye"][sequence_index][1] + (head_metric / 2)] + starting_point = [row["neck"][sequence_index][0] - 3 * head_metric, + row["leftEye"][sequence_index][1] + head_metric] + ending_point = [row["neck"][sequence_index][0] + 3 * head_metric, starting_point[1] - 6 * head_metric] + + last_starting_point, last_ending_point = starting_point, ending_point + + # Ensure that all of the bounding-box-defining coordinates are not out of the picture + if starting_point[0] < 0: + starting_point[0] = 0 + if starting_point[1] < 0: + starting_point[1] = 0 + if ending_point[0] < 0: + ending_point[0] = 0 + if ending_point[1] < 0: + ending_point[1] = 0 + + # Normalize individual landmarks and save the results + for identifier in BODY_IDENTIFIERS: + key = identifier + + # Prevent from trying to normalize incorrectly captured points + if row[key][sequence_index][0] == 0: + continue + + if (ending_point[0] - starting_point[0]) == 0 or (starting_point[1] - ending_point[1]) == 0: + logger.warning("Problematic normalization") + valid_sequence = False + break + + normalized_x = (row[key][sequence_index][0] - starting_point[0]) / (ending_point[0] - starting_point[0]) + normalized_y = (row[key][sequence_index][1] - ending_point[1]) / (starting_point[1] - ending_point[1]) + + row[key][sequence_index] = list(row[key][sequence_index]) + + row[key][sequence_index][0] = normalized_x + row[key][sequence_index][1] = normalized_y + + if valid_sequence: + return row + + else: + return original_row + + +if __name__ == "__main__": + pass diff --git a/normalization/hand_normalization.py b/normalization/hand_normalization.py new file mode 100644 index 0000000..8343491 --- /dev/null +++ b/normalization/hand_normalization.py @@ -0,0 +1,195 @@ + +import pandas as pd +from utils import get_logger + + +HAND_IDENTIFIERS = [ + "wrist", + "indexTip", + "indexDIP", + "indexPIP", + "indexMCP", + "middleTip", + "middleDIP", + "middlePIP", + "middleMCP", + "ringTip", + "ringDIP", + "ringPIP", + "ringMCP", + "littleTip", + "littleDIP", + "littlePIP", + "littleMCP", + "thumbTip", + "thumbIP", + "thumbMP", + "thumbCMC" +] + + +def normalize_hands_full(df: pd.DataFrame) -> pd.DataFrame: + """ + Normalizes the hands position data using the Bohacek-normalization algorithm. + + :param df: pd.DataFrame to be normalized + :return: pd.DataFrame with normalized values for hand pose + """ + + logger = get_logger(__name__) + # TODO: Fix division by zero + df.columns = [item.replace("_left_", "_0_").replace("_right_", "_1_") for item in list(df.columns)] + + normalized_df = pd.DataFrame(columns=df.columns) + + hand_landmarks = {"X": {0: [], 1: []}, "Y": {0: [], 1: []}} + + # Determine how many hands are present in the dataset + range_hand_size = 1 + if "wrist_1_X" in df.columns: + range_hand_size = 2 + + # Construct the relevant identifiers + for identifier in HAND_IDENTIFIERS: + for hand_index in range(range_hand_size): + hand_landmarks["X"][hand_index].append(identifier + "_" + str(hand_index) + "_X") + hand_landmarks["Y"][hand_index].append(identifier + "_" + str(hand_index) + "_Y") + + # Iterate over all of the records in the dataset + for index, row in df.iterrows(): + # Treat each hand individually + for hand_index in range(range_hand_size): + + sequence_size = len(row["wrist_" + str(hand_index) + "_X"]) + + # Treat each element of the sequence (analyzed frame) individually + for sequence_index in range(sequence_size): + + # Retrieve all of the X and Y values of the current frame + landmarks_x_values = [row[key][sequence_index] + for key in hand_landmarks["X"][hand_index] if row[key][sequence_index] != 0] + landmarks_y_values = [row[key][sequence_index] + for key in hand_landmarks["Y"][hand_index] if row[key][sequence_index] != 0] + + # Prevent from even starting the analysis if some necessary elements are not present + if not landmarks_x_values or not landmarks_y_values: + logger.warning( + " HAND LANDMARKS: One frame could not be normalized as there is no data present. Record: " + + str(index) + + ", Frame: " + str(sequence_index)) + continue + + # Calculate the deltas + width, height = max(landmarks_x_values) - min(landmarks_x_values), max(landmarks_y_values) - min( + landmarks_y_values) + if width > height: + delta_x = 0.1 * width + delta_y = delta_x + ((width - height) / 2) + else: + delta_y = 0.1 * height + delta_x = delta_y + ((height - width) / 2) + + # Set the starting and ending point of the normalization bounding box + starting_point = (min(landmarks_x_values) - delta_x, min(landmarks_y_values) - delta_y) + ending_point = (max(landmarks_x_values) + delta_x, max(landmarks_y_values) + delta_y) + + # Normalize individual landmarks and save the results + for identifier in HAND_IDENTIFIERS: + key = identifier + "_" + str(hand_index) + "_" + + # Prevent from trying to normalize incorrectly captured points + if row[key + "X"][sequence_index] == 0 or (ending_point[0] - starting_point[0]) == 0 or \ + (starting_point[1] - ending_point[1]) == 0: + continue + + normalized_x = (row[key + "X"][sequence_index] - starting_point[0]) / (ending_point[0] - + starting_point[0]) + normalized_y = (row[key + "Y"][sequence_index] - ending_point[1]) / (starting_point[1] - + ending_point[1]) + + row[key + "X"][sequence_index] = normalized_x + row[key + "Y"][sequence_index] = normalized_y + + normalized_df = normalized_df.append(row, ignore_index=True) + + return normalized_df + + +def normalize_single_dict(row: dict): + """ + Normalizes the skeletal data for a given sequence of frames with signer's hand pose data. The normalization follows + the definition from our paper. + + :param row: Dictionary containing key-value pairs with joint identifiers and corresponding lists (sequences) of + that particular joints coordinates + :return: Dictionary with normalized skeletal data (following the same schema as input data) + """ + + hand_landmarks = {0: [], 1: []} + + # Determine how many hands are present in the dataset + range_hand_size = 1 + if "wrist_1" in row.keys(): + range_hand_size = 2 + + # Construct the relevant identifiers + for identifier in HAND_IDENTIFIERS: + for hand_index in range(range_hand_size): + hand_landmarks[hand_index].append(identifier + "_" + str(hand_index)) + + # Treat each hand individually + for hand_index in range(range_hand_size): + + sequence_size = len(row["wrist_" + str(hand_index)]) + + # Treat each element of the sequence (analyzed frame) individually + for sequence_index in range(sequence_size): + + # Retrieve all of the X and Y values of the current frame + landmarks_x_values = [row[key][sequence_index][0] for key in hand_landmarks[hand_index] if + row[key][sequence_index][0] != 0] + landmarks_y_values = [row[key][sequence_index][1] for key in hand_landmarks[hand_index] if + row[key][sequence_index][1] != 0] + + # Prevent from even starting the analysis if some necessary elements are not present + if not landmarks_x_values or not landmarks_y_values: + continue + + # Calculate the deltas + width, height = max(landmarks_x_values) - min(landmarks_x_values), max(landmarks_y_values) - min( + landmarks_y_values) + if width > height: + delta_x = 0.1 * width + delta_y = delta_x + ((width - height) / 2) + else: + delta_y = 0.1 * height + delta_x = delta_y + ((height - width) / 2) + + # Set the starting and ending point of the normalization bounding box + starting_point = (min(landmarks_x_values) - delta_x, min(landmarks_y_values) - delta_y) + ending_point = (max(landmarks_x_values) + delta_x, max(landmarks_y_values) + delta_y) + + # Normalize individual landmarks and save the results + for identifier in HAND_IDENTIFIERS: + key = identifier + "_" + str(hand_index) + + # Prevent from trying to normalize incorrectly captured points + if row[key][sequence_index][0] == 0 or (ending_point[0] - starting_point[0]) == 0 or ( + starting_point[1] - ending_point[1]) == 0: + continue + + normalized_x = (row[key][sequence_index][0] - starting_point[0]) / (ending_point[0] - + starting_point[0]) + normalized_y = (row[key][sequence_index][1] - starting_point[1]) / (ending_point[1] - + starting_point[1]) + + row[key][sequence_index] = list(row[key][sequence_index]) + + row[key][sequence_index][0] = normalized_x + row[key][sequence_index][1] = normalized_y + + return row + + +if __name__ == "__main__": + pass diff --git a/normalization/main.py b/normalization/main.py new file mode 100644 index 0000000..4c619a2 --- /dev/null +++ b/normalization/main.py @@ -0,0 +1,47 @@ +import os +import ast +import pandas as pd + +from normalization.hand_normalization import normalize_hands_full +from normalization.body_normalization import normalize_body_full + +DATASET_PATH = './data' +# Load the dataset +df = pd.read_csv(os.path.join(DATASET_PATH, "WLASL_test_15fps.csv"), encoding="utf-8") + +# Retrieve metadata +video_size_heights = df["video_size_height"].to_list() +video_size_widths = df["video_size_width"].to_list() + +# Delete redundant (non-related) properties +del df["video_size_height"] +del df["video_size_width"] + +# Temporarily remove other relevant metadata +labels = df["labels"].to_list() +video_fps = df["video_fps"].to_list() +del df["labels"] +del df["video_fps"] + +# Convert the strings into lists + + +def convert(x): return ast.literal_eval(str(x)) + + +for column in df.columns: + df[column] = df[column].apply(convert) + +# Perform the normalizations +df = normalize_hands_full(df) +df, invalid_row_indexes = normalize_body_full(df) + +# Clear lists of items from deleted rows +# labels = [t for i, t in enumerate(labels) if i not in invalid_row_indexes] +# video_fps = [t for i, t in enumerate(video_fps) if i not in invalid_row_indexes] + +# Return the metadata back to the dataset +df["labels"] = labels +df["video_fps"] = video_fps + +df.to_csv(os.path.join(DATASET_PATH, "WLASL_test_15fps_normalized.csv"), encoding="utf-8", index=False) diff --git a/notebooks/embeddings_evaluation.ipynb b/notebooks/embeddings_evaluation.ipynb new file mode 100644 index 0000000..1db8999 --- /dev/null +++ b/notebooks/embeddings_evaluation.ipynb @@ -0,0 +1,411 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "c20f7fd5", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ada032d0", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import os.path as op\n", + "import pandas as pd\n", + "import json\n", + "import base64" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05682e73", + "metadata": {}, + "outputs": [], + "source": [ + "sys.path.append(op.abspath('..'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fede7684", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce531994", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import Counter\n", + "from itertools import chain\n", + "\n", + "import torch\n", + "import multiprocessing\n", + "from scipy.spatial import distance_matrix\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4a2d672", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "from datasets import SLREmbeddingDataset, collate_fn_padd\n", + "from datasets.dataset_loader import LocalDatasetLoader\n", + "from models import embeddings_scatter_plot_splits\n", + "from models import SPOTER_EMBEDDINGS" + ] + }, + { + "cell_type": "markdown", + "id": "af8fbe32", + "metadata": {}, + "source": [ + "## Model and dataset loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d9db764", + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "seed = 43\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", + "torch.manual_seed(seed)\n", + "torch.cuda.manual_seed(seed)\n", + "torch.cuda.manual_seed_all(seed)\n", + "torch.backends.cudnn.deterministic = True\n", + "torch.use_deterministic_algorithms(True) \n", + "generator = torch.Generator()\n", + "generator.manual_seed(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71224139", + "metadata": {}, + "outputs": [], + "source": [ + "BASE_DATA_FOLDER = '../data/'\n", + "os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n", + "device = torch.device(\"cpu\")\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "013d3774", + "metadata": {}, + "outputs": [], + "source": [ + "# LOAD MODEL FROM CLEARML\n", + "# from clearml import InputModel\n", + "# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n", + "# checkpoint = torch.load(model.get_weights())\n", + "\n", + "## Set your path to checkoint here\n", + "CHECKPOINT_PATH = \"../checkpoints/checkpoint_embed_992.pth\"\n", + "checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n", + "\n", + "model = SPOTER_EMBEDDINGS(\n", + " features=checkpoint[\"config_args\"].vector_length,\n", + " hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n", + " norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n", + ").to(device)\n", + "\n", + "model.load_state_dict(checkpoint[\"state_dict\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba6b58f0", + "metadata": {}, + "outputs": [], + "source": [ + "SL_DATASET = 'wlasl' # or 'lsa'\n", + "if SL_DATASET == 'wlasl':\n", + " dataset_name = \"wlasl_mapped_mediapipe_only_landmarks_25fps\"\n", + " num_classes = 100\n", + " split_dataset_path = \"WLASL100_{}_25fps.csv\"\n", + "else:\n", + " dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n", + " num_classes = 64\n", + " split_dataset_path = \"LSA64_{}.csv\"\n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5643a72c", + "metadata": {}, + "outputs": [], + "source": [ + "def get_dataset_loader(loader_name=None):\n", + " if loader_name == 'CLEARML':\n", + " from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n", + " return ClearMLDatasetLoader()\n", + " else:\n", + " return LocalDatasetLoader()\n", + "\n", + "dataset_loader = get_dataset_loader()\n", + "dataset_project = \"Sign Language Recognition\"\n", + "batch_size = 1\n", + "dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04a62088", + "metadata": {}, + "outputs": [], + "source": [ + "def seed_worker(worker_id):\n", + " worker_seed = torch.initial_seed() % 2**32\n", + " np.random.seed(worker_seed)\n", + " random.seed(worker_seed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79c837c1", + "metadata": {}, + "outputs": [], + "source": [ + "dataloaders = {}\n", + "splits = ['train', 'val']\n", + "dfs = {}\n", + "for split in splits:\n", + " split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n", + " split_set = SLREmbeddingDataset(split_set_path, triplet=False, augmentations=False)\n", + " data_loader = DataLoader(\n", + " split_set,\n", + " batch_size=batch_size,\n", + " shuffle=False,\n", + " collate_fn=collate_fn_padd,\n", + " pin_memory=torch.cuda.is_available(),\n", + " num_workers=multiprocessing.cpu_count(),\n", + " worker_init_fn=seed_worker,\n", + " generator=generator,\n", + " )\n", + " dataloaders[split] = data_loader\n", + " dfs[split] = pd.read_csv(split_set_path)\n", + "\n", + "with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n", + " id_to_label = json.load(fid)\n", + "id_to_label = {int(key): value for key, value in id_to_label.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b5bda73", + "metadata": {}, + "outputs": [], + "source": [ + "labels_split = {}\n", + "embeddings_split = {}\n", + "splits = list(dataloaders.keys())\n", + "with torch.no_grad():\n", + " for split, dataloader in dataloaders.items():\n", + " labels_str = []\n", + " embeddings = []\n", + " k = 0\n", + " for i, (inputs, labels, masks) in enumerate(dataloader):\n", + " k += 1\n", + " inputs = inputs.to(device)\n", + " masks = masks.to(device)\n", + " outputs = model(inputs, masks)\n", + " for n in range(outputs.shape[0]):\n", + " embeddings.append(outputs[n, 0].cpu().detach().numpy())\n", + " embeddings_split[split] = embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0efa0871", + "metadata": {}, + "outputs": [], + "source": [ + "len(embeddings_split['train']), len(dfs['train'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab83c6e2", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "for split in splits:\n", + " df = dfs[split]\n", + " df['embeddings'] = embeddings_split[split]" + ] + }, + { + "cell_type": "markdown", + "id": "2951638d", + "metadata": {}, + "source": [ + "## Compute metrics\n", + "Here computing top1 and top5 metrics either by using only a class centroid or by using the whole dataset to classify vectors.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7399b8ae", + "metadata": {}, + "outputs": [], + "source": [ + "for use_centroids, str_use_centroids in zip([True, False],\n", + " ['Using centroids only', 'Using all embeddings']):\n", + "\n", + " df_val = dfs['val']\n", + " df_train = dfs['train']\n", + " if use_centroids:\n", + " df_train = dfs['train'].groupby('labels')['embeddings'].apply(np.mean).reset_index()\n", + " x_train = np.vstack(df_train['embeddings'])\n", + " x_val = np.vstack(df_val['embeddings'])\n", + "\n", + " d_mat = distance_matrix(x_val, x_train, p=2)\n", + "\n", + " top5_embs = 0\n", + " top5_classes = 0\n", + " knn = 0\n", + " top1 = 0\n", + "\n", + " len_val_dataset = len(df_val)\n", + " good_samples = []\n", + "\n", + " for i in range(d_mat.shape[0]):\n", + " true_label = df_val.loc[i, 'labels']\n", + " labels = df_train['labels'].values\n", + " argsort = np.argsort(d_mat[i])\n", + " sorted_labels = labels[argsort]\n", + " if sorted_labels[0] == true_label:\n", + " top1 += 1\n", + " if use_centroids:\n", + " good_samples.append(df_val.loc[i, 'video_id'])\n", + " else:\n", + " good_samples.append((df_val.loc[i, 'video_id'],\n", + " df_train.loc[argsort[0], 'video_id'],\n", + " i,\n", + " argsort[0]))\n", + "\n", + "\n", + " if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n", + " knn += 1\n", + " if true_label in sorted_labels[:5]:\n", + " top5_embs += 1\n", + " if true_label in list(dict.fromkeys(sorted_labels))[:5]:\n", + " top5_classes += 1\n", + " else:\n", + " continue\n", + "\n", + "\n", + " print(str_use_centroids)\n", + "\n", + "\n", + " print(f'Top-1 accuracy: {100 * top1 / len_val_dataset : 0.2f} %')\n", + " if not use_centroids:\n", + " print(f'5-nn accuracy: {100 * knn / len_val_dataset : 0.2f} % (Picks the class that appears most often in the 5 closest embeddings)')\n", + " print(f'Top-5 embeddings class match: {100 * top5_embs / len_val_dataset: 0.2f} % (Picks any class in the 5 closest embeddings)')\n", + " if not use_centroids:\n", + " print(f'Top-5 unique class match: {100 * top5_classes / len_val_dataset: 0.2f} % (Picks the 5 closest distinct classes)')\n", + " print('\\n' + '#'*32 + '\\n')" + ] + }, + { + "cell_type": "markdown", + "id": "d2aaac6c", + "metadata": {}, + "source": [ + "## Show some examples (only for WLASL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9d1d309", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Video" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd2a0cd8", + "metadata": {}, + "outputs": [], + "source": [ + "for row in df_train[df_train.label_name == 'thursday'][:3].itertuples():\n", + " display(Video(op.join(BASE_DATA_FOLDER, f'wlasl/videos/{row.video_id}.mp4'), embed=True))" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/visualize_embeddings.ipynb b/notebooks/visualize_embeddings.ipynb new file mode 100644 index 0000000..c49a9fa --- /dev/null +++ b/notebooks/visualize_embeddings.ipynb @@ -0,0 +1,491 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "8ef5cd92", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "78c4643a", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import os.path as op\n", + "import pandas as pd\n", + "import json\n", + "import base64" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ffba4333", + "metadata": {}, + "outputs": [], + "source": [ + "sys.path.append(op.abspath('..'))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5bc81f71", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3de8bcf2", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "import torch\n", + "import multiprocessing\n", + "from itertools import chain\n", + "import numpy as np\n", + "import random" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "91a045ba", + "metadata": {}, + "outputs": [], + "source": [ + "from bokeh.io import output_notebook, output_file\n", + "from bokeh.plotting import figure, show\n", + "from bokeh.models import LinearColorMapper, ColumnDataSource\n", + "from bokeh.transform import factor_cmap, factor_mark\n", + "from torch.utils.data import DataLoader\n", + "\n", + "\n", + "from datasets import SLREmbeddingDataset, collate_fn_padd\n", + "from datasets.dataset_loader import LocalDatasetLoader\n", + "from models import embeddings_scatter_plot_splits\n", + "from models import SPOTER_EMBEDDINGS" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bc50c296", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "seed = 43\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", + "torch.manual_seed(seed)\n", + "torch.cuda.manual_seed(seed)\n", + "torch.cuda.manual_seed_all(seed)\n", + "torch.backends.cudnn.deterministic = True\n", + "torch.use_deterministic_algorithms(True) \n", + "generator = torch.Generator()\n", + "generator.manual_seed(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "82766a17", + "metadata": {}, + "outputs": [], + "source": [ + "BASE_DATA_FOLDER = '../data/'\n", + "os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n", + "device = torch.device(\"cpu\")\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ead15a36", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# LOAD MODEL FROM CLEARML\n", + "# from clearml import InputModel\n", + "# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n", + "# checkpoint = torch.load(model.get_weights())\n", + "\n", + "\n", + "CHECKPOINT_PATH = \"../checkpoints/checkpoint_embed_992.pth\"\n", + "checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n", + "\n", + "\n", + "model = SPOTER_EMBEDDINGS(\n", + " features=checkpoint[\"config_args\"].vector_length,\n", + " hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n", + " norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n", + ").to(device)\n", + "\n", + "model.load_state_dict(checkpoint[\"state_dict\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "20f8036d", + "metadata": {}, + "outputs": [], + "source": [ + "SL_DATASET = 'wlasl' # or 'lsa'\n", + "if SL_DATASET == 'wlasl':\n", + " dataset_name = \"wlasl_mapped_mediapipe_only_landmarks_25fps\"\n", + " num_classes = 100\n", + " split_dataset_path = \"WLASL100_{}_25fps.csv\"\n", + "else:\n", + " dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n", + " num_classes = 64\n", + " split_dataset_path = \"LSA64_{}.csv\"\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "758716b6", + "metadata": {}, + "outputs": [], + "source": [ + "def get_dataset_loader(loader_name=None):\n", + " if loader_name == 'CLEARML':\n", + " from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n", + " return ClearMLDatasetLoader()\n", + " else:\n", + " return LocalDatasetLoader()\n", + "\n", + "dataset_loader = get_dataset_loader()\n", + "dataset_project = \"Sign Language Recognition\"\n", + "batch_size = 1\n", + "dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f1527959", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.7/site-packages/sklearn/manifold/_t_sne.py:783: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.\n", + " FutureWarning,\n", + "/opt/conda/lib/python3.7/site-packages/sklearn/manifold/_t_sne.py:793: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.\n", + " FutureWarning,\n" + ] + } + ], + "source": [ + "dataloaders = {}\n", + "splits = ['train', 'val']\n", + "dfs = {}\n", + "for split in splits:\n", + " split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n", + " split_set = SLREmbeddingDataset(split_set_path, triplet=False)\n", + " data_loader = DataLoader(\n", + " split_set,\n", + " batch_size=batch_size,\n", + " shuffle=False,\n", + " collate_fn=collate_fn_padd,\n", + " pin_memory=torch.cuda.is_available(),\n", + " num_workers=multiprocessing.cpu_count()\n", + " )\n", + " dataloaders[split] = data_loader\n", + " dfs[split] = pd.read_csv(split_set_path)\n", + "\n", + "with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n", + " id_to_label = json.load(fid)\n", + "id_to_label = {int(key): value for key, value in id_to_label.items()}\n", + "\n", + "tsne_results, labels_results = embeddings_scatter_plot_splits(model,\n", + " dataloaders,\n", + " device,\n", + " id_to_label,\n", + " perplexity=40,\n", + " n_iter=1000)\n", + "\n", + "\n", + "set_labels = list(set(next(chain(labels_results.values()))))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3c3af5bf", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1533" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dfs = {}\n", + "for split in splits:\n", + " split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n", + " df = pd.read_csv(split_set_path)\n", + " df['tsne_x'] = tsne_results[split][:, 0]\n", + " df['tsne_y'] = tsne_results[split][:, 1]\n", + " df['split'] = split\n", + " if SL_DATASET == 'wlasl':\n", + " df['video_fn'] = df['video_id'].apply(lambda video_id: os.path.join(BASE_DATA_FOLDER, f'wlasl/videos/{video_id:05d}.mp4'))\n", + " else:\n", + " df['video_fn'] = df['video_id'].apply(lambda video_id: os.path.join(BASE_DATA_FOLDER, f'lsa/videos/{video_id}.mp4'))\n", + " dfs[split] = df\n", + "\n", + "df = pd.concat([dfs['train'].sample(100), dfs['val']]).reset_index(drop=True)\n", + "len(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "dccbe1b9", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.auto import tqdm\n", + "\n", + "def load_videos(video_list):\n", + " print('loading videos')\n", + " videos = []\n", + " for video_fn in tqdm(video_list):\n", + " if video_fn is None:\n", + " video_data = None\n", + " else:\n", + " with open(video_fn, 'rb') as fid:\n", + " video_data = base64.b64encode(fid.read()).decode()\n", + " videos.append(video_data)\n", + " print('Done loading videos')\n", + " return videos" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "904298f0", + "metadata": {}, + "outputs": [], + "source": [ + "use_img_div = False\n", + "if use_img_div:\n", + " # sample dataframe data to avoid overloading scatter plot with too many videos\n", + " df = df.loc[(df['tsne_x'] > 10) & (df['tsne_x'] < 20)]\n", + " df = df.loc[(df['tsne_y'] > 10) & (df['tsne_y'] < 20)]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "42832f7c", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " Loading BokehJS ...\n", + "
\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\nconst JS_MIME_TYPE = 'application/javascript';\n const HTML_MIME_TYPE = 'text/html';\n const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n const CLASS_NAME = 'output_bokeh rendered_html';\n\n /**\n * Render data to the DOM node\n */\n function render(props, node) {\n const script = document.createElement(\"script\");\n node.appendChild(script);\n }\n\n /**\n * Handle when an output is cleared or removed\n */\n function handleClearOutput(event, handle) {\n const cell = handle.cell;\n\n const id = cell.output_area._bokeh_element_id;\n const server_id = cell.output_area._bokeh_server_id;\n // Clean up Bokeh references\n if (id != null && id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n\n if (server_id !== undefined) {\n // Clean up Bokeh references\n const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n cell.notebook.kernel.execute(cmd_clean, {\n iopub: {\n output: function(msg) {\n const id = msg.content.text.trim();\n if (id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n }\n }\n });\n // Destroy server and session\n const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n cell.notebook.kernel.execute(cmd_destroy);\n }\n }\n\n /**\n * Handle when a new output is added\n */\n function handleAddOutput(event, handle) {\n const output_area = handle.output_area;\n const output = handle.output;\n\n // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n return\n }\n\n const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n\n if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n // store reference to embed id on output_area\n output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n }\n if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n const bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n const script_attrs = bk_div.children[0].attributes;\n for (let i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n }\n\n function register_renderer(events, OutputArea) {\n\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n const toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[toinsert.length - 1]);\n element.append(toinsert);\n return toinsert\n }\n\n /* Handle when an output is cleared or removed */\n events.on('clear_output.CodeCell', handleClearOutput);\n events.on('delete.Cell', handleClearOutput);\n\n /* Handle when a new output is added */\n events.on('output_added.OutputArea', handleAddOutput);\n\n /**\n * Register the mime type and append_mime function with output_area\n */\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n /* Is output safe? */\n safe: true,\n /* Index of renderer in `output_area.display_order` */\n index: 0\n });\n }\n\n // register the mime type if in Jupyter Notebook environment and previously unregistered\n if (root.Jupyter !== undefined) {\n const events = require('base/js/events');\n const OutputArea = require('notebook/js/outputarea').OutputArea;\n\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n }\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"
\\n\"+\n \"

\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"

\\n\"+\n \"
    \\n\"+\n \"
  • re-rerun `output_notebook()` to attempt to load from CDN again, or
  • \\n\"+\n \"
  • use INLINE resources instead, as so:
  • \\n\"+\n \"
\\n\"+\n \"\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"\\n\"+\n \"
\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"1107\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"1107\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));", + "application/vnd.bokehjs_load.v0+json": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "img_div = '''\n", + "
\n", + " \n", + "
\n", + "'''\n", + "TOOLTIPS = f\"\"\"\n", + "
\n", + " {img_div if use_img_div else ''}\n", + "
\n", + " @label_desc - @split\n", + " [#@video_id]\n", + "
\n", + "
\n", + " \n", + "\"\"\"\n", + "cmap = LinearColorMapper(palette=\"Turbo256\", low=0, high=len(set_labels))\n", + "\n", + "output_notebook()\n", + "# or \n", + "# output_file(\"scatter_plot.html\")\n", + "\n", + "p = figure(width=1000,\n", + " height=800,\n", + " tooltips=TOOLTIPS,\n", + " title=f\"Check {'video' if use_img_div else 'label'} by hovering mouse over the dots\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "ead4daf7", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": "(function(root) {\n function embed_document(root) {\n const docs_json = {\"458dc79d-472d-4f36-92df-cc9c2d7d3fb7\":{\"defs\":[],\"roots\":{\"references\":[{\"attributes\":{\"below\":[{\"id\":\"1119\"}],\"center\":[{\"id\":\"1122\"},{\"id\":\"1126\"}],\"height\":800,\"left\":[{\"id\":\"1123\"}],\"renderers\":[{\"id\":\"1148\"}],\"title\":{\"id\":\"1109\"},\"toolbar\":{\"id\":\"1135\"},\"width\":1000,\"x_range\":{\"id\":\"1111\"},\"x_scale\":{\"id\":\"1115\"},\"y_range\":{\"id\":\"1113\"},\"y_scale\":{\"id\":\"1117\"}},\"id\":\"1108\",\"subtype\":\"Figure\",\"type\":\"Plot\"},{\"attributes\":{},\"id\":\"1128\",\"type\":\"WheelZoomTool\"},{\"attributes\":{},\"id\":\"1165\",\"type\":\"BasicTickFormatter\"},{\"attributes\":{},\"id\":\"1163\",\"type\":\"AllLabels\"},{\"attributes\":{},\"id\":\"1124\",\"type\":\"BasicTicker\"},{\"attributes\":{\"data\":{\"label\":[53,51,57,89,19,27,51,42,18,18,60,61,94,35,72,51,80,1,40,50,10,38,87,62,19,36,49,32,75,72,98,7,38,59,14,59,3,94,68,67,2,30,63,3,8,28,82,80,35,46,19,57,82,9,51,69,69,38,5,3,42,26,89,1,92,85,33,92,53,34,42,52,51,18,49,50,22,89,7,24,94,73,31,61,63,23,88,73,76,80,0,98,11,13,82,47,65,73,66,57,71,52,87,62,3,19,89,82,68,34,95,2,34,93,22,93,5,58,92,86,31,52,10,42,36,13,42,1,35,23,89,16,7,9,59,13,45,67,48,74,53,26,11,36,92,63,87,48,49,16,73,7,87,94,93,97,61,44,97,6,80,60,13,59,47,53,5,18,10,13,70,19,35,55,30,81,99,7,86,40,94,12,15,46,0,88,31,62,13,20,65,19,34,7,81,84,48,67,38,78,91,41,2,75,49,43,26,44,18,94,34,26,42,87,93,41,7,76,56,1,19,67,40,16,60,10,81,85,32,39,64,25,62,14,28,39,99,88,73,53,71,14,69,67,3,73,53,83,2,45,85,58,16,32,35,37,16,62,3,4,76,15,17,48,37,99,64,47,69,95,69,5,72,44,24,49,66,61,8,38,49,33,40,53,24,92,21,40,47,50,97,34,57,26,80,72,19,31,5,10,98,81,88,25,11,11,40,82,79,81,25,70,37,73,12,72,15,12,99,48,71,81,3,4,50,32,3,91,50,97,74,9,45,46,41,67,63,37,57,50,8,66,83,13,97,34,54,83,2,48,18,51,5,60,42,5,97,27,20,27,70,37,37,36,51,88,11,90,10,39,85,44,18,41,98,3,29,72,49,96,10,21,39,91,3,15,22,21,79,32,52,92,45,48,29,50,38,70,26,3,67,17,23,51,32,45,66,61,98,17,49,69,68,43,44,5,82,20,84,8,27,80,46,22,70,44,33,50,77,31,32,34,64,39,4,55,16,36,76,59,45,25,70,95,21,66,76,1,35,9,85,57,99,69,50,5,90,90,20,78,96,12,43,70,4,78,93,35,17,35,55,52,99,23,3,10,36,75,52,91,76,35,86,74,78,37,5,28,17,69,10,2,47,95,77,70,36,80,93,43,91,14,22,13,90,71,0,27,55,65,54,40,9,26,0,77,75,23,30,11,22,55,13,33,59,13,56,66,14,68,16,9,78,75,61,85,82,52,66,12,51,38,15,51,8,0,22,83,57,41,48,99,28,11,91,55,12,15,84,9,22,24,41,21,73,64,27,88,1,60,22,64,52,67,33,35,92,17,63,4,2,82,65,56,17,40,52,67,40,77,54,88,29,56,25,42,23,38,7,82,64,44,42,18,20,19,14,29,71,33,69,46,41,59,94,33,49,76,11,42,64,96,66,32,1,12,79,75,15,6,35,2,4,24,27,44,93,23,22,62,50,72,45,78,55,87,58,84,63,6,96,1,9,41,32,65,24,30,73,71,65,87,30,78,8,54,83,77,20,29,68,32,71,27,98,0,80,24,23,34,7,10,9,6,58,14,7,54,8,72,17,0,69,54,85,75,6,85,47,87,72,71,47,68,44,6,56,16,31,33,96,77,86,43,12,49,39,92,26,24,21,5,59,58,4,93,90,86,30,8,40,59,55,79,18,21,70,15,60,86,63,45,29,43,84,89,64,15,45,55,2,47,41,23,48,53,66,4,85,65,15,39,65,84,12,12,42,66,19,40,30,35,53,65,23,1,48,86,90,30,57,43,66,68,5,68,9,5,11,21,92,95,31,95,41,90,79,34,98,43,46,36,5,52,2,90,74,36,29,17,63,99,58,22,56,57,20,23,16,81,37,1,88,94,85,71,14,60,88,18,38,14,21,78,17,94,46,32,69,76,71,47,36,90,37,25,54,98,35,74,17,27,1,18,8,74,3,42,77,23,96,47,58,20,74,5,44,54,34,82,54,31,33,43,2,63,68,83,61,9,0,57,33,54,14,27,79,56,26,45,28,81,55,68,75,85,27,6,12,25,10,88,20,22,62,71,76,96,77,46,88,57,8,27,90,17,23,79,8,48,66,93,24,2,36,77,48,13,75,90,25,2,9,84,31,20,43,2,36,83,12,25,60,96,14,52,43,51,24,49,20,49,27,76,46,25,6,75,9,65,91,12,89,96,10,86,71,56,33,39,74,46,84,74,1,7,41,56,93,95,3,64,41,64,74,20,40,62,79,16,68,86,8,96,28,91,18,74,24,46,72,53,47,46,64,45,94,11,17,70,4,99,26,23,24,34,99,95,29,97,85,97,92,87,99,82,96,31,47,80,57,34,22,10,25,29,0,46,30,26,91,78,20,29,21,57,62,19,24,38,98,2,67,63,19,32,44,60,24,75,39,1,59,77,83,39,2,28,83,13,84,79,67,97,14,50,93,49,31,84,79,29,84,48,59,68,56,74,10,50,25,30,19,83,31,58,89,64,3,39,37,58,83,29,95,53,26,16,1,25,97,81,68,92,21,42,72,7,13,8,9,91,1,2,79,4,89,23,58,5,11,15,38,28,81,62,5,32,63,53,50,3,7,61,67,14,98,54,1,1,14,18,85,99,69,46,43,20,63,97,37,85,59,58,94,47,22,97,23,9,5,13,87,72,11,22,52,0,13,11,22,0,70,1,39,9,38,29,96,7,14,51,71,74,22,17,48,41,90,61,97,5,36,42,83,45,91,32,68,76,84,78,66,49,67,63,29,5,9,19,50,10,79,51,86,80,2,76,88,7,89,15,4,46,47,41,57,39,16,81,63,17,82,71,27,3,49,64,50,95,61,87,1,21,73,74,24,48,25,44,19,8,20,68,19,6,28,59,98,46,32,53,9,30,76,90,75,73,79,77,38,36,58,83,1,16,46,14,21,35,10,7,96,73,18,70,42,93,24,61,72,64,13,15,58,76,31,65,59,82,90,38,32,42,99,57,33,37,10,69,95,10,41,26,45,33,80,21,51,41,68,22,47,26,88,71,17,24,99,74,24,14,91,47,13,52,36,34,18,36,60,55,27,98,12,57,18,65,60,64,16,79,99,98,40,12,44,1,81,93,48,77,23,62,72,8,35,54,5,94,59,75,28,67,30,67,23,97,12,15,92,84,17,3,60,86,37,27,2,49,49,2,12,13,20,56,74,2,25,14,52,55,84,75,23,91,95,29,16,33,44,92,93,54,29,24,6,4,3,26,45,0,43,44,40,82,9,20,21,11,69,82,5,87,86,70,4,56,34,19,92,48,53,56,34,30,57,75,28,77,88,3,55,66,8,89,53,7,27,25,40,93,40,50,68,25,27,78,35,43,67,66,2,2,54,85,71,37,42,66,63,32,38,34,65,62,45,26,51,39,23,17,81,72,10,83,62,31,53,85,52,25,8,31,80,1,90,89,96,64,11,5,50,31,54,94,78,55,68,34,30,3,35],\"label_desc\":[\"bird\",\"accident\",\"cow\",\"medicine\",\"black\",\"orange\",\"accident\",\"kiss\",\"all\",\"all\",\"doctor\",\"eat\",\"purple\",\"can\",\"short\",\"accident\",\"but\",\"drink\",\"hat\",\"wrong\",\"deaf\",\"fish\",\"jacket\",\"enjoy\",\"black\",\"dog\",\"white\",\"bed\",\"work\",\"short\",\"tell\",\"who\",\"fish\",\"dark\",\"thin\",\"dark\",\"before\",\"purple\",\"pizza\",\"pink\",\"computer\",\"what\",\"forget\",\"before\",\"candy\",\"table\",\"city\",\"but\",\"can\",\"shirt\",\"black\",\"cow\",\"city\",\"cousin\",\"accident\",\"play\",\"play\",\"fish\",\"go\",\"before\",\"kiss\",\"now\",\"medicine\",\"drink\",\"paper\",\"full\",\"blue\",\"paper\",\"bird\",\"bowling\",\"kiss\",\"apple\",\"accident\",\"all\",\"white\",\"wrong\",\"hot\",\"medicine\",\"who\",\"many\",\"purple\",\"time\",\"woman\",\"eat\",\"forget\",\"like\",\"letter\",\"time\",\"africa\",\"but\",\"book\",\"tell\",\"fine\",\"no\",\"city\",\"study\",\"last\",\"time\",\"meet\",\"cow\",\"secretary\",\"apple\",\"jacket\",\"enjoy\",\"before\",\"black\",\"medicine\",\"city\",\"pizza\",\"bowling\",\"right\",\"computer\",\"bowling\",\"pull\",\"hot\",\"pull\",\"go\",\"dance\",\"paper\",\"how\",\"woman\",\"apple\",\"deaf\",\"kiss\",\"dog\",\"no\",\"kiss\",\"drink\",\"can\",\"like\",\"medicine\",\"year\",\"who\",\"cousin\",\"dark\",\"no\",\"man\",\"pink\",\"tall\",\"want\",\"bird\",\"now\",\"fine\",\"dog\",\"paper\",\"forget\",\"jacket\",\"tall\",\"white\",\"year\",\"time\",\"who\",\"jacket\",\"purple\",\"pull\",\"son\",\"eat\",\"later\",\"son\",\"clothes\",\"but\",\"doctor\",\"no\",\"dark\",\"study\",\"bird\",\"go\",\"all\",\"deaf\",\"no\",\"school\",\"black\",\"can\",\"color\",\"what\",\"cheat\",\"thursday\",\"who\",\"how\",\"hat\",\"purple\",\"help\",\"walk\",\"shirt\",\"book\",\"letter\",\"woman\",\"enjoy\",\"no\",\"cool\",\"last\",\"black\",\"bowling\",\"who\",\"cheat\",\"decide\",\"tall\",\"pink\",\"fish\",\"birthday\",\"paint\",\"hearing\",\"computer\",\"work\",\"white\",\"language\",\"now\",\"later\",\"all\",\"purple\",\"bowling\",\"now\",\"kiss\",\"jacket\",\"pull\",\"hearing\",\"who\",\"africa\",\"corn\",\"drink\",\"black\",\"pink\",\"hat\",\"year\",\"doctor\",\"deaf\",\"cheat\",\"full\",\"bed\",\"graduate\",\"give\",\"mother\",\"enjoy\",\"thin\",\"table\",\"graduate\",\"thursday\",\"letter\",\"time\",\"bird\",\"secretary\",\"thin\",\"play\",\"pink\",\"before\",\"time\",\"bird\",\"cook\",\"computer\",\"man\",\"full\",\"dance\",\"year\",\"bed\",\"can\",\"family\",\"year\",\"enjoy\",\"before\",\"chair\",\"africa\",\"walk\",\"yes\",\"tall\",\"family\",\"thursday\",\"give\",\"study\",\"play\",\"right\",\"play\",\"go\",\"short\",\"later\",\"many\",\"white\",\"meet\",\"eat\",\"candy\",\"fish\",\"white\",\"blue\",\"hat\",\"bird\",\"many\",\"paper\",\"finish\",\"hat\",\"study\",\"wrong\",\"son\",\"bowling\",\"cow\",\"now\",\"but\",\"short\",\"black\",\"woman\",\"go\",\"deaf\",\"tell\",\"cheat\",\"letter\",\"mother\",\"fine\",\"fine\",\"hat\",\"city\",\"brown\",\"cheat\",\"mother\",\"school\",\"family\",\"time\",\"help\",\"short\",\"walk\",\"help\",\"thursday\",\"tall\",\"secretary\",\"cheat\",\"before\",\"chair\",\"wrong\",\"bed\",\"before\",\"paint\",\"wrong\",\"son\",\"want\",\"cousin\",\"man\",\"shirt\",\"hearing\",\"pink\",\"forget\",\"family\",\"cow\",\"wrong\",\"candy\",\"meet\",\"cook\",\"no\",\"son\",\"bowling\",\"change\",\"cook\",\"computer\",\"tall\",\"all\",\"accident\",\"go\",\"doctor\",\"kiss\",\"go\",\"son\",\"orange\",\"cool\",\"orange\",\"school\",\"family\",\"family\",\"dog\",\"accident\",\"letter\",\"fine\",\"need\",\"deaf\",\"graduate\",\"full\",\"later\",\"all\",\"hearing\",\"tell\",\"before\",\"thanksgiving\",\"short\",\"white\",\"same\",\"deaf\",\"finish\",\"graduate\",\"paint\",\"before\",\"walk\",\"hot\",\"finish\",\"brown\",\"bed\",\"apple\",\"paper\",\"man\",\"tall\",\"thanksgiving\",\"wrong\",\"fish\",\"school\",\"now\",\"before\",\"pink\",\"yes\",\"like\",\"accident\",\"bed\",\"man\",\"meet\",\"eat\",\"tell\",\"yes\",\"white\",\"play\",\"pizza\",\"language\",\"later\",\"go\",\"city\",\"cool\",\"decide\",\"candy\",\"orange\",\"but\",\"shirt\",\"hot\",\"school\",\"later\",\"blue\",\"wrong\",\"basketball\",\"woman\",\"bed\",\"bowling\",\"give\",\"graduate\",\"chair\",\"color\",\"year\",\"dog\",\"africa\",\"dark\",\"man\",\"mother\",\"school\",\"right\",\"finish\",\"meet\",\"africa\",\"drink\",\"can\",\"cousin\",\"full\",\"cow\",\"thursday\",\"play\",\"wrong\",\"go\",\"need\",\"need\",\"cool\",\"birthday\",\"same\",\"help\",\"language\",\"school\",\"chair\",\"birthday\",\"pull\",\"can\",\"yes\",\"can\",\"color\",\"apple\",\"thursday\",\"like\",\"before\",\"deaf\",\"dog\",\"work\",\"apple\",\"paint\",\"africa\",\"can\",\"how\",\"want\",\"birthday\",\"family\",\"go\",\"table\",\"yes\",\"play\",\"deaf\",\"computer\",\"study\",\"right\",\"basketball\",\"school\",\"dog\",\"but\",\"pull\",\"language\",\"paint\",\"thin\",\"hot\",\"no\",\"need\",\"secretary\",\"book\",\"orange\",\"color\",\"last\",\"change\",\"hat\",\"cousin\",\"now\",\"book\",\"basketball\",\"work\",\"like\",\"what\",\"fine\",\"hot\",\"color\",\"no\",\"blue\",\"dark\",\"no\",\"corn\",\"meet\",\"thin\",\"pizza\",\"year\",\"cousin\",\"birthday\",\"work\",\"eat\",\"full\",\"city\",\"apple\",\"meet\",\"help\",\"accident\",\"fish\",\"walk\",\"accident\",\"candy\",\"book\",\"hot\",\"cook\",\"cow\",\"hearing\",\"tall\",\"thursday\",\"table\",\"fine\",\"paint\",\"color\",\"help\",\"walk\",\"decide\",\"cousin\",\"hot\",\"many\",\"hearing\",\"finish\",\"time\",\"give\",\"orange\",\"letter\",\"drink\",\"doctor\",\"hot\",\"give\",\"apple\",\"pink\",\"blue\",\"can\",\"paper\",\"yes\",\"forget\",\"chair\",\"computer\",\"city\",\"last\",\"corn\",\"yes\",\"hat\",\"apple\",\"pink\",\"hat\",\"basketball\",\"change\",\"letter\",\"thanksgiving\",\"corn\",\"mother\",\"kiss\",\"like\",\"fish\",\"who\",\"city\",\"give\",\"later\",\"kiss\",\"all\",\"cool\",\"black\",\"thin\",\"thanksgiving\",\"secretary\",\"blue\",\"play\",\"shirt\",\"hearing\",\"dark\",\"purple\",\"blue\",\"white\",\"africa\",\"fine\",\"kiss\",\"give\",\"same\",\"meet\",\"bed\",\"drink\",\"help\",\"brown\",\"work\",\"walk\",\"clothes\",\"can\",\"computer\",\"chair\",\"many\",\"orange\",\"later\",\"pull\",\"like\",\"hot\",\"enjoy\",\"wrong\",\"short\",\"man\",\"birthday\",\"color\",\"jacket\",\"dance\",\"decide\",\"forget\",\"clothes\",\"same\",\"drink\",\"cousin\",\"hearing\",\"bed\",\"last\",\"many\",\"what\",\"time\",\"secretary\",\"last\",\"jacket\",\"what\",\"birthday\",\"candy\",\"change\",\"cook\",\"basketball\",\"cool\",\"thanksgiving\",\"pizza\",\"bed\",\"secretary\",\"orange\",\"tell\",\"book\",\"but\",\"many\",\"like\",\"bowling\",\"who\",\"deaf\",\"cousin\",\"clothes\",\"dance\",\"thin\",\"who\",\"change\",\"candy\",\"short\",\"yes\",\"book\",\"play\",\"change\",\"full\",\"work\",\"clothes\",\"full\",\"study\",\"jacket\",\"short\",\"secretary\",\"study\",\"pizza\",\"later\",\"clothes\",\"corn\",\"year\",\"woman\",\"blue\",\"same\",\"basketball\",\"how\",\"language\",\"help\",\"white\",\"graduate\",\"paper\",\"now\",\"many\",\"finish\",\"go\",\"dark\",\"dance\",\"chair\",\"pull\",\"need\",\"how\",\"what\",\"candy\",\"hat\",\"dark\",\"color\",\"brown\",\"all\",\"finish\",\"school\",\"walk\",\"doctor\",\"how\",\"forget\",\"man\",\"thanksgiving\",\"language\",\"decide\",\"medicine\",\"give\",\"walk\",\"man\",\"color\",\"computer\",\"study\",\"hearing\",\"like\",\"tall\",\"bird\",\"meet\",\"chair\",\"full\",\"last\",\"walk\",\"graduate\",\"last\",\"decide\",\"help\",\"help\",\"kiss\",\"meet\",\"black\",\"hat\",\"what\",\"can\",\"bird\",\"last\",\"like\",\"drink\",\"tall\",\"how\",\"need\",\"what\",\"cow\",\"language\",\"meet\",\"pizza\",\"go\",\"pizza\",\"cousin\",\"go\",\"fine\",\"finish\",\"paper\",\"right\",\"woman\",\"right\",\"hearing\",\"need\",\"brown\",\"bowling\",\"tell\",\"language\",\"shirt\",\"dog\",\"go\",\"apple\",\"computer\",\"need\",\"want\",\"dog\",\"thanksgiving\",\"yes\",\"forget\",\"thursday\",\"dance\",\"hot\",\"corn\",\"cow\",\"cool\",\"like\",\"year\",\"cheat\",\"family\",\"drink\",\"letter\",\"purple\",\"full\",\"secretary\",\"thin\",\"doctor\",\"letter\",\"all\",\"fish\",\"thin\",\"finish\",\"birthday\",\"yes\",\"purple\",\"shirt\",\"bed\",\"play\",\"africa\",\"secretary\",\"study\",\"dog\",\"need\",\"family\",\"mother\",\"change\",\"tell\",\"can\",\"want\",\"yes\",\"orange\",\"drink\",\"all\",\"candy\",\"want\",\"before\",\"kiss\",\"basketball\",\"like\",\"same\",\"study\",\"dance\",\"cool\",\"want\",\"go\",\"later\",\"change\",\"bowling\",\"city\",\"change\",\"woman\",\"blue\",\"language\",\"computer\",\"forget\",\"pizza\",\"cook\",\"eat\",\"cousin\",\"book\",\"cow\",\"blue\",\"change\",\"thin\",\"orange\",\"brown\",\"corn\",\"now\",\"man\",\"table\",\"cheat\",\"color\",\"pizza\",\"work\",\"full\",\"orange\",\"clothes\",\"help\",\"mother\",\"deaf\",\"letter\",\"cool\",\"hot\",\"enjoy\",\"secretary\",\"africa\",\"same\",\"basketball\",\"shirt\",\"letter\",\"cow\",\"candy\",\"orange\",\"need\",\"yes\",\"like\",\"brown\",\"candy\",\"tall\",\"meet\",\"pull\",\"many\",\"computer\",\"dog\",\"basketball\",\"tall\",\"no\",\"work\",\"need\",\"mother\",\"computer\",\"cousin\",\"decide\",\"woman\",\"cool\",\"language\",\"computer\",\"dog\",\"cook\",\"help\",\"mother\",\"doctor\",\"same\",\"thin\",\"apple\",\"language\",\"accident\",\"many\",\"white\",\"cool\",\"white\",\"orange\",\"africa\",\"shirt\",\"mother\",\"clothes\",\"work\",\"cousin\",\"last\",\"paint\",\"help\",\"medicine\",\"same\",\"deaf\",\"how\",\"secretary\",\"corn\",\"blue\",\"graduate\",\"want\",\"shirt\",\"decide\",\"want\",\"drink\",\"who\",\"hearing\",\"corn\",\"pull\",\"right\",\"before\",\"give\",\"hearing\",\"give\",\"want\",\"cool\",\"hat\",\"enjoy\",\"brown\",\"year\",\"pizza\",\"how\",\"candy\",\"same\",\"table\",\"paint\",\"all\",\"want\",\"many\",\"shirt\",\"short\",\"bird\",\"study\",\"shirt\",\"give\",\"man\",\"purple\",\"fine\",\"yes\",\"school\",\"chair\",\"thursday\",\"now\",\"like\",\"many\",\"bowling\",\"thursday\",\"right\",\"thanksgiving\",\"son\",\"full\",\"son\",\"paper\",\"jacket\",\"thursday\",\"city\",\"same\",\"woman\",\"study\",\"but\",\"cow\",\"bowling\",\"hot\",\"deaf\",\"mother\",\"thanksgiving\",\"book\",\"shirt\",\"what\",\"now\",\"paint\",\"birthday\",\"cool\",\"thanksgiving\",\"finish\",\"cow\",\"enjoy\",\"black\",\"many\",\"fish\",\"tell\",\"computer\",\"pink\",\"forget\",\"black\",\"bed\",\"later\",\"doctor\",\"many\",\"work\",\"graduate\",\"drink\",\"dark\",\"basketball\",\"cook\",\"graduate\",\"computer\",\"table\",\"cook\",\"no\",\"decide\",\"brown\",\"pink\",\"son\",\"thin\",\"wrong\",\"pull\",\"white\",\"woman\",\"decide\",\"brown\",\"thanksgiving\",\"decide\",\"tall\",\"dark\",\"pizza\",\"corn\",\"want\",\"deaf\",\"wrong\",\"mother\",\"what\",\"black\",\"cook\",\"woman\",\"dance\",\"medicine\",\"give\",\"before\",\"graduate\",\"family\",\"dance\",\"cook\",\"thanksgiving\",\"right\",\"bird\",\"now\",\"year\",\"drink\",\"mother\",\"son\",\"cheat\",\"pizza\",\"paper\",\"finish\",\"kiss\",\"short\",\"who\",\"no\",\"candy\",\"cousin\",\"paint\",\"drink\",\"computer\",\"brown\",\"chair\",\"medicine\",\"like\",\"dance\",\"go\",\"fine\",\"walk\",\"fish\",\"table\",\"cheat\",\"enjoy\",\"go\",\"bed\",\"forget\",\"bird\",\"wrong\",\"before\",\"who\",\"eat\",\"pink\",\"thin\",\"tell\",\"change\",\"drink\",\"drink\",\"thin\",\"all\",\"full\",\"thursday\",\"play\",\"shirt\",\"language\",\"cool\",\"forget\",\"son\",\"family\",\"full\",\"dark\",\"dance\",\"purple\",\"study\",\"hot\",\"son\",\"like\",\"cousin\",\"go\",\"no\",\"jacket\",\"short\",\"fine\",\"hot\",\"apple\",\"book\",\"no\",\"fine\",\"hot\",\"book\",\"school\",\"drink\",\"graduate\",\"cousin\",\"fish\",\"thanksgiving\",\"same\",\"who\",\"thin\",\"accident\",\"secretary\",\"want\",\"hot\",\"yes\",\"tall\",\"hearing\",\"need\",\"eat\",\"son\",\"go\",\"dog\",\"kiss\",\"cook\",\"man\",\"paint\",\"bed\",\"pizza\",\"africa\",\"decide\",\"birthday\",\"meet\",\"white\",\"pink\",\"forget\",\"thanksgiving\",\"go\",\"cousin\",\"black\",\"wrong\",\"deaf\",\"brown\",\"accident\",\"how\",\"but\",\"computer\",\"africa\",\"letter\",\"who\",\"medicine\",\"walk\",\"chair\",\"shirt\",\"study\",\"hearing\",\"cow\",\"graduate\",\"year\",\"cheat\",\"forget\",\"yes\",\"city\",\"secretary\",\"orange\",\"before\",\"white\",\"give\",\"wrong\",\"right\",\"eat\",\"jacket\",\"drink\",\"finish\",\"time\",\"want\",\"many\",\"tall\",\"mother\",\"later\",\"black\",\"candy\",\"cool\",\"pizza\",\"black\",\"clothes\",\"table\",\"dark\",\"tell\",\"shirt\",\"bed\",\"bird\",\"cousin\",\"what\",\"africa\",\"need\",\"work\",\"time\",\"brown\",\"basketball\",\"fish\",\"dog\",\"dance\",\"cook\",\"drink\",\"year\",\"shirt\",\"thin\",\"finish\",\"can\",\"deaf\",\"who\",\"same\",\"time\",\"all\",\"school\",\"kiss\",\"pull\",\"many\",\"eat\",\"short\",\"give\",\"no\",\"walk\",\"dance\",\"africa\",\"woman\",\"last\",\"dark\",\"city\",\"need\",\"fish\",\"bed\",\"kiss\",\"thursday\",\"cow\",\"blue\",\"family\",\"deaf\",\"play\",\"right\",\"deaf\",\"hearing\",\"now\",\"man\",\"blue\",\"but\",\"finish\",\"accident\",\"hearing\",\"pizza\",\"hot\",\"study\",\"now\",\"letter\",\"secretary\",\"yes\",\"many\",\"thursday\",\"want\",\"many\",\"thin\",\"paint\",\"study\",\"no\",\"apple\",\"dog\",\"bowling\",\"all\",\"dog\",\"doctor\",\"color\",\"orange\",\"tell\",\"help\",\"cow\",\"all\",\"last\",\"doctor\",\"give\",\"year\",\"brown\",\"thursday\",\"tell\",\"hat\",\"help\",\"later\",\"drink\",\"cheat\",\"pull\",\"tall\",\"basketball\",\"like\",\"enjoy\",\"short\",\"candy\",\"can\",\"change\",\"go\",\"purple\",\"dark\",\"work\",\"table\",\"pink\",\"what\",\"pink\",\"like\",\"son\",\"help\",\"walk\",\"paper\",\"decide\",\"yes\",\"before\",\"doctor\",\"how\",\"family\",\"orange\",\"computer\",\"white\",\"white\",\"computer\",\"help\",\"no\",\"cool\",\"corn\",\"want\",\"computer\",\"mother\",\"thin\",\"apple\",\"color\",\"decide\",\"work\",\"like\",\"paint\",\"right\",\"thanksgiving\",\"year\",\"blue\",\"later\",\"paper\",\"pull\",\"change\",\"thanksgiving\",\"many\",\"clothes\",\"chair\",\"before\",\"now\",\"man\",\"book\",\"language\",\"later\",\"hat\",\"city\",\"cousin\",\"cool\",\"finish\",\"fine\",\"play\",\"city\",\"go\",\"jacket\",\"how\",\"school\",\"chair\",\"corn\",\"bowling\",\"black\",\"paper\",\"tall\",\"bird\",\"corn\",\"bowling\",\"what\",\"cow\",\"work\",\"table\",\"basketball\",\"letter\",\"before\",\"color\",\"meet\",\"candy\",\"medicine\",\"bird\",\"who\",\"orange\",\"mother\",\"hat\",\"pull\",\"hat\",\"wrong\",\"pizza\",\"mother\",\"orange\",\"birthday\",\"can\",\"language\",\"pink\",\"meet\",\"computer\",\"computer\",\"change\",\"full\",\"secretary\",\"family\",\"kiss\",\"meet\",\"forget\",\"bed\",\"fish\",\"bowling\",\"last\",\"enjoy\",\"man\",\"now\",\"accident\",\"graduate\",\"like\",\"yes\",\"cheat\",\"short\",\"deaf\",\"cook\",\"enjoy\",\"woman\",\"bird\",\"full\",\"apple\",\"mother\",\"candy\",\"woman\",\"but\",\"drink\",\"need\",\"medicine\",\"same\",\"give\",\"fine\",\"go\",\"wrong\",\"woman\",\"change\",\"purple\",\"birthday\",\"color\",\"pizza\",\"bowling\",\"what\",\"before\",\"can\"],\"split\":[\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"train\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\",\"val\"],\"video_id\":[6343,639,13697,35460,6476,40115,631,31757,1999,1992,17023,18323,70247,8949,70355,65009,8435,17722,69359,64089,14888,22121,30843,19259,6478,68035,63209,5632,63792,51236,57278,63227,22114,14685,57943,14669,5740,45438,42960,42835,12332,62968,65761,70348,68018,56552,10904,8429,8936,51069,6455,13698,10894,13633,629,69433,68133,22115,24970,5729,31753,70345,35453,17712,41037,23771,6839,68127,6334,65241,31758,2999,633,1998,63203,64091,28123,35462,63231,34829,45432,58499,63673,18332,68054,33266,68089,58502,1386,8431,70266,57273,21887,38536,10902,55369,32257,58498,35517,13695,50042,3000,30835,70051,5724,6474,35457,10899,42959,7400,48107,12320,7394,45271,28125,66351,24946,65434,41026,28205,63668,65085,14885,31755,17095,38544,31749,17727,65294,33269,35467,64222,63225,13636,14671,38524,34737,42841,66591,62251,70107,38982,21886,65507,41034,22954,30849,56844,63202,64221,58508,69534,30831,45436,69439,66531,18325,32319,53268,11310,8437,17026,69411,14674,55370,6331,24948,1991,68033,38482,49595,6481,8945,11773,62987,10149,66637,63242,28204,26712,45443,27209,62164,51071,7075,32954,66798,19267,38530,13196,32263,6472,7395,63239,65341,65449,56846,42829,22124,6355,40837,26974,12335,63789,63208,68085,39000,32322,1988,45434,7402,39003,31750,30834,45265,26972,63232,1384,65408,17713,6477,42830,26719,64212,17016,14900,10158,23766,5638,25318,24643,36932,68046,57942,56564,25324,58370,32946,58497,6332,50036,57950,43180,66296,5730,58503,6341,13160,12316,34732,23774,70152,64210,65162,8955,20983,64209,19264,5731,9849,1385,62169,64281,56849,20987,58367,24651,68162,43173,48108,43170,24954,51225,66014,34826,69533,68099,18331,8927,22109,63214,6843,26721,6335,34839,41008,68050,26717,66575,64094,53279,7389,13699,38994,8424,51233,6484,63672,24955,14898,57282,65342,32949,36946,68048,21883,26715,10901,7961,10160,36930,49577,20982,69511,27216,51223,62163,27215,58368,56850,50040,10157,5744,65328,64085,5637,5728,40847,64086,53273,62257,13641,66099,51056,26983,70246,70376,20981,13681,70132,8926,66112,13154,38534,53269,7391,9953,13158,12336,56843,1987,625,24857,17014,31756,24973,53274,69422,13200,40117,69455,20986,20989,17097,632,32953,21869,37891,14883,25333,23782,32338,2000,26975,57284,5746,57641,51224,63212,49186,14896,21943,25323,40841,68007,62168,28110,21954,7963,65161,65084,41025,69395,56845,57632,64097,22130,49599,39006,5734,68132,69546,33279,635,5642,34734,35523,18333,57277,64295,63211,43166,42974,32163,32337,24941,10893,13198,70119,8916,40114,8433,51067,69368,49597,32320,6833,64088,5233,63679,5639,7401,24648,25326,9869,11772,66816,17085,1392,14676,34743,36934,49596,48120,70361,35520,65029,17734,8938,13640,23775,13710,58361,43174,64090,24960,37883,37886,13197,6363,49178,69364,32167,49602,9855,6360,45261,8952,64294,69257,11777,3002,58362,33274,5739,14895,17091,63806,3001,40835,1388,8950,28201,62249,6368,20978,24971,70323,64280,43169,14884,12329,55365,48126,69225,49598,65506,68016,45252,32158,40844,57919,28115,38532,37889,50048,7076,40126,11768,32260,9957,26688,13646,39001,70212,5234,63803,33273,62984,21872,28107,11752,38541,6822,14675,38527,13333,35515,57934,42966,64224,13632,6366,63793,69307,65792,10900,68003,35521,27213,634,22126,62152,618,8917,7099,28108,13167,13696,68070,56839,66639,56563,21885,40836,11767,27217,62160,15037,13635,28112,34825,26980,21942,68171,69343,40122,32959,68042,17020,28109,24638,3003,69430,69238,8942,41029,64293,22952,9851,12330,10888,32255,69282,64287,26726,3008,42840,26739,5241,9945,32945,57647,65409,69402,31765,33283,22120,63229,10892,24645,32325,31767,1995,13216,65200,57939,57638,50039,6842,43175,51064,26981,69290,69440,6834,63205,1387,21890,31759,24642,49184,35506,5641,17723,65891,7973,63799,62170,11316,8948,12328,9854,34830,70310,32333,45270,33285,28116,19265,64082,51227,66098,70359,11770,30832,14623,15035,22965,11309,49176,17721,13648,70016,5628,32246,34824,62975,58488,50037,32248,30842,62988,6365,8921,9968,13156,5227,69281,57631,42972,5630,50043,40123,57283,69241,8434,34823,33268,7397,63236,14855,13643,68024,14633,57940,68183,70379,8929,51231,64275,7074,43179,9967,23772,63790,11311,23776,55366,30840,51232,50050,55364,66297,32335,11305,13323,68190,63669,6835,49181,5231,28214,32154,27219,63210,25339,70211,39002,34822,21951,24965,65440,68032,9856,45262,37882,28202,62964,8920,26724,65439,11774,7967,65043,21949,70207,68177,17017,28203,22962,34742,57628,32164,15034,35458,24636,62172,34738,11780,68028,55361,26976,33271,56842,6337,35512,70230,23767,32249,62171,25329,32254,15039,27207,27214,31746,35514,6471,26713,69531,8951,65187,68086,33282,17710,56841,69370,37892,62970,13704,66008,35513,42962,24950,42969,13630,69345,21870,21944,41035,48106,63666,48124,26986,37888,7957,65242,57291,32156,51072,69298,24943,69213,12313,37881,62254,17083,57629,64284,22953,58359,14628,28122,13325,13707,65403,69389,64211,10161,20976,17729,32956,45435,23768,50052,66606,17015,70325,1986,22116,57936,21933,6367,64292,45440,51054,5634,43176,1382,50041,55375,17076,70237,20992,36929,9950,57288,8947,68178,64283,40119,17709,69206,8915,62253,5733,31754,65145,68093,49180,55368,14625,13214,62259,24952,32324,9966,7393,69269,9960,63665,6845,66007,12317,68053,42953,70030,18316,13634,68011,13706,6840,9955,57941,40124,65263,13334,38990,34746,56560,10146,11775,42967,63802,23777,40121,11315,65889,36939,14899,32951,13201,28121,19258,50044,1398,49179,5230,51066,32950,13701,8909,40118,37887,68192,70299,7962,65300,66592,35516,45267,34834,12319,17087,5232,56837,38540,63788,37884,36942,12315,13638,15031,63667,13209,68084,12318,17092,13165,27218,36944,70049,49173,57953,3005,32160,626,34831,63201,13202,63207,40116,1391,51063,36933,11314,63795,13639,32253,40843,27206,35455,49185,14886,28210,50046,13337,6841,70176,62241,51060,15043,62247,17730,63237,26985,13329,45268,48117,5732,24655,26973,24660,62245,13213,26714,19261,7971,64219,42958,28211,70326,49183,56567,40840,1912,69524,34828,51070,51220,6339,55373,51068,24639,34744,68137,21878,64296,68145,9850,58366,38991,33286,34837,7399,58363,48115,57642,70335,70029,53270,41032,30841,58365,10898,49182,63662,55372,8426,13705,7392,70270,14893,36938,57639,7069,51061,62966,69413,40816,6371,13199,57643,21952,69283,19269,6480,34833,22117,57276,12312,42831,22960,6486,5631,66015,17018,34836,63804,25330,17733,14683,5243,13161,25321,12311,56566,68029,38539,15033,69252,42836,53271,57935,64092,45264,70245,63664,65450,7968,57634,15040,56852,14672,42977,13328,62250,14887,64087,36940,62982,69236,13164,63677,14621,35463,24657,5727,25325,20980,14624,13162,57635,68142,6330,39004,66818,17728,36941,53277,10159,42963,41033,21950,31764,51221,63226,38533,8918,13637,40845,17711,12326,70242,68019,35454,33277,14627,24951,21891,62159,22125,56557,10166,19266,24962,65163,22963,69233,64095,5741,63240,18324,42833,57948,57287,9956,69302,17731,57949,1996,23778,66638,43167,51059,32155,70271,22964,53275,20988,23779,14673,14630,45433,55363,28120,53258,33278,65415,24940,38531,68079,51226,21884,28119,70309,7070,38529,70234,28074,68012,49600,65540,25322,13647,22113,57633,49188,63238,57937,623,66441,62248,28111,64297,56838,26971,37890,18335,66532,24947,17093,31762,13157,66097,40842,5636,69431,1393,15038,6369,35511,63191,42843,22961,57645,24961,13642,6482,64093,65445,7966,627,70295,8425,12338,1395,32947,63228,35452,62158,9848,51081,55362,65884,13708,65843,64213,10151,22955,64288,10896,50038,66246,5749,68182,24649,64084,48105,68044,30830,17720,21955,66644,62256,69396,56848,36936,32326,6483,65298,13203,42971,70244,11313,56556,14680,57285,51057,5629,6326,13631,62944,1383,37885,63791,70356,7960,5238,69325,17090,14622,13168,70173,64218,51058,57933,21941,8937,14894,63219,49174,58504,68001,49606,31763,45263,34832,18329,51235,24652,68110,62173,14631,1394,63675,66010,14682,10895,37879,65731,5644,31752,70026,13703,65216,69316,14903,43168,48109,14882,26984,38999,34733,6832,8421,21953,628,26982,42964,28118,55371,68114,32955,50045,64291,34827,66640,62246,70308,66607,40834,55356,38525,2997,17084,7383,1997,17086,17019,68027,68122,57286,65890,13702,2003,32250,17013,24641,64201,7969,58360,57289,26723,27208,32321,17724,10148,70378,56840,5229,33267,19255,51206,8925,8935,9970,24956,66355,14681,66804,56579,42827,62967,42832,33281,53276,27221,62175,41030,15041,64300,5747,17007,28212,20979,70249,12333,63204,63200,12327,27194,38538,13217,13326,62244,12314,36927,57947,3006,11778,15032,63769,33270,68125,48114,57630,70306,68010,32334,41028,45269,9963,57640,34835,11330,70263,5750,38995,34736,7068,32146,32323,26722,65362,70332,13208,21945,21874,43171,65363,24969,30833,28187,49603,9847,13309,7390,6473,41027,56835,6340,13327,7388,62965,13700,63801,56558,5239,32948,5743,11769,35509,8924,35456,6338,63230,40130,36931,26741,45273,68068,70296,42956,66147,40129,6359,8944,32157,42838,35518,12331,12306,9949,23773,50049,65677,31751,35519,22967,5633,22127,7396,32261,19260,34685,38997,624,25332,33280,64298,10147,66469,70015,13155,19257,68187,6333,23769,65086,36937,8919,63676,8432,17725,37894,35461,49175,24640,21871,65824,68189,66799,9954,45439,70357,69274,42961,7398,62979,5742,8946],\"x\":{\"__ndarray__\":\"TfY7QVYQC8FdZHpBebgoQaGjIsFc+ZDA3zPVQdhVgcGzihK/hLGwPj3knUGz7QPCe74Lwon6CUIErTfBODP5wLKrj8CHSAXBY+xdQYcGa8Fp+4ZAznt7P+atEUI9vqvA5/IawrUNdMEdPfTBNcyvwfgrFELFpy7B5PWrwTQT08FyURDAvpbnQEEbCz9GbaBAgeWuwLpoDcIXZB7B54PAwd5FicDJxvRBBC++wd6GosHbhQfCItfMQTsZ4sBdRofA+3ECQtW/G8D1MLzA9xYhQlwZW8DChYjBVsMMwcS07UEEqvZBYdq+wDMfmUGD0dNBAQSIwc9mCEJpcCRBXMdVwZSQoj/87FtBIHyiwagUHUBW9SNB+QeYP70HicFzr/bAlUBdwTlK8T8FP5PBfH5mwdWj3MFZdCBBK4TAwW3Q3UH7dsM/dT0HQiMS/8E6/wPCTL2ewYTnBsIwHwhCPsEHQtUmw8B7T2lBRtCdQcldrMED+sPBat2LQctaxcB6f71BxQCbwVNQCUJ0mGtB6zVxQSk6k0GzJvbAAiQQQt6tBMHWoIHB1OYVwpXtIEHVZ7vAZQyJwSrBFUDgt77B1OggwDNFrjxLGBNCX3HmwcW9FEKvK4pBKufXQdd4Vz9The5B98bHwVncG0KBAudApFSOwNHxe8EQvQQ/2/KMwbpHU8EXuQNCJXr/wSz0KEFnoCBCscnOwY8RjMECIgJBdAT0PUGLicHe5ejB2RU7QSorB0KcrClBMVXxQSVvxsH5Jw3BCJqWPyJ8xcGA5RhC5Y82QQjN98GhKSFCS+UKQiNb0sHvBBFChtIHwg50FEJdNgDB/vkBwsr9esC35gXBsOK8QUDXR0FGEJpBrS/av/ULvUAAxcpBtYcsQUAVlEEfQp1BouyAQM7Btb/2t8lAEhEZwjprB0LtZpzB/OL/Qe8T+0F4WOHBS+3MwYv+80FpBBrCXb4KwuVBakF7AKpBTwjsv1P+o0F1rAxC3tzKwYoCysCBWVW/egbwwHVehsF+GxjC0tQDQJRk1cEbrANCD1c3QZyRX8FyBvDBHEVEwLfbo8EgD01BP9dmwQrcjcCkThxCs/CcwRXOxEHyof9B5DzcQNr0nL9HwwnCuw9YP5KaCkIqSfrBEfUKQtX0EUKsMXjBIgfKwZ+GN8HEw1LBm5dPwSl7NcFx/erBxQ0BQWVdI0KwSZdBh0WQQDceAkKJE3hB8kTRwfEQB0FGBYTA4JjLwauws8C/ieS9F2zCQbEOTcE3yebBzAUJQqK2CEJGjTFBIJ1vQRuVQr7eXPxBuT/owfl/1MAqQwpCgNc7Qbn7xD+upV7AINqDwR8TZ8GS4M1BnSYhQg+v4sEuUP9Bwru+wE4MI0IQmbfALBeFwb3Jp0GmRtLAlGOuQdAOqcBHbSJA5G+UwPuE1sFkZXjANN3JQSTR+EFY2a/BjOjzQUU5lkHQRjDBNFzgQGXfzUFebfbBbfg0Qe+HBsIt3M/BjOfZPrFZ88GR6+rBTmMIwtzoK0HpvdNB//c/QGJcxD/AOM7Ao6TNQTZqTcEidRHBLC2PP2UFikFJhwZCef44QZv6FME+MxnCjnCEwQmzlUE6uXxAxUy8wTJi/0Fe9QtCpsjHwe7+0MHH+IlBYeYWwo0olMBrZ6jB+VYAQl10xcGiX7pAEYajwM2KCkIRbnVBGLQfwWnkr0F0vXpB/9uMwXeZIkCR76ZBJHWwQcFZ7cC256lBUVhYwbyD2sFhRW7Bwj9UQYOYgsG5LwzBwWgNQimzecED1YbBX/7wv4rFgMFi0/bBsVHEwYpvkcDy7oJB3EJdwYGVBsL8ai1BY0ymPbsonr96XArBpF0cQVz7EELPDbE/WchpwPPKZkDGQ1e/l3hFwTA8oUHkbJxBt8FjwfO6nkEPwvbA6xO0wJT/kMEnb6fAbJ7qQA7wicA4MLPA+QuBwZiWD8Hg9g5Ca0rMwWc74b4JGQq/YNlvQUTGfEE9ktRAkOjWv6AzgcERM6nBqHV3wWMGs8GVpijBSY1gwSggEELSTYRA21jUwXdpgUE230hBGoihwXfMqkFEftbBaTZ/wdxqq8FO3tzBSHsBwSfhfUCH8YBBdPlfQc83H8Hy5H7BS2wywKlU20BakgtCLrejwUsk4cEGpqRACOv8wQTNBsFYf9bB8niBwbhGKkFPLgLCiAaswcec9kA7A/zBdNjzQazvOMG29MFBeLffQEUImkEqkLvAbZmqwYjMSUED8A7COZPSwA+Lf8DXxGnB5SzfwcJgtkBnp+dAqa7owW1SesEb1MlAhAPGwYUr28E4hw1AD0mVQcPCaEFJ1aRBN4KOwZx1JkKRyWvBBeM6wdBTDEHDwI3BBFZNv1LLsUAoEkLAT+WLP6FsM0H4ip3BuJdEwTt7BELXg3XB8OtcwfqogUHRYeHBZ9HyQQ+QZ8GKppLB9u8lvgTfb79XfZLBa+sTwlv+uUDbOGVBhfTOQbcGo0DB+KVB2Da3wV7nEkJStwdCy/H9QAaABkLXo/XB8I6twNG4uMECtJ3BBFjCP2AGjUCZuy5Am4YbQvnzB8Gv1SrAQS+lwW38BUIIrO5ByoUFQqf7FMK7+5zABGmmQYelEcKQPtpAFnr8QcwXRkCttP+/YVXLQRtV4L90pdFAMHfHQCqHdsEKCWRB4HnxQSQu1kFvQFxB6IMkQovdecFcdqI/4B5RvpNznEHC4qJB+ohewBCX8cFA8IPBiigQQvsn0MGF4ozBAJcHQlRlFMJ8B8FAg+8VQoCcwUBeJf1BHDDuwesq4sHF8ZXBbwqqP/ZElcEefutAAVkGPxRWSME7eX5BsgSuPqyvF8FfaSRCZsmHwaOzvMFp9RdCytIAwiYnbUHf+o7ABfL7wLY+NUF243pBcSzZQSzWpMDF2qhBXSnRQRQ06MHqGZ1BLojdwadfQkCZLHdBViaCwTM4SEBr3ufBvofOQbHexcHMzhrAWSDywSgtQEElR6RBWQgrQWxBisEa9ePBgHXFQUWln8EIVkg/F0kHQjbZosAZ2J3AjMWLQfcJWsGvd6NB1KriweTBmcGBZgjB+9TpwSmn48FPlAhC/0cUQJjlx0BmQ77BcjCkQbpO7r82JXbADyaNwYPMb8FoqTzBzucbwoNs5sDRZvDBTacbwkaxw0BnwQVCgAUPQkfoRMHVPSbBmgW/wQWZg8ELnf/BIlqnwEA4vcEQ+6vAAU+fwNbH6UAJSIjBtZnpv/nipsGophnCpOcjQjBng8EavZtBhWfiwafq90G8ZGfBy/GCwZzIBEGBaArCEz3bwYwr/cFVmK7AKcfFwR+/j8EVQIjAtPqwQG1DhUFaftzBzRCHwclobkH1X6zBQqQZQp0drkHtrMRBgWcFQsvQHcCmlqlB4ZXDQf1cgMCWGOxAvPgWQrIMAMLL7nnBgfXhwNQHXcHkZBzBS5GFwYqgvsE2eOnB6PUTQqW900H7fDhBR6W2wT0mw0Exas5AoZgHwfdzhMG/7xLCnDPUwfbcfsFIftdBCkcJQo5gCkKbEqNBd1SAQWgNFkLJ98fBXxehwZmGxMFG6g5ChpJFQEVy5kD3bh5Cash+wV1KL8Ehc+TBPzmYQTJAg8AQeZ7BXR2iQW2wSkFzZNdBUPcAwq/V4T+M0cTBwBuwwbrnhsHpYcFB9AzSQfc3mb9P38/BoQkOQpnuCcIJQ5hBjvJSwKCC00E9bfxBMlQSQmXcasHX3xpCzxS8QTbpmMHXVs1BxWgRQql1ocBVyadBXLHIQbBNHsHbTNdAXejGQQdtF8J6diBCqZXLwTxk8MFRT71A74C9QC269EGiyMNBojZaQUnr+MFgXWRBAYoIP9NhCEI2IMFB00+JwR+ioEFQPQVBro7ZQUs/qUFNCxNCf20lvp6r80FOt/5BiJDdwbbhGMI2Nv1AARwHwrCTtsHC8OY9G4GGwcRQsUB8u6JB/hWnQbZe8UGUgMbB8xpvwQIzh8FfQ8FBaeo5QTzmN0HFfZXASlGwQZrDh8FSiPPB8OtvwNwmwkFVKXfBTcaYwQ7oYUHwvSNBzJQ5Qbd6oEEdwnVB1DOFwZzerEFCFnZBBjZ4webAQ0FiYWhBnE5kQYXChMGG3IlBr4sUwrK1FcKEPgBCD5MEQmCQP0EEHYHBCG8Cwsm6Z8FcbVbBkL/wQQN4SD9Ad/tBCal4QbPZx0GEWitBWxImwTe3eMEmA3bBu72NwXTQlEEP883BOKZGwHc9BUCG4Z/BKXbLwcQAp8FW9nHBi3Kev/nIRr8a759BHKnQwZMTwUFMau+/jsl/wQirZ0Fg2/DAZUEEwIDRPL9e8QlCw9lwwUeWN8HNqPpAbCa2wVU+oMF3T9RBNNPmwS6zKMHiKJFBhUO9wQFi/cGFGCNC3o8BQvslgsDSL9fBe5UPQrrlDcIU94dBA+6fQeG/+L56rp5BjL0QQnnvFb+slvG/IfIPwbL+nT6FHLXB8c7IQB5TC8KR6FrB+zLbwarV90GEdDBAR2ShQUAqxkFgqG7BP2CYv8CZ08B8+D6/wNkPQoAPtMEMyApCRg8OQi7RFEGTXYfAcHZOwfEOcL8N1QvCfW0KQoZpD8EB2YPBF2DbQOLjA8L3qfBAZxy7QYUV0UFA9rjB/i8PQmgZicHdEexAWtcJQtffrD/sdrnAEyoNQt8uy8EHhOrBeynKQTSRh8C+6cDBL/N1wA+0SUCya27BmPKMwVUgmsBvvYJBpeemwZpXEELQTue/CSxwwNbetsE2SGnB1pkIQlv4jMFW+M1Bn+P/QTn8BsLv2ra/ZJ0YQpxXYUEM7rjAkxfGQUK4aUEvq8fBcLNXQGcoDUIRQIPBpM7RwWHgu8DLE5NBcog1wSwRhkD3ZbZA2YZJwRIjFkKrAGFBz0egwQgvlMD/UIW/X2HpQOf/BcJHjrHBp83ZwQeOK0GzVopBKD0QQlvi1EGS/LPAr36GwQWjskBWOkNB0aWoP9LgGEIa88a/lunJwW+uGMBCOIrBKhwqQbLVpUBwyqXBdo3fQbFXYsCBN2HB5PiAP59YWkGFKsTBTo/aQFUDwkBBvAc/CsIOwTiDwUHhatpBdE5Pwaga9cFLBZPBZ5t7wYmCf8BrYjrBxaZoQVuUxMEmM7hBdeMVQqqOjMG2GYbBXVBMQVU7c0HXWSdBE7HCQJ5jhEATqYJB4AzRQGjCF8L3GebBN2thQYkXCEIOPmdBdFg3QTnuDkIEGUzBIQrFwfyveMGSWFzBRAEWQh0NpcHT+6DBWqWHwDZcX8Etu3rAhj4IQgMbb8EEvMnBMP2zwEUQwcE8HB5CLowcwSRy70Gt2wrC7K36QNTGx0EXuU9BgUW9v9yWCkKf3t1BZYz0v8ZrMMF7JThBWnnIQZpmcsGAeda+IcWLwWATCMIAO4DBFB3wQPyrvkDtUqxBQuqxwXtdBkITjAbC+125QU037j9D0uDB6SUjQlREccGwaADB/GpZQdk/CcG/1vE/LrAWQs9nZ8H5f17AUDPtQP6mbMGyLsZBezo+QS/CfUHZBBhBVA7XwfbV579pj8fBGtN3wc2lzb5wFlnAo+39QApaCELy12BBJMC4wQqsp8GhDEXBDjmzP39ghEEn+s7AcA4awsgNz0HKFdG/l2OlwXVqP8CXee7B/H2pwb8XGsJodNXBpFeJv8oqmUHhgNRBVgIaQoSyf0ErTkTBk/4PQeon10DuJug/YppmQc4WDMBp6c5BE38pQIv8uz5VjzpBDi65wbpS18EVHgXBp5NYP/QwXMH1O9hBP5ASwjmawsH+rkNB9mzDwZrhlsGNI0ZB5WM1Qff2qkBOdojBhCcewelLCkLKVx5AhpZpwTJUxcFWFQBC+y8Wwngj/r3yucjBgtvSQQySHUHuyLnAyW8HwbphiEFPPcHA6LjYQeRAOz9uw4/BZ2ylwYPnOkHLPgtCglcbQv86qj9F7MLBsa4BwULY+kEgJ6bBo5aSPwUNgL55zYXBp3IiwUi7xMFffjXB+HzgwS7bj8HUZF5BHWR2wXIfPMB9CLbBsemlQSL7JEFK8fnBMyPYQfrufkGXANfBecmmQcWFj8BMyMpBJXMAQiKCs8Cb2Z/B5FLkwVcCvcGsBjJB2NxcwV0ajMFeHtPBmNYAwn7MxcH4yAvBbYa8wdh4F0Kwjg7BxxVNwRy/C8F5G4K/jRt9QXyR4MHi4gBCjnEbQqg170GOXAZCIIS5wTB1/MARg7DAwD2GQRndGkHM4NNBIFaFwbXRPkGosQnC6WCZv+CjAcIKmIbBCbKlQTyLQ79HwxtCSae4QNxpy8FGAArC/AzgwYuIZ0Ep4ai7ZOO8wSiTfsE+Gj+/CiuWQEJKd0Htl6VBK02IwfDy20FGxonBQDeoQRNg0sHDOEK/hqEiwa3Tr0HMZwVCayc1wZD/AEFmPSRBISyEwbwbLcB2D+PBnMwUwc7aWUFe4obBgdyIwXapVL+luonBDxVZQeSKl8AnB6XBXG28wZ1dEMF0HQnC8S5HQTs8dcGvSfDBTOTBwS+hi8HlzJHBSC1swVWJy8GkhYbBoH6jQNqAt8FRdVHByldRwfRTWUFLKUPAbi64wWKdY0HXTMTBVqA3QbgD0EHxraVBCn4YwH8IxEG1l4nBBi1kQaH7ckHFAx9Cn9vhQVSLtMHrJqhA4LXnwHGbDUKT253A4FSfwa0nQsEcsnTAJ6BywZc8+b+buJjBki0NQsxUBMFCQ4DBVN8IQscoBkKjmAdCUFNfQbHGxcETbM9AyMnLwVne28EV/tDBMuHCwOTwUcGkSb5B4RvRQZKtAcB6Jl3BIFzuv9R00sGqKUVBdfJjwRDK4UH3c6fBFhkuwHg0GEI/GhHC8WLCwb9j4EFEqL1A6kKlvg6l20GTycI/pGDFQdeeG0KcljLAU67cP75Sgz+gdgFCirfEv5GuZMET5n1B5QILQk1JC0DnYLpA5k2OwWJuzD+EsutBRbibwaNxi0EFSbHA53jOvyqGrEGoR9RBvTe5wbJqz8GBdIfB8pHwQFoJG8IW+prAglZZvzb14MEk2IjB48z7weDLhEEX9ufBGxa8wHL25kAc3ptBnRS0wfORn0C+93TBuDUxQbNpjMFUSufBH1dtQd/fij8w2UfB2LNvwWsfPMEd6OzB6czCQZ+SBEKnAQxCnErjQWzk50B1hclBHFGFwTxGisHIDRLCKlNMPgkvjsDlRb5Bo4pCvr4SF8FKR2/BMYwVwHpSEEBbA7rBQvGhQafAkcE4uKHAtlhewbRuXUHw/mlBg2rPvpmLBkLa8aZBWcKqwPtMHkIc37LBhpbowWfp1MDhkhjCLwRhQeHNRUB7QS1BcrTXQV3++8GwkeLAFJgAQdtrAsLmlytASl6DQBIIs8HYPQRCr/ARQpbcg0GjWA3C4fkSQZvLFkKUlMxBtayYwZLQFMLpcOzBWiACwqEc8MBGYXBBDmKoQePTNUBs+BvCNuHzQE4HYkG3upRB3QwSQiDjvMB41fLBuzP5QXtv7cEWzvfBdvZvwNnZVUHpRCVBcxzGwX+5FcGLqfpBodcxwOOxkr9SPhxCIQkBwQmJ7cFNw4VBqHcXQheyokCCClxB1SZLwOnlF8E8hJxBOzxxweisyEDcw2rA/MgUQr4G+0HPihDBAWLUQSLhukF42zvA9i1yQdXpBkLJoLS/3YNTv54LbUEax7xAMIpIQbkHtMABUphBeEuRwcN/Qj/YFMjBHWvvQQKg08Cpg5pB4DMQQg2pgL6nb+NAVGM2QeAXdcF6MS1B8Y4awsim2r/OC2TBCx4lQbpjncAijxlAvk8DQtLQiEGXihpCqGPKQXuduEHpG5FBhrSfwa6k6cHCQDNB0VWhwdpydEAoaSlBil7PwWbr2cBZgMrB34YXwmTeFEJDOovBchrKwapDtj7zI7zBmxjXwIPRrMFitARCirfAQZU328E5lolBS94kwDVGRcDl2wdCASyDQV2ps0GCnODA0QmDwQ6lIkE0zLvBx3DfwWuej0EY+ZhAFb2pwRRSqsCxVGnBpOAEQqAk9sAmD2xBY1T/wVtcpkHBWodBZ3Y2wZ9kzEEEEQhAPSHQwO1EhsFPjjdB32RwQTKTHkL4cLzBfwAJwmmt98Hx2X/AqXKJwb0rD8A7hh1BWJXCQBwrzUFAP6jBJIycQQwOa8H8QMTBFIIfQmlnDsKQzeJAxFYFwskOQ0Aeo40/0U65QYhri8FqHAhC\",\"dtype\":\"float32\",\"order\":\"little\",\"shape\":[1533]},\"y\":{\"__ndarray__\":\"bt7Rwe+lR8F82rTBB0SFQSoWBcK6menBWoiLQZBW8EAQnpJAZqyLQKWxQUEEJsRAZBc6wZCKw0Gcc9ZB1JhFwVhRYEHP8cLBdhoUwCn0hMFhAEhBCO8awfUlWkEB4eLAfezAwbNQEcDmaR3BKcpYQecsrUFLkthBMSnswaDM/cGdcirBEOurvy+q8MHnWVW+VXwQQA1aVcH354fBQl7ZwXVy2UHdswjBFdeMwaLWUcHs3sHBSYeDwCRuxr5dtlRBb7W4QbePocHX6AHCeJSkQWRJvz9DvkBBPdNKwfLMgkHxg35BugkhwcRv4EFw1pFBiua/QL1wjMHRqY5BnZ5sQfn7n0ERH0/B6NoOQHC8qEHAusfBxLI2wdCu+UDxeQnC8J6PwW5mE0EEbqvBGyqKwRumMMFO/YVBOVHlwTAFesGOWP3AY1zVQcnqocCce8FAltqRwT1vuMDLc5ZBq0fOQYrAz8HY5OtBNKXKvyS/78HbhOo/nYMFQglplL5+XzNB7naQwWGt1UGKGc1BMd29we6ypkGUNQvCU91ZQfw59MBEq5fB9aSzwWN7gUHbjH6/oUp6wSebZcHE+7VBnXDRQW+OXcEj+8E/hP0fweNJtT8H+/FBV329Pm0prEF22UFBzpgwwLRt8j4vsCBBfY0QP0cNMMDy1gxC4Df0QGjFeEG65rZBtb60wM4Mi0F4zDxBv5HxwQbEREF74CDASQIMQssCr0EA18zBBoTbQb/AE8HMSNvBMDx8wYDu+T+UUflBYHmlQf/BjMGJe1BBsIHeQYf4PsEe3jpBHZTdQZ1W98GiAVRBB5hBwYZICz8rDapBtFnKQOyTMkGjlqFBjUyoQL1++EEMezlBMvMMQsMkYr+n3zRBVrPQwf/K7EF6GlpB/t1bQZwwBkLZ+ZpBZFrHwXwptkEEwrvBIfQMwTmNtz/hP0FA6/32wVrCSEHAXpzBrktAwStOUkGVBlLB5ZCawZw6AsCG44pBgyOhwOMC7sCb/ARCaHkQwm4TN8F0EMDBrDBNwd0p9MFHemQ/oMqtQPs8ikBts8zBs9AfwaDtxcHKub3AdQkGwpCp1kFu+6BByMnGwZBLr0Fx/3zB12DdQbRQjkDzxj/BFIZbwfMYhcG37eZA2hFhQc89wj9MsgzCqH/uwRb9CsENdRPC3sJmQU0tCcI37MXB+dSXwXdPMUFPHkBBQehfQeAzyT9x/kvB4HZMQfzenEGBaUjB/4AZQT2h5cCtZPXB0rNfwGBZx8H4AiZAOSGKQQss20Hvyc3BImnWQU268MGHcYRBeEnRwVSkIcEUZ9BBIVjXwc/qAkHAlcxBN+GvQdoY5cHDACC/FD8uQR85dEG9CL9BYzSTQJlAOkGPKfHAdiOdwbxF9D9FR8rB3Y1cwauC1MGhc9hBokieQNDtnkB20m3BXOMsQcIgiUH4FLZB8R1+Qbbn5UGWZttBIpPUQYU3b8H2sR3BaGXVQZhrez8FGNPBDcpCwQVCK8HqZPRA995bQNyQ1cHlUG7BVg6mQYxaFsDUM8nBrDIPQT39i8EG6Z9BGMBbwRPnsMFmiXnBetz4Qde0y0EifLzBLITVwQb75UG2KU5B/yrdwV1Slj9PnYlBs8oHQRMazj9fI4xBawCiwbXYh74mNYlBo2mMP52TLUFg4qFBVbSYQDbH1kGnI15BUrfOQfTmUcFby1tBYuT7vzRW2EE1s8hB9pc4wUmKLcFtNA9A5cZ8wa1kgEFxGJnBYEHQwFs9lMH7MatB228awZ9NeUDX3rFBxd2VwRUnEMI8S8rBugCBweinjkA3iLLBKz55wVoDxcHXDM5BJinsQBwxB0Ky56RBIdBAQeJp+UBuz/dAAD7fQcyE1kF4qKNA/66YwTL83EFrMzxBTbrDQBK15kE1xp1BJYPrwci29cE3K/rB1ZyWQVrsnkAFmqpArfbIvxMsRcGv6pRBzTY1P1BPc0H29k9Bxh6PQTptTcFfDuJB7JTcQBG/DsLy4d3BvLyTwf7CzsGVRdNB3i69wd/kz0E1mFtBHKE8wdsKn0G2pbXA/L9WwdgNWsHciBfBu+26we4YakENK2ZBYC8Jwlo/n0GWSYnAWiDgQaLw/MGTsIfBkvApwdbLn0GEkoLBFf9kwSbpy8FJUqbBNkTIwFZiT8HO/3tBe8qrQbh6yEHdl8lAHhDxwSOZscGfKxbBLD6DQRrJgsHAaLhBP/fgQfgO3kGf6hi+fgKZP7zGj0AJcrrB0j7rwXDEXUEH9d3BKZwawVVfnEGyldxBE0bgQJAcjcFQHxZAw568wOeuaUFtIFTBlMUzwZJDn0GPyvA//smqwc+FLUGnIiDAd1gSwQ1I+b7Vda1B2OC/QZQnoUFUpA1CDERAwAlWy0H+T45Bc2VoQeNWw0GYj4FAW7PfwW52vsG0Uuw/DA6IQZpekMFEWInBb892QWmBb0GnUPTBmfuqwHGjJsHtN1lBt3GhQVBLnUFvuNY/PKi1wWTrej/TJL9BLnKvwUK1tkFGba7Be1AGwpT6tUE9C9XBcRd7QM93WkFb7G3BPkCqQWy/DcJ5sSU/zsuVQfKYwkG6R0BBGEdBwXwz/sCpL4xAFpvfQfaWC8GErbfBJPFsQarHVUH+K9NBVfE2QQkuEUJjgfY/pjSfQfV+9L/WIO5BwR0uQb2RlkGJ8dLAHdZWQVLtr8EcLAtCxqdvQc5YvkH+NfK/Ntrxwd5VqsHn7jrBMiXmQCH8mcE4y1pBJ+2HwZJAGME/q+w/H2qnQST7O8FkgxXBOn7OwBORJcF48L7BsbgMQp51FUDyzP6/ebkNQiO6B8Jhuc5BbOv0wRgtjMHgADtBa65UQSsxq8Gi96hBc8zJQJgrPsGLYUW+HQQMwuaBxkHV/F1B6ASKQV0EHcEuh0HB9JiNQTJavsGMJee/u+ccwZDQy0D9abjBjIwTwlFx10G4m0NAoz2awL95tkGHSPI/TKCrwdu4WUG5vkrBQg1hQFDrOEGdZwvB6P+Cwcs35MEHNwrAMNfXQfNKgcF2RfDB7iX1QQ1NWEFlBDlBQ4MuwQjOBcBiVwvCQsjIwQj2BEG5pcNB/2ykQTWNpMEFwH3BmRKNP26y2kHxBVG9CLadwdWPAMJSM/vA6MCdwS3jC8K5C8jBetWjwZqZOkDHQSpBo6yMQc2W/MF71Zc/zUUWQcXB70A3zdbA+P4bwUU08sFTlKM7EUeLwY5Y2EE0ig1B8giBQIt5JEBJhLrBxchHQa55sMGq8L5BGqn+QM8peUEtHtHBcgcTwpHbeL9fljbBXksKQSuAHMHvGddB2h+5QW303kA1glLBEIstwYf9z0FfM29B5qvlwaYKSEHxcHlBylegQbxBacEwO8lAoru+QSRczkGBMSdAGMWDwX0m68EIBN1BUI/IPrX458DVMNjBvtX5wFnAiMFG28ZB4ImsQR1aoMH8wqzBzMZWQV0h3r4aPZNAIkKSwUkOsUBFjjDBlznCwYeAPUEBsRPBCQFKQd7nM8F6Q3DBfxZuwW+fy0EcarxBXhJpwapfWEFggcFB2ejCweD028E+9A9BVw7OQCfMF0Cu4/FAUPOrwTG+isF7dXBBuF+pQUqo8MEer+XBdh7tvwXS9kEwC4rB2wq0wAZ4VsEQpf7BV0HXwVpTTUFETL1AgNgGPqL088EYePrBLn3ZQNfLusERq0PBeabxQSFWTkFqg4lBvGzVQMUP+MEGMalBg9mmQGaasEG7zBVBcdhkQe440EFOvcdBkgZNQWu6gcHN7dtBYbPSQPx9zsHr3yVBAN6xwK+LBUDVVDzBuqIbQBGjN0HJQKxBhZJWQfxbMsFmUqRB+xWbQYsHiMFwR4XBxjSnvrXW6UFlsKK/nK4cP0p0jD+A2Ek+LRNmQZdtS0HqoBbBOx3FwaQHocFw+7G/dxenP+zvdUEtFUtAt2M+PmbvlEHd6EjB2KYyQXN+PEGKZ4rBRPioQZLcs8HvprBBxW99QCOSe0H3/IfBqqRKwSfIqUGzAbDBNzXRQc3WL0FyKhHC5HbTwSuT20HxnszBeMLLQQZVzj/IAUfBY4M6wSTYSME7tIlBF54wwTRgoECDv19BrtJfQUu/okAoT8hBBJC1wfQBosHyFQnBCLOzQXek0MGeBTHB6xPLwEeKZEEzB4pA1cdYQSZibkHBsBDBi0G8wQJfr0E7CsZBE7GEwbfKXL9/IvLBgkVOQRBs20FJ5QRA7sZPQZuSoUHAiaZBIVaJwHCPuUHfWBDCWSNlQZDbokEhEFjBuCvgwSItpUEC/I3BYvwAwNuj9kHBUg7CHujNQcB/fEEByifB0m48wAUABMIuqrTBI9aPwXFstD8XA50/LAoTwbjKnT/oqbTBEqgZQJ/228C18CtBY7xnP622YUDcEgPCqbFtQeNGN8FSO0bBLte6QUre6cH0WzdBNT+KQYlSiUATQkDBlV30we16xr+RHLrB3/GkwcM4S8HA89XBtgR5QVpWhUE5xQvBMxizQeSHK0FXN5q/4sGEQaUPrECCSoVBk2LWQOGX68FgBr5B2aobwSYRxMFBlvXBFmZhQcpXnEC1lb3BjBoqwZZpNcG+2J5AqiEzQDeXvcALvi7BLvg2QRKNHT/XcTpAFvYVwXgAlEBmHSFB/tQSQWjMRcGpHtC+ZCcMQYYDbcDVxd9AGm2sQV6KzkFBT4HBYuAwQavK3ECqn8TBOWo5Qf18HUBZfbvBmiZ2QCzG4kBkBuHBwKzwwWvaZUGSHQzCFDCDwUfIpkHcCK3AubzIPuPZrj9YJQxC0zKtQUJKRsGBq/nBzfnRQOkOQ0G7iP9AwwdWQVxBj0H59cXB0rogwRGLAsFo7ntBfQQAwVYhMcHr1bE/cuPMwf9MIUFAmbvBeQy1wZPP8sGyFmVBEga5wZGi3cApQXlB+Oe+wR9m40EeOshBvzFrP6Z7Y8Eo9sZBIbZywD72DUCDJeJBPGkNQsb3r0GcHXlBlSDzQHu/10EsjRBA82VeQFWkM8F8rds/QuuYQUDG2kHejC7AAYDaQMlOUUFOdhtBIupQQWK/HcFjie3BL2oFwmV1rEFKYopBcDarwftFLMG+wfPBgU3OwWGU9cHyI53BJe0WwUNIs0BvyJ9A3NKiQbu1FUCgTKHBiGLQwPgST0FZYYdBx+EowTFYUUHhr21BRousQZf4zMHcdQdBKe6kQQokOMEinw7BCte1QBglLMG/bG5BaLj4wWDHD8IzYRPC7hh4Pzi1t0HljFvBQnF0wc+zBMKckoLB1O0LwS/duMHWfnpBUN4FwcXHQ0G88SBBvfuDwZ8yYkFPkLjBhTMkwelqccDV5LfA+jCXQAbcEsEGRnrBJ76YwfL42kFFm9PBZxE8QUWY0sEX40vBea2nQT+JUcE5LsnBd0q7wXqgl0FhqRtAJz0QQKSEh8Hf0s/AwAxhwSuORMErZso/i+kNQYXkscE056VBhq5RwdFBn0FkBqVBwOlPQSJiycFOSEc+Z1Umwd5S2cF8ZThBlrT3QQ7cucHAjz9BNd4dwanoS0FpuxJB7iniwWKkb78LrqDBGFOXwWRUfsEmO7nAbb22wfpvUEA13/zB7NgBwAtUt8GINAjB767CwZBHcMH4bS7BxFYBwsNj2EHC8dvBkxOXwauZtMEDV0ZBWy/SQeJcJ0HICGTBfmmnQfRTkUHdsWJBkS7qvuAP4z/uRdhAeqiiQTkNyUEFbW/AL9PRQLzeA0K5LJpA8rxrQUZ3xcEdsZ1BW+XswXs5f8GOEuG/DPQOwZHUwsBzV4xAGnRBQWJW48Ft2J9A5AjfQX9og79jc3/BaGkNwo4bHcHOnVBBIQSCwV/t3EBCHBPBGvm4weeiykBO5LHAViacP6GHjUEokXDBfhQwwU4lkUGoOKBAH5Q8vqqCzUDhYcvB6467QSblyMH5S4bBxmEwQXrAacG1FCRB8MehQRGRGj5DNs3BTESlQSq6x7/3AQFBo7XPQacN78HeMoVBynjBwdu7FUCPDsrAwBtUQSUl1UHxUXFBBfxTPwYXi0H8bb3A4GnCP+LM5kElsRlA8yZSwbzEIMFbOz7ApWuOPyoN0MDMpAs/AUdfQXDXiMFmfMzBtUCGwdRHosENiPTBuvfYQCtC2cFd3fnBwNvfwRaLNEFvccDBr5xqQYT2+cEc3n5A++I5wSltUECy3nxBWfVjQZXNjEHCFOFBUf+MwTuGmkG8d7tAlPM3wZemwz7OsbQ/QdwTwlt3rMBIxKg/3lWjQcYI7sD1j1tBm73lQXIwEEIcsU9Bf/0dwfemmz+59LE/Ie/wP5o4dkFHcgNCyxCrweBf3MFprQK/nNakQeFcwMG07A1BsQgqQaLLlUEvYb/BNl1kQFnh7MHjUfDBxJNHwbaBskFKEkbBYdkHwdRQrcGAmeJBIQUQwl5bb0FtorvBZmiiQZgk+EGOLgbAA8vhQM13xEBhMLdBNZGrwPQvtcBc7UnB+LYpQX+69MCBWr/AQ+PTQW1/zMFaRdbBLiKIwXOS0MHwF4rBqCJ8QJBancECFI/BmdtQQbm0s0Cq4ZfB5kbMQTkf+0HexNlBjwwmQcEq5kH9A+3BcqyLQbrEKUGJcwE+RFidwWDqJkFEUw/CYe6/waiPl0H4iUZBc+UwPzOqksF7mafBgpKCvlVGz0By1uvBekJiwRFoysG+6mXBKEmFwbMGE0Lu279A8KlqQYy2mMCocrvB9+zdQTK2DcHcrPjABn7VQa8pkUBxUNxBn7udwdB29sHvFipA5HeSwVBmzEExq89Ac+VfwPo9JkAox5S/71DbwT76QEF07rnBP31WQf7cwT+OPpFBMfxYQQHgqEGqvwjBQv8xQfpUmkGKS4hBabJMwZCVGL82iOpA2jQ9wJyaP0FteZ7BuOHrwfJR1L/4HsJB82xHQSxxAcKEKvpBZ7nZQdh6H0Fu8o9B3HvvQPR55MEFvnrBaXytwQC0BUL8enrB5o4KQhZlS8GnUwc/8ekkQRjoGEDOEDXBcl2iv3NXMcEsTFJB8JspwXBQcUHYOgVBfv3PwTcIrsFB59RAupO7QMrvN0HqtzXB51UrQasBT0FBWA7CnChJQYCD2cHKpJ5APlLsQczL0b9kM5nBfTMUwtOVdcEWcyzBZ908QYRQgMEGcoVBLLpNQdl7ksF+onrB1s2nQTpcA8IjhCHBzBLtwWCZnj9CtT5BZt4IQi6EA8KZwFTAdhQ6wTR4IUFF8NDBqjE5QUogtsGhivfBMsuhv2fsQUGZdLzB7ItJQFxtKUETNkRB9MtmwW/TPUFlo4BBW72vP+4Qf8EvkqTBQgVKQeLSSEEqmsPBy0rlv5hjtsFg7cVB6pMfQEzP4sC6tTfBw5cdwaF23sFehrdBBMIPQRdT9kHaYzvBtrLIOx0wskHDblfAghnJwVJk58BaAs3Bc+nSwKZhpEGXEjtBXL1mwR0To0Hi56rBoimzwRuC5cBLwnNBa1kVQJPstECqFN/BAVDIQcwizcA9FBzBl6vEQfoFY0HW0Uc/J9evP8UXDsJcihrBdf/gQcKmwUGt3jdBeL4OwujnsMEIjfXAaPWeQbMrGcFmdKXAa/QMQiIC/MHZUrlBJvq3wRbA1EGebTe/MRlEP67+PUGEOgTCQL9jwaZpp0CtYDLB5xkCwTK+g8HIGAJCXhrbPf4A6kFvpUVBsXLHwfjEkz5m5t8/tYmOwYlk4r+lScVBI9R3QZiQCz8Ttt9BefRhQUjLBkEN2p1B+IcpP9ZH58G6/T1ByL+wwbe4X0BoxI1A6KDOwSt3B8KL8E3BsVMSwZxEssF/1bNBGtxdwByeskFsQzjB+NCRwbeVrsGen8VBB3G0wTldA0DiT9TBFqvywQEkB8JV789APOqjwcKSAr+PxYZB45hZQRVcG8Fe8FZBYNYFwkSzvsHDNcxBSHqoQRhv7MEeQtJBOCTeQUuH4EE5exFBRT42wWC7skFZ3b5ANJ0QQb7j10HFJ4HBhMZ2QQj2MMEms2LB1ZlpwRKassDJpOPBNpFVwW9wO8Hjdo9BZ1bBwI4JU0Fzq5lBZVbYQd9Uvr+YuN9A3QTpwBog18F3EsfB+mFLwaPQ7kAB18NAxkPBwT684MCwtl5Bf45WwHrwd0FATolBJocuwb0LjMH2CWxBXQzmQbq0isEVPrjAxIUUQRcvQMGpsEZBTkrSP2qNQ8EIMlbBBYZUQX+ansGMEbNB\",\"dtype\":\"float32\",\"order\":\"little\",\"shape\":[1533]}},\"selected\":{\"id\":\"1168\"},\"selection_policy\":{\"id\":\"1167\"}},\"id\":\"1143\",\"type\":\"ColumnDataSource\"},{\"attributes\":{},\"id\":\"1111\",\"type\":\"DataRange1d\"},{\"attributes\":{\"coordinates\":null,\"data_source\":{\"id\":\"1143\"},\"glyph\":{\"id\":\"1145\"},\"group\":null,\"hover_glyph\":null,\"muted_glyph\":{\"id\":\"1147\"},\"nonselection_glyph\":{\"id\":\"1146\"},\"view\":{\"id\":\"1149\"}},\"id\":\"1148\",\"type\":\"GlyphRenderer\"},{\"attributes\":{\"tools\":[{\"id\":\"1127\"},{\"id\":\"1128\"},{\"id\":\"1129\"},{\"id\":\"1130\"},{\"id\":\"1131\"},{\"id\":\"1132\"},{\"id\":\"1134\"}]},\"id\":\"1135\",\"type\":\"Toolbar\"},{\"attributes\":{\"fill_color\":{\"field\":\"label\",\"transform\":{\"id\":\"1106\"}},\"line_color\":{\"field\":\"label\",\"transform\":{\"id\":\"1106\"}},\"size\":{\"value\":10},\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1145\",\"type\":\"Scatter\"},{\"attributes\":{},\"id\":\"1131\",\"type\":\"ResetTool\"},{\"attributes\":{},\"id\":\"1120\",\"type\":\"BasicTicker\"},{\"attributes\":{\"high\":100,\"low\":0,\"palette\":[\"#30123b\",\"#311542\",\"#32184a\",\"#341b51\",\"#351e58\",\"#36215f\",\"#372365\",\"#38266c\",\"#392972\",\"#3a2c79\",\"#3b2f7f\",\"#3c3285\",\"#3c358b\",\"#3d3791\",\"#3e3a96\",\"#3f3d9c\",\"#4040a1\",\"#4043a6\",\"#4145ab\",\"#4148b0\",\"#424bb5\",\"#434eba\",\"#4350be\",\"#4353c2\",\"#4456c7\",\"#4458cb\",\"#455bce\",\"#455ed2\",\"#4560d6\",\"#4563d9\",\"#4666dd\",\"#4668e0\",\"#466be3\",\"#466de6\",\"#4670e8\",\"#4673eb\",\"#4675ed\",\"#4678f0\",\"#467af2\",\"#467df4\",\"#467ff6\",\"#4682f8\",\"#4584f9\",\"#4587fb\",\"#4589fc\",\"#448cfd\",\"#438efd\",\"#4291fe\",\"#4193fe\",\"#4096fe\",\"#3f98fe\",\"#3e9bfe\",\"#3c9dfd\",\"#3ba0fc\",\"#39a2fc\",\"#38a5fb\",\"#36a8f9\",\"#34aaf8\",\"#33acf6\",\"#31aff5\",\"#2fb1f3\",\"#2db4f1\",\"#2bb6ef\",\"#2ab9ed\",\"#28bbeb\",\"#26bde9\",\"#25c0e6\",\"#23c2e4\",\"#21c4e1\",\"#20c6df\",\"#1ec9dc\",\"#1dcbda\",\"#1ccdd7\",\"#1bcfd4\",\"#1ad1d2\",\"#19d3cf\",\"#18d5cc\",\"#18d7ca\",\"#17d9c7\",\"#17dac4\",\"#17dcc2\",\"#17debf\",\"#18e0bd\",\"#18e1ba\",\"#19e3b8\",\"#1ae4b6\",\"#1be5b4\",\"#1de7b1\",\"#1ee8af\",\"#20e9ac\",\"#22eba9\",\"#24eca6\",\"#27eda3\",\"#29eea0\",\"#2cef9d\",\"#2ff09a\",\"#32f197\",\"#35f394\",\"#38f491\",\"#3bf48d\",\"#3ff58a\",\"#42f687\",\"#46f783\",\"#4af880\",\"#4df97c\",\"#51f979\",\"#55fa76\",\"#59fb72\",\"#5dfb6f\",\"#61fc6c\",\"#65fc68\",\"#69fd65\",\"#6dfd62\",\"#71fd5f\",\"#74fe5c\",\"#78fe59\",\"#7cfe56\",\"#80fe53\",\"#84fe50\",\"#87fe4d\",\"#8bfe4b\",\"#8efe48\",\"#92fe46\",\"#95fe44\",\"#98fe42\",\"#9bfd40\",\"#9efd3e\",\"#a1fc3d\",\"#a4fc3b\",\"#a6fb3a\",\"#a9fb39\",\"#acfa37\",\"#aef937\",\"#b1f836\",\"#b3f835\",\"#b6f735\",\"#b9f534\",\"#bbf434\",\"#bef334\",\"#c0f233\",\"#c3f133\",\"#c5ef33\",\"#c8ee33\",\"#caed33\",\"#cdeb34\",\"#cfea34\",\"#d1e834\",\"#d4e735\",\"#d6e535\",\"#d8e335\",\"#dae236\",\"#dde036\",\"#dfde36\",\"#e1dc37\",\"#e3da37\",\"#e5d838\",\"#e7d738\",\"#e8d538\",\"#ead339\",\"#ecd139\",\"#edcf39\",\"#efcd39\",\"#f0cb3a\",\"#f2c83a\",\"#f3c63a\",\"#f4c43a\",\"#f6c23a\",\"#f7c039\",\"#f8be39\",\"#f9bc39\",\"#f9ba38\",\"#fab737\",\"#fbb537\",\"#fbb336\",\"#fcb035\",\"#fcae34\",\"#fdab33\",\"#fda932\",\"#fda631\",\"#fda330\",\"#fea12f\",\"#fe9e2e\",\"#fe9b2d\",\"#fe982c\",\"#fd952b\",\"#fd9229\",\"#fd8f28\",\"#fd8c27\",\"#fc8926\",\"#fc8624\",\"#fb8323\",\"#fb8022\",\"#fa7d20\",\"#fa7a1f\",\"#f9771e\",\"#f8741c\",\"#f7711b\",\"#f76e1a\",\"#f66b18\",\"#f56817\",\"#f46516\",\"#f36315\",\"#f26014\",\"#f15d13\",\"#ef5a11\",\"#ee5810\",\"#ed550f\",\"#ec520e\",\"#ea500d\",\"#e94d0d\",\"#e84b0c\",\"#e6490b\",\"#e5460a\",\"#e3440a\",\"#e24209\",\"#e04008\",\"#de3e08\",\"#dd3c07\",\"#db3a07\",\"#d93806\",\"#d73606\",\"#d63405\",\"#d43205\",\"#d23005\",\"#d02f04\",\"#ce2d04\",\"#cb2b03\",\"#c92903\",\"#c72803\",\"#c52602\",\"#c32402\",\"#c02302\",\"#be2102\",\"#bb1f01\",\"#b91e01\",\"#b61c01\",\"#b41b01\",\"#b11901\",\"#ae1801\",\"#ac1601\",\"#a91501\",\"#a61401\",\"#a31201\",\"#a01101\",\"#9d1001\",\"#9a0e01\",\"#970d01\",\"#940c01\",\"#910b01\",\"#8e0a01\",\"#8b0901\",\"#870801\",\"#840701\",\"#810602\",\"#7d0502\",\"#7a0402\"]},\"id\":\"1106\",\"type\":\"LinearColorMapper\"},{\"attributes\":{\"axis\":{\"id\":\"1119\"},\"coordinates\":null,\"group\":null,\"ticker\":null},\"id\":\"1122\",\"type\":\"Grid\"},{\"attributes\":{\"axis\":{\"id\":\"1123\"},\"coordinates\":null,\"dimension\":1,\"group\":null,\"ticker\":null},\"id\":\"1126\",\"type\":\"Grid\"},{\"attributes\":{},\"id\":\"1168\",\"type\":\"Selection\"},{\"attributes\":{},\"id\":\"1132\",\"type\":\"HelpTool\"},{\"attributes\":{\"callback\":null,\"tooltips\":\"\\n
\\n \\n
\\n @label_desc - @split\\n [#@video_id]\\n
\\n
\\n \\n\"},\"id\":\"1134\",\"type\":\"HoverTool\"},{\"attributes\":{},\"id\":\"1113\",\"type\":\"DataRange1d\"},{\"attributes\":{\"bottom_units\":\"screen\",\"coordinates\":null,\"fill_alpha\":0.5,\"fill_color\":\"lightgrey\",\"group\":null,\"left_units\":\"screen\",\"level\":\"overlay\",\"line_alpha\":1.0,\"line_color\":\"black\",\"line_dash\":[4,4],\"line_width\":2,\"right_units\":\"screen\",\"syncable\":false,\"top_units\":\"screen\"},\"id\":\"1133\",\"type\":\"BoxAnnotation\"},{\"attributes\":{\"coordinates\":null,\"formatter\":{\"id\":\"1162\"},\"group\":null,\"major_label_policy\":{\"id\":\"1163\"},\"ticker\":{\"id\":\"1124\"}},\"id\":\"1123\",\"type\":\"LinearAxis\"},{\"attributes\":{\"overlay\":{\"id\":\"1133\"}},\"id\":\"1129\",\"type\":\"BoxZoomTool\"},{\"attributes\":{\"fill_alpha\":{\"value\":0.1},\"fill_color\":{\"field\":\"label\",\"transform\":{\"id\":\"1106\"}},\"hatch_alpha\":{\"value\":0.1},\"line_alpha\":{\"value\":0.1},\"line_color\":{\"field\":\"label\",\"transform\":{\"id\":\"1106\"}},\"size\":{\"value\":10},\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1146\",\"type\":\"Scatter\"},{\"attributes\":{\"coordinates\":null,\"group\":null,\"text\":\"Check label by hovering mouse over the dots\"},\"id\":\"1109\",\"type\":\"Title\"},{\"attributes\":{},\"id\":\"1167\",\"type\":\"UnionRenderers\"},{\"attributes\":{\"coordinates\":null,\"formatter\":{\"id\":\"1165\"},\"group\":null,\"major_label_policy\":{\"id\":\"1166\"},\"ticker\":{\"id\":\"1120\"}},\"id\":\"1119\",\"type\":\"LinearAxis\"},{\"attributes\":{},\"id\":\"1130\",\"type\":\"SaveTool\"},{\"attributes\":{},\"id\":\"1117\",\"type\":\"LinearScale\"},{\"attributes\":{},\"id\":\"1162\",\"type\":\"BasicTickFormatter\"},{\"attributes\":{},\"id\":\"1115\",\"type\":\"LinearScale\"},{\"attributes\":{\"source\":{\"id\":\"1143\"}},\"id\":\"1149\",\"type\":\"CDSView\"},{\"attributes\":{},\"id\":\"1127\",\"type\":\"PanTool\"},{\"attributes\":{},\"id\":\"1166\",\"type\":\"AllLabels\"},{\"attributes\":{\"fill_alpha\":{\"value\":0.2},\"fill_color\":{\"field\":\"label\",\"transform\":{\"id\":\"1106\"}},\"hatch_alpha\":{\"value\":0.2},\"line_alpha\":{\"value\":0.2},\"line_color\":{\"field\":\"label\",\"transform\":{\"id\":\"1106\"}},\"size\":{\"value\":10},\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1147\",\"type\":\"Scatter\"}],\"root_ids\":[\"1108\"]},\"title\":\"Bokeh Application\",\"version\":\"2.4.3\"}};\n const render_items = [{\"docid\":\"458dc79d-472d-4f36-92df-cc9c2d7d3fb7\",\"root_ids\":[\"1108\"],\"roots\":{\"1108\":\"74ea627a-4164-4f24-bbf8-b9b7c4c5836c\"}}];\n root.Bokeh.embed.embed_items_notebook(docs_json, render_items);\n }\n if (root.Bokeh !== undefined) {\n embed_document(root);\n } else {\n let attempts = 0;\n const timer = setInterval(function(root) {\n if (root.Bokeh !== undefined) {\n clearInterval(timer);\n embed_document(root);\n } else {\n attempts++;\n if (attempts > 100) {\n clearInterval(timer);\n console.log(\"Bokeh: ERROR: Unable to run BokehJS code because BokehJS library is missing\");\n }\n }\n }, 10, root)\n }\n})(window);", + "application/vnd.bokehjs_exec.v0+json": "" + }, + "metadata": { + "application/vnd.bokehjs_exec.v0+json": { + "id": "1108" + } + }, + "output_type": "display_data" + } + ], + "source": [ + "column_data = dict(\n", + " x=df['tsne_x'],\n", + " y=df['tsne_y'],\n", + " label=df['labels'],\n", + " label_desc=df['label_name'],\n", + " split=df['split'],\n", + " video_id=df['video_id']\n", + ")\n", + "\n", + "if use_img_div:\n", + " emb_videos = load_videos(df['video_fn'])\n", + " column_data[\"videos\"] = emb_videos\n", + "source = ColumnDataSource(data=column_data)\n", + "\n", + "p.scatter('x', 'y',\n", + " size=10,\n", + " source=source,\n", + " fill_color={\"field\": 'label', \"transform\": cmap},\n", + " line_color={\"field\": 'label', \"transform\": cmap}, \n", + " #legend_label={\"field\": 'split', \"transform\": lambda x: df['split']},\n", + "# marker={\"field\": 'split'}\n", + " )\n", + "\n", + "show(p)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d761766", + "metadata": {}, + "outputs": [], + "source": [ + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c73f195", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/preprocessing.py b/preprocessing.py new file mode 100644 index 0000000..34ac917 --- /dev/null +++ b/preprocessing.py @@ -0,0 +1,21 @@ +from argparse import ArgumentParser +from preprocessing.create_wlasl_landmarks_dataset import parse_create_args, create +from preprocessing.extract_mediapipe_landmarks import parse_extract_args, extract + + +if __name__ == '__main__': + main_parser = ArgumentParser() + subparser = main_parser.add_subparsers(dest="action") + create_subparser = subparser.add_parser("create") + extract_subparser = subparser.add_parser("extract") + parse_create_args(create_subparser) + parse_extract_args(extract_subparser) + + args = main_parser.parse_args() + + if args.action == "create": + create(args) + elif args.action == "extract": + extract(args) + else: + ValueError("action command must be either 'create' or 'extract'") diff --git a/preprocessing/__init__.py b/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/create_wlasl_landmarks_dataset.py b/preprocessing/create_wlasl_landmarks_dataset.py new file mode 100644 index 0000000..a457c89 --- /dev/null +++ b/preprocessing/create_wlasl_landmarks_dataset.py @@ -0,0 +1,155 @@ +import os +import os.path as op +import json +import shutil + +import cv2 +import mediapipe as mp +import numpy as np +import pandas as pd +from utils import get_logger +from tqdm.auto import tqdm +from sklearn.model_selection import train_test_split +from normalization.blazepose_mapping import map_blazepose_df + +BASE_DATA_FOLDER = 'data/' + +mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles +mp_hands = mp.solutions.hands +mp_holistic = mp.solutions.holistic +pose_landmarks = mp_holistic.PoseLandmark +hand_landmarks = mp_holistic.HandLandmark + + +def get_landmarks_names(): + ''' + Returns landmark names for mediapipe holistic model + ''' + pose_lmks = ','.join([f'{lmk.name.lower()}_x,{lmk.name.lower()}_y' for lmk in pose_landmarks]) + left_hand_lmks = ','.join([f'left_hand_{lmk.name.lower()}_x,left_hand_{lmk.name.lower()}_y' + for lmk in hand_landmarks]) + right_hand_lmks = ','.join([f'right_hand_{lmk.name.lower()}_x,right_hand_{lmk.name.lower()}_y' + for lmk in hand_landmarks]) + lmks_names = f'{pose_lmks},{left_hand_lmks},{right_hand_lmks}' + return lmks_names + + +def convert_to_str(arr, precision=6): + if isinstance(arr, np.ndarray): + values = [] + for val in arr: + if val == 0: + values.append('0') + else: + values.append(f'{val:.{precision}f}') + return f"[{','.join(values)}]" + else: + return str(arr) + + +def parse_create_args(parser): + parser.add_argument('--landmarks-dataset', '-lmks', required=True, + help='Path to folder with landmarks npy files. \ + You need to run `extract_mediapipe_landmarks.py` script first') + parser.add_argument('--dataset-folder', '-df', default='data/wlasl', + help='Path to folder where original `WLASL_v0.3.json` and `id_to_label.json` are stored. \ + Note that final CSV files will be saved in this folder too.') + parser.add_argument('--videos-folder', '-videos', default=None, + help='Path to folder with videos. If None, then no information of videos (fps, length, \ + width and height) will be stored in final csv file') + parser.add_argument('--num-classes', '-nc', default=100, type=int, help='Number of classes to use in WLASL dataset') + parser.add_argument('--create-new-split', action='store_true') + parser.add_argument('--test-size', '-ts', default=0.25, type=float, + help='Test split percentage size. Only required if --create-new-split is set') + + +# python3 preprocessing.py --landmarks-dataset=data/landmarks -videos data/wlasl/videos +def create(args): + logger = get_logger(__name__) + + landmarks_dataset = args.landmarks_dataset + videos_folder = args.videos_folder + dataset_folder = args.dataset_folder + num_classes = args.num_classes + test_size = args.test_size + + os.makedirs(dataset_folder, exist_ok=True) + + shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/id_to_label.json'), dataset_folder) + shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/WLASL_v0.3.json'), dataset_folder) + + wlasl_json_fn = op.join(dataset_folder, 'WLASL_v0.3.json') + + with open(wlasl_json_fn) as fid: + data = json.load(fid) + + video_data = [] + for label_id, datum in enumerate(tqdm(data[:num_classes])): + instances = [] + for instance in datum['instances']: + instances.append(instance) + video_id = instance['video_id'] + print(video_id) + video_dict = {'video_id': video_id, + 'label_name': datum['gloss'], + 'labels': label_id, + 'split': instance['split']} + if videos_folder is not None: + cap = cv2.VideoCapture(op.join(videos_folder, f'{video_id}.mp4')) + if not cap.isOpened(): + logger.warning(f'Video {video_id}.mp4 not found') + continue + width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + fps = cap.get(cv2.CAP_PROP_FPS) + length = cap.get(cv2.CAP_PROP_FRAME_COUNT) / float(cap.get(cv2.CAP_PROP_FPS)) + video_info = {'video_width': width, + 'video_height': height, + 'fps': fps, + 'length': length} + video_dict.update(video_info) + video_data.append(video_dict) + df_video = pd.DataFrame(video_data) + video_ids = df_video['video_id'].unique() + lmks_data = [] + lmks_names = get_landmarks_names().split(',') + for video_id in video_ids: + lmk_fn = op.join(landmarks_dataset, f'{video_id}.npy') + if not op.exists(lmk_fn): + logger.warning(f'{lmk_fn} file not found. Skipping') + continue + lmk = np.load(lmk_fn).T + lmks_dict = {'video_id': video_id} + for lmk_, name in zip(lmk, lmks_names): + lmks_dict[name] = lmk_ + lmks_data.append(lmks_dict) + + df_lmks = pd.DataFrame(lmks_data) + print(df_lmks) + df = pd.merge(df_video, df_lmks) + print(df) + aux_columns = ['split', 'video_id', 'labels', 'label_name'] + if videos_folder is not None: + aux_columns += ['video_width', 'video_height', 'fps', 'length'] + df_aux = df[aux_columns] + df = map_blazepose_df(df) + df = pd.concat([df, df_aux], axis=1) + if args.create_new_split: + df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42) + else: + print(df['split'].unique()) + df_train = df[(df['split'] == 'train') | (df['split'] == 'val')] + df_test = df[df['split'] == 'test'] + + print(f'Num classes: {num_classes}') + print(df_train['labels'].value_counts()) + assert set(df_train['labels'].unique()) == set(df_test['labels'].unique( + )), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \ + the --create-new-split flag' + for split, df_split in zip(['train', 'val'], + [df_train, df_test]): + fn_out = op.join(dataset_folder, f'WLASL{num_classes}_{split}.csv') + (df_split.reset_index(drop=True) + .applymap(convert_to_str) + .to_csv(fn_out, index=False)) diff --git a/preprocessing/extract_mediapipe_landmarks.py b/preprocessing/extract_mediapipe_landmarks.py new file mode 100644 index 0000000..6d63076 --- /dev/null +++ b/preprocessing/extract_mediapipe_landmarks.py @@ -0,0 +1,154 @@ +import os +import os.path as op +from itertools import chain +from collections import namedtuple +import glob + +import cv2 +import numpy as np +import mediapipe as mp +from tqdm.auto import tqdm + +# Import drawing_utils and drawing_styles. +mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles +mp_holistic = mp.solutions.holistic +mp_pose = mp.solutions.pose + +LEN_LANDMARKS_POSE = len(mp_holistic.PoseLandmark) +LEN_LANDMARKS_HAND = len(mp_holistic.HandLandmark) +TOTAL_LANDMARKS = LEN_LANDMARKS_POSE + 2 * LEN_LANDMARKS_HAND + +Landmark = namedtuple("Landmark", ["x", "y"]) + + +class LandmarksResults: + """ + Wrapper for landmarks results. When not available it fills with 0 + """ + + def __init__( + self, + results, + num_landmarks_pose=LEN_LANDMARKS_POSE, + num_landmarks_hand=LEN_LANDMARKS_HAND, + ): + self.results = results + self.num_landmarks_pose = num_landmarks_pose + self.num_landmarks_hand = num_landmarks_hand + + @property + def pose_landmarks(self): + if self.results.pose_landmarks is None: + return [Landmark(0, 0)] * self.num_landmarks_pose + else: + return self.results.pose_landmarks.landmark + + @property + def left_hand_landmarks(self): + if self.results.left_hand_landmarks is None: + return [Landmark(0, 0)] * self.num_landmarks_hand + else: + return self.results.left_hand_landmarks.landmark + + @property + def right_hand_landmarks(self): + if self.results.right_hand_landmarks is None: + return [Landmark(0, 0)] * self.num_landmarks_hand + else: + return self.results.right_hand_landmarks.landmark + + +def get_landmarks(image_orig, holistic, debug=False): + """ + Runs landmarks detection for single image + Returns: list of landmarks + """ + # Convert the BGR image to RGB before processing. + image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB) + results = LandmarksResults(holistic.process(image)) + if debug: + lmks_pose = [] + for lmk in results.pose_landmarks: + lmks_pose.append(lmk.x) + lmks_pose.append(lmk.y) + assert len(lmks_pose) == LEN_LANDMARKS_POSE + + lmks_left_hand = [] + + for lmk in results.left_hand_landmarks: + lmks_left_hand.append(lmk.x) + lmks_left_hand.append(lmk.y) + + assert ( + len(lmks_left_hand) == 2 * LEN_LANDMARKS_HAND + ), f"{len(lmks_left_hand)} != {2 * LEN_LANDMARKS_HAND}" + + lmks_right_hand = [] + + for lmk in results.right_hand_landmarks: + lmks_right_hand.append(lmk.x) + lmks_right_hand.append(lmk.y), + + assert ( + len(lmks_right_hand) == 2 * LEN_LANDMARKS_HAND + ), f"{len(lmks_right_hand)} != {2 * LEN_LANDMARKS_HAND}" + landmarks = [] + for lmk in chain( + results.pose_landmarks, + results.left_hand_landmarks, + results.right_hand_landmarks, + ): + landmarks.append(lmk.x) + landmarks.append(lmk.y) + assert ( + len(landmarks) == TOTAL_LANDMARKS * 2 + ), f"{len(landmarks)} != {TOTAL_LANDMARKS * 2}" + return landmarks + + +def parse_extract_args(parser): + parser.add_argument( + "--videos-folder", + "-videos", + help="Path of folder with videos to extract landmarks from", + required=True, + ) + parser.add_argument( + "--output-landmarks", + "-lmks", + help="Path of output folder where landmarks npy files will be saved", + required=True, + ) + + +# python3 preprocessing.py -videos=data/wlasl/videos_25fps/ -lmks=data/landmarks +def extract(args): + landmarks_output = args.output_landmarks + videos_folder = args.videos_folder + os.makedirs(landmarks_output, exist_ok=True) + for fn_video in tqdm(sorted(glob.glob(op.join(videos_folder, "*mp4")))): + cap = cv2.VideoCapture(fn_video) + ret, image_orig = cap.read() + height, width = image_orig.shape[:2] + landmarks_video = [] + with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar: + with mp_holistic.Holistic( + static_image_mode=False, + min_detection_confidence=0.5, + model_complexity=2, + ) as holistic: + while ret: + try: + landmarks = get_landmarks(image_orig, holistic) + except Exception as e: + print(e) + landmarks = get_landmarks(image_orig, holistic, debug=True) + ret, image_orig = cap.read() + landmarks_video.append(landmarks) + pbar.update(1) + landmarks_video = np.vstack(landmarks_video) + np.save( + op.join(landmarks_output, op.basename(fn_video).split(".")[0]), + landmarks_video, + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2e081ab --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +bokeh==2.4.3 +boto3>=1.9 +clearml==1.6.4 +ipywidgets==8.0.4 +matplotlib==3.5.3 +mediapipe==0.8.11 +notebook==6.5.2 +opencv-python==4.6.0.66 +pandas==1.1.5 +pandas==1.1.5 +plotly==5.11.0 +scikit-learn==1.0.2 +torchvision==0.13.0 +tqdm==4.54.1 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_batch_sorter.py b/tests/test_batch_sorter.py new file mode 100644 index 0000000..2e3b00f --- /dev/null +++ b/tests/test_batch_sorter.py @@ -0,0 +1,104 @@ +import unittest +# from traceback_with_variables import activate_by_import #noqa +import torch +from training.batch_sorter import BatchGrouper, sort_batches, get_scaled_distances, get_dist_tuple_list + + +class TestBatchSorting(unittest.TestCase): + + def get_sorted_dists(self): + device = get_device() + embeddings = torch.rand(32*8, 8).to(device) + labels = torch.rand(32*8, 1) + scaled_dist = get_scaled_distances(embeddings, labels, device) + # Get vector of (row, column, dist) + dist_list = get_dist_tuple_list(scaled_dist) + + A = dist_list.cpu().detach().numpy() + return A[:, A[-1, :].argsort()[::-1]] + + def setUp(self) -> None: + dists = self.get_sorted_dists() + self.grouper = BatchGrouper(sorted_dists=dists, total_items=32*8, mini_batch_size=32) + return super().setUp() + + def test_assigns_and_merges(self): + group0 = self.grouper.create_or_get_group() + self.grouper.assign_group(1, group0) + self.grouper.assign_group(2, group0) + + group1 = self.grouper.create_or_get_group() + self.grouper.assign_group(3, group1) + self.grouper.assign_group(4, group1) + + # Merge groups + self.grouper.merge_groups(group0, group1) + self.assertEqual(len(self.grouper.groups[group0]), 4) + self.assertFalse(group1 in self.grouper.groups) + self.assertEqual(self.grouper.item_to_group[3], group0) + self.assertEqual(self.grouper.item_to_group[4], group0) + + def test_full_groups(self): + group0 = self.grouper.create_or_get_group() + for i in range(30): + self.grouper.assign_group(i, group0) + + self.assertFalse(self.grouper.group_is_full(group0)) + initial_group_len = len(self.grouper.groups[group0]) + + group1 = self.grouper.create_or_get_group() + for i in range(30, 33): + self.grouper.assign_group(i, group1) + + self.grouper.merge_groups(group0, group1) + # Assert no merge done + self.assertEqual(len(self.grouper.groups[group0]), initial_group_len) + self.assertTrue(group1 in self.grouper.groups) + self.assertEqual(self.grouper.item_to_group[31], group1) + self.assertEqual(self.grouper.item_to_group[32], group1) + + def test_replace_groups(self): + group0 = self.grouper.create_or_get_group() + for i in range(20): + self.grouper.assign_group(i, group0) + + group1 = self.grouper.create_or_get_group() + for i in range(20, 23): + self.grouper.assign_group(i, group1) + + group2 = self.grouper.create_or_get_group() + for i in range(23, 30): + self.grouper.assign_group(i, group2) + + self.grouper.merge_groups(group1, group0) + self.assertEqual(len(self.grouper.groups[group0]), 23) + self.assertTrue(group1 in self.grouper.groups) + self.assertFalse(group2 in self.grouper.groups) + self.assertEqual(len(self.grouper.groups[group1]), 7) + + +def get_device(): + device = torch.device("cpu") + return device + + +def test_get_scaled_distances(): + device = get_device() + emb = torch.rand(4, 3) + labels = torch.tensor([0, 1, 2, 2]) + distances = get_scaled_distances(emb, labels, device) + assert torch.all(distances >= 0) + assert torch.all(distances <= 1) + + +def test_batch_sorter_indices(): + device = get_device() + inputs = torch.rand(32*16, 1000) + labels = torch.rand(32*16, 1) + masks = torch.rand(32*16, 100) + embeddings = torch.rand(32*16, 32).to(device) + + i_out, l_out, m_out = sort_batches(inputs, labels, masks, embeddings, device) + first_match_index = torch.all(inputs == i_out[0], dim=1).nonzero(as_tuple=True)[0][0] + assert torch.all(labels[first_match_index] == l_out[0]) + assert torch.all(masks[first_match_index] == m_out[0]) diff --git a/tracking/__init__.py b/tracking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tracking/clearml_tracker.py b/tracking/clearml_tracker.py new file mode 100644 index 0000000..b91ddfa --- /dev/null +++ b/tracking/clearml_tracker.py @@ -0,0 +1,21 @@ +from clearml import Task, Logger +from .tracker import Tracker + + +class ClearMLTracker(Tracker): + + def __init__(self, project_name=None, experiment_name=None): + self.task = Task.current_task() or Task.init(project_name=project_name, task_name=experiment_name) + + def execute_remotely(self, queue_name): + self.task.execute_remotely(queue_name=queue_name) + + def log_scalar_metric(self, metric, series, iteration, value): + Logger.current_logger().report_scalar(metric, series, iteration=iteration, value=value) + + def log_chart(self, title, series, iteration, figure): + Logger.current_logger().report_plotly(title=title, series=series, iteration=iteration, figure=figure) + + def finish_run(self): + self.task.mark_completed() + self.task.close() diff --git a/tracking/tracker.py b/tracking/tracker.py new file mode 100644 index 0000000..a98297b --- /dev/null +++ b/tracking/tracker.py @@ -0,0 +1,28 @@ + +class Tracker: + + def __init__(self, project_name, experiment_name): + super().__init__() + + def execute_remotely(self, queue_name): + pass + + def track_config(self, configs): + # Used to track configuration parameters of an experiment run + pass + + def track_artifacts(self, filepath): + # Used to track artifacts like model weights + pass + + def log_scalar_metric(self, metric, series, iteration, value): + pass + + def log_chart(self, title, series, iteration, figure): + pass + + def finish_run(self): + pass + + def get_callback(self): + pass diff --git a/train.py b/train.py new file mode 100644 index 0000000..76fabfe --- /dev/null +++ b/train.py @@ -0,0 +1,287 @@ + +from datetime import datetime +import os +import os.path as op +import argparse +import json +from datasets.dataset_loader import LocalDatasetLoader +from tracking.tracker import Tracker +import torch +import multiprocessing +import torch.nn as nn +import torch.optim as optim +# import matplotlib.pyplot as plt +from torchvision import transforms +from torch.utils.data import DataLoader +from pathlib import Path +import copy + +from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd +from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \ + train_epoch_embedding_online, evaluate_embedding +from training.online_batch_mining import BatchAllTripletLoss +from training.batching_scheduler import BatchingScheduler +from training.gaussian_noise import GaussianNoise +from training.train_utils import train_setup, create_embedding_scatter_plots +from training.train_arguments import get_default_args +from utils import get_logger +try: + # Needed for argparse patching in case clearml is used + import clearml # noqa +except ImportError: + pass + + +PROJECT_NAME = "spoter" +CLEARML = "clearml" + + +def is_pre_batch_sorting_enabled(args): + return args.start_mining_hard is not None and args.start_mining_hard > 0 + + +def get_tracker(tracker_name, project, experiment_name): + if tracker_name == CLEARML: + from tracking.clearml_tracker import ClearMLTracker + return ClearMLTracker(project_name=project, experiment_name=experiment_name) + else: + return Tracker(project_name=project, experiment_name=experiment_name) + + +def get_dataset_loader(loader_name): + if loader_name == CLEARML: + from datasets.clearml_dataset_loader import ClearMLDatasetLoader + return ClearMLDatasetLoader() + else: + return LocalDatasetLoader() + + +def build_data_loader(dataset, batch_size, shuffle, collate_fn, generator): + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, + generator=generator, pin_memory=torch.cuda.is_available(), num_workers=multiprocessing.cpu_count()) + + +def train(args, tracker: Tracker): + tracker.execute_remotely(queue_name="default") + # Initialize all the random seeds + gen = train_setup(args.seed, args.experiment_name) + os.environ['EXPERIMENT_NAME'] = args.experiment_name + logger = get_logger(args.experiment_name) + + # Set device to CUDA only if applicable + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + + # Construct the model + if not args.classification_model: + slrt_model = SPOTER_EMBEDDINGS( + features=args.vector_length, + hidden_dim=args.hidden_dim, + norm_emb=args.normalize_embeddings, + dropout=args.dropout + ) + model_type = 'embed' + if args.hard_triplet_mining == "None": + cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2) + elif args.hard_triplet_mining == "in_batch": + cel_criterion = BatchAllTripletLoss( + device=device, + margin=args.triplet_loss_margin, + filter_easy_triplets=bool(args.filter_easy_triplets) + ) + else: + slrt_model = SPOTER(num_classes=args.num_classes, hidden_dim=args.hidden_dim) + model_type = 'classif' + cel_criterion = nn.CrossEntropyLoss() + slrt_model.to(device) + + if args.optimizer == "SGD": + optimizer = optim.SGD(slrt_model.parameters(), lr=args.lr) + elif args.optimizer == "ADAM": + optimizer = optim.Adam(slrt_model.parameters(), lr=args.lr) + + if args.scheduler_factor > 0: + mode = 'min' if args.classification_model else 'max' + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode=mode, + factor=args.scheduler_factor, + patience=args.scheduler_patience + ) + else: + scheduler = None + + if args.hard_mining_scheduler_triplets_threshold > 0: + batching_scheduler = BatchingScheduler(triplets_threshold=args.hard_mining_scheduler_triplets_threshold) + else: + batching_scheduler = None + + # Ensure that the path for checkpointing and for images both exist + Path("out-checkpoints/" + args.experiment_name + "/").mkdir(parents=True, exist_ok=True) + Path("out-img/").mkdir(parents=True, exist_ok=True) + + # Training set + transform = transforms.Compose([GaussianNoise(args.gaussian_mean, args.gaussian_std)]) + dataset_loader = get_dataset_loader(args.dataset_loader) + dataset_folder = dataset_loader.get_dataset_folder(args.dataset_project, args.dataset_name) + training_set_path = op.join(dataset_folder, args.training_set_path) + + with open(op.join(dataset_folder, 'id_to_label.json')) as fid: + id_to_label = json.load(fid) + id_to_label = {int(key): value for key, value in id_to_label.items()} + + if not args.classification_model: + batch_size = args.batch_size + val_batch_size = args.batch_size + if args.hard_triplet_mining == "None": + train_set = SLREmbeddingDataset(training_set_path, triplet=True, transform=transform, augmentations=True, + augmentations_prob=args.augmentations_prob) + collate_fn_train = collate_fn_triplet_padd + elif args.hard_triplet_mining == "in_batch": + train_set = SLREmbeddingDataset(training_set_path, triplet=False, transform=transform, augmentations=True, + augmentations_prob=args.augmentations_prob) + collate_fn_train = collate_fn_padd + if is_pre_batch_sorting_enabled(args): + batch_size *= args.hard_mining_pre_batch_multipler + train_val_set = SLREmbeddingDataset(training_set_path, triplet=False) + # Train dataloader for validation + train_val_loader = build_data_loader(train_val_set, val_batch_size, False, collate_fn_padd, gen) + else: + train_set = CzechSLRDataset(training_set_path, transform=transform, augmentations=True) + batch_size = 1 + val_batch_size = 1 + collate_fn_train = None + + train_loader = build_data_loader(train_set, batch_size, True, collate_fn_train, gen) + + # Validation set + validation_set_path = op.join(dataset_folder, args.validation_set_path) + + if args.classification_model: + val_set = CzechSLRDataset(validation_set_path) + collate_fn_val = None + else: + val_set = SLREmbeddingDataset(validation_set_path, triplet=False) + collate_fn_val = collate_fn_padd + + val_loader = build_data_loader(val_set, val_batch_size, False, collate_fn_val, gen) + + # MARK: TRAINING + train_acc, val_acc = 0, 0 + losses, train_accs, val_accs = [], [], [] + lr_progress = [] + top_val_acc = -999 + top_model_saved = True + + logger.info("Starting " + args.experiment_name + "...\n\n") + + if is_pre_batch_sorting_enabled(args): + mini_batch_size = int(batch_size / args.hard_mining_pre_batch_multipler) + else: + mini_batch_size = None + enable_batch_sorting = False + pre_batch_mining_count = 1 + for epoch in range(1, args.epochs + 1): + start_time = datetime.now() + if not args.classification_model: + train_kwargs = {"model": slrt_model, + "epoch_iters": args.epoch_iters, + "train_loader": train_loader, + "val_loader": val_loader, + "criterion": cel_criterion, + "optimizer": optimizer, + "device": device, + "scheduler": scheduler if epoch >= args.scheduler_warmup else None, + } + if args.hard_triplet_mining == "None": + train_loss, val_silhouette_coef = train_epoch_embedding(**train_kwargs) + elif args.hard_triplet_mining == "in_batch": + if epoch == args.start_mining_hard: + enable_batch_sorting = True + pre_batch_mining_count = args.hard_mining_pre_batch_mining_count + train_kwargs.update(dict(enable_batch_sorting=enable_batch_sorting, + mini_batch_size=mini_batch_size, + pre_batch_mining_count=pre_batch_mining_count, + batching_scheduler=batching_scheduler if enable_batch_sorting else None)) + + train_loss, val_silhouette_coef, triplets_stats = train_epoch_embedding_online(**train_kwargs) + + tracker.log_scalar_metric("triplets", "valid_triplets", epoch, triplets_stats["valid_triplets"]) + tracker.log_scalar_metric("triplets", "used_triplets", epoch, triplets_stats["used_triplets"]) + tracker.log_scalar_metric("triplets_pct", "pct_used", epoch, triplets_stats["pct_used"]) + tracker.log_scalar_metric("train_loss", "loss", epoch, train_loss) + losses.append(train_loss) + + # calculate acc on train dataset + silhouette_coefficient_train = evaluate_embedding(slrt_model, train_val_loader, device) + + tracker.log_scalar_metric("silhouette_coefficient", "train", epoch, silhouette_coefficient_train) + train_accs.append(silhouette_coefficient_train) + + val_accs.append(val_silhouette_coef) + tracker.log_scalar_metric("silhouette_coefficient", "val", epoch, val_silhouette_coef) + + else: + train_loss, _, _, train_acc = train_epoch(slrt_model, train_loader, cel_criterion, optimizer, device) + tracker.log_scalar_metric("train_loss", "loss", epoch, train_loss) + tracker.log_scalar_metric("acc", "train", epoch, train_acc) + losses.append(train_loss) + train_accs.append(train_acc) + + _, _, val_acc = evaluate(slrt_model, val_loader, device) + val_accs.append(val_acc) + tracker.log_scalar_metric("acc", "val", epoch, val_acc) + + logger.info(f"Epoch time: {datetime.now() - start_time}") + logger.info("[" + str(epoch) + "] TRAIN loss: " + str(train_loss) + " acc: " + str(train_accs[-1])) + logger.info("[" + str(epoch) + "] VALIDATION acc: " + str(val_accs[-1])) + + lr_progress.append(optimizer.param_groups[0]["lr"]) + tracker.log_scalar_metric("lr", "lr", epoch, lr_progress[-1]) + + if val_accs[-1] > top_val_acc: + top_val_acc = val_accs[-1] + top_model_name = "checkpoint_" + model_type + "_" + str(epoch) + ".pth" + top_model_dict = { + "name": top_model_name, + "epoch": epoch, + "val_acc": val_accs[-1], + "config_args": args, + "state_dict": copy.deepcopy(slrt_model.state_dict()), + } + top_model_saved = False + + # Save checkpoint if it is the best on validation and delete previous checkpoints + if args.save_checkpoints_every > 0 and epoch % args.save_checkpoints_every == 0 and not top_model_saved: + torch.save( + top_model_dict, + "out-checkpoints/" + args.experiment_name + "/" + top_model_name + ) + top_model_saved = True + logger.info("Saved new best checkpoint: " + top_model_name) + + # save top model if checkpoints are disabled + if not top_model_saved: + torch.save( + top_model_dict, + "out-checkpoints/" + args.experiment_name + "/" + top_model_name + ) + logger.info("Saved new best checkpoint: " + top_model_name) + + # Log scatter plots + if not args.classification_model and args.hard_triplet_mining == "in_batch": + logger.info("Generating Scatter Plot.") + best_model = slrt_model + best_model.load_state_dict(top_model_dict["state_dict"]) + create_embedding_scatter_plots(tracker, best_model, train_loader, val_loader, device, id_to_label, epoch, + top_model_name) + logger.info("The experiment is finished.") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("", parents=[get_default_args()], add_help=False) + args = parser.parse_args() + tracker = get_tracker(args.tracker, PROJECT_NAME, args.experiment_name) + train(args, tracker) + tracker.finish_run() diff --git a/train.sh b/train.sh new file mode 100755 index 0000000..73f0c3a --- /dev/null +++ b/train.sh @@ -0,0 +1,24 @@ +#!/bin/sh +python -m train \ + --save_checkpoints_every -1 \ + --experiment_name "augment_rotate_75_x8" \ + --epochs 10 \ + --optimizer "SGD" \ + --lr 0.001 \ + --batch_size 32 \ + --dataset_name "wlasl" \ + --training_set_path "WLASL100_train.csv" \ + --validation_set_path "WLASL100_test.csv" \ + --vector_length 32 \ + --epoch_iters -1 \ + --scheduler_factor 0 \ + --hard_triplet_mining "in_batch" \ + --filter_easy_triplets \ + --triplet_loss_margin 1 \ + --dropout 0.2 \ + --start_mining_hard=200 \ + --hard_mining_pre_batch_multipler=16 \ + --hard_mining_pre_batch_mining_count=5 \ + --augmentations_prob=0.75 \ + --hard_mining_scheduler_triplets_threshold=0 \ + # --normalize_embeddings \ diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/batch_sorter.py b/training/batch_sorter.py new file mode 100644 index 0000000..6891365 --- /dev/null +++ b/training/batch_sorter.py @@ -0,0 +1,215 @@ +import logging +from datetime import datetime +import numpy as np +from typing import Optional +from .batching_scheduler import BatchingScheduler +import torch + +logger = logging.getLogger("BatchGrouper") + + +class BatchGrouper: + """ + Will cluster all `total_items` into `max_groups` clusters based on distances in + `sorted_dists`. Each group has `mini_batch_size` elements and these elements are just integers in + range 0...total_items. + + Distances between these items are expected to be scaled to 0...1 in a way that distances for two items in the + same class are higher if closer to 1, while distances between elements of different classes are higher if closer + to 0. + + The logic is picking the highest value distance and assigning both items to the same cluster/group if possible. + This might include merging 2 clusters. + There are a few threshold to limit the computational cost. If the scaled distance between a pair is below + `dist_threshold`, or more than `assign_threshold` percent of items have been assigned to the groups, we stop and + assign the remanining items to the groups that have space left. + """ + # Counters + next_group = 0 + items_assigned = 0 + + # Thresholds + dist_threshold = 0.5 + assign_threshold = 0.80 + + def __init__(self, sorted_dists, total_items, mini_batch_size=32, dist_threshold=0.5, assign_threshold=0.8) -> None: + self.sorted_dists = sorted_dists + self.total_items = total_items + self.mini_batch_size = mini_batch_size + self.max_groups = int(total_items / mini_batch_size) + self.groups = {} + self.item_to_group = {} + self.items_assigned = 0 + self.next_group = 0 + self.dist_threshold = dist_threshold + self.assign_threshold = assign_threshold + + def cluster_items(self): + """Main function of this class. Does the clustering explained in class docstring. + + :raises e: _description_ + :return _type_: _description_ + """ + for i in range(self.sorted_dists.shape[-1]): # and some other conditions are unmet + a, b, dist = self.sorted_dists[:, i] + a, b = int(a), int(b) + if dist < self.dist_threshold or self.items_assigned > self.total_items * self.assign_threshold: + logger.info(f"Breaking with dist: {dist}, and {self.items_assigned} items assigned") + break + if a not in self.item_to_group and b not in self.item_to_group: + g = self.create_or_get_group() + self.assign_group(a, g) + self.assign_group(b, g) + elif a not in self.item_to_group: + if not self.group_is_full(self.item_to_group[b]): + self.assign_group(a, self.item_to_group[b]) + elif b not in self.item_to_group: + if not self.group_is_full(self.item_to_group[a]): + self.assign_group(b, self.item_to_group[a]) + else: + grp_a = self.item_to_group[a] + grp_b = self.item_to_group[b] + self.merge_groups(grp_a, grp_b) + self.assign_remaining_items() + return list(np.concatenate(list(self.groups.values())).flat) + + def assign_group(self, item, group): + """Assigns `item` to group `group` + """ + self.item_to_group[item] = group + self.groups[group].append(item) + self.items_assigned += 1 + + def create_or_get_group(self): + """Creates a new group if current group count is less than max_groups. + Otherwise returns first group with space left. + + :return int: The group id + """ + if self.next_group < self.max_groups: + group = self.next_group + self.groups[group] = [] + self.next_group += 1 + else: + for i in range(self.next_group): + if len(self.groups[i]) <= self.mini_batch_size - 2: + group = i + break # out of the for loop + return group + + def group_is_full(self, group): + return len(self.groups[group]) == self.mini_batch_size + + def can_merge_groups(self, grp_a, grp_b): + return grp_a != grp_b and (len(self.groups[grp_a]) + len(self.groups[grp_b]) < self.mini_batch_size) + + def merge_groups(self, grp_a, grp_b): + """Will merge two groups together, if possible. Otherwise does nothing. + """ + if grp_a > grp_b: + grp_a, grp_b = grp_b, grp_a + if self.can_merge_groups(grp_a, grp_b): + logger.debug(f"MERGE {grp_a} with {grp_b}: {len(self.groups[grp_a])} {len(self.groups[grp_b])}") + for b in self.groups[grp_b]: + self.item_to_group[b] = grp_a + self.groups[grp_a].extend(self.groups[grp_b]) + self.groups[grp_b] = [] + self.replace_group(grp_b) + + def replace_group(self, group): + """Replace a group with the last one in the list + + :param int group: Group to replace + """ + grp_to_change = self.next_group - 1 + if grp_to_change != group: + for item in self.groups[grp_to_change]: + self.item_to_group[item] = group + self.groups[group] = self.groups[grp_to_change] + del self.groups[grp_to_change] + self.next_group -= 1 + + def assign_remaining_items(self): + """ Assign remaining items into groups + """ + grp_pointer = 0 + i = 0 + logger.info(f"Assigning rest of items: {self.items_assigned} of {self.total_items}") + while i < self.total_items: + if i not in self.item_to_group: + if grp_pointer not in self.groups: + # This would happen if a group is still empty at this stage + assert grp_pointer < self.max_groups + new_group = self.create_or_get_group() + assert new_group == grp_pointer + if len(self.groups[grp_pointer]) < self.mini_batch_size: + self.assign_group(i, grp_pointer) + i += 1 + else: + grp_pointer += 1 + else: + i += 1 + + +def get_dist_tuple_list(dist_matrix): + batch_size = dist_matrix.size()[0] + indices = torch.tril_indices(batch_size, batch_size, offset=-1) + values = dist_matrix[indices[0], indices[1]].cpu() + return torch.cat([indices, values.unsqueeze(0)], dim=0) + + +def get_scaled_distances(embeddings, labels, device, same_label_factor=1): + """Returns distance matrix between all embeddings scaled to the 0-1 range where 0 is good and 1 is bad. + This means that small distances for embeddings of the same class will be close to 0 while small distances for + embeddings of different classes will be close to 1 + + :param _type_ embeddings: Embeddings of batch items + :param _type_ labels: Labels associated to the embeddings + :param _type_ device: Device to run on (cuda or cpu) + :param int same_label_factor: Multiplies the weight of same-class distances allowing to give more or less importance + to these compared to distinct-class distances, defaults to 1 (which means equal weight) + :return torch.Tensor: Scaled distance matrix + """ + # Get pairwise distance matrix + distance_matrix = torch.cdist(embeddings, embeddings, p=2) + # Get list of tuples with emb_A, emb_B, dist ordered by greater for same label and smaller for diff label + # shape: (batch_size, batch_size) + labels = labels.to(device) + labels_equal = (labels.unsqueeze(0) == labels.unsqueeze(1)).squeeze() + labels_distinct = torch.logical_not(labels_equal) + pos_dist = distance_matrix * labels_equal + neg_dist = distance_matrix * labels_distinct + + # Use some scaling to bring both to a range of 0-1 + pos_max = pos_dist.max() + neg_max = neg_dist.max() + # Closer to 1 is harder + pos_dist = pos_dist / pos_max * same_label_factor + neg_dist = 1 * labels_distinct - (neg_dist / neg_max) + return pos_dist + neg_dist + + +def sort_batches(inputs, labels, masks, embeddings, device, mini_batch_size=32, + scheduler: Optional[BatchingScheduler] = None): + start = datetime.now() + + same_label_factor = scheduler.get_scaling_same_label_factor() if scheduler else 1 + scaled_dist = get_scaled_distances(embeddings, labels, device, same_label_factor) + # Get vector of (row, column, dist) + dist_list = get_dist_tuple_list(scaled_dist) + + dist_list = dist_list.cpu().detach().numpy() + # Sort distances descending by last row + sorted_dists = dist_list[:, dist_list[-1, :].argsort()[::-1]] + + # Loop through list assigning both items to same group + dist_threshold = scheduler.get_dist_threshold() if scheduler else 0.5 + grouper = BatchGrouper(sorted_dists, total_items=labels.size()[0], mini_batch_size=mini_batch_size, + dist_threshold=dist_threshold) + indices = torch.tensor(grouper.cluster_items()).type(torch.IntTensor) + final_inputs = torch.index_select(inputs, dim=0, index=indices) + final_labels = torch.index_select(labels, dim=0, index=indices) + final_masks = torch.index_select(masks, dim=0, index=indices) + + logger.info(f"Batch sorting took: {datetime.now() - start}") + return final_inputs, final_labels, final_masks diff --git a/training/batching_scheduler.py b/training/batching_scheduler.py new file mode 100644 index 0000000..9324dad --- /dev/null +++ b/training/batching_scheduler.py @@ -0,0 +1,62 @@ +from collections import deque +import numpy as np + + +class BatchingScheduler(): + """ This class acts as scheduler for the batching algorithm + """ + + def __init__(self, decay_factor=0.8, min_threshold=0.2, triplets_threshold=10, cooldown=10) -> None: + # internal vars + self._step_count = 0 + self._dist_threshold = 0.5 + self._last_used_triplets = deque([], 5) + self._scaling_same_label_factor = 1 + self._last_update_step = -10 + + # Parameters + self.decay_factor = decay_factor + self.min_threshold = min_threshold + self.triplets_threshold = triplets_threshold + self.cooldown = cooldown + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + """ + return {key: value for key, value in self.__dict__.items()} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def step(self, used_triplets): + self._step_count += 1 + self._last_used_triplets.append(used_triplets) + if (np.mean(self._last_used_triplets) < self.triplets_threshold and + self._last_update_step + self.cooldown <= self._step_count): + if self._dist_threshold > self.min_threshold: + print(f"Updating dist_threshold at {self._step_count} ({np.mean(self._last_used_triplets)})") + self.update_dist_threshold() + if self._scaling_same_label_factor > 0.6: + print(f"Updating scale factor at {self._step_count} ({np.mean(self._last_used_triplets)})") + self.update_scale_factor() + self._last_update_step = self._step_count + + def update_scale_factor(self): + self._scaling_same_label_factor = max(self._scaling_same_label_factor * 0.9, 0.6) + print(f"Updating scaling factor to {self._scaling_same_label_factor}") + + def update_dist_threshold(self): + self._dist_threshold = max(self.min_threshold, self._dist_threshold * self.decay_factor) + print(f"Updated dist_threshold to {self._dist_threshold}") + + def get_dist_threshold(self) -> float: + return self._dist_threshold + + def get_scaling_same_label_factor(self) -> float: + return self._scaling_same_label_factor diff --git a/training/gaussian_noise.py b/training/gaussian_noise.py new file mode 100644 index 0000000..7ca8889 --- /dev/null +++ b/training/gaussian_noise.py @@ -0,0 +1,18 @@ + +import torch + + +class GaussianNoise(object): + def __init__(self, mean=0., std=1.): + self.std = std + self.mean = mean + + def __call__(self, tensor): + return tensor + torch.randn(tensor.size()) * self.std + self.mean + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + + +if __name__ == "__main__": + pass diff --git a/training/online_batch_mining.py b/training/online_batch_mining.py new file mode 100644 index 0000000..5d0cf7a --- /dev/null +++ b/training/online_batch_mining.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks + +# Adapted from https://qdrant.tech/articles/triplet-loss/ + + +class BatchAllTripletLoss(nn.Module): + """Uses all valid triplets to compute Triplet loss + Args: + margin: Margin value in the Triplet Loss equation + """ + + def __init__(self, device, margin=1., filter_easy_triplets=True): + super().__init__() + self.margin = margin + self.device = device + self.filter_easy_triplets = filter_easy_triplets + + def get_triplet_mask(self, labels): + """compute a mask for valid triplets + Args: + labels: Batch of integer labels. shape: (batch_size,) + Returns: + Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size) + A triplet is valid if: + `labels[i] == labels[j] and labels[i] != labels[k]` + and `i`, `j`, `k` are different. + """ + # step 1 - get a mask for distinct indices + + # shape: (batch_size, batch_size) + indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device) + indices_not_equal = torch.logical_not(indices_equal) + # shape: (batch_size, batch_size, 1) + i_not_equal_j = indices_not_equal.unsqueeze(2) + # shape: (batch_size, 1, batch_size) + i_not_equal_k = indices_not_equal.unsqueeze(1) + # shape: (1, batch_size, batch_size) + j_not_equal_k = indices_not_equal.unsqueeze(0) + # Shape: (batch_size, batch_size, batch_size) + distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k) + + # step 2 - get a mask for valid anchor-positive-negative triplets + + # shape: (batch_size, batch_size) + labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + # shape: (batch_size, batch_size, 1) + i_equal_j = labels_equal.unsqueeze(2) + # shape: (batch_size, 1, batch_size) + i_equal_k = labels_equal.unsqueeze(1) + # shape: (batch_size, batch_size, batch_size) + valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k)) + + # step 3 - combine two masks + mask = torch.logical_and(distinct_indices, valid_indices) + + return mask + + def forward(self, embeddings, labels, filter_easy_triplets=True): + """computes loss value. + Args: + embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim) + labels: Batch of integer labels associated with embeddings. shape: (batch_size,) + Returns: + Scalar loss value. + """ + # step 1 - get distance matrix + # shape: (batch_size, batch_size) + distance_matrix = torch.cdist(embeddings, embeddings, p=2) + + # step 2 - compute loss values for all triplets by applying broadcasting to distance matrix + + # shape: (batch_size, batch_size, 1) + anchor_positive_dists = distance_matrix.unsqueeze(2) + # shape: (batch_size, 1, batch_size) + anchor_negative_dists = distance_matrix.unsqueeze(1) + # get loss values for all possible n^3 triplets + # shape: (batch_size, batch_size, batch_size) + triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin + + # step 3 - filter out invalid or easy triplets by setting their loss values to 0 + + # shape: (batch_size, batch_size, batch_size) + mask = self.get_triplet_mask(labels) + valid_triplets = mask.sum() + triplet_loss *= mask.to(self.device) + # easy triplets have negative loss values + triplet_loss = F.relu(triplet_loss) + + if self.filter_easy_triplets: + # step 4 - compute scalar loss value by averaging positive losses + num_positive_losses = (triplet_loss > eps).float().sum() + # We want to factor in how many triplets were used compared to batch_size (used_triplets * 3 / batch_size) + # The effect of this should be similar to LR decay but penalizing batches with fewer hard triplets + percent_used_factor = min(1.0, num_positive_losses * 3 / labels.size()[0]) + + triplet_loss = triplet_loss.sum() / (num_positive_losses + eps) * percent_used_factor + return triplet_loss, valid_triplets, int(num_positive_losses) + else: + triplet_loss = triplet_loss.sum() / (valid_triplets + eps) + return triplet_loss, valid_triplets, valid_triplets diff --git a/training/train_arguments.py b/training/train_arguments.py new file mode 100644 index 0000000..5980aff --- /dev/null +++ b/training/train_arguments.py @@ -0,0 +1,84 @@ +import argparse + + +def get_default_args(): + parser = argparse.ArgumentParser(add_help=False) + + parser.add_argument("--experiment_name", type=str, default="lsa_64_spoter", + help="Name of the experiment after which the logs and plots will be named") + parser.add_argument("--num_classes", type=int, default=100, help="Number of classes to be recognized by the model") + parser.add_argument("--hidden_dim", type=int, default=108, + help="Hidden dimension of the underlying Transformer model") + parser.add_argument("--seed", type=int, default=379, + help="Seed with which to initialize all the random components of the training") + + # Embeddings + parser.add_argument("--classification_model", action='store_true', default=False, + help="Select SPOTER model to train, pass only for original classification model") + parser.add_argument("--vector_length", type=int, default=32, + help="Number of features used in the embedding vector") + parser.add_argument("--epoch_iters", type=int, default=-1, + help="Iterations per epoch while training embeddings. Will loop through dataset once if -1") + parser.add_argument("--batch_size", type=int, default=32, help="Batch Size during training and validation") + parser.add_argument("--hard_triplet_mining", type=str, default=None, + help="Strategy to select hard triplets, options [None, in_batch]") + parser.add_argument("--triplet_loss_margin", type=float, default=1, + help="Margin used in triplet loss margin (See documentation)") + parser.add_argument("--normalize_embeddings", action='store_true', default=False, + help="Normalize model output to keep vector length to one") + parser.add_argument("--filter_easy_triplets", action='store_true', default=False, + help="Filter easy triplets in online in batch triplets") + + # Data + parser.add_argument("--dataset_name", type=str, default="", help="Dataset name") + parser.add_argument("--dataset_project", type=str, default="Sign Language Recognition", help="Dataset project name") + parser.add_argument("--training_set_path", type=str, default="", + help="Path to the training dataset CSV file (relative to root dataset)") + parser.add_argument("--validation_set_path", type=str, default="", help="Path to the validation dataset CSV file") + parser.add_argument("--dataset_loader", type=str, default="local", + help="Dataset loader to use, options: [clearml, local]") + + # Training hyperparameters + parser.add_argument("--epochs", type=int, default=1300, help="Number of epochs to train the model for") + parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the model training") + parser.add_argument("--dropout", type=float, default=0.1, + help="Dropout used in transformer layer") + parser.add_argument("--augmentations_prob", type=float, default=0.5, help="How often to use data augmentation") + + # Checkpointing + parser.add_argument("--save_checkpoints_every", type=int, default=-1, + help="Determines every how many epochs the weight checkpoints are saved. If -1 only best model \ + after final epoch") + + # Optimizer + parser.add_argument("--optimizer", type=str, default="SGD", + help="Optimizer used during training, options: [SGD, ADAM]") + + # Tracker + parser.add_argument("--tracker", type=str, default="none", + help="Experiment tracker to use, options: [clearml, none]") + + # Scheduler + parser.add_argument("--scheduler_factor", type=float, default=0, + help="Factor for the ReduceLROnPlateau scheduler") + parser.add_argument("--scheduler_patience", type=int, default=10, + help="Patience for the ReduceLROnPlateau scheduler") + parser.add_argument("--scheduler_warmup", type=int, default=400, + help="Warmup epochs before scheduler starts") + + # Gaussian noise normalization + parser.add_argument("--gaussian_mean", type=float, default=0, help="Mean parameter for Gaussian noise layer") + parser.add_argument("--gaussian_std", type=float, default=0.001, + help="Standard deviation parameter for Gaussian noise layer") + + # Batch Sorting + parser.add_argument("--start_mining_hard", type=int, default=None, help="On which epoch to start hard mining") + parser.add_argument("--hard_mining_pre_batch_multipler", type=int, default=16, + help="How many batches should be computed at once") + parser.add_argument("--hard_mining_pre_batch_mining_count", type=int, default=5, + help="How many times to loop through a list of computed batches") + parser.add_argument("--hard_mining_scheduler_triplets_threshold", type=float, default=0, + help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \ + distance threshold of the batch sorter") + + return parser diff --git a/training/train_utils.py b/training/train_utils.py new file mode 100644 index 0000000..8069abd --- /dev/null +++ b/training/train_utils.py @@ -0,0 +1,71 @@ +import os +import random +import numpy as np +import pandas as pd +import plotly.express as px +import torch + +from models import embeddings_scatter_plot, embeddings_scatter_plot_splits + + +def train_setup(seed, experiment_name): + random.seed(seed) + np.random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + g = torch.Generator() + g.manual_seed(seed) + return g + + +def create_embedding_scatter_plots(tracker, model, train_loader, val_loader, device, id_to_label, epoch, model_name): + tsne_results, labels = embeddings_scatter_plot(model, train_loader, device, id_to_label, perplexity=40, n_iter=1000) + + df = pd.DataFrame({'x': tsne_results[:, 0], + 'y': tsne_results[:, 1], + 'label': labels}) + fig = px.scatter(df, y="y", x="x", color="label") + + tracker.log_chart( + title="Training Scatter Plot with Best Model: " + model_name, + series="Scatter Plot", + iteration=epoch, + figure=fig + ) + + tsne_results, labels = embeddings_scatter_plot(model, val_loader, device, id_to_label, perplexity=40, n_iter=1000) + + df = pd.DataFrame({'x': tsne_results[:, 0], + 'y': tsne_results[:, 1], + 'label': labels}) + fig = px.scatter(df, y="y", x="x", color="label") + + tracker.log_chart( + title="Validation Scatter Plot with Best Model: " + model_name, + series="Scatter Plot", + iteration=epoch, + figure=fig, + ) + + dataloaders = {'train': train_loader, + 'val': val_loader} + splits = list(dataloaders.keys()) + tsne_results_splits, labels_splits = embeddings_scatter_plot_splits(model, dataloaders, + device, id_to_label, perplexity=40, n_iter=1000) + tsne_results = np.vstack([tsne_results_splits[split] for split in splits]) + labels = np.concatenate([labels_splits[split] for split in splits]) + split = np.concatenate([[split]*len(labels_splits[split]) for split in splits]) + df = pd.DataFrame({'x': tsne_results[:, 0], + 'y': tsne_results[:, 1], + 'label': labels, + 'split': split}) + fig = px.scatter(df, y="y", x="x", color="label", symbol='split') + tracker.log_chart( + title="Scatter Plot of train and val with Best Model: " + model_name, + series="Scatter Plot", + iteration=epoch, + figure=fig, + ) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..0949518 --- /dev/null +++ b/utils.py @@ -0,0 +1,40 @@ +import logging +import os + + +class CustomFormatter(logging.Formatter): + + grey = "\x1b[38;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + reset = "\x1b[0m" + custom_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" + + FORMATS = { + logging.DEBUG: grey + custom_format + reset, + logging.INFO: grey + custom_format + reset, + logging.WARNING: yellow + custom_format + reset, + logging.ERROR: red + custom_format + reset, + logging.CRITICAL: bold_red + custom_format + reset + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) + + +def get_logger(name): + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + # create console handler with a higher log level + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + ch.setFormatter(CustomFormatter()) + file_handler = logging.FileHandler(os.getenv('EXPERIMENT_NAME', 'run') + ".log") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(CustomFormatter()) + logger.addHandler(ch) + logger.addHandler(file_handler) + return logger diff --git a/web/README.md b/web/README.md new file mode 100644 index 0000000..0490fe0 --- /dev/null +++ b/web/README.md @@ -0,0 +1,8 @@ +# SPOTER Web + +To test Spoter model in the web, follow these steps: +* Convert your latest Pytorch model to Onnx by running `python convert.py`. This is best done inside the Docker container. You will need to install additional dependencies for the conversions (see commented lines in requirements.txt) +* The ONNX should be generated in the `web` folder, otherwise copy it there. +* run `npx light-server -s . -p 8080` in the `web` folder. (`npx` comes with `npm`) + +Enjoy! diff --git a/web/index.html b/web/index.html new file mode 100644 index 0000000..4bd6a86 --- /dev/null +++ b/web/index.html @@ -0,0 +1,61 @@ + + +
+ ONNX Runtime JavaScript examples: Quick Start - Web (using script tag) +
+ + +

+ + + + + \ No newline at end of file