From c8d209d36c71c28416cb3a5eec43a00b8e129f99 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 10:29:48 +0900 Subject: [PATCH 1/9] update diffusers to 1.16 | train_network --- library/attention_processors.py | 227 ++++++++++++++++++++++++++++++++ library/hypernetwork.py | 223 +++++++++++++++++++++++++++++++ library/lpw_stable_diffusion.py | 2 +- library/train_util.py | 205 +--------------------------- networks/lora.py | 2 +- requirements.txt | 6 +- train_network.py | 45 +++---- 7 files changed, 482 insertions(+), 228 deletions(-) create mode 100644 library/attention_processors.py create mode 100644 library/hypernetwork.py diff --git a/library/attention_processors.py b/library/attention_processors.py new file mode 100644 index 00000000..310c2cb1 --- /dev/null +++ b/library/attention_processors.py @@ -0,0 +1,227 @@ +import math +from typing import Any +from einops import rearrange +import torch +from diffusers.models.attention_processor import Attention + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + +EPSILON = 1e-6 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full( + (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device + ) + + scale = q.shape[-1] ** -0.5 + + if mask is None: + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if row_mask is not None: + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if row_mask is not None: + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( + min=EPSILON + ) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum( + "... i j, ... j d -> ... i d", exp_weights, vc + ) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = ( + exp_row_max_diff * row_sums + + exp_block_row_max_diff * block_row_sums + ) + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( + (exp_block_row_max_diff / new_row_sums) * exp_values + ) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if row_mask is not None: + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +class FlashAttnProcessor: + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ) -> Any: + q_bucket_size = 512 + k_bucket_size = 1024 + + h = attn.heads + q = attn.to_q(hidden_states) + + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) + + if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: + context_k, context_v = attn.hypernetwork.forward( + hidden_states, encoder_hidden_states + ) + context_k = context_k.to(hidden_states.dtype) + context_v = context_v.to(hidden_states.dtype) + else: + context_k = encoder_hidden_states + context_v = encoder_hidden_states + + k = attn.to_k(context_k) + v = attn.to_v(context_v) + del encoder_hidden_states, hidden_states + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = FlashAttentionFunction.apply( + q, k, v, attention_mask, False, q_bucket_size, k_bucket_size + ) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = attn.to_out[0](out) + out = attn.to_out[1](out) + return out diff --git a/library/hypernetwork.py b/library/hypernetwork.py new file mode 100644 index 00000000..fbd3fb24 --- /dev/null +++ b/library/hypernetwork.py @@ -0,0 +1,223 @@ +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor2_0, + SlicedAttnProcessor, + XFormersAttnProcessor +) + +try: + import xformers.ops +except: + xformers = None + + +loaded_networks = [] + + +def apply_single_hypernetwork( + hypernetwork, hidden_states, encoder_hidden_states +): + context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) + return context_k, context_v + + +def apply_hypernetworks(context_k, context_v, layer=None): + if len(loaded_networks) == 0: + return context_v, context_v + for hypernetwork in loaded_networks: + context_k, context_v = hypernetwork.forward(context_k, context_v) + + context_k = context_k.to(dtype=context_k.dtype) + context_v = context_v.to(dtype=context_k.dtype) + + return context_k, context_v + + + +def xformers_forward( + self: XFormersAttnProcessor, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor = None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def sliced_attn_forward( + self: SlicedAttnProcessor, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor = None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.device, + dtype=query.dtype, + ) + + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] if attention_mask is not None else None + ) + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +def v2_0_forward( + self: AttnProcessor2_0, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def replace_attentions_for_hypernetwork(): + import diffusers.models.attention_processor + + diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( + xformers_forward + ) + diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( + sliced_attn_forward + ) + diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3e04b887..84e1ab15 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -464,10 +464,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, - clip_skip: int, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, + clip_skip: int = 1, ): super().__init__( vae=vae, diff --git a/library/train_util.py b/library/train_util.py index 46c5c3b2..008ccd64 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -63,6 +63,8 @@ import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util +from library.attention_processors import FlashAttnProcessor +from library.hypernetwork import replace_attentions_for_hypernetwork # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -1630,209 +1632,14 @@ def get_git_revision_hash() -> str: return "(unknown)" -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.function.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = q.shape[-1] ** -0.5 - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) - dp = einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + replace_attentions_for_hypernetwork() # unet is not used currently, but it is here for future use if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() + unet.set_attn_processor(FlashAttnProcessor()) elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("CrossAttention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + unet.enable_xformers_memory_efficient_attention() def replace_unet_cross_attn_to_xformers(): @@ -3458,10 +3265,10 @@ def sample_images( unet=unet, tokenizer=tokenizer, scheduler=scheduler, - clip_skip=args.clip_skip, safety_checker=None, feature_extractor=None, requires_safety_checker=False, + clip_skip=args.clip_skip, ) pipeline.to(device) diff --git a/networks/lora.py b/networks/lora.py index 19fbbbdb..3a475e25 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -665,7 +665,7 @@ class LoRANetwork(torch.nn.Module): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" diff --git a/requirements.txt b/requirements.txt index 801cf321..96da36a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -accelerate==0.15.0 -transformers==4.26.0 +accelerate==0.19.0 +transformers==4.29.2 +diffusers[torch]==0.16.1 ftfy==6.1.1 albumentations==1.3.0 opencv-python==4.7.0.68 einops==0.6.0 -diffusers[torch]==0.10.2 pytorch-lightning==1.9.0 bitsandbytes==0.35.0 tensorboard==2.10.1 diff --git a/train_network.py b/train_network.py index cd90b0a2..109d1ff2 100644 --- a/train_network.py +++ b/train_network.py @@ -6,7 +6,6 @@ import os import random import time import json -import toml from multiprocessing import Value from tqdm import tqdm @@ -165,7 +164,7 @@ def train(args): import sys sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) + accelerator.print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) if args.base_weights is not None: @@ -176,14 +175,15 @@ def train(args): else: multiplier = args.base_weights_multiplier[i] - print(f"merging module: {weight_path} with multiplier {multiplier}") + accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") module, weights_sd = network_module.create_network_from_weights( multiplier, weight_path, vae, text_encoder, unet, for_inference=True ) module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - print(f"all weights merged: {', '.join(args.base_weights)}") + accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -225,7 +225,7 @@ def train(args): if args.network_weights is not None: info = network.load_weights(args.network_weights) - print(f"loaded network weights from {args.network_weights}: {info}") + accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -233,13 +233,13 @@ def train(args): network.enable_gradient_checkpointing() # may have no effect # 学習に必要なクラスを準備する - print("preparing optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") # 後方互換性を確保するよ try: trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) except TypeError: - print( + accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) @@ -264,8 +264,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - if is_main_process: - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -278,7 +277,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enabling full fp16 training.") + accelerator.print("enable full fp16 training.") network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい @@ -338,16 +337,15 @@ def train(args): # TODO: find a way to handle total batch size when there are multiple datasets total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - if is_main_process: - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") # TODO refactor metadata creation and move to util metadata = { @@ -572,7 +570,7 @@ def train(args): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_steps"] = str(steps) metadata["ss_epoch"] = str(epoch_no) @@ -584,13 +582,12 @@ def train(args): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - if is_main_process: - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) From 1f1cae6c5a1c55af9578ddfeda7dcfb0d1e1c2f3 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 10:32:34 +0900 Subject: [PATCH 2/9] make the device of `snr_weight` the same as loss --- library/custom_train_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index f32f050e..fa24f9fa 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -14,7 +14,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): all_snr = (alpha / sigma) ** 2 snr = torch.stack([all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper loss = loss * snr_weight return loss From 23c4e5cb016a37840919573774838809985ff25a Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 10:37:23 +0900 Subject: [PATCH 3/9] update diffusers to 1.16 | train_textual_inversion --- train_textual_inversion.py | 59 ++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b73027de..3d028442 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,15 +1,12 @@ -import importlib import argparse import gc import math import os -import toml from multiprocessing import Value from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util @@ -104,7 +101,7 @@ def train(args): if args.init_word is not None: init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( + accelerator.print( f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" ) else: @@ -118,7 +115,7 @@ def train(args): ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") + accelerator.print(f"tokens are added: {token_ids}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" @@ -130,7 +127,7 @@ def train(args): if init_token_ids is not None: for i, token_id in enumerate(token_ids): token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights if args.weights is not None: @@ -138,22 +135,22 @@ def train(args): assert len(token_ids) == len( embeddings ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) + # accelerator.print(token_ids, embeddings.size()) for token_id, embedding in zip(token_ids, embeddings): token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + accelerator.print(f"weighs loaded") - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + accelerator.print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + accelerator.print( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -161,14 +158,14 @@ def train(args): else: use_dreambooth_method = args.in_json is None if use_dreambooth_method: - print("Use DreamBooth method.") + accelerator.print("Use DreamBooth method.") user_config = { "datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ] } else: - print("Train with captions.") + accelerator.print("Train with captions.") user_config = { "datasets": [ { @@ -192,7 +189,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print("use template for training captions. is object: {args.use_object_template}") + accelerator.print("use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] @@ -216,7 +213,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, show_input_ids=True) return if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return if cache_latents: @@ -246,7 +243,7 @@ def train(args): text_encoder.gradient_checkpointing_enable() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -267,7 +264,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -284,7 +281,7 @@ def train(args): text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) + # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -322,15 +319,15 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -347,7 +344,7 @@ def train(args): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") save_weights(ckpt_file, embs, save_dtype) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -355,12 +352,12 @@ def train(args): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 text_encoder.train() From e743ee5d5cf7fa7e8b97ae9c22ffb04813140e70 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 10:40:58 +0900 Subject: [PATCH 4/9] update diffusers to 1.16 | dylora --- networks/dylora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/dylora.py b/networks/dylora.py index 90b509df..e5a55d19 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -239,7 +239,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class DyLoRANetwork(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" From 1214f35985b233788d835b7380b4c315f18d59bb Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 20:15:06 +0900 Subject: [PATCH 5/9] update diffusers to 1.16 | train_db --- train_db.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/train_db.py b/train_db.py index 7ec06354..2fd1c7c5 100644 --- a/train_db.py +++ b/train_db.py @@ -2,18 +2,15 @@ # XXX dropped option: fine_tune import gc -import time import argparse import itertools import math import os -import toml from multiprocessing import Value from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util @@ -138,7 +135,7 @@ def train(args): unet.requires_grad_(True) # 念のため追加 text_encoder.requires_grad_(train_text_encoder) if not train_text_encoder: - print("Text Encoder is not trained.") + accelerator.print("Text Encoder is not trained.") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -150,7 +147,7 @@ def train(args): vae.to(accelerator.device, dtype=weight_dtype) # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") if train_text_encoder: trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) else: @@ -175,7 +172,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -191,7 +188,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") + accelerator.print("enable full fp16 training.") unet.to(weight_dtype) text_encoder.to(weight_dtype) @@ -224,15 +221,15 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -247,7 +244,7 @@ def train(args): loss_list = [] loss_total = 0.0 for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 @@ -260,7 +257,7 @@ def train(args): current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: - print(f"stop text encoder training at step {global_step}") + accelerator.print(f"stop text encoder training at step {global_step}") if not args.gradient_checkpointing: text_encoder.train(False) text_encoder.requires_grad_(False) From 4f8ce004772cf11053da921527637512ca47a21c Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 20:47:54 +0900 Subject: [PATCH 6/9] update diffusers to 1.16 | finetune --- fine_tune.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 154d3be7..61c5d28e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,13 +5,11 @@ import argparse import gc import math import os -import toml from multiprocessing import Value from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util @@ -128,11 +126,11 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む if args.diffusers_xformers: - print("Use xformers by Diffusers") + accelerator.print("Use xformers by Diffusers") set_diffusers_xformers_flag(unet, True) else: # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある - print("Disable Diffusers' xformers") + accelerator.print("Disable Diffusers' xformers") set_diffusers_xformers_flag(unet, False) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -157,7 +155,7 @@ def train(args): training_models.append(unet) if args.train_text_encoder: - print("enable text encoder training") + accelerator.print("enable text encoder training") if args.gradient_checkpointing: text_encoder.gradient_checkpointing_enable() training_models.append(text_encoder) @@ -183,7 +181,7 @@ def train(args): params_to_optimize = params # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する @@ -203,7 +201,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -216,7 +214,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") + accelerator.print("enable full fp16 training.") unet.to(weight_dtype) text_encoder.to(weight_dtype) @@ -246,14 +244,14 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -266,7 +264,7 @@ def train(args): accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 for m in training_models: From 62d00b4520aaae6076474389c9f61db2c982c9e2 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Wed, 31 May 2023 14:13:15 +0900 Subject: [PATCH 7/9] add controlnet training --- library/config_util.py | 97 ++++++- library/model_util.py | 76 ++++++ library/train_util.py | 318 ++++++++++++++++++++++ train_controlnet.py | 594 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1075 insertions(+), 10 deletions(-) create mode 100644 train_controlnet.py diff --git a/library/config_util.py b/library/config_util.py index 98b41751..ae17655c 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -33,8 +33,10 @@ from . import train_util from .train_util import ( DreamBoothSubset, FineTuningSubset, + ControlNetSubset, DreamBoothDataset, FineTuningDataset, + ControlNetDataset, DatasetGroup, ) @@ -70,6 +72,11 @@ class DreamBoothSubsetParams(BaseSubsetParams): class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None +@dataclass +class ControlNetSubsetParams(BaseSubsetParams): + conditioning_data_dir: str = None + caption_extension: str = ".caption" + @dataclass class BaseDatasetParams: tokenizer: CLIPTokenizer = None @@ -96,6 +103,15 @@ class FineTuningDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False +@dataclass +class ControlNetDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class SubsetBlueprint: params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] @@ -103,6 +119,7 @@ class SubsetBlueprint: @dataclass class DatasetBlueprint: is_dreambooth: bool + is_controlnet: bool params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] subsets: Sequence[SubsetBlueprint] @@ -163,6 +180,13 @@ class ConfigSanitizer: Required("metadata_file"): str, "image_dir": str, } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } # datasets schema DATASET_ASCENDABLE_SCHEMA = { @@ -192,8 +216,8 @@ class ConfigSanitizer: "dataset_repeats": "num_repeats", } - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None: - assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -208,6 +232,13 @@ class ConfigSanitizer: self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + self.db_dataset_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, @@ -223,13 +254,23 @@ class ConfigSanitizer: {"subsets": [self.ft_subset_schema]}, ) - if support_dreambooth and support_finetuning: + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) + + if support_dreambooth and support_finetuning and support_controlnet: def validate_flex_dataset(dataset_config: dict): subsets_config = dataset_config.get("subsets", []) + if all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) # check dataset meets FT style # NOTE: all FT subsets should have "metadata_file" - if all(["metadata_file" in subset for subset in subsets_config]): + elif all(["metadata_file" in subset for subset in subsets_config]): return Schema(self.ft_dataset_schema)(dataset_config) # check dataset meets DB style # NOTE: all DB subsets should have no "metadata_file" @@ -241,13 +282,16 @@ class ConfigSanitizer: self.dataset_schema = validate_flex_dataset elif support_dreambooth: self.dataset_schema = self.db_dataset_schema - else: + elif support_finetuning: self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema self.general_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) @@ -318,7 +362,11 @@ class BlueprintGenerator: # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets subsets = dataset_config.get("subsets", []) is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - if is_dreambooth: + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: subset_params_klass = DreamBoothSubsetParams dataset_params_klass = DreamBoothDatasetParams else: @@ -333,7 +381,7 @@ class BlueprintGenerator: params = self.generate_params_by_fallbacks(dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints)) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) @@ -361,10 +409,13 @@ class BlueprintGenerator: def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = [] + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_dreambooth: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset else: @@ -379,6 +430,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu info = "" for i, dataset in enumerate(datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) info += dedent(f"""\ [Dataset {i}] batch_size: {dataset.batch_size} @@ -421,7 +473,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} \n"""), " ") - else: + elif not is_controlnet: info += indent(dedent(f"""\ metadata_file: {subset.metadata_file} \n"""), " ") @@ -479,6 +531,31 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] return subsets_config +def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + subset_config = {"image_dir": str(subdir), "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1} + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir, False) + + return subsets_config + + def load_user_config(file: str) -> dict: file: Path = Path(file) if not file.is_file(): diff --git a/library/model_util.py b/library/model_util.py index 26f72235..bb168653 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -732,6 +732,82 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict): return new_state_dict +def convert_controlnet_state_dict_to_sd(controlnet_state_dict): + unet_conversion_map = [ + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("middle_block_out.0.weight", "controlnet_mid_block.weight"), + ("middle_block_out.0.bias", "controlnet_mid_block.bias"), + ] + + unet_conversion_map_resnet = [ + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + for j in range(2): + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + controlnet_cond_embedding_names = ( + ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"] + ) + for i, hf_prefix in enumerate(controlnet_cond_embedding_names): + hf_prefix = f"controlnet_cond_embedding.{hf_prefix}." + sd_prefix = f"input_hint_block.{i*2}." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + for i in range(12): + hf_prefix = f"controlnet_down_blocks.{i}." + sd_prefix = f"zero_convs.{i}.0." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[diffusers_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + # ================# # VAE Conversion # # ================# diff --git a/library/train_util.py b/library/train_util.py index 008ccd64..1921c2a4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -403,6 +403,54 @@ class FineTuningSubset(BaseSubset): return self.metadata_file == other.metadata_file +class ControlNetSubset(BaseSubset): + def __init__( + self, + image_dir: str, + conditioning_data_dir: str, + caption_extension: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) + + self.conditioning_data_dir = conditioning_data_dir + self.caption_extension = caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension + + def __eq__(self, other) -> bool: + if not isinstance(other, ControlNetSubset): + return NotImplemented + return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir + + class BaseDataset(torch.utils.data.Dataset): def __init__( self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool @@ -1387,6 +1435,274 @@ class FineTuningDataset(BaseDataset): return npz_file_norm, npz_file_flip +class ControlNetDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[ControlNetSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + self.conditioning_image_data: Dict[str, ImageInfo] = {} + + assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.latents_cache = None + + self.num_reg_images = 0 + + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None # この情報は使われない + self.bucket_no_upscale = False + + def read_caption(img_path, caption_extension): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption + + def load_controlnet_dir(subset: ControlNetSubset): + if not os.path.isdir(subset.image_dir): + print(f"not directory: {subset.image_dir}") + return [], [] + if not os.path.isdir(subset.conditioning_data_dir): + print(f"not directory: {subset.conditioning_data_dir}") + return [], [] + + img_paths = glob_images(subset.image_dir, "*") + conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") + img_paths = sorted(img_paths) + conditioning_img_paths = sorted(conditioning_img_paths) + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files") + + img_basenames = [os.path.basename(img) for img in img_paths] + conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths] + missing_imgs = [] + extra_imgs = [] + + for img in img_basenames: + if img not in conditioning_img_basenames: + missing_imgs.append(img) + for img in conditioning_img_basenames: + if img not in img_basenames: + extra_imgs.append(img) + + assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" + assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + + + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + missing_captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None: + print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}") + captions.append("") + missing_captions.append(img_path) + else: + captions.append(cap_for_img) + + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + + if missing_captions: + number_of_missing_captions = len(missing_captions) + number_of_missing_captions_to_show = 5 + remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show + + print( + f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" + ) + for i, missing_caption in enumerate(missing_captions): + if i >= number_of_missing_captions_to_show: + print(missing_caption + f"... and {remaining_missing_captions} more") + break + print(missing_caption) + return img_paths, conditioning_img_paths, captions + + print("prepare images.") + num_train_images = 0 + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + ) + continue + + if subset in self.subsets: + print( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" + ) + continue + + img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset) + if len(img_paths) < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + continue + + num_train_images += subset.num_repeats * len(img_paths) + + for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions): + info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path) + setattr(info, "cond_img_path", cond_img_path) + self.register_image(info, subset) + + subset.img_count = len(img_paths) + self.subsets.append(subset) + + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + self.conditioning_image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + + def __getitem__(self, index): + bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] + bucket_batch_size = self.buckets_indices[index].bucket_batch_size + image_index = self.buckets_indices[index].batch_index * bucket_batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + conditioning_images = [] + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + loss_weights.append(1.0) + + # image/latentsを処理する + if image_info.latents is not None: # cache_latents=Trueの場合 + latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 + latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img = self.load_image(image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + else: + im_h, im_w = img.shape[0:2] + assert ( + im_h == self.height and im_w == self.width + ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + if aug is not None: + img = aug(image=img)["image"] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(subset, image_info.caption) + if self.XTI_layers: + caption_layer = [] + for layer in self.XTI_layers: + token_strings_from = " ".join(self.token_strings) + token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + caption_ = caption.replace(token_strings_from, token_strings_to) + caption_layer.append(caption_) + captions.append(caption_layer) + else: + captions.append(caption) + if not self.token_padding_disabled: # this option might be omitted in future + if self.XTI_layers: + token_caption = self.get_input_ids(caption_layer) + else: + token_caption = self.get_input_ids(caption) + input_ids_list.append(token_caption) + + assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}" + + cond_img = self.load_image(image_info.cond_img_path) + if self.enable_bucket: + cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size) + cond_img = self.conditioning_image_transforms(cond_img) + conditioning_images.append(cond_img) + conditioning_images = torch.stack(conditioning_images) + + example = {} + example["loss_weights"] = torch.FloatTensor(loss_weights) + + if self.token_padding_disabled: + # padding=True means pad in the batch + example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example["input_ids"] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example["images"] = images + + example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + example["captions"] = captions + + if self.debug_dataset: + example["image_keys"] = bucket[image_index : image_index + self.batch_size] + + example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float() + + return example + # behave as Dataset mock class DatasetGroup(torch.utils.data.ConcatDataset): def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): @@ -1636,6 +1952,8 @@ def get_git_revision_hash() -> str: def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): replace_attentions_for_hypernetwork() # unet is not used currently, but it is here for future use + unet.enable_xformers_memory_efficient_attention() + return if mem_eff_attn: unet.set_attn_processor(FlashAttnProcessor()) elif xformers: diff --git a/train_controlnet.py b/train_controlnet.py new file mode 100644 index 00000000..7bcaf03a --- /dev/null +++ b/train_controlnet.py @@ -0,0 +1,594 @@ +import argparse +import gc +import math +import os +import random +import time +from multiprocessing import Value + +from tqdm import tqdm +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + pyramid_noise_like, + apply_noise_offset, +) +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + download_controlnet_from_original_ckpt, +) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] + * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + + return logs + + +def train(args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint( + blueprint.dataset_group + ) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = ( + train_dataset_group if args.max_data_loader_n_workers == 0 else None + ) + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model( + args, weight_dtype, accelerator + ) + if args.controlnet_model_name_or_path: + if os.path.isfile(args.controlnet_model_name_or_path): + controlnet = download_controlnet_from_original_ckpt( + args.controlnet_model_name_or_path + ) + else: + controlnet = ControlNetModel.from_pretrained( + args.controlnet_model_name_or_path + ) + else: + controlnet = ControlNetModel.from_unet(unet) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + + trainable_params = controlnet.parameters() + + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer( + args, trainable_params + ) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min( + args.max_data_loader_n_workers, os.cpu_count() - 1 + ) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) + / accelerator.num_processes + / args.gradient_accumulation_steps + ) + if is_main_process: + print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix( + args, optimizer, accelerator.num_processes + ) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + controlnet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.to(accelerator.device) + text_encoder.to(accelerator.device) + + # transform DDP after prepare + controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet + + controlnet.train() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = ( + math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + ) + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + + if is_main_process: + print("running training / 学習開始") + print( + f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}" + ) + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print( + f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" + ) + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + if accelerator.is_main_process: + accelerator.init_trackers( + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name + ) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"\nsaving checkpoint: {ckpt_file}") + + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload( + args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload + ) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + if is_main_process: + print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode( + batch["images"].to(dtype=weight_dtype) + ).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.gethidden_states( + args, input_ids, tokenizer, text_encoder, weight_dtype + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = apply_noise_offset( + latents, noise, args.noise_offset, args.adaptive_noise_scale + ) + elif args.multires_noise_iterations: + noise = pyramid_noise_like( + noise, + latents.device, + args.multires_noise_iterations, + args.multires_noise_discount, + ) + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (b_size,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) + for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to( + dtype=weight_dtype + ), + ).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss( + noise_pred.float(), target.float(), reduction="none" + ) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight( + loss, timesteps, noise_scheduler, args.min_snr_gamma + ) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + ) + + # 指定ステップごとにモデルを保存 + if ( + args.save_every_n_steps is not None + and global_step % args.save_every_n_steps == 0 + ): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name( + args, "." + args.save_model_as, global_step + ) + save_model( + ckpt_name, unwrap_model(controlnet), global_step, epoch + ) + + if args.save_state: + train_util.save_and_remove_state_stepwise( + args, accelerator, global_step + ) + + remove_step_no = train_util.get_remove_step_no( + args, global_step + ) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name( + args, "." + args.save_model_as, remove_step_no + ) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and ( + epoch + 1 + ) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name( + args, "." + args.save_model_as, epoch + 1 + ) + save_model(ckpt_name, unwrap_model(controlnet), global_step, epoch + 1) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name( + args, "." + args.save_model_as, remove_epoch_no + ) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end( + args, accelerator, epoch + 1 + ) + + train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + ) + + # end of epoch + if is_main_process: + controlnet = unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model( + ckpt_name, controlnet, global_step, num_train_epochs, force_sync_upload=True + ) + + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) From 3bd00b88c2c13489fac9cc4a5ebf1f394d0f5df9 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 09:47:37 +0900 Subject: [PATCH 8/9] support for controlnet in sample output --- library/lpw_stable_diffusion.py | 85 ++++++++++++++++++++++++++++++++- library/model_util.py | 28 +++++++++-- library/train_util.py | 27 ++++++++--- train_controlnet.py | 47 +++++++++++------- 4 files changed, 159 insertions(+), 28 deletions(-) diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 84e1ab15..88317e30 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -6,7 +6,7 @@ import re from typing import Callable, List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -426,6 +426,59 @@ def preprocess_mask(mask, scale_factor=8): return mask +def prepare_controlnet_image( + image: PIL.Image.Image, + width: int, + height: int, + batch_size: int, + num_images_per_prompt: int, + device: torch.device, + dtype: torch.dtype, + do_classifier_free_guidance: bool = False, + guess_mode: bool = False, +): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize( + (width, height), resample=PIL_INTERPOLATION["lanczos"] + ) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing @@ -707,6 +760,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, + controlnet=None, + controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, @@ -767,6 +822,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + controlnet (`diffusers.ControlNetModel`, *optional*): + A controlnet model to be used for the inference. If not provided, controlnet will be disabled. + controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet + inference. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. @@ -785,6 +845,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + if controlnet is not None and controlnet_image is None: + raise ValueError("controlnet_image must be provided if controlnet is not None.") + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -824,6 +887,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): else: mask = None + if controlnet_image is not None: + controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False) + + # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) @@ -851,8 +918,22 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + unet_additional_args = {} + if controlnet is not None: + down_block_res_samples, mid_block_res_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=controlnet_image, + conditioning_scale=1.0, + guess_mode=False, + return_dict=False, + ) + unet_additional_args['down_block_additional_residuals'] = down_block_res_samples + unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample + # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample # perform guidance if do_classifier_free_guidance: diff --git a/library/model_util.py b/library/model_util.py index bb168653..0764a881 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -731,8 +731,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict): return new_state_dict - -def convert_controlnet_state_dict_to_sd(controlnet_state_dict): +def controlnet_conversion_map(): unet_conversion_map = [ ("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.bias", "time_embedding.linear_1.bias"), @@ -792,6 +791,12 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict): sd_prefix = f"zero_convs.{i}.0." unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer + + +def convert_controlnet_state_dict_to_sd(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + mapping = {k: k for k in controlnet_state_dict.keys()} for sd_name, diffusers_name in unet_conversion_map: mapping[diffusers_name] = sd_name @@ -807,6 +812,23 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict): new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} return new_state_dict +def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[sd_name] = diffusers_name + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + for k, v in mapping.items(): + if "resnets" in v: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict # ================# # VAE Conversion # @@ -928,7 +950,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False): +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True): _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) # Convert the UNet2DConditionModel model. diff --git a/library/train_util.py b/library/train_util.py index 1921c2a4..81dffb1d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1674,7 +1674,6 @@ class ControlNetDataset(BaseDataset): cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size) cond_img = self.conditioning_image_transforms(cond_img) conditioning_images.append(cond_img) - conditioning_images = torch.stack(conditioning_images) example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) @@ -1699,7 +1698,7 @@ class ControlNetDataset(BaseDataset): if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] - example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float() + example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float() return example @@ -3138,13 +3137,13 @@ def prepare_dtype(args: argparse.Namespace): return weight_dtype, save_dtype -def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): +def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False): name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: print(f"load StableDiffusion checkpoint: {name_or_path}") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2) else: # Diffusers model is loaded to CPU print(f"load Diffusers pretrained models: {name_or_path}") @@ -3172,14 +3171,14 @@ def transform_if_model_is_DDP(text_encoder, unet, network=None): return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) -def load_target_model(args, weight_dtype, accelerator): +def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( - args, weight_dtype, accelerator.device if args.lowram else "cpu" + args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 ) # work on low-ram device @@ -3493,7 +3492,7 @@ SCHEDLER_SCHEDULE = "scaled_linear" def sample_images( - accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None + accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した @@ -3609,6 +3608,7 @@ def sample_images( height = prompt.get("height", 512) scale = prompt.get("scale", 7.5) seed = prompt.get("seed") + controlnet_image = prompt.get("controlnet_image") prompt = prompt.get("prompt") else: # prompt = prompt.strip() @@ -3623,6 +3623,7 @@ def sample_images( width = height = 512 scale = 7.5 seed = None + controlnet_image = None for parg in prompt_args: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) @@ -3655,6 +3656,12 @@ def sample_images( negative_prompt = m.group(1) continue + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: # negative prompt + controlnet_image = m.group(1) + continue + + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) @@ -3668,6 +3675,10 @@ def sample_images( if negative_prompt is not None: negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + height = max(64, height - height % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8 print(f"prompt: {prompt}") @@ -3683,6 +3694,8 @@ def sample_images( num_inference_steps=sample_steps, guidance_scale=scale, negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, ).images[0] ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) diff --git a/train_controlnet.py b/train_controlnet.py index 7bcaf03a..263e8813 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,5 +1,6 @@ import argparse import gc +import json import math import os import random @@ -11,6 +12,7 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file import library.model_util as model_util import library.train_util as train_util @@ -26,9 +28,6 @@ from library.custom_train_functions import ( pyramid_noise_like, apply_noise_offset, ) -from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - download_controlnet_from_original_ckpt, -) # TODO 他のスクリプトと共通化する @@ -124,19 +123,24 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model( - args, weight_dtype, accelerator + args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True ) + + controlnet = ControlNetModel.from_unet(unet) + if args.controlnet_model_name_or_path: - if os.path.isfile(args.controlnet_model_name_or_path): - controlnet = download_controlnet_from_original_ckpt( - args.controlnet_model_name_or_path - ) - else: - controlnet = ControlNetModel.from_pretrained( - args.controlnet_model_name_or_path - ) - else: - controlnet = ControlNetModel.from_unet(unet) + filename = args.controlnet_model_name_or_path + if os.path.isfile(filename): + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) + controlnet.load_state_dict(state_dict) + elif os.path.isdir(filename): + controlnet = ControlNetModel.from_pretrained(filename) + + # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -289,7 +293,9 @@ def train(args): ) if accelerator.is_main_process: accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name + "controlnet_train" + if args.log_tracker_name is None + else args.log_tracker_name ) loss_list = [] @@ -350,7 +356,7 @@ def train(args): b_size = latents.shape[0] input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.gethidden_states( + encoder_hidden_states = train_util.get_hidden_states( args, input_ids, tokenizer, text_encoder, weight_dtype ) @@ -450,6 +456,7 @@ def train(args): tokenizer, text_encoder, unet, + controlnet=controlnet, ) # 指定ステップごとにモデルを保存 @@ -537,6 +544,7 @@ def train(args): tokenizer, text_encoder, unet, + controlnet=controlnet, ) # end of epoch @@ -569,6 +577,13 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) parser.add_argument( "--controlnet_model_name_or_path", type=str, From 1e3daa247bbfe10957d54fb337e2ced8853a0d2d Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 21:58:45 +0900 Subject: [PATCH 9/9] fix bucketing --- library/train_util.py | 55 +++++++++++++++++---------------------- train_controlnet.py | 60 ++++++++++++++++++++----------------------- 2 files changed, 52 insertions(+), 63 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 81dffb1d..b5e6aa3f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -754,12 +754,14 @@ class BaseDataset(torch.utils.data.Dataset): img = np.array(image, np.uint8) return img - def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size, cond_img = None): image_height, image_width = image.shape[0:2] if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if exists(cond_img): + cond_img = cv2.resize(cond_img, resized_size, interpolation=cv2.INTER_AREA) image_height, image_width = image.shape[0:2] if image_width > reso[0]: @@ -767,15 +769,26 @@ class BaseDataset(torch.utils.data.Dataset): p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) # print("w", trim_size, p) image = image[:, p : p + reso[0]] + if exists(cond_img): + cond_img = cond_img[:, p : p + reso[0]] if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) # print("h", trim_size, p) image = image[p : p + reso[1]] + if exists(cond_img): + cond_img = cond_img[p : p + reso[1]] assert ( image.shape[0] == reso[1] and image.shape[1] == reso[0] ), f"internal error, illegal trimmed size: {image.shape}, {reso}" + + if exists(cond_img): + assert ( + cond_img.shape[0] == reso[1] and cond_img.shape[1] == reso[0] + ), f"internal error, illegal trimmed size: {cond_img.shape}, {reso}" + return image, cond_img + return image def is_latent_cacheable(self): @@ -1617,6 +1630,8 @@ class ControlNetDataset(BaseDataset): subset = self.image_to_subset[image_key] loss_weights.append(1.0) + assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}" + # image/latentsを処理する if image_info.latents is not None: # cache_latents=Trueの場合 latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped @@ -1628,10 +1643,11 @@ class ControlNetDataset(BaseDataset): else: # 画像を読み込み、必要ならcropする img = self.load_image(image_info.absolute_path) + cond_img = self.load_image(image_info.cond_img_path) im_h, im_w = img.shape[0:2] if self.enable_bucket: - img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + img, cond_img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size, cond_img=cond_img) else: im_h, im_w = img.shape[0:2] assert ( @@ -1649,41 +1665,18 @@ class ControlNetDataset(BaseDataset): images.append(image) latents_list.append(latents) - caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer) - else: - token_caption = self.get_input_ids(caption) - input_ids_list.append(token_caption) - - assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}" - - cond_img = self.load_image(image_info.cond_img_path) - if self.enable_bucket: - cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size) cond_img = self.conditioning_image_transforms(cond_img) conditioning_images.append(cond_img) + caption = self.process_caption(subset, image_info.caption) + captions.append(caption) + token_caption = self.get_input_ids(caption) + input_ids_list.append(token_caption) + example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids - else: - # batch processing seems to be good - example["input_ids"] = torch.stack(input_ids_list) + example["input_ids"] = torch.stack(input_ids_list) if images[0] is not None: images = torch.stack(images) diff --git a/train_controlnet.py b/train_controlnet.py index 263e8813..6e4e5bb8 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -141,7 +141,6 @@ def train(args): controlnet = ControlNetModel.from_pretrained(filename) - # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -168,11 +167,11 @@ def train(args): controlnet.enable_gradient_checkpointing() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") trainable_params = controlnet.parameters() - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer( + _, _, optimizer = train_util.get_optimizer( args, trainable_params ) @@ -198,10 +197,9 @@ def train(args): / accelerator.num_processes / args.gradient_accumulation_steps ) - if is_main_process: - print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -216,7 +214,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") + accelerator.print("enable full fp16 training.") controlnet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい @@ -258,23 +256,21 @@ def train(args): # 学習する # TODO: find a way to handle total batch size when there are multiple datasets - - if is_main_process: - print("running training / 学習開始") - print( - f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}" - ) - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print( - f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" - ) - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print( + f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}" + ) + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" + ) + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm( range(args.max_train_steps), @@ -303,11 +299,11 @@ def train(args): del train_dataset_group # function for saving/removing - def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False): + def save_model(ckpt_name, model, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) @@ -332,13 +328,13 @@ def train(args): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): if is_main_process: - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): @@ -470,7 +466,7 @@ def train(args): args, "." + args.save_model_as, global_step ) save_model( - ckpt_name, unwrap_model(controlnet), global_step, epoch + ckpt_name, unwrap_model(controlnet), ) if args.save_state: @@ -520,7 +516,7 @@ def train(args): ckpt_name = train_util.get_epoch_ckpt_name( args, "." + args.save_model_as, epoch + 1 ) - save_model(ckpt_name, unwrap_model(controlnet), global_step, epoch + 1) + save_model(ckpt_name, unwrap_model(controlnet)) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -561,7 +557,7 @@ def train(args): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model( - ckpt_name, controlnet, global_step, num_train_epochs, force_sync_upload=True + ckpt_name, controlnet, force_sync_upload=True ) print("model saved.")