mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model'
While loading T5 model in GPU.
This commit is contained in:
@@ -540,9 +540,13 @@ class NetworkTrainer:
|
||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||
if t_enc.device.type != "cpu":
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
if hasattr(t_enc.text_model, "embeddings"):
|
||||
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
t_enc.text_model.embeddings.to(
|
||||
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
||||
t_enc.encoder.embeddings.to(
|
||||
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if args.deepspeed:
|
||||
|
||||
Reference in New Issue
Block a user