From 07ef03d3403d56a05a87afdabbe6e8964cc05a9d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Dec 2023 08:03:27 +0900 Subject: [PATCH] fix controlnet to work with gradual latent --- tools/original_control_net.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tools/original_control_net.py b/tools/original_control_net.py index cd47bd76..5e1c074e 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -179,8 +179,21 @@ def call_unet_and_control_net( return original_unet(sample, timestep, encoder_hidden_states) guided_hint = guided_hints[cnet_idx] + + # gradual latent support: match the size of guided_hint to the size of sample + if guided_hint.shape[-2:] != sample.shape[-2:]: + # print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}") + org_dtype = guided_hint.dtype + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(torch.float32) + guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic") + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(org_dtype) + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) - outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net) + outs = unet_forward( + True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net + ) outs = [o * cnet_info.weight for o in outs] # U-Net