diff --git a/library/train_util.py b/library/train_util.py index 021d2ccb..7f9fd75a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -55,7 +55,6 @@ from diffusers import ( from library import custom_train_functions from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download -import albumentations as albu import numpy as np from PIL import Image import cv2 @@ -285,42 +284,40 @@ class BucketBatchIndex(NamedTuple): class AugHelper: + # albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる + def __init__(self): - # prepare all possible augmentators - self.color_aug_method = albu.OneOf( - [ - albu.HueSaturationValue(8, 0, 0, p=0.5), - albu.RandomGamma((95, 105), p=0.5), - ], - p=0.33, - ) + pass - # key: (use_color_aug, use_flip_aug) - # self.augmentors = { - # (True, True): albu.Compose( - # [ - # color_aug_method, - # flip_aug_method, - # ], - # p=1.0, - # ), - # (True, False): albu.Compose( - # [ - # color_aug_method, - # ], - # p=1.0, - # ), - # (False, True): albu.Compose( - # [ - # flip_aug_method, - # ], - # p=1.0, - # ), - # (False, False): None, - # } + def color_aug(self, image: np.ndarray): + # self.color_aug_method = albu.OneOf( + # [ + # albu.HueSaturationValue(8, 0, 0, p=0.5), + # albu.RandomGamma((95, 105), p=0.5), + # ], + # p=0.33, + # ) + hue_shift_limit = 8 - def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]: - return self.color_aug_method if use_color_aug else None + # remove dependency to albumentations + if random.random() <= 0.33: + if random.random() > 0.5: + # hue shift + hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit) + if hue_shift < 0: + hue_shift = 180 + hue_shift + hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180 + image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) + else: + # random gamma + gamma = random.uniform(0.95, 1.05) + image = np.clip(image**gamma, 0, 255).astype(np.uint8) + + return {"image": image} + + def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]: + return self.color_aug if use_color_aug else None class BaseSubset: @@ -3443,7 +3440,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - + if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) diff --git a/requirements.txt b/requirements.txt index a1d3e37c..427621d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ accelerate==0.19.0 transformers==4.30.2 diffusers[torch]==0.18.2 ftfy==6.1.1 -albumentations==1.3.0 +# albumentations==1.3.0 opencv-python==4.7.0.68 einops==0.6.0 pytorch-lightning==1.9.0