simplify and update alpha mask to work with various cases

This commit is contained in:
Kohya S
2024-05-19 21:26:18 +09:00
parent f2dd43e198
commit da6fea3d97
10 changed files with 140 additions and 105 deletions

View File

@@ -11,6 +11,7 @@ import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
@@ -18,8 +19,10 @@ from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = get_preferred_device()
@@ -89,7 +92,9 @@ def main(args):
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
assert (
len(max_reso) == 2
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
@@ -107,7 +112,7 @@ def main(args):
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
@@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument(
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
@@ -231,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument(
"--full_path",
@@ -242,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser:
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
"--flip_aug",
action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
)
parser.add_argument(
"--alpha_mask",
type=str,
default="",
help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する",
)
parser.add_argument(
"--skip_existing",

View File

@@ -214,11 +214,13 @@ class ConfigSanitizer:
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
"is_reg": bool,
"alpha_mask": bool,
}
# FT means FineTuning
FT_SUBSET_DISTINCT_SCHEMA = {
Required("metadata_file"): str,
"image_dir": str,
"alpha_mask": bool,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,

View File

@@ -479,14 +479,19 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, mask_image):
# mask image is -1 to 1. we need to convert it to 0 to 1
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image.to(dtype=loss.dtype)
def apply_masked_loss(loss, batch):
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype)
else:
return loss
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss

View File

@@ -159,9 +159,7 @@ class ImageInfo:
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None
self.alpha_mask_flipped: Optional[torch.Tensor] = None
self.use_alpha_mask: bool = False
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
class BucketManager:
@@ -364,6 +362,7 @@ class BaseSubset:
def __init__(
self,
image_dir: Optional[str],
alpha_mask: Optional[bool],
num_repeats: int,
shuffle_caption: bool,
caption_separator: str,
@@ -382,9 +381,9 @@ class BaseSubset:
caption_suffix: Optional[str],
token_warmup_min: int,
token_warmup_step: Union[float, int],
alpha_mask: bool,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
self.num_repeats = num_repeats
self.shuffle_caption = shuffle_caption
self.caption_separator = caption_separator
@@ -407,8 +406,6 @@ class BaseSubset:
self.img_count = 0
self.alpha_mask = alpha_mask
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -418,6 +415,7 @@ class DreamBoothSubset(BaseSubset):
class_tokens: Optional[str],
caption_extension: str,
cache_info: bool,
alpha_mask: bool,
num_repeats,
shuffle_caption,
caption_separator: str,
@@ -441,6 +439,7 @@ class DreamBoothSubset(BaseSubset):
super().__init__(
image_dir,
alpha_mask,
num_repeats,
shuffle_caption,
caption_separator,
@@ -479,6 +478,7 @@ class FineTuningSubset(BaseSubset):
self,
image_dir,
metadata_file: str,
alpha_mask: bool,
num_repeats,
shuffle_caption,
caption_separator,
@@ -502,6 +502,7 @@ class FineTuningSubset(BaseSubset):
super().__init__(
image_dir,
alpha_mask,
num_repeats,
shuffle_caption,
caption_separator,
@@ -921,7 +922,7 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List(BucketBatchIndex) = []
self.buckets_indices: List[BucketBatchIndex] = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
@@ -991,8 +992,6 @@ class BaseDataset(torch.utils.data.Dataset):
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
info.use_alpha_mask = subset.alpha_mask
if info.latents_npz is not None: # fine tuning dataset
continue
@@ -1002,7 +1001,9 @@ class BaseDataset(torch.utils.data.Dataset):
if not is_main_process: # store to info only
continue
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
cache_available = is_disk_cached_latents_is_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
@@ -1028,7 +1029,7 @@ class BaseDataset(torch.utils.data.Dataset):
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
@@ -1202,18 +1203,15 @@ class BaseDataset(torch.utils.data.Dataset):
alpha_mask = image_info.alpha_mask
else:
latents = image_info.latents_flipped
alpha_mask = image_info.alpha_mask_flipped
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(
image_info.latents_npz
)
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz)
if flipped:
latents = flipped_latents
alpha_mask = flipped_alpha_mask
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
del flipped_latents
del flipped_alpha_mask
latents = torch.FloatTensor(latents)
if alpha_mask is not None:
alpha_mask = torch.FloatTensor(alpha_mask)
@@ -1255,23 +1253,28 @@ class BaseDataset(torch.utils.data.Dataset):
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug)
if aug is not None:
img = aug(image=img)["image"]
# augment RGB channels only
img_rgb = img[:, :, :3]
img_rgb = aug(image=img_rgb)["image"]
img[:, :, :3] = img_rgb
if flipped:
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
if subset.alpha_mask:
if img.shape[2] == 4:
alpha_mask = img[:, :, 3] # [W,H]
alpha_mask = img[:, :, 3] # [H,W]
alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1
else:
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
alpha_mask = transforms.ToTensor()(alpha_mask)
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
else:
alpha_mask = None
img = img[:, :, :3] # remove alpha channel
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
del img
images.append(image)
latents_list.append(latents)
@@ -1361,6 +1364,23 @@ class BaseDataset(torch.utils.data.Dataset):
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
# if one of alpha_masks is not None, we need to replace None with ones
none_or_not = [x is None for x in alpha_mask_list]
if all(none_or_not):
example["alpha_masks"] = None
elif any(none_or_not):
for i in range(len(alpha_mask_list)):
if alpha_mask_list[i] is None:
if images[i] is not None:
alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32)
else:
alpha_mask_list[i] = torch.ones(
(latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32
)
example["alpha_masks"] = torch.stack(alpha_mask_list)
else:
example["alpha_masks"] = torch.stack(alpha_mask_list)
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
@@ -1378,8 +1398,6 @@ class BaseDataset(torch.utils.data.Dataset):
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -1393,6 +1411,7 @@ class BaseDataset(torch.utils.data.Dataset):
resized_sizes = []
bucket_reso = None
flip_aug = None
alpha_mask = None
random_crop = None
for image_key in bucket[image_index : image_index + bucket_batch_size]:
@@ -1401,10 +1420,13 @@ class BaseDataset(torch.utils.data.Dataset):
if flip_aug is None:
flip_aug = subset.flip_aug
alpha_mask = subset.alpha_mask
random_crop = subset.random_crop
bucket_reso = image_info.bucket_reso
else:
# TODO そもそも混在してても動くようにしたほうがいい
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch"
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
@@ -1441,6 +1463,7 @@ class BaseDataset(torch.utils.data.Dataset):
example["absolute_paths"] = absolute_paths
example["resized_sizes"] = resized_sizes
example["flip_aug"] = flip_aug
example["alpha_mask"] = alpha_mask
example["random_crop"] = random_crop
example["bucket_reso"] = bucket_reso
return example
@@ -2149,7 +2172,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.disable_token_padding()
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
@@ -2167,6 +2190,12 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != reso: # HxW
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -2177,14 +2206,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[
Optional[torch.Tensor],
Optional[List[int]],
Optional[List[int]],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
@@ -2194,20 +2216,15 @@ def load_latents_from_disk(
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None
):
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
if flipped_alpha_mask is not None:
kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy()
kwargs["alpha_mask"] = alpha_mask # ndarray
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
@@ -2398,10 +2415,11 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
def load_image(image_path, alpha=False):
image = Image.open(image_path)
if not image.mode == "RGB":
if alpha:
if alpha:
if not image.mode == "RGBA":
image = image.convert("RGBA")
else:
else:
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
@@ -2441,7 +2459,7 @@ def trim_and_resize_if_required(
def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
@@ -2453,49 +2471,43 @@ def cache_batch_latents(
latents_original_size and latents_crop_ltrb are also set
"""
images = []
alpha_masks = []
alpha_masks: List[np.ndarray] = []
for info in image_infos:
image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
if info.use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [W,H]
image = image[:, :, :3]
else:
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
alpha_masks.append(transforms.ToTensor()(alpha_mask))
image = IMAGE_TRANSFORMS(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
if use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
else:
alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32)
else:
alpha_mask = None
alpha_masks.append(alpha_mask)
image = image[:, :, :3] # remove alpha channel if exists
image = IMAGE_TRANSFORMS(image)
images.append(image)
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
if info.use_alpha_mask:
alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu")
else:
alpha_masks = [None] * len(image_infos)
flipped_alpha_masks = [None] * len(image_infos)
if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
with torch.no_grad():
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
if info.use_alpha_mask:
flipped_alpha_masks = torch.flip(alpha_masks, dims=[3])
else:
flipped_latents = [None] * len(latents)
flipped_alpha_masks = [None] * len(image_infos)
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(
image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks
):
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
@@ -2508,15 +2520,12 @@ def cache_batch_latents(
info.latents_crop_ltrb,
flipped_latent,
alpha_mask,
flipped_alpha_mask,
)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
info.alpha_mask_flipped = flipped_alpha_mask
if not HIGH_VRAM:
clean_memory_on_device(vae.device)

View File

@@ -711,10 +711,8 @@ def train(args):
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:

View File

@@ -17,10 +17,13 @@ from library.config_util import (
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
@@ -107,7 +110,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
else:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
@@ -136,6 +139,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
alpha_mask = batch["alpha_mask"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
@@ -154,14 +158,16 @@ def cache_to_disk(args: argparse.Namespace) -> None:
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
if train_util.is_disk_cached_latents_is_expected(
image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask
):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")

View File

@@ -359,10 +359,8 @@ def train(args):
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight

View File

@@ -774,7 +774,9 @@ class NetworkTrainer:
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
@@ -902,10 +904,8 @@ class NetworkTrainer:
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight

View File

@@ -589,10 +589,8 @@ class TextualInversionTrainer:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight

View File

@@ -474,10 +474,8 @@ def train(args):
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight