Add type checking

This commit is contained in:
Isotr0py
2023-02-12 15:32:38 +08:00
parent 92a1af8024
commit 2b1a3080e7

View File

@@ -267,18 +267,19 @@ def train(args):
text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()
# support DistributedDataParallel
try:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
except:
pass
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet)