Added ability to finetune models
This commit is contained in:
BIN
checkpoints/checkpoint_embed_3006.pth
Normal file
BIN
checkpoints/checkpoint_embed_3006.pth
Normal file
Binary file not shown.
@@ -30,12 +30,15 @@ def seed_worker(worker_id):
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
import os
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Export embeddings')
|
||||
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
|
||||
parser.add_argument('--output', type=str, default=None, help='Path to output')
|
||||
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
|
||||
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -85,7 +88,9 @@ with torch.no_grad():
|
||||
df = pd.read_csv(args.dataset)
|
||||
df["embeddings"] = embeddings
|
||||
df = df[['embeddings', 'label_name', 'labels']]
|
||||
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist())
|
||||
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
|
||||
|
||||
|
||||
df.to_csv(args.output, index=False)
|
||||
if args.format == 'json':
|
||||
df.to_json(args.output, orient='records')
|
||||
elif args.format == 'csv':
|
||||
df.to_csv(args.output, index=False)
|
||||
@@ -12,10 +12,10 @@ output=32
|
||||
# load PyTorch model from .pth file
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
# if torch.cuda.is_available():
|
||||
# device = torch.device("cuda")
|
||||
|
||||
CHECKPOINT_PATH = "out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth"
|
||||
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
|
||||
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||
|
||||
model = SPOTER_EMBEDDINGS(
|
||||
@@ -30,7 +30,10 @@ model.eval()
|
||||
model_export = "onnx"
|
||||
if model_export == "coreml":
|
||||
dummy_input = torch.randn(1, 10, 54, 2)
|
||||
# set device for dummy input
|
||||
dummy_input = dummy_input.to(device)
|
||||
traced_model = torch.jit.trace(model, dummy_input)
|
||||
|
||||
out = traced_model(dummy_input)
|
||||
import coremltools as ct
|
||||
|
||||
@@ -41,10 +44,12 @@ if model_export == "coreml":
|
||||
)
|
||||
|
||||
# Save Core ML model
|
||||
coreml_model.save("models/" + model_name + ".mlmodel")
|
||||
coreml_model.save("out-models/" + model_name + ".mlmodel")
|
||||
else:
|
||||
# create dummy input tensor
|
||||
dummy_input = torch.randn(1, 10, 54, 2)
|
||||
# set device for dummy input
|
||||
dummy_input = dummy_input.to(device)
|
||||
|
||||
# export model to ONNX format
|
||||
output_file = 'models/' + model_name + '.onnx'
|
||||
@@ -52,7 +57,7 @@ else:
|
||||
|
||||
torch.onnx.export(model, # model being run
|
||||
dummy_input, # model input (or a tuple for multiple inputs)
|
||||
'output-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
|
||||
'out-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=9, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
|
||||
2
train.py
2
train.py
@@ -77,7 +77,7 @@ def train(args, tracker: Tracker):
|
||||
if not args.classification_model:
|
||||
# if finetune, load the weights from the classification model
|
||||
if args.finetune:
|
||||
checkpoint = torch.load(args.checkpoint, map_location=device)
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location=device)
|
||||
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
|
||||
24
train.sh
24
train.sh
@@ -1,21 +1,23 @@
|
||||
#!/bin/sh
|
||||
python3 -m train \
|
||||
--save_checkpoints_every 10 \
|
||||
--experiment_name "wlasl" \
|
||||
--epochs 600 \
|
||||
--optimizer "SGD" \
|
||||
--lr 0.001 \
|
||||
--save_checkpoints_every 1 \
|
||||
--experiment_name "Finetune Basic Signs" \
|
||||
--epochs 100 \
|
||||
--optimizer "ADAM" \
|
||||
--lr 0.00001 \
|
||||
--batch_size 16 \
|
||||
--dataset_name "WLASL" \
|
||||
--training_set_path "WLASL100_train.csv" \
|
||||
--validation_set_path "WLASL100_val.csv" \
|
||||
--dataset_name "BasicSigns" \
|
||||
--training_set_path "train.csv" \
|
||||
--validation_set_path "val.csv" \
|
||||
--vector_length 32 \
|
||||
--epoch_iters -1 \
|
||||
--scheduler_factor 0.2 \
|
||||
--hard_triplet_mining "in_batch" \
|
||||
--scheduler_factor 0.05 \
|
||||
--hard_triplet_mining "None" \
|
||||
--filter_easy_triplets \
|
||||
--triplet_loss_margin 2 \
|
||||
--dropout 0.2 \
|
||||
--tracker=clearml \
|
||||
--dataset_loader=clearml \
|
||||
--dataset_project="SpoterEmbedding"
|
||||
--dataset_project="SpoterEmbedding" \
|
||||
--finetune \
|
||||
--checkpoint_path "checkpoints/checkpoint_embed_3006.pth"
|
||||
@@ -81,4 +81,7 @@ def get_default_args():
|
||||
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
|
||||
distance threshold of the batch sorter")
|
||||
|
||||
parser.add_argument("--finetune", action='store_true', default=False, help="Fintune the model")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="")
|
||||
|
||||
return parser
|
||||
|
||||
@@ -238,6 +238,7 @@ def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
|
||||
f"{kk}-dimensional vectors")
|
||||
|
||||
if m*n*k <= threshold:
|
||||
print("Using minkowski_distance")
|
||||
return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
|
||||
else:
|
||||
result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype
|
||||
|
||||
Reference in New Issue
Block a user