mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Merge pull request #509 from kohya-ss/dev
.toml for sample generation etc.
This commit is contained in:
12
README.md
12
README.md
@@ -28,6 +28,8 @@ The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
|
||||
|
||||
Most of the documents are written in Japanese.
|
||||
|
||||
[English translation by darkstorm2150 is here](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation). Thanks to darkstorm2150!
|
||||
|
||||
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
|
||||
* [Chinese version](./docs/train_README-zh.md)
|
||||
* [Dataset config](./docs/config_README-ja.md)
|
||||
@@ -138,6 +140,16 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
### 15 May 2023, 2023/05/15
|
||||
|
||||
- Added [English translation of documents](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation) by darkstorm2150. Thank you very much!
|
||||
- The prompt for sample generation during training can now be specified in `.toml` or `.json`. [PR #504]((https://github.com/kohya-ss/sd-scripts/pull/504) Thanks to Linaqruf!
|
||||
- For details on prompt description, please see the PR.
|
||||
|
||||
- darkstorm2150氏に[ドキュメント類を英訳](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation)していただきました。ありがとうございます!
|
||||
- 学習中のサンプル生成のプロンプトを`.toml`または`.json`で指定可能になりました。 [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Linaqruf氏に感謝します。
|
||||
- プロンプト記述の詳細は当該PRをご覧ください。
|
||||
|
||||
### 11 May 2023, 2023/05/11
|
||||
|
||||
- Added an option `--dim_from_weights` to `train_network.py` to automatically determine the dim(rank) from the weight file. [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) Thanks to AI-Casanova!
|
||||
|
||||
@@ -19,6 +19,9 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
return loss
|
||||
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
||||
parser.add_argument(
|
||||
"--min_snr_gamma",
|
||||
@@ -347,7 +350,7 @@ def get_weighted_text_embeddings(
|
||||
|
||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||
for i in range(iterations):
|
||||
r = random.random() * 2 + 2 # Rather than always going 2x,
|
||||
@@ -369,7 +372,65 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
|
||||
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
||||
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
||||
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
||||
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
||||
|
||||
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
return noise
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
grid = (
|
||||
torch.stack(
|
||||
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
||||
dim=-1,
|
||||
)
|
||||
% 1
|
||||
)
|
||||
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
||||
|
||||
tile_grads = (
|
||||
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||
.repeat_interleave(d[0], 0)
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
dot = lambda grad, shift: (
|
||||
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
||||
t = fade(grid[: shape[0], : shape[1]])
|
||||
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||||
|
||||
|
||||
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
||||
noise = torch.zeros(shape, device=device)
|
||||
frequency = 1
|
||||
amplitude = 1
|
||||
for _ in range(octaves):
|
||||
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
||||
frequency *= 2
|
||||
amplitude *= persistence
|
||||
return noise
|
||||
|
||||
|
||||
def perlin_noise(noise, device, octaves):
|
||||
_, c, w, h = noise.shape
|
||||
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
||||
noise_perlin = []
|
||||
for _ in range(c):
|
||||
noise_perlin.append(perlin())
|
||||
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
||||
noise += noise_perlin # broadcast for each batch
|
||||
return noise / noise.std() # Scaled back to roughly unit variance
|
||||
"""
|
||||
|
||||
@@ -2127,6 +2127,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--perlin_noise",
|
||||
# type=int,
|
||||
# default=None,
|
||||
# help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--multires_noise_discount",
|
||||
type=float,
|
||||
@@ -2211,15 +2217,21 @@ def verify_training_args(args: argparse.Namespace):
|
||||
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
||||
)
|
||||
|
||||
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
|
||||
# Listを使って数えてもいいけど並べてしまえ
|
||||
if args.noise_offset is not None and args.multires_noise_iterations is not None:
|
||||
raise ValueError(
|
||||
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にすることはできません"
|
||||
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
|
||||
)
|
||||
# if args.noise_offset is not None and args.perlin_noise is not None:
|
||||
# raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
|
||||
# if args.perlin_noise is not None and args.multires_noise_iterations is not None:
|
||||
# raise ValueError(
|
||||
# "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません"
|
||||
# )
|
||||
|
||||
if args.adaptive_noise_scale is not None and args.noise_offset is None:
|
||||
raise ValueError(
|
||||
"adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です"
|
||||
)
|
||||
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
|
||||
|
||||
|
||||
def add_dataset_arguments(
|
||||
@@ -2918,11 +2930,11 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
if load_stable_diffusion_format:
|
||||
print("load StableDiffusion checkpoint")
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
print("load Diffusers pretrained models")
|
||||
print(f"load Diffusers pretrained models: {name_or_path}")
|
||||
try:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
||||
except EnvironmentError as ex:
|
||||
@@ -3291,8 +3303,21 @@ def sample_images(
|
||||
vae.to(device)
|
||||
|
||||
# read prompts
|
||||
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
prompts = f.readlines()
|
||||
|
||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
# prompts = f.readlines()
|
||||
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# schedulerを用意する
|
||||
sched_init_args = {}
|
||||
@@ -3362,53 +3387,63 @@ def sample_images(
|
||||
for i, prompt in enumerate(prompts):
|
||||
if not accelerator.is_main_process:
|
||||
continue
|
||||
prompt = prompt.strip()
|
||||
if len(prompt) == 0 or prompt[0] == "#":
|
||||
continue
|
||||
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = prompt.split(" --")
|
||||
prompt = prompt_args[0]
|
||||
negative_prompt = None
|
||||
sample_steps = 30
|
||||
width = height = 512
|
||||
scale = 7.5
|
||||
seed = None
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
continue
|
||||
if isinstance(prompt, dict):
|
||||
negative_prompt = prompt.get("negative_prompt")
|
||||
sample_steps = prompt.get("sample_steps", 30)
|
||||
width = prompt.get("width", 512)
|
||||
height = prompt.get("height", 512)
|
||||
scale = prompt.get("scale", 7.5)
|
||||
seed = prompt.get("seed")
|
||||
prompt = prompt.get("prompt")
|
||||
else:
|
||||
# prompt = prompt.strip()
|
||||
# if len(prompt) == 0 or prompt[0] == "#":
|
||||
# continue
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
continue
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = prompt.split(" --")
|
||||
prompt = prompt_args[0]
|
||||
negative_prompt = None
|
||||
sample_steps = 30
|
||||
width = height = 512
|
||||
scale = 7.5
|
||||
seed = None
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
seed = int(m.group(1))
|
||||
continue
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
seed = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
continue
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
continue
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
11
train_db.py
11
train_db.py
@@ -23,7 +23,14 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
)
|
||||
|
||||
# perlin_noise,
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -274,6 +281,8 @@ def train(args):
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||
# elif args.perlin_noise:
|
||||
# noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
|
||||
@@ -347,6 +347,7 @@ def train(args):
|
||||
"ss_noise_offset": args.noise_offset,
|
||||
"ss_multires_noise_iterations": args.multires_noise_iterations,
|
||||
"ss_multires_noise_discount": args.multires_noise_discount,
|
||||
"ss_adaptive_noise_scale": args.adaptive_noise_scale,
|
||||
"ss_training_comment": args.training_comment, # will not be updated after training
|
||||
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
||||
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
||||
|
||||
Reference in New Issue
Block a user