Tried policy with diffusion model
This commit is contained in:
@@ -94,7 +94,7 @@ class NrvDataset(Dataset):
|
|||||||
# get indices of all 00:15 timestamps
|
# get indices of all 00:15 timestamps
|
||||||
if self.full_day_skip:
|
if self.full_day_skip:
|
||||||
start_of_day_indices = dataframe[
|
start_of_day_indices = dataframe[
|
||||||
dataframe["datetime"].dt.time != pd.Timestamp("00:15:00").time()
|
dataframe["datetime"].dt.time != pd.Timestamp("00:00:00").time()
|
||||||
].index
|
].index
|
||||||
skip_indices.extend(start_of_day_indices)
|
skip_indices.extend(start_of_day_indices)
|
||||||
skip_indices = list(set(skip_indices))
|
skip_indices = list(set(skip_indices))
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -48,6 +48,7 @@ class DiffusionTrainer:
|
|||||||
"""
|
"""
|
||||||
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
||||||
|
|
||||||
|
|
||||||
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
|
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
|
||||||
inputs = inputs.repeat(n, 1).to(self.device)
|
inputs = inputs.repeat(n, 1).to(self.device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user