35 lines
1.4 KiB
Python
35 lines
1.4 KiB
Python
from sklearn import svm
|
|
|
|
class SVM:
|
|
def __init__(self, type="ovo"):
|
|
self.label_name_to_label = None
|
|
self.clf = None
|
|
self.embeddings_list = None
|
|
self.labels = None
|
|
self.type = type
|
|
|
|
def set_embeddings(self, embeddings):
|
|
# convert embedding from string to list of floats
|
|
embeddings["embeddings"] = embeddings["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
|
|
# drop embeddings2
|
|
df = embeddings.drop(columns=['embeddings2'])
|
|
# to list
|
|
self.embeddings_list = df["embeddings"].tolist()
|
|
self.labels = df["labels"].tolist()
|
|
self.label_name_to_label = df[["label_name", "labels"]]
|
|
self.label_name_to_label.columns = ["label_name", "label"]
|
|
self.label_name_to_label = self.label_name_to_label.drop_duplicates()
|
|
|
|
self.train()
|
|
|
|
def train(self):
|
|
self.clf = svm.SVC(decision_function_shape=self.type, probability=True)
|
|
self.clf.fit(self.embeddings_list, self.labels)
|
|
|
|
def predict(self, key_points_embeddings):
|
|
label = self.clf.predict(key_points_embeddings)
|
|
score = self.clf.predict_log_proba(key_points_embeddings)
|
|
# TODO fix dictionary
|
|
label = label.item()
|
|
return self.label_name_to_label.loc[self.label_name_to_label["label"] == label]["label_name"].iloc[0], score[0][label]
|