diff --git a/Dockerfile b/Dockerfile
index 482d478..220005b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,7 @@
-FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
+FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
+#FROM getkeops/keops-full:2.1-geomloss0.2.5-cuda11.8-pytorch2.0.0-python3.10
+# FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
RUN apt-get update
RUN apt-get install -y git
diff --git a/requirements.txt b/requirements.txt
index ccec0b4..e14eea3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,4 +11,6 @@ clearml
properscoring
nbconvert
torchinfo
-tabulate
\ No newline at end of file
+tabulate
+einops
+opt_einsum
\ No newline at end of file
diff --git a/src/models/tsdiff_s4/backbones.py b/src/models/tsdiff_s4/backbones.py
new file mode 100644
index 0000000..9224dd8
--- /dev/null
+++ b/src/models/tsdiff_s4/backbones.py
@@ -0,0 +1,172 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+import math
+
+import torch
+from torch import nn
+
+from src.models.tsdiff_s4.s4 import S4
+
+
+class SinusoidalPositionEmbeddings(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, time):
+ device = time.device
+ half_dim = self.dim // 2
+ embeddings = math.log(10000) / (half_dim - 1)
+ embeddings = torch.exp(
+ torch.arange(half_dim, device=device) * -embeddings
+ )
+ embeddings = time[:, None] * embeddings[None, :]
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
+ return embeddings
+
+
+class S4Layer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.layer = S4(
+ d_model=d_model,
+ d_state=128,
+ bidirectional=True,
+ dropout=dropout,
+ transposed=True,
+ postact=None,
+ )
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = (
+ nn.Dropout1d(dropout) if dropout > 0.0 else nn.Identity()
+ )
+
+ def forward(self, x):
+ """
+ Input x is shape (B, d_input, L)
+ """
+ z = x
+ # Prenorm
+ z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)
+ # Apply layer: we ignore the state input and output for training
+ z, _ = self.layer(z)
+ # Dropout on the output of the layer
+ z = self.dropout(z)
+ # Residual connection
+ x = z + x
+ return x, None
+
+ def default_state(self, *args, **kwargs):
+ return self.layer.default_state(*args, **kwargs)
+
+ def step(self, x, state, **kwargs):
+ z = x
+ # Prenorm
+ z = self.norm(z.transpose(-1, -2)).transpose(-1, -2)
+ # Apply layer
+ z, state = self.layer.step(z, state, **kwargs)
+ # Residual connection
+ x = z + x
+ return x, state
+
+
+class S4Block(nn.Module):
+ def __init__(self, d_model, dropout=0.0, expand=2, num_features=0):
+ super().__init__()
+ self.s4block = S4Layer(d_model, dropout=dropout)
+
+ self.time_linear = nn.Linear(d_model, d_model)
+ self.tanh = nn.Tanh()
+ self.sigm = nn.Sigmoid()
+ self.out_linear1 = nn.Conv1d(
+ in_channels=d_model, out_channels=d_model, kernel_size=1
+ )
+ self.out_linear2 = nn.Conv1d(
+ in_channels=d_model, out_channels=d_model, kernel_size=1
+ )
+ self.feature_encoder = nn.Conv1d(num_features, d_model, kernel_size=1)
+
+ def forward(self, x, t, features=None):
+ t = self.time_linear(t)[:, None, :].repeat(1, x.shape[2], 1)
+ t = t.transpose(-1, -2)
+ out, _ = self.s4block(x + t)
+ if features is not None:
+ out = out + self.feature_encoder(features)
+ out = self.tanh(out) * self.sigm(out)
+ out1 = self.out_linear1(out)
+ out2 = self.out_linear2(out)
+ return out1 + x, out2
+
+
+def Conv1dKaiming(in_channels, out_channels, kernel_size):
+ layer = nn.Conv1d(in_channels, out_channels, kernel_size)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+class BackboneModel(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ step_emb,
+ num_residual_blocks,
+ num_features,
+ residual_block="s4",
+ dropout=0.0,
+ init_skip=True,
+ ):
+ super().__init__()
+ if residual_block == "s4":
+ residual_block = S4Block
+ else:
+ raise ValueError(f"Unknown residual block {residual_block}")
+ self.input_init = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim),
+ nn.ReLU(),
+ )
+ self.time_init = nn.Sequential(
+ nn.Linear(step_emb, hidden_dim),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.SiLU(),
+ )
+ self.out_linear = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, output_dim),
+ )
+ residual_blocks = []
+ for i in range(num_residual_blocks):
+ residual_blocks.append(
+ residual_block(
+ hidden_dim, num_features=num_features, dropout=dropout
+ )
+ )
+ self.residual_blocks = nn.ModuleList(residual_blocks)
+ self.step_embedding = SinusoidalPositionEmbeddings(step_emb)
+ self.init_skip = init_skip
+
+ def forward(self, input, t, features=None):
+ x = self.input_init(input) # B, L ,C
+ step_emb = self.step_embedding(t)
+ t = self.time_init(step_emb)
+ x = x.transpose(-1, -2)
+ if features is not None:
+ features = features.transpose(-1, -2)
+ skips = []
+ for layer in self.residual_blocks:
+ x, skip = layer(x, t, features)
+ skips.append(skip)
+
+ skip = torch.stack(skips).sum(0)
+ skip = skip.transpose(-1, -2)
+ out = self.out_linear(skip)
+ if self.init_skip:
+ out = out + input
+ return out
\ No newline at end of file
diff --git a/src/models/tsdiff_s4/s4.py b/src/models/tsdiff_s4/s4.py
new file mode 100644
index 0000000..7dbc27b
--- /dev/null
+++ b/src/models/tsdiff_s4/s4.py
@@ -0,0 +1,1836 @@
+"""Standalone version of Structured (Sequence) State Space (S4) model."""
+
+import logging
+from functools import partial
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pytorch_lightning.utilities import rank_zero_only
+from einops import rearrange, repeat
+import opt_einsum as oe
+
+contract = oe.contract
+contract_expression = oe.contract_expression
+
+
+def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
+ """Initializes multi-GPU-friendly python logger."""
+
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+
+ # this ensures all logging levels get marked with the rank zero decorator
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
+ for level in (
+ "debug",
+ "info",
+ "warning",
+ "error",
+ "exception",
+ "fatal",
+ "critical",
+ ):
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
+
+ return logger
+
+
+log = get_logger(__name__)
+
+""" Cauchy and Vandermonde kernels """
+
+try: # Try CUDA extension
+ from extensions.cauchy.cauchy import cauchy_mult
+
+ has_cauchy_extension = True
+except ImportError:
+ # log.warning(
+ # "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%"
+ # )
+ has_cauchy_extension = False
+
+try: # Try pykeops
+ from pykeops.torch import Genred
+
+ has_pykeops = True
+ log.info("Pykeops installation found.")
+
+ def _broadcast_dims(*tensors):
+ max_dim = max([len(tensor.shape) for tensor in tensors])
+ tensors = [
+ tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)
+ for tensor in tensors
+ ]
+ return tensors
+
+ def cauchy_conj(v, z, w):
+ """Pykeops version"""
+ expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))"
+ expr_denom = "ComplexMult(z-w, z-Conj(w))"
+
+ cauchy_mult = Genred(
+ f"ComplexDivide({expr_num}, {expr_denom})",
+ [
+ "v = Vj(2)",
+ "z = Vi(2)",
+ "w = Vj(2)",
+ ],
+ reduction_op="Sum",
+ axis=1,
+ )
+
+ v, z, w = _broadcast_dims(v, z, w)
+ v = _c2r(v)
+ z = _c2r(z)
+ w = _c2r(w)
+
+ r = 2 * cauchy_mult(v, z, w, backend="GPU")
+ return _r2c(r)
+
+ def log_vandermonde(v, x, L):
+ expr = "ComplexMult(v, ComplexExp(ComplexMult(x, l)))"
+ vandermonde_mult = Genred(
+ expr,
+ [
+ "v = Vj(2)",
+ "x = Vj(2)",
+ "l = Vi(2)",
+ ],
+ reduction_op="Sum",
+ axis=1,
+ )
+
+ l = torch.arange(L).to(x)
+ v, x, l = _broadcast_dims(v, x, l)
+ v = _c2r(v)
+ x = _c2r(x)
+ l = _c2r(l)
+
+ r = vandermonde_mult(v, x, l, backend="GPU")
+ return 2 * _r2c(r).real
+
+ def log_vandermonde_transpose(u, v, x, L):
+ """
+ u: ... H L
+ v: ... H N
+ x: ... H N
+ Returns: ... H N
+
+ V = Vandermonde(a, L) : (H N L)
+ contract_L(V * u * v)
+ """
+ expr = "ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))"
+ vandermonde_mult = Genred(
+ expr,
+ [
+ "u = Vj(2)",
+ "v = Vi(2)",
+ "x = Vi(2)",
+ "l = Vj(2)",
+ ],
+ reduction_op="Sum",
+ axis=1,
+ )
+
+ l = torch.arange(L).to(x)
+ u, v, x, l = _broadcast_dims(u, v, x, l)
+ u = _c2r(u)
+ v = _c2r(v)
+ x = _c2r(x)
+ l = _c2r(l)
+
+ r = vandermonde_mult(u, v, x, l, backend="GPU")
+ return _r2c(r)
+
+except ImportError:
+ has_pykeops = False
+ if not has_cauchy_extension:
+ log.warning(
+ "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
+ )
+
+ def cauchy_naive(v, z, w):
+ """
+ v, w: (..., N)
+ z: (..., L)
+ returns: (..., L)
+ """
+ cauchy_matrix = v.unsqueeze(-1) / (
+ z.unsqueeze(-2) - w.unsqueeze(-1)
+ ) # (... N L)
+ return torch.sum(cauchy_matrix, dim=-2)
+
+ # Vandermonde functions
+ log.warning(
+ "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency."
+ )
+
+ def log_vandermonde(v, x, L):
+ """
+ v: (..., N)
+ x: (..., N)
+ returns: (..., L) \sum v x^l
+ """
+ vandermonde_matrix = torch.exp(
+ x.unsqueeze(-1) * torch.arange(L).to(x)
+ ) # (... N L)
+ vandermonde_prod = contract(
+ "... n, ... n l -> ... l", v, vandermonde_matrix
+ ) # (... L)
+ return 2 * vandermonde_prod.real
+
+ def log_vandermonde_transpose(u, v, x, L):
+ vandermonde_matrix = torch.exp(
+ x.unsqueeze(-1) * torch.arange(L).to(x)
+ ) # (... N L)
+ vandermonde_prod = contract(
+ "... l, ... n, ... n l -> ... n",
+ u.to(x),
+ v.to(x),
+ vandermonde_matrix,
+ ) # (... L)
+ return vandermonde_prod
+
+
+def _conj(x):
+ return torch.cat([x, x.conj()], dim=-1)
+
+
+_c2r = torch.view_as_real
+_r2c = torch.view_as_complex
+if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):
+
+ def _resolve_conj(x):
+ return x.conj().resolve_conj()
+
+else:
+
+ def _resolve_conj(x):
+ return x.conj()
+
+
+""" Simple nn.Module components """
+
+
+def Activation(activation=None, dim=-1):
+ if activation in [None, "id", "identity", "linear"]:
+ return nn.Identity()
+ elif activation == "tanh":
+ return nn.Tanh()
+ elif activation == "relu":
+ return nn.ReLU()
+ elif activation == "gelu":
+ return nn.GELU()
+ elif activation in ["swish", "silu"]:
+ return nn.SiLU()
+ elif activation == "glu":
+ return nn.GLU(dim=dim)
+ elif activation == "sigmoid":
+ return nn.Sigmoid()
+ else:
+ raise NotImplementedError(
+ "hidden activation '{}' is not implemented".format(activation)
+ )
+
+
+def LinearActivation(
+ d_input,
+ d_output,
+ bias=True,
+ transposed=False,
+ activation=None,
+ activate=False, # Apply activation as part of this module
+ **kwargs,
+):
+ """Returns a linear nn.Module with control over axes order, initialization, and activation"""
+
+ # Construct core module
+ linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear
+ if activation == "glu":
+ d_output *= 2
+ linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
+
+ if activate and activation is not None:
+ activation = Activation(activation, dim=-2 if transposed else -1)
+ linear = nn.Sequential(linear, activation)
+ return linear
+
+
+class DropoutNd(nn.Module):
+ def __init__(self, p: float = 0.5, tie=True, transposed=True):
+ """
+ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
+ """
+ super().__init__()
+ if p < 0 or p >= 1:
+ raise ValueError(
+ "dropout probability has to be in [0, 1), "
+ "but got {}".format(p)
+ )
+ self.p = p
+ self.tie = tie
+ self.transposed = transposed
+ self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p)
+
+ def forward(self, X):
+ """X: (batch, dim, lengths...)"""
+ if self.training:
+ if not self.transposed:
+ X = rearrange(X, "b d ... -> b ... d")
+ mask_shape = (
+ X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
+ )
+ mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p
+ X = X * mask * (1.0 / (1 - self.p))
+ if not self.transposed:
+ X = rearrange(X, "b ... d -> b d ...")
+ return X
+ return X
+
+
+""" Misc functional utilities """
+
+
+def power(L, A, v=None):
+ """Compute A^L and the scan sum_i A^i v_i
+
+ A: (..., N, N)
+ v: (..., N, L)
+ """
+
+ I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device)
+
+ powers = [A]
+ l = 1
+ while True:
+ if L % 2 == 1:
+ I = powers[-1] @ I
+ L //= 2
+ if L == 0:
+ break
+ l *= 2
+ powers.append(powers[-1] @ powers[-1])
+
+ if v is None:
+ return I
+
+ # Invariants:
+ # powers[-1] := A^l
+ # l := largest po2 at most L
+
+ # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
+ # We do this reverse divide-and-conquer for efficiency reasons:
+ # 1) it involves fewer padding steps for non-po2 L
+ # 2) it involves more contiguous arrays
+
+ # Take care of edge case for non-po2 arrays
+ # Note that this initial step is a no-op for the case of power of 2 (l == L)
+ k = v.size(-1) - l
+ v_ = powers.pop() @ v[..., l:]
+ v = v[..., :l]
+ v[..., :k] = v[..., :k] + v_
+
+ # Handle reduction for power of 2
+ while v.size(-1) > 1:
+ v = rearrange(v, "... (z l) -> ... z l", z=2)
+ v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
+ return I, v.squeeze(-1)
+
+
+""" HiPPO utilities """
+
+
+def transition(measure, N):
+ """A, B transition matrices for different measures"""
+ # Legendre (translated)
+ if measure == "legt":
+ Q = np.arange(N, dtype=np.float64)
+ R = (2 * Q + 1) ** 0.5
+ j, i = np.meshgrid(Q, Q)
+ A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]
+ B = R[:, None]
+ A = -A
+
+ # Halve again for timescale correctness
+ A *= 0.5
+ B *= 0.5
+ # Legendre (scaled)
+ elif measure == "legs":
+ q = np.arange(N, dtype=np.float64)
+ col, row = np.meshgrid(q, q)
+ r = 2 * q + 1
+ M = -(np.where(row >= col, r, 0) - np.diag(q))
+ T = np.sqrt(np.diag(2 * q + 1))
+ A = T @ M @ np.linalg.inv(T)
+ B = np.diag(T)[:, None]
+ B = (
+ B.copy()
+ ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
+ elif measure == "legsd":
+ # Essentially equivalent to S4D-LegS
+ q = np.arange(N, dtype=np.float64)
+ col, row = np.meshgrid(q, q)
+ r = 2 * q + 1
+ M = -(np.where(row >= col, r, 0) - np.diag(q))
+ T = np.sqrt(np.diag(2 * q + 1))
+ A = T @ M @ np.linalg.inv(T)
+ B = np.diag(T)[:, None]
+ B = (
+ B.copy()
+ ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
+ A += 0.5 * B * B[None, :, 0]
+ B = B / 2.0
+ elif measure in ["fourier_diag", "foud"]:
+ # Essentially equivalent to S4D-Lin
+ freqs = np.arange(N // 2)
+ d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
+ A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1))
+ A = A - 0.5 * np.eye(N)
+ B = np.zeros(N)
+ B[0::2] = 2**0.5
+ B[0] = 1
+ B = B[:, None]
+ elif measure in ["fourier", "fout"]:
+ freqs = np.arange(N // 2)
+ d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
+ A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
+ B = np.zeros(N)
+ B[0::2] = 2**0.5
+ B[0] = 1
+
+ # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
+ A = A - B[:, None] * B[None, :]
+ B = B[:, None]
+ else:
+ raise NotImplementedError
+
+ return A, B
+
+
+def rank_correction(measure, N, rank=1, dtype=torch.float):
+ """Return low-rank matrix L such that A + L is normal"""
+
+ if measure == "legs":
+ assert rank >= 1
+ P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(
+ 0
+ ) # (1 N)
+ elif measure == "legt":
+ assert rank >= 2
+ P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N)
+ P0 = P.clone()
+ P0[0::2] = 0.0
+ P1 = P.clone()
+ P1[1::2] = 0.0
+ P = torch.stack([P0, P1], dim=0) # (2 N)
+ P *= 2 ** (
+ -0.5
+ ) # Halve the rank correct just like the original matrix was halved
+ elif measure in ["fourier", "fout"]:
+ P = torch.zeros(N)
+ P[0::2] = 2**0.5
+ P[0] = 1
+ P = P.unsqueeze(0)
+ elif measure in ["fourier_diag", "foud", "legsd"]:
+ P = torch.zeros(1, N, dtype=dtype)
+ else:
+ raise NotImplementedError
+
+ d = P.size(0)
+ if rank > d:
+ P = torch.cat(
+ [P, torch.zeros(rank - d, N, dtype=dtype)], dim=0
+ ) # (rank N)
+ return P
+
+
+def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True):
+ """Return w, p, q, V, B such that
+ (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V
+ i.e. A = V[w - p q^*]V^*, B = V B
+ """
+ assert dtype == torch.float or dtype == torch.double
+ cdtype = torch.cfloat if dtype == torch.float else torch.cdouble
+
+ A, B = transition(measure, N)
+ A = torch.as_tensor(A, dtype=dtype) # (N, N)
+ B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,)
+
+ P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N)
+ AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)
+
+ # We require AP to be nearly skew-symmetric
+ _A = AP + AP.transpose(-1, -2)
+ if (
+ err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N
+ ) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
+ print("WARNING: HiPPO matrix not skew symmetric", err)
+
+ # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately
+ # Imaginary part can use eigh instead of eig
+ w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)
+
+ # Diagonalize in double precision
+ if diagonalize_precision:
+ AP = AP.to(torch.double)
+ w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N)
+ if diagonalize_precision:
+ w_im, V = w_im.to(cdtype), V.to(cdtype)
+ w = w_re + 1j * w_im
+ # Check: V w V^{-1} = A
+ # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
+
+ # Only keep half of each conjugate pair
+ _, idx = torch.sort(w.imag)
+ w_sorted = w[idx]
+ V_sorted = V[:, idx]
+
+ # There is an edge case when eigenvalues can be 0, which requires some machinery to handle
+ # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
+ V = V_sorted[:, : N // 2]
+ w = w_sorted[: N // 2]
+ assert (
+ w[-2].abs() > 1e-4
+ ), "Only 1 zero eigenvalue allowed in diagonal part of A"
+ if w[-1].abs() < 1e-4:
+ V[:, -1] = 0.0
+ V[0, -1] = 2**-0.5
+ V[1, -1] = 2**-0.5 * 1j
+
+ _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)
+ if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5:
+ print(
+ "Warning: Diagonalization of A matrix not numerically precise - error",
+ err,
+ )
+ # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
+
+ V_inv = V.conj().transpose(-1, -2)
+
+ B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B
+ P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P
+
+ return w, P, B, V
+
+
+def dplr(
+ scaling,
+ N,
+ rank=1,
+ H=1,
+ dtype=torch.float,
+ real_scale=1.0,
+ imag_scale=1.0,
+ random_real=False,
+ random_imag=False,
+ normalize=False,
+ diagonal=True,
+ random_B=False,
+):
+ assert dtype == torch.float or dtype == torch.double
+ dtype = torch.cfloat if dtype == torch.float else torch.cdouble
+
+ pi = torch.tensor(math.pi)
+ if random_real:
+ real_part = torch.rand(H, N // 2)
+ else:
+ real_part = 0.5 * torch.ones(H, N // 2)
+ if random_imag:
+ imag_part = N // 2 * torch.rand(H, N // 2)
+ else:
+ imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H)
+
+ real_part = real_scale * real_part
+ if scaling == "random":
+ imag_part = torch.randn(H, N // 2)
+ elif scaling == "real":
+ imag_part = 0 * imag_part
+ real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H)
+ elif scaling in ["linear", "lin"]:
+ imag_part = pi * imag_part
+ elif scaling in [
+ "inverse",
+ "inv",
+ ]: # Based on asymptotics of the default HiPPO matrix
+ imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)
+ elif scaling in ["inverse2", "inv2"]:
+ imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)
+ elif scaling in ["quadratic", "quad"]:
+ imag_part = 1 / pi * (1 + 2 * imag_part) ** 2
+ elif scaling in ["legs", "hippo"]:
+ w, _, _, _ = nplr("legsd", N)
+ imag_part = w.imag
+
+ else:
+ raise NotImplementedError
+ imag_part = imag_scale * imag_part
+ w = -real_part + 1j * imag_part
+
+ # Initialize B
+ if random_B:
+ B = torch.randn(H, N // 2, dtype=dtype)
+ else:
+ B = torch.ones(H, N // 2, dtype=dtype)
+
+ if normalize:
+ norm = (
+ -B / w
+ ) # (H, N) # Result if you integrate the kernel with constant 1 function
+ zeta = 2 * torch.sum(
+ torch.abs(norm) ** 2, dim=-1, keepdim=True
+ ) # Variance with a random C vector
+ B = B / zeta**0.5
+
+ P = torch.randn(rank, H, N // 2, dtype=dtype)
+ if diagonal:
+ P = P * 0.0
+ V = torch.eye(N, dtype=dtype)[:: N // 2] # Only used in testing
+ V = repeat(V, "n m -> h n m", h=H)
+
+ return w, P, B, V
+
+
+def ssm(measure, N, R, H, **ssm_args):
+ """Dispatcher to create single SSM initialization
+
+ N: state size
+ R: rank (for DPLR parameterization)
+ H: number of independent SSM copies
+ """
+
+ if measure == "dplr":
+ w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)
+ elif measure.startswith("diag"):
+ args = measure.split("-")
+ assert args[0] == "diag" and len(args) > 1
+ scaling = args[1]
+ w, P, B, V = dplr(
+ scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args
+ )
+ else:
+ w, P, B, V = nplr(measure, N, R, **ssm_args)
+ w = repeat(w, "n -> s n", s=H)
+ P = repeat(P, "r n -> r s n", s=H)
+ B = repeat(B, "n -> s n", s=H)
+ V = repeat(V, "n m -> s n m", s=H)
+ return w, P, B, V
+
+
+combinations = {
+ "hippo": ["legs", "fourier"],
+ "diag": ["diag-inv", "diag-lin"],
+ "all": ["legs", "fourier", "diag-inv", "diag-lin"],
+}
+
+
+def combination(measures, N, R, S, **ssm_args):
+ if isinstance(measures, str):
+ measures = (
+ combinations[measures] if measures in combinations else [measures]
+ )
+
+ assert (
+ S % len(measures) == 0
+ ), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures"
+ w, P, B, V = zip(
+ *[
+ ssm(measure, N, R, S // len(measures), **ssm_args)
+ for measure in measures
+ ]
+ )
+ w = torch.cat(w, dim=0) # (S N)
+ P = torch.cat(P, dim=1) # (R S N)
+ B = torch.cat(B, dim=0) # (S N)
+ V = torch.cat(V, dim=0) # (S N N)
+ return w, P, B, V
+
+
+class OptimModule(nn.Module):
+ """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters"""
+
+ def register(self, name, tensor, lr=None):
+ """Register a tensor with a configurable learning rate and 0 weight decay"""
+
+ if lr == 0.0:
+ self.register_buffer(name, tensor)
+ else:
+ self.register_parameter(name, nn.Parameter(tensor))
+
+ optim = {"weight_decay": 0.0}
+ if lr is not None:
+ optim["lr"] = lr
+ setattr(getattr(self, name), "_optim", optim)
+
+
+class SSKernelNPLR(OptimModule):
+ """Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)"""
+
+ @torch.no_grad()
+ def _setup_C(self, L):
+ """Construct C~ from C
+
+ Two modes are supported: go directly to length L if self.L is 1, or length is doubled
+ """
+
+ if self.L.item() == 0:
+ if self.verbose:
+ log.info(f"S4: Initializing kernel to length {L}")
+ double_length = False
+ elif L > self.L.item(): # 2*int(self.L) == L:
+ if self.verbose:
+ log.info(
+ f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}"
+ )
+ double_length = True
+ L = self.L.item() # Convenience for the math below
+ else:
+ return
+
+ C = _r2c(self.C)
+ dA, _ = self._setup_state()
+ dA_L = power(L, dA)
+ # Multiply C by I - dA_L
+ C_ = _conj(C)
+ prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
+ if double_length:
+ prod = -prod # Multiply by I + dA_L instead
+ C_ = C_ - prod
+ C_ = C_[..., : self.N] # Take conjugate pairs again
+ self.C.copy_(_c2r(C_))
+
+ self.L = (
+ 2 * self.L if double_length else self.L + L
+ ) # Preserve type/device
+
+ def _omega(self, L, dtype, device, cache=True):
+ """Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform
+ This should be called everytime the internal length self.L changes"""
+
+ # Use cached if available
+ if (
+ cache
+ and hasattr(self, "omega")
+ and self.omega.size(-1) == L // 2 + 1
+ ):
+ return self.omega, self.z
+
+ omega = torch.tensor(
+ np.exp(-2j * np.pi / (L)), dtype=dtype, device=device
+ ) # \omega_{2L}
+ omega = omega ** torch.arange(0, L // 2 + 1, device=device)
+ z = 2 * (1 - omega) / (1 + omega)
+
+ # Cache if necessary
+ if cache:
+ self.omega = omega
+ self.z = z
+ return omega, z
+
+ def __init__(
+ self,
+ w,
+ P,
+ B,
+ C,
+ log_dt,
+ L=None, # starting/maximum length of kernel
+ lr=None,
+ verbose=False,
+ keops=False,
+ real_type="exp", # ['none' | 'exp' | 'relu' | sigmoid']
+ real_tolerance=1e-3,
+ bandlimit=None,
+ ):
+ """
+ L: Maximum length; this module computes an SSM kernel of length L
+ A is represented by diag(w) - PP^*
+ w: (S, N) diagonal part
+ P: (R, S, N) low-rank part
+
+ B: (S, N)
+ C: (C, H, N)
+ dt: (H) timescale per feature
+ lr: [dict | float | None] hook to set lr of special parameters (A, B, dt)
+
+ Dimensions:
+ N (or d_state): state size
+ H (or d_model): total SSM copies
+ S (or n_ssm): number of trainable copies of (A, B, dt); must divide H
+ R (or rank): rank of low-rank part
+ C (or channels): system is 1-dim to C-dim
+
+ The forward pass of this Module returns a tensor of shape (C, H, L)
+
+ Note: tensor shape N here denotes half the true state size, because of conjugate symmetry
+ """
+
+ super().__init__()
+ self.verbose = verbose
+ self.keops = keops
+ self.bandlimit = bandlimit
+ self.real_type = real_type
+ self.real_tolerance = real_tolerance
+
+ # Rank of low-rank correction
+ self.rank = P.shape[-3]
+ assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1)
+ self.H = log_dt.size(-1)
+ self.N = w.size(-1)
+
+ # Check different SSM inits
+ assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm
+ assert self.H % w.size(0) == 0
+ self.n_ssm = w.size(0)
+ self.repeat = self.H // w.size(
+ 0
+ ) # Each trainable SSM needs to be duplicated this many times
+
+ # Broadcast everything to correct shapes
+ C = C.expand(
+ torch.broadcast_shapes(C.shape, (1, self.H, self.N))
+ ) # (C, H, N)
+ B = B.unsqueeze(0) # (1, 1, N)
+
+ # Register parameters
+ self.C = nn.Parameter(_c2r(_resolve_conj(C)))
+ if lr is None or isinstance(lr, float):
+ lr_dict = {}
+ else:
+ lr_dict, lr = lr, None
+ self.register("log_dt", log_dt, lr_dict.get("dt", lr))
+ self.register("B", _c2r(B), lr_dict.get("B", lr))
+ self.register("P", _c2r(P), lr_dict.get("A", lr))
+ self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr))
+ self.register("w_imag", w.imag, lr_dict.get("A", lr))
+
+ self.l_max = L
+ self.register_buffer("L", torch.tensor(0)) # Internal length
+
+ def _w_init(self, w_real):
+ w_real = torch.clamp(w_real, max=-self.real_tolerance)
+ if self.real_type == "none":
+ return -w_real
+ elif self.real_type == "exp":
+ return torch.log(
+ -w_real
+ ) # Some of the HiPPO methods have real part 0
+ elif self.real_type == "relu":
+ return -w_real
+ elif self.real_type == "sigmoid":
+ return torch.logit(-w_real)
+ elif self.real_type == "softplus":
+ return torch.log(torch.exp(-w_real) - 1)
+ else:
+ raise NotImplementedError
+
+ def _w(self):
+ # Get the internal w (diagonal) parameter
+ if self.real_type == "none":
+ w_real = -self.inv_w_real
+ elif self.real_type == "exp":
+ w_real = -torch.exp(self.inv_w_real)
+ elif self.real_type == "relu":
+ w_real = -F.relu(self.inv_w_real)
+ elif self.real_type == "sigmoid":
+ w_real = -F.sigmoid(self.inv_w_real)
+ elif self.real_type == "softplus":
+ w_real = -F.softplus(self.inv_w_real)
+ else:
+ raise NotImplementedError
+ w = w_real + 1j * self.w_imag
+ return w
+
+ def forward(self, state=None, rate=1.0, L=None):
+ """
+ state: (B, H, N) initial state
+ rate: sampling rate factor
+ L: target length
+
+ returns:
+ (C, H, L) convolution kernel (generally C=1)
+ (B, H, L) output from initial state
+ """
+
+ # Initialize C~ if necessary (done in forward pass so it's on the correct device)
+ if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
+ self._setup_C(self.l_max)
+
+ # Handle sampling rate logic
+ # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate
+ if L is None:
+ L = round(self.L.item() / rate)
+
+ # Increase the internal length if needed
+ continuous_L = round(rate * L)
+ while continuous_L > self.L.item():
+ self._setup_C(continuous_L)
+ discrete_L = round(self.L.item() / rate)
+
+ dt = torch.exp(self.log_dt) * rate
+ B = _r2c(self.B)
+ C = _r2c(self.C)
+ P = _r2c(self.P)
+ Q = P.conj()
+ w = self._w() # (n_ssm, N)
+
+ # Address bandlimiting
+ if self.bandlimit is not None:
+ freqs = w.imag.abs() / (2 * math.pi) # (H, N)
+ freqs = dt[:, None] / rate * freqs # (H, N)
+ mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
+ C = C * mask
+
+ # Get FFT nodes of right length
+ omega, z = self._omega(
+ discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0)
+ )
+
+ # Broadcast parameters to same hidden features H
+ B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat)
+ P = repeat(P, "r t n -> r (v t) n", v=self.repeat)
+ Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat)
+ w = repeat(w, "t n -> (v t) n", v=self.repeat)
+
+ # Augment B
+ if state is not None:
+ # Have to "unbilinear" the state to put it into the same "type" as B
+ # Compute 1/dt * (I + dt/2 A) @ state
+
+ # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way
+ s = _conj(state) if state.size(-1) == self.N else state # (B H N)
+ sA = s * _conj(w) - contract( # (B H N)
+ "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P)
+ )
+ s = s / dt.unsqueeze(-1) + sA / 2
+ s = s[..., : self.N]
+
+ B = torch.cat([s, B], dim=-3) # (B+1, H, N)
+
+ # Incorporate dt into A
+ w = w * dt.unsqueeze(-1) # (H N)
+
+ # Stack B and p, C and q for convenient batching
+ B = torch.cat([B, P], dim=-3) # (B+1+R, H, N)
+ C = torch.cat([C, Q], dim=-3) # (C+R, H, N)
+
+ # Incorporate B and C batch dimensions
+ v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N)
+
+ # Calculate resolvent at omega
+ if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops:
+ r = cauchy_mult(v, z, w, symmetric=True)
+ elif has_pykeops:
+ r = cauchy_conj(v, z, w)
+ else:
+ r = cauchy_naive(v, z, w)
+ r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L)
+
+ # Low-rank Woodbury correction
+ if self.rank == 1:
+ k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (
+ 1 + r[-1:, -1:, :, :]
+ )
+ elif self.rank == 2:
+ r00 = r[: -self.rank, : -self.rank, :, :]
+ r01 = r[: -self.rank, -self.rank :, :, :]
+ r10 = r[-self.rank :, : -self.rank, :, :]
+ r11 = r[-self.rank :, -self.rank :, :, :]
+ det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[
+ :1, 1:, :, :
+ ] * r11[1:, :1, :, :]
+ s = (
+ r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :]
+ + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :]
+ - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :]
+ - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :]
+ )
+ s = s / det
+ k_f = r00 - s
+ else:
+ r00 = r[: -self.rank, : -self.rank, :, :]
+ r01 = r[: -self.rank, -self.rank :, :, :]
+ r10 = r[-self.rank :, : -self.rank, :, :]
+ r11 = r[-self.rank :, -self.rank :, :, :]
+ r11 = rearrange(r11, "a b h n -> h n a b")
+ r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11)
+ r11 = rearrange(r11, "h n a b -> a b h n")
+ k_f = r00 - torch.einsum(
+ "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10
+ )
+
+ # Final correction for the bilinear transform
+ k_f = k_f * 2 / (1 + omega)
+
+ # Move from frequency to coefficients
+ k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L)
+
+ # # Truncate to target length
+ k = k[..., :L]
+
+ if state is not None:
+ k_state = k[:-1, :, :, :] # (B, C, H, L)
+ else:
+ k_state = None
+ k_B = k[-1, :, :, :] # (C H L)
+
+ return k_B, k_state
+
+ @torch.no_grad()
+ def _setup_linear(self):
+ """Create parameters that allow fast linear stepping of state"""
+ w = self._w()
+ B = _r2c(self.B) # (H N)
+ P = _r2c(self.P)
+ Q = P.conj()
+
+ # Repeat w shape properly
+ B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat)
+ P = repeat(P, "r t n -> r (v t) n", v=self.repeat)
+ Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat)
+ w = repeat(w, "t n -> (v t) n", v=self.repeat)
+
+ # Prepare Linear stepping
+ dt = torch.exp(self.log_dt)
+ D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N)
+ R = (
+ torch.eye(self.rank, dtype=w.dtype, device=w.device)
+ + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real
+ ) # (H R R)
+ Q_D = rearrange(Q * D, "r h n -> h r n")
+ try:
+ R = torch.linalg.solve(R, Q_D) # (H R N)
+ except Exception:
+ R = torch.tensor(
+ np.linalg.solve(
+ R.to(Q_D).contiguous().detach().cpu(),
+ Q_D.contiguous().detach().cpu(),
+ )
+ ).to(Q_D)
+ R = rearrange(R, "h r n -> r h n")
+
+ self.step_params = {
+ "D": D, # (H N)
+ "R": R, # (R H N)
+ "P": P, # (R H N)
+ "Q": Q, # (R H N)
+ "B": B, # (1 H N)
+ "E": 2.0 / dt.unsqueeze(-1) + w, # (H N)
+ }
+
+ def _step_state_linear(self, u=None, state=None):
+ """
+ Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization.
+
+ Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster
+
+ u: (H) input
+ state: (H, N/2) state with conjugate pairs
+ Optionally, the state can have last dimension N
+ Returns: same shape as state
+ """
+ C = _r2c(self.C) # View used for dtype/device
+
+ if u is None: # Special case used to find dA
+ u = torch.zeros(self.H, dtype=C.dtype, device=C.device)
+ if state is None: # Special case used to find dB
+ state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device)
+
+ step_params = self.step_params.copy()
+ if (
+ state.size(-1) == self.N
+ ): # Only store half of the conjugate pairs; should be true by default
+ # There should be a slightly faster way using conjugate symmetry
+ def contract_fn(p, x, y):
+ return contract(
+ "r h n, r h m, ... h m -> ... h n",
+ _conj(p),
+ _conj(x),
+ _conj(y),
+ )[
+ ..., : self.N
+ ] # inner outer product
+
+ else:
+ assert state.size(-1) == 2 * self.N
+ step_params = {k: _conj(v) for k, v in step_params.items()}
+
+ # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping
+ def contract_fn(p, x, y):
+ return contract(
+ "r h n, r h m, ... h m -> ... h n", p, x, y
+ ) # inner outer product
+
+ D = step_params["D"] # (H N)
+ E = step_params["E"] # (H N)
+ R = step_params["R"] # (R H N)
+ P = step_params["P"] # (R H N)
+ Q = step_params["Q"] # (R H N)
+ B = step_params["B"] # (1 H N)
+
+ new_state = E * state - contract_fn(P, Q, state) # (B H N)
+ new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N)
+ new_state = D * (new_state - contract_fn(P, R, new_state))
+
+ return new_state
+
+ def _setup_state(self):
+ """Construct dA and dB for discretized state equation"""
+
+ # Construct dA and dB by using the stepping
+ self._setup_linear()
+ C = _r2c(
+ self.C
+ ) # Just returns a view that we use for finding dtype/device
+
+ state = torch.eye(
+ 2 * self.N, dtype=C.dtype, device=C.device
+ ).unsqueeze(
+ -2
+ ) # (N 1 N)
+ dA = self._step_state_linear(state=state)
+ dA = rearrange(dA, "n h m -> h m n")
+
+ u = C.new_ones(self.H)
+ dB = self._step_state_linear(u=u)
+ dB = _conj(dB)
+ dB = rearrange(dB, "1 h n -> h n") # (H N)
+ return dA, dB
+
+ def _step_state(self, u, state):
+ """Must be called after self.default_state() is used to construct an initial state!"""
+ next_state = self.state_contraction(
+ self.dA, state
+ ) + self.input_contraction(self.dB, u)
+ return next_state
+
+ def _setup_step(self, mode="dense"):
+ """Set up dA, dB, dC discretized parameters for stepping"""
+ self.dA, self.dB = self._setup_state()
+
+ # Calculate original C
+ C = _conj(_r2c(self.C)) # (H C N)
+ if self.L.item() == 0:
+ dC = C
+ else:
+ # self.C represents C_tilde
+ dA_L = power(self.L.item(), self.dA)
+ I = torch.eye(self.dA.size(-1)).to(dA_L)
+
+ dC = torch.linalg.solve(
+ I - dA_L.transpose(-1, -2),
+ C.unsqueeze(-1),
+ ).squeeze(-1)
+ self.dC = dC
+
+ # Do special preprocessing for different step modes
+
+ self._step_mode = mode
+ if mode == "linear":
+ # Linear case: special step function for the state, we need to handle output
+ # use conjugate symmetry by default, which affects the output projection
+ self.dC = 2 * self.dC[:, :, : self.N]
+ elif mode == "diagonal":
+ # Eigendecomposition of the A matrix
+ L, V = torch.linalg.eig(self.dA)
+ V_inv = torch.linalg.inv(V)
+ # Check that the eigendedecomposition is correct
+ if self.verbose:
+ print(
+ "Diagonalization error:",
+ torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA),
+ )
+
+ # Change the parameterization to diagonalize
+ self.dA = L
+ self.dB = contract("h n m, h m -> h n", V_inv, self.dB)
+ self.dC = contract("h n m, c h n -> c h m", V, self.dC)
+
+ elif mode == "dense":
+ pass
+ else:
+ raise NotImplementedError(
+ "NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}"
+ )
+
+ def default_state(self, *batch_shape):
+ C = _r2c(self.C)
+ N = C.size(-1)
+ H = C.size(-2)
+
+ # Cache the tensor contractions we will later do, for efficiency
+ # These are put in this function because they depend on the batch size
+ step_mode = getattr(
+ self, "_step_mode", "dense"
+ ) # Used in default_state, which is called without _setup_step() in forward_state()
+ if step_mode != "linear":
+ N *= 2
+
+ if step_mode == "diagonal":
+ self.state_contraction = contract_expression(
+ "h n, ... h n -> ... h n",
+ (H, N),
+ batch_shape + (H, N),
+ )
+ else:
+ # Dense (quadratic) case: expand all terms
+ self.state_contraction = contract_expression(
+ "h m n, ... h n -> ... h m",
+ (H, N, N),
+ batch_shape + (H, N),
+ )
+
+ self.input_contraction = contract_expression(
+ "h n, ... h -> ... h n",
+ (H, N), # self.dB.shape
+ batch_shape + (H,),
+ )
+
+ self.output_contraction = contract_expression(
+ "c h n, ... h n -> ... c h",
+ (C.shape[0], H, N), # self.dC.shape
+ batch_shape + (H, N),
+ )
+
+ state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device)
+ return state
+
+ def step(self, u, state):
+ """Must have called self._setup_step() and created state with self.default_state() before calling this"""
+
+ if self._step_mode == "linear":
+ new_state = self._step_state_linear(u, state)
+ else:
+ new_state = self._step_state(u, state)
+ y = self.output_contraction(self.dC, new_state)
+ return y.real, new_state
+
+
+class SSKernelDiag(OptimModule):
+ """Version using (complex) diagonal state matrix (S4D)"""
+
+ def __init__(
+ self,
+ A,
+ B,
+ C,
+ log_dt,
+ L=None,
+ disc="bilinear",
+ real_type="exp",
+ lr=None,
+ bandlimit=None,
+ ):
+ super().__init__()
+ self.L = L
+ self.disc = disc
+ self.bandlimit = bandlimit
+ self.real_type = real_type
+
+ # Rank of low-rank correction
+ assert A.size(-1) == C.size(-1)
+ self.H = log_dt.size(-1)
+ self.N = A.size(-1)
+ assert A.size(-2) == B.size(-2) # Number of independent SSMs trained
+ assert self.H % A.size(-2) == 0
+ self.n_ssm = A.size(-2)
+ self.repeat = self.H // A.size(0)
+
+ self.channels = C.shape[0]
+ self.C = nn.Parameter(_c2r(_resolve_conj(C)))
+
+ # Register parameters
+ if lr is None or isinstance(lr, float):
+ lr_dict = {}
+ else:
+ lr_dict, lr = lr, None
+
+ self.register("log_dt", log_dt, lr_dict.get("dt", lr))
+ self.register("B", _c2r(B), lr_dict.get("B", lr))
+ self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr))
+ self.register("A_imag", A.imag, lr_dict.get("A", lr))
+
+ def _A_init(self, A_real):
+ A_real = torch.clamp(A_real, max=-1e-4)
+ if self.real_type == "none":
+ return -A_real
+ elif self.real_type == "exp":
+ return torch.log(
+ -A_real
+ ) # Some of the HiPPO methods have real part 0
+ elif self.real_type == "relu":
+ return -A_real
+ elif self.real_type == "sigmoid":
+ return torch.logit(-A_real)
+ elif self.real_type == "softplus":
+ return torch.log(torch.exp(-A_real) - 1)
+ else:
+ raise NotImplementedError
+
+ def _A(self):
+ # Get the internal A (diagonal) parameter
+ if self.real_type == "none":
+ A_real = -self.inv_A_real
+ elif self.real_type == "exp":
+ A_real = -torch.exp(self.inv_A_real)
+ elif self.real_type == "relu":
+ # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it
+ A_real = -F.relu(self.inv_A_real) - 1e-4
+ elif self.real_type == "sigmoid":
+ A_real = -F.sigmoid(self.inv_A_real)
+ elif self.real_type == "softplus":
+ A_real = -F.softplus(self.inv_A_real)
+ else:
+ raise NotImplementedError
+ A = A_real + 1j * self.A_imag
+ return A
+
+ def forward(self, L, state=None, rate=1.0, u=None):
+ """
+ state: (B, H, N) initial state
+ rate: sampling rate factor
+ L: target length
+
+ returns:
+ (C, H, L) convolution kernel (generally C=1)
+ (B, H, L) output from initial state
+ """
+
+ dt = torch.exp(self.log_dt) * rate # (H)
+ C = _r2c(self.C) # (C H N)
+ A = self._A() # (H N)
+
+ B = _r2c(self.B)
+ B = repeat(B, "t n -> 1 (v t) n", v=self.repeat)
+
+ if self.bandlimit is not None:
+ freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N)
+ mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0)
+ C = C * mask
+
+ # Incorporate dt into A
+ A = repeat(A, "t n -> (v t) n", v=self.repeat)
+ dtA = A * dt.unsqueeze(-1) # (H N)
+
+ # Augment B with state
+ if state is not None:
+ s = state / dt.unsqueeze(-1)
+ if self.disc == "bilinear":
+ s = s * (1.0 + dtA / 2)
+ elif self.disc == "zoh":
+ s = s * dtA * dtA.exp() / (dtA.exp() - 1.0)
+ B = torch.cat([s, B], dim=-3) # (1+B H N)
+
+ C = (B[:, None, :, :] * C).view(-1, self.H, self.N)
+ if self.disc == "zoh":
+ # Power up
+ C = C * (torch.exp(dtA) - 1.0) / A
+ K = log_vandermonde(C, dtA, L) # (H L)
+ elif self.disc == "bilinear":
+ C = (
+ C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
+ ) # or * dtA / A
+ dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
+ K = log_vandermonde(C, dA.log(), L)
+ elif self.disc == "dss":
+ # Implementation from DSS meant for case when real eigenvalues can be positive
+ P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L]
+ A_gt_0 = A.real > 0 # [N]
+ if A_gt_0.any():
+ with torch.no_grad():
+ P_max = dtA * (A_gt_0 * (L - 1)) # [H N]
+ P = P - P_max.unsqueeze(-1) # [H N L]
+ S = P.exp() # [H N L]
+
+ dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N]
+ num = dtA_neg.exp() - 1 # [H N]
+ den = (dtA_neg * L).exp() - 1 # [H N]
+
+ # Inline reciprocal function for DSS logic
+ x = den * A
+ x_conj = _resolve_conj(x)
+ r = x_conj / (x * x_conj + 1e-7)
+
+ C = C * num * r # [C H N]
+ K = contract("chn,hnl->chl", C, S).float()
+ else:
+ assert False, f"{self.disc} not supported"
+
+ K = K.view(-1, self.channels, self.H, L) # (1+B C H L)
+ if state is not None:
+ K_state = K[:-1, :, :, :] # (B C H L)
+ else:
+ K_state = None
+ K = K[-1, :, :, :] # (C H L)
+ return K, K_state
+
+ def _setup_step(self):
+ # These methods are organized like this to be compatible with the NPLR kernel interface
+ dt = torch.exp(self.log_dt) # (H)
+ B = _r2c(self.B) # (H N)
+ C = _r2c(self.C) # (C H N)
+ self.dC = C
+ A = self._A() # (H N)
+
+ A = repeat(A, "t n -> (v t) n", v=self.repeat)
+ B = repeat(B, "t n -> (v t) n", v=self.repeat)
+
+ # Incorporate dt into A
+ dtA = A * dt.unsqueeze(-1) # (H N)
+ if self.disc == "zoh":
+ self.dA = torch.exp(dtA) # (H N)
+ self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N)
+ elif self.disc == "bilinear":
+ self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2)
+ self.dB = (
+ B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1)
+ ) # or * dtA / A
+
+ def default_state(self, *batch_shape):
+ C = _r2c(self.C)
+ state = torch.zeros(
+ *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device
+ )
+ return state
+
+ def step(self, u, state):
+ next_state = contract(
+ "h n, b h n -> b h n", self.dA, state
+ ) + contract("h n, b h -> b h n", self.dB, u)
+ y = contract("c h n, b h n -> b c h", self.dC, next_state)
+ return 2 * y.real, next_state
+
+ def forward_state(self, u, state):
+ self._setup_step()
+ AL = self.dA ** u.size(-1)
+ u = u.flip(-1).to(self.dA).contiguous() # (B H L)
+ v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))
+ next_state = AL * state + v
+ return next_state
+
+
+class SSKernel(nn.Module):
+ """Wrapper around SSKernel parameterizations.
+
+ The SSKernel is expected to support the interface
+ forward()
+ default_state()
+ _setup_step()
+ step()
+ """
+
+ def __init__(
+ self,
+ H,
+ N=64,
+ L=None,
+ measure="legs",
+ rank=1,
+ channels=1,
+ dt_min=0.001,
+ dt_max=0.1,
+ deterministic=False,
+ lr=None,
+ mode="nplr",
+ n_ssm=None,
+ verbose=False,
+ measure_args={},
+ **kernel_args,
+ ):
+ """State Space Kernel which computes the convolution kernel $\\bar{K}$
+
+ H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config.
+ N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much.
+ L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known.
+ measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin)
+ rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt"
+ channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead
+ dt_min, dt_max: min and max values for the step size dt (\Delta)
+ mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing
+ n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H
+ lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.
+ """
+ super().__init__()
+ self.N = N
+ self.H = H
+ dtype, cdtype = torch.float, torch.cfloat
+ self.channels = channels
+ self.n_ssm = n_ssm if n_ssm is not None else H
+ self.mode = mode
+ self.verbose = verbose
+ self.kernel_args = kernel_args
+
+ # Generate dt
+ if deterministic:
+ log_dt = torch.exp(
+ torch.linspace(math.log(dt_min), math.log(dt_max), H)
+ )
+ else:
+ log_dt = torch.rand(self.H, dtype=dtype) * (
+ math.log(dt_max) - math.log(dt_min)
+ ) + math.log(dt_min)
+
+ # Compute the preprocessed representation
+ w, P, B, V = combination(
+ measure, self.N, rank, self.n_ssm, **measure_args
+ )
+
+ # Broadcast C to have H channels
+ if deterministic:
+ C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype)
+ C[:, :, :1] = 1.0
+ C = contract(
+ "hmn, chn -> chm", V.conj().transpose(-1, -2), C
+ ) # V^* C
+ C = (
+ repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2))
+ .clone()
+ .contiguous()
+ )
+ else:
+ C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype)
+
+ # Broadcast other parameters to have n_ssm copies
+ assert (
+ self.n_ssm % B.size(-2) == 0
+ and self.n_ssm % P.size(-2) == 0
+ and self.n_ssm % w.size(-2) == 0
+ )
+ # Broadcast tensors to n_ssm copies
+ # These will be the parameters, so make sure tensors are materialized and contiguous
+ B = (
+ repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2))
+ .clone()
+ .contiguous()
+ )
+ P = (
+ repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2))
+ .clone()
+ .contiguous()
+ )
+ w = (
+ repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2))
+ .clone()
+ .contiguous()
+ )
+
+ if mode == "nplr":
+ self.kernel = SSKernelNPLR(
+ w,
+ P,
+ B,
+ C,
+ log_dt,
+ L=L,
+ lr=lr,
+ verbose=verbose,
+ **kernel_args,
+ )
+ elif mode == "diag":
+ if not measure.startswith("diag"):
+ log.warning(
+ "Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv."
+ )
+ C = C * repeat(B, "t n -> (v t) n", v=H // self.n_ssm)
+ self.kernel = SSKernelDiag(
+ w,
+ B,
+ C,
+ log_dt,
+ L=L,
+ lr=lr,
+ **kernel_args,
+ )
+ else:
+ raise NotImplementedError(f"{mode=} is not valid")
+
+ def forward(self, state=None, L=None, rate=1.0):
+ return self.kernel(state=state, L=L, rate=rate)
+
+ @torch.no_grad()
+ def forward_state(self, u, state):
+ """Forward the state through a sequence, i.e. computes the state after passing chunk through SSM
+
+ state: (B, H, N)
+ u: (B, H, L)
+
+ Returns: (B, H, N)
+ """
+
+ if hasattr(self.kernel, "forward_state"):
+ return self.kernel.forward_state(u, state)
+
+ dA, dB = self.kernel._setup_state() # Construct dA, dB matrices
+ # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N)
+
+ conj = state.size(-1) != dA.size(-1)
+ if conj:
+ state = _conj(state)
+
+ v = contract(
+ "h n, b h l -> b h n l", dB, u.flip(-1)
+ ) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2)
+ AL, v = power(u.size(-1), dA, v)
+ next_state = contract("h m n, b h n -> b h m", AL, state)
+ next_state = next_state + v
+
+ if conj:
+ next_state = next_state[..., : next_state.size(-1) // 2]
+ return next_state
+
+ def _setup_step(self, **kwargs):
+ # This method is intended to be private so that setting up an S4 module with
+ # ```
+ # if hasattr(module, 'setup_step'): module.setup_step()
+ # ```
+ # will not trigger this method multiple times
+ self.kernel._setup_step(**kwargs)
+
+ def step(self, u, state, **kwargs):
+ y, state = self.kernel.step(u, state, **kwargs)
+ return y, state
+
+ def default_state(self, *args, **kwargs):
+ return self.kernel.default_state(*args, **kwargs)
+
+
+class S4(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ d_state=64,
+ l_max=None,
+ channels=1,
+ bidirectional=False,
+ # Arguments for position-wise feedforward components
+ activation="gelu",
+ postact="glu",
+ hyper_act=None,
+ dropout=0.0,
+ tie_dropout=False,
+ bottleneck=None,
+ gate=None,
+ transposed=True,
+ verbose=False,
+ # SSM Kernel arguments
+ **kernel_args,
+ ):
+ """
+ d_state: the dimension of the state, also denoted by N
+ l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel
+ channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models
+ bidirectional: if True, convolution kernel will be two-sided
+
+ Position-wise feedforward components:
+ --------------------
+ activation: activation in between SS and FF
+ postact: activation after FF
+ hyper_act: use a "hypernetwork" multiplication (experimental)
+ dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d
+
+ Other arguments:
+ --------------------
+ transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension]
+ gate: add gated activation (GSS)
+ bottleneck: reduce SSM dimension (GSS)
+
+ See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr"
+
+ Other options are all experimental and should not need to be configured
+ """
+
+ super().__init__()
+ if verbose:
+ log.info(
+ f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})"
+ )
+
+ self.d_model = d_model
+ self.H = d_model
+ self.N = d_state
+ self.L = l_max
+ self.bidirectional = bidirectional
+ self.channels = channels
+ self.transposed = transposed
+
+ self.gate = gate
+ self.bottleneck = bottleneck
+
+ if bottleneck is not None:
+ self.H = self.H // bottleneck
+ self.input_linear = LinearActivation(
+ self.d_model,
+ self.H,
+ transposed=self.transposed,
+ activation=activation,
+ activate=True,
+ )
+
+ if gate is not None:
+ self.input_gate = LinearActivation(
+ self.d_model,
+ self.d_model * gate,
+ transposed=self.transposed,
+ activation=activation,
+ activate=True,
+ )
+ self.output_gate = LinearActivation(
+ self.d_model * gate,
+ self.d_model,
+ transposed=self.transposed,
+ activation=None,
+ activate=False,
+ )
+
+ # optional multiplicative modulation GLU-style
+ # https://arxiv.org/abs/2002.05202
+ self.hyper = hyper_act is not None
+ if self.hyper:
+ channels *= 2
+ self.hyper_activation = Activation(hyper_act)
+
+ self.D = nn.Parameter(torch.randn(channels, self.H))
+
+ if self.bidirectional:
+ channels *= 2
+
+ # SSM Kernel
+ self.kernel = SSKernel(
+ self.H,
+ N=self.N,
+ L=self.L,
+ channels=channels,
+ verbose=verbose,
+ **kernel_args,
+ )
+
+ # Pointwise
+ self.activation = Activation(activation)
+ dropout_fn = DropoutNd if tie_dropout else nn.Dropout
+ self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
+ # position-wise output transform to mix features
+ self.output_linear = LinearActivation(
+ self.H * self.channels,
+ self.d_model * (1 if self.gate is None else self.gate),
+ transposed=self.transposed,
+ activation=postact,
+ activate=True,
+ )
+
+ def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs):
+ """
+ u: (B H L) if self.transposed else (B L H)
+ state: (H N) never needed unless you know what you're doing
+
+ Returns: same shape as u
+ """
+ if not self.transposed:
+ u = u.transpose(-1, -2)
+ L = u.size(-1)
+
+ # Mask out padding tokens
+ if isinstance(lengths, int):
+ if lengths != L:
+ lengths = torch.tensor(
+ lengths, dtype=torch.long, device=u.device
+ )
+ else:
+ lengths = None
+ if lengths is not None:
+ assert (
+ isinstance(lengths, torch.Tensor)
+ and lengths.ndim == 1
+ and lengths.size(0) in [1, u.size(0)]
+ )
+ mask = torch.where(
+ torch.arange(L, device=lengths.device)
+ < lengths[:, None, None],
+ 1.0,
+ 0.0,
+ )
+ u = u * mask
+
+ if self.gate is not None:
+ v = self.input_gate(u)
+ if self.bottleneck is not None:
+ u = self.input_linear(u)
+
+ # Compute SS Kernel
+ L_kernel = L if self.L is None else min(L, round(self.L / rate))
+ k, k_state = self.kernel(
+ L=L_kernel, rate=rate, state=state
+ ) # (C H L) (B C H L)
+
+ # Convolution
+ if self.bidirectional:
+ k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2)
+ k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0))
+ k_f = torch.fft.rfft(k, n=L_kernel + L) # (C H L)
+ u_f = torch.fft.rfft(u, n=L_kernel + L) # (B H L)
+ y_f = contract("bhl,chl->bchl", u_f, k_f)
+ y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L] # (B C H L)
+
+ # Compute D term in state space equation - essentially a skip connection
+ y = y + contract("bhl,ch->bchl", u, self.D)
+
+ # Compute state update
+ if state is not None:
+ assert (
+ not self.bidirectional
+ ), "Bidirectional not supported with state forwarding"
+ y = y + k_state #
+ next_state = self.kernel.forward_state(u, state)
+ else:
+ next_state = None
+
+ # Optional hyper-network multiplication
+ if self.hyper:
+ y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2)
+ y = self.hyper_activation(yh) * y
+
+ # Reshape to flatten channels
+ y = rearrange(y, "... c h l -> ... (c h) l")
+
+ y = self.dropout(self.activation(y))
+
+ if not self.transposed:
+ y = y.transpose(-1, -2)
+
+ y = self.output_linear(y)
+
+ if self.gate is not None:
+ y = self.output_gate(y * v)
+
+ return y, next_state
+
+ def setup_step(self, **kwargs):
+ self.kernel._setup_step(**kwargs)
+
+ def step(self, u, state):
+ """Step one time step as a recurrent model. Intended to be used during validation.
+
+ u: (B H)
+ state: (B H N)
+ Returns: output (B H), state (B H N)
+ """
+ assert not self.training
+
+ y, next_state = self.kernel.step(u, state) # (B C H)
+ y = y + u.unsqueeze(-2) * self.D
+ y = rearrange(y, "b c h -> b (c h)")
+ y = self.activation(y)
+ if self.transposed:
+ y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)
+ else:
+ y = self.output_linear(y)
+ return y, next_state
+
+ def default_state(self, *batch_shape, device=None):
+ # kernel is not a SequenceModule so it doesn't need to adhere to same interface
+ # the kernel will know the device of its own parameters
+ return self.kernel.default_state(*batch_shape)
+
+ @property
+ def d_output(self):
+ return self.d_model
\ No newline at end of file
diff --git a/src/notebooks/diffusion-training.ipynb b/src/notebooks/diffusion-training.ipynb
index c6dbfb2..18ef1cd 100644
--- a/src/notebooks/diffusion-training.ipynb
+++ b/src/notebooks/diffusion-training.ipynb
@@ -2,20 +2,41 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 1,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
"source": [
"import sys\n",
"sys.path.append('../..')\n",
- "import torch"
+ "import torch\n",
+ "\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
"from src.data import DataProcessor, DataConfig\n",
"from src.trainers.quantile_trainer import AutoRegressiveQuantileTrainer, NonAutoRegressiveQuantileRegression\n",
@@ -44,7 +65,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -62,8 +83,8 @@
"\n",
"data_config.NOMINAL_NET_POSITION = True\n",
"\n",
- "data_processor = DataProcessor(data_config, path=\"../../\")\n",
- "data_processor.set_batch_size(1024)\n",
+ "data_processor = DataProcessor(data_config, path=\"../../\", lstm=True)\n",
+ "data_processor.set_batch_size(128)\n",
"data_processor.set_full_day_skip(True)"
]
},
@@ -222,6 +243,165 @@
"sample_diffusion(new_model, 1, inputs)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Trying out BackboneModel using S4 state space model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[KeOps] Compiling cuda jit compiler engine ... \n",
+ "[KeOps] Warning : There were warnings or errors compiling formula :\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libstdc++.so: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libstdc++.so: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n",
+ "\n",
+ "OK\n",
+ "[pyKeOps] Compiling nvrtc binder for python ... \n",
+ "[KeOps] Warning : There were warnings or errors compiling formula :\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libstdc++.so: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libstdc++.so: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n",
+ "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n",
+ "\n",
+ "OK\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append('../..')\n",
+ "import torch\n",
+ "\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "from src.models.tsdiff_s4.backbones import BackboneModel\n",
+ "from src.trainers.diffusion_trainer import DiffusionTrainer\n",
+ "\n",
+ "backbone = BackboneModel(\n",
+ " input_dim=1,\n",
+ " hidden_dim=512,\n",
+ " output_dim=1,\n",
+ " step_emb=128,\n",
+ " num_residual_blocks=3,\n",
+ " num_features=2\n",
+ ")\n",
+ "backbone = backbone.to(\"cuda\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[KeOps] Generating code for formula Sum_Reduction(ComplexMult(Real2Complex(1/ComplexSquareAbs(ComplexMult(Var(1,2,0)-Var(2,2,1),Var(1,2,0)-Conj(Var(2,2,1))))),ComplexMult(Var(1,2,0)*ComplexReal(Var(0,2,1))-Real2Complex(Sum(Var(0,2,1)*Var(2,2,1))),Conj(ComplexMult(Var(1,2,0)-Var(2,2,1),Var(1,2,0)-Conj(Var(2,2,1)))))),0) ... "
+ ]
+ },
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details."
+ ]
+ }
+ ],
+ "source": [
+ "# now lets find out what the input shape of the featues and input must be\n",
+ "\n",
+ "# input: (B, L, C)\n",
+ "# features: (B, L, F)\n",
+ "# time: (B, 1)\n",
+ "\n",
+ "# output: (B, L, C)? \n",
+ "\n",
+ "input = torch.randn(2, 96, 1).to(\"cuda\")\n",
+ "features = torch.randn(2, 96, 2).to(\"cuda\")\n",
+ "times = torch.randn(2).to(\"cuda\")\n",
+ "\n",
+ "backbone(input, times, features).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "KeyError",
+ "evalue": "'nvrtc'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[3], line 13\u001b[0m\n\u001b[1;32m 10\u001b[0m times \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 11\u001b[0m features \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m96\u001b[39m, \u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 13\u001b[0m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/backbones.py:164\u001b[0m, in \u001b[0;36mBackboneModel.forward\u001b[0;34m(self, input, t, features)\u001b[0m\n\u001b[1;32m 162\u001b[0m skips \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresidual_blocks:\n\u001b[0;32m--> 164\u001b[0m x, skip \u001b[38;5;241m=\u001b[39m \u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 165\u001b[0m skips\u001b[38;5;241m.\u001b[39mappend(skip)\n\u001b[1;32m 167\u001b[0m skip \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack(skips)\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m0\u001b[39m)\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/backbones.py:96\u001b[0m, in \u001b[0;36mS4Block.forward\u001b[0;34m(self, x, t, features)\u001b[0m\n\u001b[1;32m 94\u001b[0m t \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_linear(t)[:, \u001b[38;5;28;01mNone\u001b[39;00m, :]\u001b[38;5;241m.\u001b[39mrepeat(\u001b[38;5;241m1\u001b[39m, x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m], \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 95\u001b[0m t \u001b[38;5;241m=\u001b[39m t\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 96\u001b[0m out, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43ms4block\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m features \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 98\u001b[0m out \u001b[38;5;241m=\u001b[39m out \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeature_encoder(features)\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/backbones.py:56\u001b[0m, in \u001b[0;36mS4Layer.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 54\u001b[0m z \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm(z\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m))\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 55\u001b[0m \u001b[38;5;66;03m# Apply layer: we ignore the state input and output for training\u001b[39;00m\n\u001b[0;32m---> 56\u001b[0m z, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mz\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;66;03m# Dropout on the output of the layer\u001b[39;00m\n\u001b[1;32m 58\u001b[0m z \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout(z)\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/s4.py:1761\u001b[0m, in \u001b[0;36mS4.forward\u001b[0;34m(self, u, state, rate, lengths, **kwargs)\u001b[0m\n\u001b[1;32m 1759\u001b[0m \u001b[38;5;66;03m# Compute SS Kernel\u001b[39;00m\n\u001b[1;32m 1760\u001b[0m L_kernel \u001b[38;5;241m=\u001b[39m L \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mL \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mmin\u001b[39m(L, \u001b[38;5;28mround\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mL \u001b[38;5;241m/\u001b[39m rate))\n\u001b[0;32m-> 1761\u001b[0m k, k_state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1762\u001b[0m \u001b[43m \u001b[49m\u001b[43mL\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mL_kernel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstate\u001b[49m\n\u001b[1;32m 1763\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (C H L) (B C H L)\u001b[39;00m\n\u001b[1;32m 1765\u001b[0m \u001b[38;5;66;03m# Convolution\u001b[39;00m\n\u001b[1;32m 1766\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional:\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/s4.py:1549\u001b[0m, in \u001b[0;36mSSKernel.forward\u001b[0;34m(self, state, L, rate)\u001b[0m\n\u001b[1;32m 1548\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, state\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, L\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1.0\u001b[39m):\n\u001b[0;32m-> 1549\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mL\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mL\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrate\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/s4.py:925\u001b[0m, in \u001b[0;36mSSKernelNPLR.forward\u001b[0;34m(self, state, rate, L)\u001b[0m\n\u001b[1;32m 923\u001b[0m r \u001b[38;5;241m=\u001b[39m cauchy_mult(v, z, w, symmetric\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 924\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m has_pykeops:\n\u001b[0;32m--> 925\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43mcauchy_conj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 926\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 927\u001b[0m r \u001b[38;5;241m=\u001b[39m cauchy_naive(v, z, w)\n",
+ "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/s4.py:89\u001b[0m, in \u001b[0;36mcauchy_conj\u001b[0;34m(v, z, w)\u001b[0m\n\u001b[1;32m 86\u001b[0m z \u001b[38;5;241m=\u001b[39m _c2r(z)\n\u001b[1;32m 87\u001b[0m w \u001b[38;5;241m=\u001b[39m _c2r(w)\n\u001b[0;32m---> 89\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[43mcauchy_mult\u001b[49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mGPU\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _r2c(r)\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py:688\u001b[0m, in \u001b[0;36mGenred.__call__\u001b[0;34m(self, backend, device_id, ranges, out, *args)\u001b[0m\n\u001b[1;32m 686\u001b[0m params\u001b[38;5;241m.\u001b[39mny \u001b[38;5;241m=\u001b[39m ny\n\u001b[1;32m 687\u001b[0m params\u001b[38;5;241m.\u001b[39mout \u001b[38;5;241m=\u001b[39m out\n\u001b[0;32m--> 688\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mGenredAutograd_fun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 690\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m postprocess(out, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreduction_op, nout, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mopt_arg, dtype)\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py:384\u001b[0m, in \u001b[0;36mGenredAutograd_fun\u001b[0;34m(*inputs)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mGenredAutograd_fun\u001b[39m(\u001b[38;5;241m*\u001b[39minputs):\n\u001b[0;32m--> 384\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mGenredAutograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py:506\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m 504\u001b[0m \u001b[38;5;66;03m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m 505\u001b[0m args \u001b[38;5;241m=\u001b[39m _functorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 506\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 508\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39msetup_context \u001b[38;5;241m==\u001b[39m _SingleLevelFunction\u001b[38;5;241m.\u001b[39msetup_context:\n\u001b[1;32m 509\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 510\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 511\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 512\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstaticmethod. For more details, please see \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 513\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhttps://pytorch.org/docs/master/notes/extending.func.html\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py:295\u001b[0m, in \u001b[0;36mGenredAutograd.forward\u001b[0;34m(*inputs)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39minputs):\n\u001b[0;32m--> 295\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mGenredAutograd_base\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py:91\u001b[0m, in \u001b[0;36mGenredAutograd_base._forward\u001b[0;34m(params, *args)\u001b[0m\n\u001b[1;32m 85\u001b[0m device_id, device_args \u001b[38;5;241m=\u001b[39m set_device(\n\u001b[1;32m 86\u001b[0m tagCPUGPU, tagHostDevice, params\u001b[38;5;241m.\u001b[39mdevice_id_request, \u001b[38;5;241m*\u001b[39margs\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpykeops\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mkeops_io\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m keops_binder\n\u001b[0;32m---> 91\u001b[0m myconv \u001b[38;5;241m=\u001b[39m \u001b[43mkeops_binder\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnvrtc\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtagCPUGPU\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcpp\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m(\n\u001b[1;32m 92\u001b[0m tagCPUGPU,\n\u001b[1;32m 93\u001b[0m tag1D2D,\n\u001b[1;32m 94\u001b[0m tagHostDevice,\n\u001b[1;32m 95\u001b[0m use_ranges,\n\u001b[1;32m 96\u001b[0m device_id,\n\u001b[1;32m 97\u001b[0m params\u001b[38;5;241m.\u001b[39mformula,\n\u001b[1;32m 98\u001b[0m params\u001b[38;5;241m.\u001b[39maliases,\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28mlen\u001b[39m(args),\n\u001b[1;32m 100\u001b[0m params\u001b[38;5;241m.\u001b[39mdtype,\n\u001b[1;32m 101\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 102\u001b[0m params\u001b[38;5;241m.\u001b[39moptional_flags,\n\u001b[1;32m 103\u001b[0m )\u001b[38;5;241m.\u001b[39mimport_module()\n\u001b[1;32m 105\u001b[0m \u001b[38;5;66;03m# N.B.: KeOps C++ expects contiguous data arrays\u001b[39;00m\n\u001b[1;32m 106\u001b[0m test_contig \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mall\u001b[39m(arg\u001b[38;5;241m.\u001b[39mis_contiguous() \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args)\n",
+ "\u001b[0;31mKeyError\u001b[0m: 'nvrtc'"
+ ]
+ }
+ ],
+ "source": [
+ "# inputDim = data_processor.get_input_size()\n",
+ "learningRate = 0.0001\n",
+ "epochs=150\n",
+ "\n",
+ "#### Model ####\n",
+ "model = BackboneModel(1, 512, output_dim=1, step_emb=64, num_residual_blocks=4, num_features=2)\n",
+ "model.to(\"cuda\")\n",
+ "\n",
+ "inputs = torch.randn(2, 96, 1).to(\"cuda\")\n",
+ "times = torch.tensor([0]*2).to(\"cuda\")\n",
+ "features = torch.randn(2, 96, 2).to(\"cuda\")\n",
+ "\n",
+ "model(inputs, times, features).shape\n",
+ "\n",
+ "#### Trainer ####\n",
+ "# trainer = DiffusionTrainer(model, data_processor, \"cuda\")\n",
+ "# trainer.train(epochs, learningRate, None)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -246,7 +426,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.11"
+ "version": "3.10.8"
}
},
"nbformat": 4,
diff --git a/src/trainers/diffusion_trainer.py b/src/trainers/diffusion_trainer.py
index 6f35ae3..85de501 100644
--- a/src/trainers/diffusion_trainer.py
+++ b/src/trainers/diffusion_trainer.py
@@ -132,6 +132,8 @@ class DiffusionTrainer:
t = self.sample_timesteps(time_series.shape[0]).to(self.device)
x_t, noise = self.noise_time_series(time_series, t)
+ x_t = x_t.unsqueeze(-1)
+ print(x_t.shape, t.shape, base_pattern.shape)
predicted_noise = self.model(x_t, t, base_pattern)
loss = criterion(predicted_noise, noise)
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..b72c153
--- /dev/null
+++ b/test.py
@@ -0,0 +1,2 @@
+import pykeops
+pykeops.test_numpy_bindings()
\ No newline at end of file