mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
remove dependency for albumenations
This commit is contained in:
@@ -55,7 +55,6 @@ from diffusers import (
|
||||
from library import custom_train_functions
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
@@ -285,42 +284,40 @@ class BucketBatchIndex(NamedTuple):
|
||||
|
||||
|
||||
class AugHelper:
|
||||
# albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる
|
||||
|
||||
def __init__(self):
|
||||
# prepare all possible augmentators
|
||||
self.color_aug_method = albu.OneOf(
|
||||
[
|
||||
albu.HueSaturationValue(8, 0, 0, p=0.5),
|
||||
albu.RandomGamma((95, 105), p=0.5),
|
||||
],
|
||||
p=0.33,
|
||||
)
|
||||
pass
|
||||
|
||||
# key: (use_color_aug, use_flip_aug)
|
||||
# self.augmentors = {
|
||||
# (True, True): albu.Compose(
|
||||
# [
|
||||
# color_aug_method,
|
||||
# flip_aug_method,
|
||||
# ],
|
||||
# p=1.0,
|
||||
# ),
|
||||
# (True, False): albu.Compose(
|
||||
# [
|
||||
# color_aug_method,
|
||||
# ],
|
||||
# p=1.0,
|
||||
# ),
|
||||
# (False, True): albu.Compose(
|
||||
# [
|
||||
# flip_aug_method,
|
||||
# ],
|
||||
# p=1.0,
|
||||
# ),
|
||||
# (False, False): None,
|
||||
# }
|
||||
def color_aug(self, image: np.ndarray):
|
||||
# self.color_aug_method = albu.OneOf(
|
||||
# [
|
||||
# albu.HueSaturationValue(8, 0, 0, p=0.5),
|
||||
# albu.RandomGamma((95, 105), p=0.5),
|
||||
# ],
|
||||
# p=0.33,
|
||||
# )
|
||||
hue_shift_limit = 8
|
||||
|
||||
def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]:
|
||||
return self.color_aug_method if use_color_aug else None
|
||||
# remove dependency to albumentations
|
||||
if random.random() <= 0.33:
|
||||
if random.random() > 0.5:
|
||||
# hue shift
|
||||
hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit)
|
||||
if hue_shift < 0:
|
||||
hue_shift = 180 + hue_shift
|
||||
hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
|
||||
image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
|
||||
else:
|
||||
# random gamma
|
||||
gamma = random.uniform(0.95, 1.05)
|
||||
image = np.clip(image**gamma, 0, 255).astype(np.uint8)
|
||||
|
||||
return {"image": image}
|
||||
|
||||
def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]:
|
||||
return self.color_aug if use_color_aug else None
|
||||
|
||||
|
||||
class BaseSubset:
|
||||
@@ -3443,7 +3440,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
|
||||
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ accelerate==0.19.0
|
||||
transformers==4.30.2
|
||||
diffusers[torch]==0.18.2
|
||||
ftfy==6.1.1
|
||||
albumentations==1.3.0
|
||||
# albumentations==1.3.0
|
||||
opencv-python==4.7.0.68
|
||||
einops==0.6.0
|
||||
pytorch-lightning==1.9.0
|
||||
|
||||
Reference in New Issue
Block a user