diff --git a/library/chroma_models.py b/library/chroma_models.py index b9c54db4..0c93f526 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -695,6 +695,7 @@ class Chroma(Flux): input_vec = self.get_input_vec(timesteps, guidance, img.shape[0]) # kohya-ss: I'm not sure why requires_grad is set to True here + # original code: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L217 input_vec.requires_grad = True mod_vectors = self.distilled_guidance_layer(input_vec) else: