remove dependency for albumenations

This commit is contained in:
Kohya S
2023-07-30 16:29:53 +09:00
parent 496c3f2732
commit f61996b425
2 changed files with 33 additions and 36 deletions

View File

@@ -55,7 +55,6 @@ from diffusers import (
from library import custom_train_functions
from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import albumentations as albu
import numpy as np
from PIL import Image
import cv2
@@ -285,42 +284,40 @@ class BucketBatchIndex(NamedTuple):
class AugHelper:
# albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる
def __init__(self):
# prepare all possible augmentators
self.color_aug_method = albu.OneOf(
[
albu.HueSaturationValue(8, 0, 0, p=0.5),
albu.RandomGamma((95, 105), p=0.5),
],
p=0.33,
)
pass
# key: (use_color_aug, use_flip_aug)
# self.augmentors = {
# (True, True): albu.Compose(
# [
# color_aug_method,
# flip_aug_method,
# ],
# p=1.0,
# ),
# (True, False): albu.Compose(
# [
# color_aug_method,
# ],
# p=1.0,
# ),
# (False, True): albu.Compose(
# [
# flip_aug_method,
# ],
# p=1.0,
# ),
# (False, False): None,
# }
def color_aug(self, image: np.ndarray):
# self.color_aug_method = albu.OneOf(
# [
# albu.HueSaturationValue(8, 0, 0, p=0.5),
# albu.RandomGamma((95, 105), p=0.5),
# ],
# p=0.33,
# )
hue_shift_limit = 8
def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]:
return self.color_aug_method if use_color_aug else None
# remove dependency to albumentations
if random.random() <= 0.33:
if random.random() > 0.5:
# hue shift
hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit)
if hue_shift < 0:
hue_shift = 180 + hue_shift
hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
else:
# random gamma
gamma = random.uniform(0.95, 1.05)
image = np.clip(image**gamma, 0, 255).astype(np.uint8)
return {"image": image}
def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]:
return self.color_aug if use_color_aug else None
class BaseSubset:
@@ -3443,7 +3440,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))

View File

@@ -2,7 +2,7 @@ accelerate==0.19.0
transformers==4.30.2
diffusers[torch]==0.18.2
ftfy==6.1.1
albumentations==1.3.0
# albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
pytorch-lightning==1.9.0