diff --git a/library/train_util.py b/library/train_util.py index 8a54cd0c..67df2258 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5512,11 +5512,18 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend - if args.activation_memory_budget: - logger.info( - f"set torch compile activation memory budget to {args.activation_memory_budget}" - ) - torch._functorch.config.activation_memory_budget = args.activation_memory_budget # type: ignore + if args.activation_memory_budget is not None: # Note: 0 is a valid value. + if 0 <= args.activation_memory_budget <= 1: + logger.info( + f"set torch compile activation memory budget to {args.activation_memory_budget}" + ) + torch._functorch.config.activation_memory_budget = ( # type: ignore + args.activation_memory_budget + ) + else: + raise ValueError( + "activation_memory_budget must be between 0 and 1 (inclusive)" + ) kwargs_handlers = [ (