Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model'

While loading T5 model in GPU.
This commit is contained in:
DukeG
2024-08-14 19:58:54 +08:00
parent 56d7651f08
commit 9760d097b0

View File

@@ -540,9 +540,13 @@ class NetworkTrainer:
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
if hasattr(t_enc.text_model, "embeddings"):
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
t_enc.text_model.embeddings.to(
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
t_enc.encoder.embeddings.to(
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed: