Files
spoterembedding/training/batching_scheduler.py
Mathias Claassen 81bbf66aab Initial codebase (#1)
* Add project code

* Logger improvements

* Improvements to web demo code

* added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py

* Fix rotation augmentation

* fixed error in docstring, and removed unnecessary replace -1 -> 0

* Readme updates

* Share base notebooks

* Add notebooks and unify for different datasets

* requirements update

* fixes

* Make evaluate more deterministic

* Allow training with clearml

* refactor preprocessing and apply linter

* Minor fixes

* Minor notebook tweaks

* Readme updates

* Fix PR comments

* Remove unneeded code

* Add banner to Readme

---------

Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
2023-03-03 10:07:54 -03:00

63 lines
2.4 KiB
Python

from collections import deque
import numpy as np
class BatchingScheduler():
""" This class acts as scheduler for the batching algorithm
"""
def __init__(self, decay_factor=0.8, min_threshold=0.2, triplets_threshold=10, cooldown=10) -> None:
# internal vars
self._step_count = 0
self._dist_threshold = 0.5
self._last_used_triplets = deque([], 5)
self._scaling_same_label_factor = 1
self._last_update_step = -10
# Parameters
self.decay_factor = decay_factor
self.min_threshold = min_threshold
self.triplets_threshold = triplets_threshold
self.cooldown = cooldown
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
"""
return {key: value for key, value in self.__dict__.items()}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def step(self, used_triplets):
self._step_count += 1
self._last_used_triplets.append(used_triplets)
if (np.mean(self._last_used_triplets) < self.triplets_threshold and
self._last_update_step + self.cooldown <= self._step_count):
if self._dist_threshold > self.min_threshold:
print(f"Updating dist_threshold at {self._step_count} ({np.mean(self._last_used_triplets)})")
self.update_dist_threshold()
if self._scaling_same_label_factor > 0.6:
print(f"Updating scale factor at {self._step_count} ({np.mean(self._last_used_triplets)})")
self.update_scale_factor()
self._last_update_step = self._step_count
def update_scale_factor(self):
self._scaling_same_label_factor = max(self._scaling_same_label_factor * 0.9, 0.6)
print(f"Updating scaling factor to {self._scaling_same_label_factor}")
def update_dist_threshold(self):
self._dist_threshold = max(self.min_threshold, self._dist_threshold * self.decay_factor)
print(f"Updated dist_threshold to {self._dist_threshold}")
def get_dist_threshold(self) -> float:
return self._dist_threshold
def get_scaling_same_label_factor(self) -> float:
return self._scaling_same_label_factor