Tried policy with diffusion model

This commit is contained in:
Victor Mylle
2023-12-29 12:30:30 +00:00
parent da3ab3d5b3
commit ef8b5f49ac
3 changed files with 113 additions and 118 deletions

View File

@@ -48,6 +48,7 @@ class DiffusionTrainer:
"""
return torch.randint(low=1, high=self.noise_steps, size=(n,))
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
inputs = inputs.repeat(n, 1).to(self.device)
model.eval()