diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py index f1ae8f96..668584e9 100644 --- a/networks/flux_extract_lora.py +++ b/networks/flux_extract_lora.py @@ -80,10 +80,19 @@ def svd( keys.append(key) with open_fn(model_tuned) as f_tuned: - for key in tqdm(keys): + for original_key in tqdm(keys): + # if the key has annoying prefix, remove it + key = original_key + if key.startswith("model.diffusion_model."): + key = key.replace("model.diffusion_model.", "") # normalize key name + # get tensors and calculate difference - value_o = f_org.get_tensor(key) - value_t = f_tuned.get_tensor(key) + value_o = f_org.get_tensor(original_key) # the original model has this key + try: + value_t = f_tuned.get_tensor(original_key) # the tuned model may not have this key + except: + value_t = f_tuned.get_tensor(key) + mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) del value_o, value_t