import multiprocessing import os import torch import argparse from datasets.dataset_loader import LocalDatasetLoader from datasets.embedding_dataset import SLREmbeddingDataset from torch.utils.data import DataLoader from datasets import SLREmbeddingDataset, collate_fn_padd from models.spoter_embedding_model import SPOTER_EMBEDDINGS import numpy as np import random import pandas as pd seed = 43 random.seed(seed) np.random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True) generator = torch.Generator() generator.manual_seed(seed) def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) generator = torch.Generator() generator.manual_seed(seed) import os os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" def parse_args(): parser = argparse.ArgumentParser(description='Export embeddings') parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint') parser.add_argument('--output', type=str, default=None, help='Path to output') parser.add_argument('--dataset', type=str, default=None, help='Path to data') parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)') args = parser.parse_args() return args args = parse_args() device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") # load the model checkpoint = torch.load(args.checkpoint, map_location=device) model = SPOTER_EMBEDDINGS( features=checkpoint["config_args"].vector_length, hidden_dim=checkpoint["config_args"].hidden_dim, norm_emb=checkpoint["config_args"].normalize_embeddings, ).to(device) model.load_state_dict(checkpoint["state_dict"]) dataset_loader = LocalDatasetLoader() dataset = SLREmbeddingDataset(args.dataset, triplet=False, augmentations=False) data_loader = DataLoader( dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_padd, pin_memory=torch.cuda.is_available(), #num_workers=0, # Uncomment this line (and comment out next line) if you want to disable multithreading num_workers=multiprocessing.cpu_count(), worker_init_fn=seed_worker, generator=generator, ) embeddings = [] k = 0 with torch.no_grad(): for i, (inputs, labels, masks) in enumerate(data_loader): k += 1 inputs = inputs.to(device) masks = masks.to(device) outputs = model(inputs, masks) for n in range(outputs.shape[0]): embeddings.append(outputs[n].cpu().numpy()) df = pd.read_csv(args.dataset) df["embeddings"] = embeddings df = df[['embeddings', 'label_name', 'labels']] df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0]) if args.format == 'json': df.to_json(args.output, orient='records') elif args.format == 'csv': df.to_csv(args.output, index=False)