diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..9b0397d7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5229,7 +5229,7 @@ def sample_images_common( clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) - if cuda_rng_state is not None: + if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) @@ -5263,11 +5263,13 @@ def sample_image_inference( if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) else: # True random sample image generation torch.seed() - torch.cuda.seed() + if torch.cuda.is_available(): + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -5302,8 +5304,9 @@ def sample_image_inference( controlnet_image=controlnet_image, ) - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0]