Fix block swap for sample images

This commit is contained in:
rockerBOO
2025-02-28 14:08:27 -05:00
parent 9647f1e324
commit d6f7e2e20c
3 changed files with 3 additions and 2 deletions

View File

@@ -317,7 +317,6 @@ def denoise(
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()

View File

@@ -604,7 +604,6 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def denoise(
scheduler,
model: lumina_models.NextDiT,
@@ -648,6 +647,7 @@ def denoise(
"""
for i, t in enumerate(tqdm(timesteps)):
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / scheduler.config.num_train_timesteps
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -700,6 +700,7 @@ def denoise(
noise_pred = -noise_pred
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
model.prepare_block_swap_before_forward()
return img

View File

@@ -367,6 +367,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)