spoterembedding/export_embeddings.py
2023-05-21 20:30:12 +00:00

97 lines
3.0 KiB
Python

import multiprocessing
import os
import torch
import argparse
from datasets.dataset_loader import LocalDatasetLoader
from datasets.embedding_dataset import SLREmbeddingDataset
from torch.utils.data import DataLoader
from datasets import SLREmbeddingDataset, collate_fn_padd
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
import numpy as np
import random
import pandas as pd
seed = 43
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
generator = torch.Generator()
generator.manual_seed(seed)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
generator = torch.Generator()
generator.manual_seed(seed)
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def parse_args():
parser = argparse.ArgumentParser(description='Export embeddings')
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
parser.add_argument('--output', type=str, default=None, help='Path to output')
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
args = parser.parse_args()
return args
args = parse_args()
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
# load the model
checkpoint = torch.load(args.checkpoint, map_location=device)
model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)
model.load_state_dict(checkpoint["state_dict"])
dataset_loader = LocalDatasetLoader()
dataset = SLREmbeddingDataset(args.dataset, triplet=False, augmentations=False)
data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_fn_padd,
pin_memory=torch.cuda.is_available(),
#num_workers=0, # Uncomment this line (and comment out next line) if you want to disable multithreading
num_workers=multiprocessing.cpu_count(),
worker_init_fn=seed_worker,
generator=generator,
)
embeddings = []
k = 0
with torch.no_grad():
for i, (inputs, labels, masks) in enumerate(data_loader):
k += 1
inputs = inputs.to(device)
masks = masks.to(device)
outputs = model(inputs, masks)
for n in range(outputs.shape[0]):
embeddings.append(outputs[n].cpu().numpy())
df = pd.read_csv(args.dataset)
df["embeddings"] = embeddings
df = df[['embeddings', 'label_name', 'labels']]
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
if args.format == 'json':
df.to_json(args.output, orient='records')
elif args.format == 'csv':
df.to_csv(args.output, index=False)