Tried policy with diffusion model
This commit is contained in:
@@ -48,6 +48,7 @@ class DiffusionTrainer:
|
||||
"""
|
||||
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
||||
|
||||
|
||||
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
|
||||
inputs = inputs.repeat(n, 1).to(self.device)
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user