Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-09-29 10:00:01 +09:00
9 changed files with 62 additions and 10 deletions

View File

@@ -710,6 +710,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- bitsandbytes, transformers, accelerate and huggingface_hub are updated.
- If you encounter any issues, please report them.
- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632)
- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!
- There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes).

View File

@@ -128,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d
* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
* `max_bucket_reso`, `min_bucket_reso`
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.

View File

@@ -118,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
* `batch_size`
* コマンドライン引数の `--train_batch_size` と同等です。
* `max_bucket_reso`, `min_bucket_reso`
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
これらの設定はデータセットごとに固定です。
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。

View File

@@ -100,6 +100,8 @@ def train(args):
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return

View File

@@ -661,6 +661,34 @@ class BaseDataset(torch.utils.data.Dataset):
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
def adjust_min_max_bucket_reso_by_steps(
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
) -> Tuple[int, int]:
# make min/max bucket reso to be multiple of bucket_reso_steps
if min_bucket_reso % bucket_reso_steps != 0:
adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps
logger.warning(
f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps"
f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}"
)
min_bucket_reso = adjusted_min_bucket_reso
if max_bucket_reso % bucket_reso_steps != 0:
adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps
logger.warning(
f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps"
f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}"
)
max_bucket_reso = adjusted_max_bucket_reso
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
return min_bucket_reso, max_bucket_reso
def set_seed(self, seed):
self.seed = seed
@@ -1707,12 +1735,9 @@ class DreamBoothDataset(BaseDataset):
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
@@ -2085,6 +2110,9 @@ class FineTuningDataset(BaseDataset):
self.enable_bucket = enable_bucket
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
@@ -4149,8 +4177,20 @@ def add_dataset_arguments(
action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする",
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
parser.add_argument(
"--min_bucket_reso",
type=int,
default=256,
help="minimum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--max_bucket_reso",
type=int,
default=1024,
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--bucket_reso_steps",
type=int,

View File

@@ -107,6 +107,8 @@ def train(args):
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return

View File

@@ -101,6 +101,8 @@ def train(args):
if args.no_token_padding:
train_dataset_group.disable_token_padding()
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return

View File

@@ -96,7 +96,7 @@ class NetworkTrainer:
return logs
def assert_extra_args(self, args, train_dataset_group):
pass
train_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

View File

@@ -100,7 +100,7 @@ class TextualInversionTrainer:
self.is_sdxl = False
def assert_extra_args(self, args, train_dataset_group):
pass
train_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)