diff --git a/fine_tune.py b/fine_tune.py index 8b06abda..fa3c81be 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -230,7 +230,8 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) diff --git a/train_db.py b/train_db.py index d1ef350c..8c9cdb95 100644 --- a/train_db.py +++ b/train_db.py @@ -155,7 +155,7 @@ def train(args): unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - + if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error @@ -227,7 +227,8 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)