diff --git a/lumina_train.py b/lumina_train.py index 0a91f4a0..a333427d 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -294,7 +294,7 @@ def train(args): # load lumina nextdit = lumina_util.load_lumina_model( args.pretrained_model_name_or_path, - loading_dtype, + weight_dtype, torch.device("cpu"), disable_mmap=args.disable_mmap_load_safetensors, use_flash_attn=args.use_flash_attn, @@ -494,6 +494,8 @@ def train(args): clean_memory_on_device(accelerator.device) + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 @@ -739,7 +741,7 @@ def train(args): with accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = nextdit( - x=img, # image latents (B, C, H, W) + x=noisy_model_input, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features cap_mask=gemma2_attn_mask.to( @@ -751,8 +753,8 @@ def train(args): args, model_pred, noisy_model_input, sigmas ) - # flow matching loss: this is different from SD3 - target = noise - latents + # flow matching loss + target = latents - noise # calculate loss huber_c = train_util.get_huber_threshold_if_needed(