Merge pull request #755 from kohya-ss/dev

add lora_fa
This commit is contained in:
Kohya S
2023-08-13 15:20:49 +09:00
committed by GitHub
3 changed files with 1296 additions and 2 deletions

View File

@@ -22,7 +22,11 @@ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__
The feature of SDXL training is now available in sdxl branch as an experimental feature.
Aug 12, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version.
Aug 13, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version.
- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model.
Aug 12, 2023: Following are the changes from the previous version.
- The default value of noise offset when omitted has been changed to 0 from 0.0357.
- The different learning rates for each U-Net block are now supported. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.

1241
networks/lora_fa.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -371,6 +371,8 @@ class PipelineLike:
width: int = 1024,
original_height: int = None,
original_width: int = None,
original_height_negative: int = None,
original_width_negative: int = None,
crop_top: int = 0,
crop_left: int = 0,
num_inference_steps: int = 50,
@@ -505,15 +507,22 @@ class PipelineLike:
original_height = height
if original_width is None:
original_width = width
if original_height_negative is None:
original_height_negative = original_height
if original_width_negative is None:
original_width_negative = original_width
if crop_top is None:
crop_top = 0
if crop_left is None:
crop_left = 0
emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
uc_emb1 = sdxl_train_util.get_timestep_embedding(
torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256
)
emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype)
uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
c_vector = torch.cat([text_pool, c_vector], dim=1)
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
@@ -1260,6 +1269,8 @@ class BatchDataExt(NamedTuple):
height: int
original_width: int
original_height: int
original_width_negative: int
original_height_negative: int
crop_left: int
crop_top: int
steps: int
@@ -1820,6 +1831,8 @@ def main(args):
original_width_1st = scale_and_round(ext.original_width)
original_height_1st = scale_and_round(ext.original_height)
original_width_negative_1st = scale_and_round(ext.original_width_negative)
original_height_negative_1st = scale_and_round(ext.original_height_negative)
crop_left_1st = scale_and_round(ext.crop_left)
crop_top_1st = scale_and_round(ext.crop_top)
@@ -1830,6 +1843,8 @@ def main(args):
height_1st,
original_width_1st,
original_height_1st,
original_width_negative_1st,
original_height_negative_1st,
crop_left_1st,
crop_top_1st,
args.highres_fix_steps,
@@ -1897,6 +1912,8 @@ def main(args):
height,
original_width,
original_height,
original_width_negative,
original_height_negative,
crop_left,
crop_top,
steps,
@@ -2020,6 +2037,8 @@ def main(args):
width,
original_height,
original_width,
original_height_negative,
original_width_negative,
crop_top,
crop_left,
steps,
@@ -2060,6 +2079,8 @@ def main(args):
metadata.add_text("clip-prompt", clip_prompt)
metadata.add_text("original-height", str(original_height))
metadata.add_text("original-width", str(original_width))
metadata.add_text("original-height-negative", str(original_height_negative))
metadata.add_text("original-width-negative", str(original_width_negative))
metadata.add_text("crop-top", str(crop_top))
metadata.add_text("crop-left", str(crop_left))
@@ -2123,6 +2144,8 @@ def main(args):
height = args.H
original_width = args.original_width
original_height = args.original_height
original_width_negative = args.original_width_negative
original_height_negative = args.original_height_negative
crop_top = args.crop_top
crop_left = args.crop_left
scale = args.scale
@@ -2165,6 +2188,18 @@ def main(args):
print(f"original height: {original_height}")
continue
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
if m:
original_width_negative = int(m.group(1))
print(f"original width negative: {original_width_negative}")
continue
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
if m:
original_height_negative = int(m.group(1))
print(f"original height negative: {original_height_negative}")
continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m:
crop_top = int(m.group(1))
@@ -2301,6 +2336,8 @@ def main(args):
height,
original_width,
original_height,
original_width_negative,
original_height_negative,
crop_left,
crop_top,
steps,
@@ -2367,6 +2404,18 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値"
)
parser.add_argument(
"--original_height_negative",
type=int,
default=None,
help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値",
)
parser.add_argument(
"--original_width_negative",
type=int,
default=None,
help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値",
)
parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値")
parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")