From ef051427df4387ab056c02eddba03c1c6a110fa0 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 16 Feb 2026 07:58:15 +0900 Subject: [PATCH] fix: `str is not "no"` to `str != "no"` --- library/deepspeed_utils.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index a8a05c3a..4daeb254 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -96,7 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace): deepspeed_plugin.deepspeed_config["train_batch_size"] = ( args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) ) - + deepspeed_plugin.set_mixed_precision(args.mixed_precision) if args.mixed_precision.lower() == "fp16": deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. @@ -125,18 +125,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): class DeepSpeedWrapper(torch.nn.Module): def __init__(self, **kw_models) -> None: super().__init__() - + self.models = torch.nn.ModuleDict() - - wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" + + wrap_model_forward_with_torch_autocast = args.mixed_precision != "no" for key, model in kw_models.items(): if isinstance(model, list): model = torch.nn.ModuleList(model) - + if wrap_model_forward_with_torch_autocast: - model = self.__wrap_model_with_torch_autocast(model) - + model = self.__wrap_model_with_torch_autocast(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" @@ -151,7 +151,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): return model def __wrap_model_forward_with_torch_autocast(self, model): - + assert hasattr(model, "forward"), f"model must have a forward method." forward_fn = model.forward @@ -161,20 +161,19 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): device_type = model.device.type except AttributeError: logger.warning( - "[DeepSpeed] model.device is not available. Using get_preferred_device() " - "to determine the device_type for torch.autocast()." - ) + "[DeepSpeed] model.device is not available. Using get_preferred_device() " + "to determine the device_type for torch.autocast()." + ) device_type = get_preferred_device().type - with torch.autocast(device_type = device_type): + with torch.autocast(device_type=device_type): return forward_fn(*args, **kwargs) model.forward = forward return model - + def get_models(self): return self.models - ds_model = DeepSpeedWrapper(**models) return ds_model