Compare commits
7 Commits
WES-129-ba
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40c16548b2 | ||
|
|
17251edfda | ||
|
|
bfef06d720 | ||
|
|
7cf35d7357 | ||
|
|
65d478ef1b | ||
|
|
cd9cc8ce8b | ||
|
|
0af9320571 |
BIN
models/model_A-Z_v2.onnx
Normal file
BIN
models/model_A-Z_v2.onnx
Normal file
Binary file not shown.
Binary file not shown.
@@ -37,10 +37,27 @@ def circle_intersection(x0, y0, r0, x1, y1, r1):
|
|||||||
|
|
||||||
class MirrorKeypoints:
|
class MirrorKeypoints:
|
||||||
def __call__(self, sample):
|
def __call__(self, sample):
|
||||||
|
if sample.shape[0] == 0:
|
||||||
|
return sample
|
||||||
if random.random() > 0.5:
|
if random.random() > 0.5:
|
||||||
return sample
|
return sample
|
||||||
# flip the keypoints tensor
|
|
||||||
sample = 1 - sample
|
# flip the x coordinates
|
||||||
|
sample[:, :, 0] *= -1
|
||||||
|
|
||||||
|
# switch hands (left becomes right and vice versa)
|
||||||
|
left, right, n = 12, 33, 21
|
||||||
|
if isinstance(sample, np.ndarray): # For testing purposes only
|
||||||
|
sample[:, left:left+n, :], sample[:, right:right+n, :] = sample[: , right:right+n, :], sample[:, left:left+n, :].copy()
|
||||||
|
else:
|
||||||
|
sample[:, left:left+n, :], sample[:, right:right+n, :] = sample[: , right:right+n, :], sample[:, left:left+n, :].clone()
|
||||||
|
|
||||||
|
# switch pose keypoints
|
||||||
|
sample[:, [1, 2], :] = sample[:, [2, 1], :] #eye
|
||||||
|
sample[:, [3, 4], :] = sample[:, [4, 3], :] #ear
|
||||||
|
sample[:, [6, 7], :] = sample[:, [7, 6], :] #shoulder
|
||||||
|
sample[:, [8, 9], :] = sample[:, [9, 8], :] #elbow
|
||||||
|
sample[:, [10, 11], :] = sample[:, [11, 10], :] #wrist
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -125,3 +142,15 @@ class NoiseAugmentation:
|
|||||||
# add noise to the keypoints
|
# add noise to the keypoints
|
||||||
sample = sample + torch.randn(sample.shape) * self.noise
|
sample = sample + torch.randn(sample.shape) * self.noise
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
# augmentation to rotate all keypoints around 0,0
|
||||||
|
class RotateAugmentation:
|
||||||
|
def __call__(self, sample):
|
||||||
|
# generate a random angle between -13 and 13 degrees
|
||||||
|
angle_max = 13.0
|
||||||
|
angle = math.radians(random.uniform(a=-angle_max, b=angle_max))
|
||||||
|
# rotate the keypoints around 0.0
|
||||||
|
new_sample = sample
|
||||||
|
new_sample[:, :, 0] = sample[:, :, 0]*math.cos(angle) - sample[:, :, 1]*math.sin(angle)
|
||||||
|
new_sample[:, :, 1] = sample[:, :, 0]*math.sin(angle) + sample[:, :, 1]*math.cos(angle)
|
||||||
|
return new_sample
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
|
|
||||||
from src.identifiers import LANDMARKS
|
|
||||||
from src.keypoint_extractor import KeypointExtractor
|
|
||||||
|
|
||||||
|
|
||||||
class BasicsDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, data_folder: str, bad_data_folder: str = "", subset:str="train", keypoints_identifier: dict = None, transform=None):
|
|
||||||
|
|
||||||
|
|
||||||
# list files with path in the datafolder ending with .mp4
|
|
||||||
files = [data_folder + f for f in os.listdir(data_folder) if f.endswith(".mp4")]
|
|
||||||
|
|
||||||
# append files from bad data folder
|
|
||||||
if bad_data_folder != "":
|
|
||||||
files += [bad_data_folder + f for f in os.listdir(bad_data_folder) if f.endswith(".mp4")]
|
|
||||||
|
|
||||||
labels = [f.split("/")[-1].split("!")[0] for f in files]
|
|
||||||
train_test = [f.split("/")[-1].split("!")[1] for f in files]
|
|
||||||
|
|
||||||
# count the number of each label
|
|
||||||
self.label_mapping, counts = np.unique(labels, return_counts=True)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# map the labels to their integer
|
|
||||||
labels = [np.where(self.label_mapping == label)[0][0] for label in labels]
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: make split for train and val and test when enough data is available
|
|
||||||
if subset == "train":
|
|
||||||
# mask for train data
|
|
||||||
mask = np.array(train_test) == "train"
|
|
||||||
elif subset == "test":
|
|
||||||
mask = np.array(train_test) == "test"
|
|
||||||
|
|
||||||
# filter data and labels
|
|
||||||
self.data = np.array(files)[mask]
|
|
||||||
self.labels = np.array(labels)[mask]
|
|
||||||
|
|
||||||
# filter data by subset
|
|
||||||
self.transform = transform
|
|
||||||
self.subset = subset
|
|
||||||
self.keypoint_extractor = KeypointExtractor()
|
|
||||||
if keypoints_identifier:
|
|
||||||
self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
# get i th element from ordered dict
|
|
||||||
video_name = self.data[index]
|
|
||||||
|
|
||||||
cache_name = video_name.split("/")[-1].split(".")[0] + ".npy"
|
|
||||||
|
|
||||||
# check if cache_name file exists
|
|
||||||
if not os.path.isfile(os.path.join("cache_processed", cache_name)):
|
|
||||||
|
|
||||||
|
|
||||||
# get the keypoints for the video (normalizations: minxmax, bohacek)
|
|
||||||
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name, normalize="bohacek")
|
|
||||||
|
|
||||||
# filter the keypoints by the identified subset
|
|
||||||
if self.keypoints_to_keep:
|
|
||||||
keypoints_df = keypoints_df[self.keypoints_to_keep]
|
|
||||||
|
|
||||||
current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2))
|
|
||||||
for i in range(0, keypoints_df.shape[1], 2):
|
|
||||||
current_row[:, i // 2, 0] = keypoints_df.iloc[:, i]
|
|
||||||
current_row[:, i // 2, 1] = keypoints_df.iloc[:, i + 1]
|
|
||||||
|
|
||||||
# check if cache_processed folder exists
|
|
||||||
if not os.path.isdir("cache_processed"):
|
|
||||||
os.mkdir("cache_processed")
|
|
||||||
|
|
||||||
# save the processed data to a file
|
|
||||||
np.save(os.path.join("cache_processed", cache_name), current_row)
|
|
||||||
|
|
||||||
else:
|
|
||||||
current_row = np.load(os.path.join("cache_processed", cache_name))
|
|
||||||
|
|
||||||
# get the label
|
|
||||||
label = self.labels[index]
|
|
||||||
# data to tensor
|
|
||||||
data = torch.from_numpy(current_row)
|
|
||||||
|
|
||||||
if self.transform:
|
|
||||||
data = self.transform(data)
|
|
||||||
|
|
||||||
return data, label
|
|
||||||
@@ -7,7 +7,7 @@ from src.model import SPOTER
|
|||||||
from src.identifiers import LANDMARKS
|
from src.identifiers import LANDMARKS
|
||||||
|
|
||||||
# set parameters of the model
|
# set parameters of the model
|
||||||
model_name = 'model_A-Z'
|
model_name = 'model_A-Z_v2'
|
||||||
num_classes = 26
|
num_classes = 26
|
||||||
|
|
||||||
# load PyTorch model from .pth file
|
# load PyTorch model from .pth file
|
||||||
|
|||||||
27
src/model.py
27
src/model.py
@@ -1,7 +1,6 @@
|
|||||||
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
|
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import math
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -39,19 +38,6 @@ class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer):
|
|||||||
|
|
||||||
return tgt
|
return tgt
|
||||||
|
|
||||||
class PositionalEmbedding(nn.Module):
|
|
||||||
def __init__(self, d_model, max_len=60):
|
|
||||||
super().__init__()
|
|
||||||
pe = torch.zeros(max_len, d_model)
|
|
||||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
||||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
|
||||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
|
||||||
self.register_buffer('pe', pe)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x + self.pe[:x.size(0), :]
|
|
||||||
|
|
||||||
class SPOTER(nn.Module):
|
class SPOTER(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -62,9 +48,8 @@ class SPOTER(nn.Module):
|
|||||||
def __init__(self, num_classes, hidden_dim=55):
|
def __init__(self, num_classes, hidden_dim=55):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
|
||||||
self.pos = PositionalEmbedding(hidden_dim)
|
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
|
||||||
|
|
||||||
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
|
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
|
||||||
self.transformer = nn.Transformer(hidden_dim, 9, 6, 6)
|
self.transformer = nn.Transformer(hidden_dim, 9, 6, 6)
|
||||||
self.linear_class = nn.Linear(hidden_dim, num_classes)
|
self.linear_class = nn.Linear(hidden_dim, num_classes)
|
||||||
@@ -76,13 +61,7 @@ class SPOTER(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
|
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
|
||||||
# add positional encoding
|
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
||||||
h = self.pos(h)
|
|
||||||
|
|
||||||
# add class query
|
|
||||||
h = self.transformer(h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
|
||||||
|
|
||||||
# get class prediction
|
|
||||||
res = self.linear_class(h)
|
res = self.linear_class(h)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
18
src/train.py
18
src/train.py
@@ -8,7 +8,7 @@ import torch.optim as optim
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation
|
from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation, RotateAugmentation
|
||||||
from src.datasets.finger_spelling_dataset import FingerSpellingDataset
|
from src.datasets.finger_spelling_dataset import FingerSpellingDataset
|
||||||
from src.identifiers import LANDMARKS
|
from src.identifiers import LANDMARKS
|
||||||
from src.model import SPOTER
|
from src.model import SPOTER
|
||||||
@@ -29,13 +29,17 @@ def train():
|
|||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
g.manual_seed(379)
|
g.manual_seed(379)
|
||||||
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
|
|
||||||
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) *2)
|
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) *2)
|
||||||
|
|
||||||
|
# use cuda if available
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
spoter_model.train(True)
|
spoter_model.train(True)
|
||||||
spoter_model.to(device)
|
spoter_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
criterion_bad = CustomLoss()
|
criterion_bad = CustomLoss()
|
||||||
optimizer = optim.Adam(spoter_model.parameters(), lr=0.00001)
|
optimizer = optim.Adam(spoter_model.parameters(), lr=0.00001)
|
||||||
@@ -45,7 +49,7 @@ def train():
|
|||||||
if not os.path.exists("checkpoints"):
|
if not os.path.exists("checkpoints"):
|
||||||
os.makedirs("checkpoints")
|
os.makedirs("checkpoints")
|
||||||
|
|
||||||
transform = transforms.Compose([MirrorKeypoints(), NoiseAugmentation(noise=0.1)])
|
transform = transforms.Compose([MirrorKeypoints(), NoiseAugmentation(noise=0.1), RotateAugmentation()])
|
||||||
|
|
||||||
train_set = FingerSpellingDataset("data/fingerspelling/data/", bad_data_folder="", keypoints_identifier=LANDMARKS, subset="train", transform=transform)
|
train_set = FingerSpellingDataset("data/fingerspelling/data/", bad_data_folder="", keypoints_identifier=LANDMARKS, subset="train", transform=transform)
|
||||||
train_loader = DataLoader(train_set, shuffle=True, generator=g)
|
train_loader = DataLoader(train_set, shuffle=True, generator=g)
|
||||||
@@ -124,9 +128,9 @@ def train():
|
|||||||
if val_acc > best_val_acc:
|
if val_acc > best_val_acc:
|
||||||
best_val_acc = val_acc
|
best_val_acc = val_acc
|
||||||
epochs_without_improvement = 0
|
epochs_without_improvement = 0
|
||||||
if epoch > 55:
|
if epoch > 45:
|
||||||
top_val_acc = val_acc
|
top_val_acc = val_acc
|
||||||
top_train_acc = train_acc
|
top_train_acc = pred_correct / pred_all
|
||||||
checkpoint_index = epoch
|
checkpoint_index = epoch
|
||||||
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,152 +0,0 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation
|
|
||||||
from src.datasets.basics_dataset import BasicsDataset
|
|
||||||
from src.identifiers import LANDMARKS
|
|
||||||
from src.model import SPOTER
|
|
||||||
from src.loss_function import CustomLoss
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
def train():
|
|
||||||
writer = SummaryWriter()
|
|
||||||
random.seed(379)
|
|
||||||
np.random.seed(379)
|
|
||||||
os.environ['PYTHONHASHSEED'] = str(379)
|
|
||||||
torch.manual_seed(379)
|
|
||||||
torch.cuda.manual_seed(379)
|
|
||||||
torch.cuda.manual_seed_all(379)
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
g = torch.Generator()
|
|
||||||
g.manual_seed(379)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
spoter_model = SPOTER(num_classes=15, hidden_dim=len(LANDMARKS) *2)
|
|
||||||
spoter_model.train(True)
|
|
||||||
spoter_model.to(device)
|
|
||||||
|
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
criterion_bad = CustomLoss()
|
|
||||||
optimizer = optim.Adam(spoter_model.parameters(), lr=0.00001)
|
|
||||||
scheduler = None
|
|
||||||
|
|
||||||
# check if checkpoints folder exists
|
|
||||||
if not os.path.exists("checkpoints"):
|
|
||||||
os.makedirs("checkpoints")
|
|
||||||
|
|
||||||
transform = transforms.Compose([NoiseAugmentation(noise=0.1)])
|
|
||||||
|
|
||||||
train_set = BasicsDataset("data/basics/data/", bad_data_folder="", keypoints_identifier=LANDMARKS, subset="train", transform=transform)
|
|
||||||
train_loader = DataLoader(train_set, shuffle=True, generator=g)
|
|
||||||
|
|
||||||
val_set = BasicsDataset("data/basics/data/", bad_data_folder="", keypoints_identifier=LANDMARKS, subset="test")
|
|
||||||
val_loader = DataLoader(val_set, shuffle=True, generator=g)
|
|
||||||
|
|
||||||
|
|
||||||
train_acc, val_acc = 0, 0
|
|
||||||
lr_progress = []
|
|
||||||
top_train_acc, top_val_acc = 0, 0
|
|
||||||
checkpoint_index = 0
|
|
||||||
|
|
||||||
epochs_without_improvement = 0
|
|
||||||
best_val_acc = 0
|
|
||||||
|
|
||||||
for epoch in range(300):
|
|
||||||
|
|
||||||
running_loss = 0.0
|
|
||||||
pred_correct, pred_all = 0, 0
|
|
||||||
|
|
||||||
# train
|
|
||||||
for i, (inputs, labels) in enumerate(train_loader):
|
|
||||||
# skip videos that are too short
|
|
||||||
if inputs.shape[1] < 20:
|
|
||||||
continue
|
|
||||||
|
|
||||||
inputs = inputs.squeeze(0).to(device)
|
|
||||||
labels = labels.to(device, dtype=torch.long)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
outputs = spoter_model(inputs).expand(1, -1, -1)
|
|
||||||
loss = criterion(outputs[0], labels)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
running_loss += loss
|
|
||||||
|
|
||||||
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
|
|
||||||
pred_correct += 1
|
|
||||||
pred_all += 1
|
|
||||||
|
|
||||||
|
|
||||||
if scheduler:
|
|
||||||
scheduler.step(running_loss.item() / (len(train_loader)) )
|
|
||||||
|
|
||||||
writer.add_scalar("Loss/train", loss, epoch)
|
|
||||||
writer.add_scalar("Accuracy/train", (pred_correct / pred_all), epoch)
|
|
||||||
|
|
||||||
# validate and print val acc
|
|
||||||
val_pred_correct, val_pred_all = 0, 0
|
|
||||||
val_loss = 0.0
|
|
||||||
with torch.no_grad():
|
|
||||||
for i, (inputs, labels) in enumerate(val_loader):
|
|
||||||
inputs = inputs.squeeze(0).to(device)
|
|
||||||
labels = labels.to(device, dtype=torch.long)
|
|
||||||
|
|
||||||
outputs = spoter_model(inputs).expand(1, -1, -1)
|
|
||||||
|
|
||||||
# calculate loss
|
|
||||||
val_loss += criterion(outputs[0], labels)
|
|
||||||
|
|
||||||
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
|
|
||||||
val_pred_correct += 1
|
|
||||||
val_pred_all += 1
|
|
||||||
|
|
||||||
val_acc = (val_pred_correct / val_pred_all)
|
|
||||||
|
|
||||||
writer.add_scalar("Loss/val", val_loss, epoch)
|
|
||||||
writer.add_scalar("Accuracy/val", val_acc, epoch)
|
|
||||||
|
|
||||||
|
|
||||||
print(f"Epoch: {epoch} | Train Acc: {(pred_correct / pred_all)} | Val Acc: {val_acc}")
|
|
||||||
|
|
||||||
# save checkpoint and update epochs_without_improvement
|
|
||||||
if val_acc > best_val_acc:
|
|
||||||
best_val_acc = val_acc
|
|
||||||
epochs_without_improvement = 0
|
|
||||||
if epoch > 20:
|
|
||||||
top_val_acc = val_acc
|
|
||||||
top_train_acc = train_acc
|
|
||||||
checkpoint_index = epoch
|
|
||||||
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
|
||||||
else:
|
|
||||||
epochs_without_improvement += 1
|
|
||||||
|
|
||||||
# early stopping
|
|
||||||
if epochs_without_improvement >= 40:
|
|
||||||
print("Early stopping due to no improvement in validation accuracy for 40 epochs.")
|
|
||||||
break
|
|
||||||
|
|
||||||
lr_progress.append(optimizer.param_groups[0]['lr'])
|
|
||||||
|
|
||||||
print(f"Best val acc: {top_val_acc} | Best train acc: {top_train_acc} | Epoch: {checkpoint_index}")
|
|
||||||
writer.flush()
|
|
||||||
writer.close()
|
|
||||||
|
|
||||||
|
|
||||||
# Path: src/train.py
|
|
||||||
if __name__ == "__main__":
|
|
||||||
train()
|
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -26,8 +26,8 @@ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|||||||
|
|
||||||
keypoints = []
|
keypoints = []
|
||||||
|
|
||||||
spoter_model = SPOTER(num_classes=19, hidden_dim=len(LANDMARKS) * 2)
|
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) * 2)
|
||||||
spoter_model.load_state_dict(torch.load('checkpoints/spoter_80.pth', map_location=torch.device('cpu')))
|
spoter_model.load_state_dict(torch.load('models/model_A-Z_v2.pth', map_location=torch.device('cpu')))
|
||||||
|
|
||||||
# get values of the landmarks as a list of integers
|
# get values of the landmarks as a list of integers
|
||||||
values = []
|
values = []
|
||||||
|
|||||||
Reference in New Issue
Block a user