diff --git a/Dockerfile b/Dockerfile index 482d478..220005b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,7 @@ -FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime +FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel +#FROM getkeops/keops-full:2.1-geomloss0.2.5-cuda11.8-pytorch2.0.0-python3.10 +# FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel RUN apt-get update RUN apt-get install -y git diff --git a/requirements.txt b/requirements.txt index ccec0b4..e14eea3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ clearml properscoring nbconvert torchinfo -tabulate \ No newline at end of file +tabulate +einops +opt_einsum \ No newline at end of file diff --git a/src/models/tsdiff_s4/backbones.py b/src/models/tsdiff_s4/backbones.py new file mode 100644 index 0000000..9224dd8 --- /dev/null +++ b/src/models/tsdiff_s4/backbones.py @@ -0,0 +1,172 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import math + +import torch +from torch import nn + +from src.models.tsdiff_s4.s4 import S4 + + +class SinusoidalPositionEmbeddings(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, time): + device = time.device + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = torch.exp( + torch.arange(half_dim, device=device) * -embeddings + ) + embeddings = time[:, None] * embeddings[None, :] + embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + return embeddings + + +class S4Layer(nn.Module): + def __init__( + self, + d_model, + dropout=0.0, + ): + super().__init__() + self.layer = S4( + d_model=d_model, + d_state=128, + bidirectional=True, + dropout=dropout, + transposed=True, + postact=None, + ) + self.norm = nn.LayerNorm(d_model) + self.dropout = ( + nn.Dropout1d(dropout) if dropout > 0.0 else nn.Identity() + ) + + def forward(self, x): + """ + Input x is shape (B, d_input, L) + """ + z = x + # Prenorm + z = self.norm(z.transpose(-1, -2)).transpose(-1, -2) + # Apply layer: we ignore the state input and output for training + z, _ = self.layer(z) + # Dropout on the output of the layer + z = self.dropout(z) + # Residual connection + x = z + x + return x, None + + def default_state(self, *args, **kwargs): + return self.layer.default_state(*args, **kwargs) + + def step(self, x, state, **kwargs): + z = x + # Prenorm + z = self.norm(z.transpose(-1, -2)).transpose(-1, -2) + # Apply layer + z, state = self.layer.step(z, state, **kwargs) + # Residual connection + x = z + x + return x, state + + +class S4Block(nn.Module): + def __init__(self, d_model, dropout=0.0, expand=2, num_features=0): + super().__init__() + self.s4block = S4Layer(d_model, dropout=dropout) + + self.time_linear = nn.Linear(d_model, d_model) + self.tanh = nn.Tanh() + self.sigm = nn.Sigmoid() + self.out_linear1 = nn.Conv1d( + in_channels=d_model, out_channels=d_model, kernel_size=1 + ) + self.out_linear2 = nn.Conv1d( + in_channels=d_model, out_channels=d_model, kernel_size=1 + ) + self.feature_encoder = nn.Conv1d(num_features, d_model, kernel_size=1) + + def forward(self, x, t, features=None): + t = self.time_linear(t)[:, None, :].repeat(1, x.shape[2], 1) + t = t.transpose(-1, -2) + out, _ = self.s4block(x + t) + if features is not None: + out = out + self.feature_encoder(features) + out = self.tanh(out) * self.sigm(out) + out1 = self.out_linear1(out) + out2 = self.out_linear2(out) + return out1 + x, out2 + + +def Conv1dKaiming(in_channels, out_channels, kernel_size): + layer = nn.Conv1d(in_channels, out_channels, kernel_size) + nn.init.kaiming_normal_(layer.weight) + return layer + + +class BackboneModel(nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + step_emb, + num_residual_blocks, + num_features, + residual_block="s4", + dropout=0.0, + init_skip=True, + ): + super().__init__() + if residual_block == "s4": + residual_block = S4Block + else: + raise ValueError(f"Unknown residual block {residual_block}") + self.input_init = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + ) + self.time_init = nn.Sequential( + nn.Linear(step_emb, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + ) + self.out_linear = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim), + ) + residual_blocks = [] + for i in range(num_residual_blocks): + residual_blocks.append( + residual_block( + hidden_dim, num_features=num_features, dropout=dropout + ) + ) + self.residual_blocks = nn.ModuleList(residual_blocks) + self.step_embedding = SinusoidalPositionEmbeddings(step_emb) + self.init_skip = init_skip + + def forward(self, input, t, features=None): + x = self.input_init(input) # B, L ,C + step_emb = self.step_embedding(t) + t = self.time_init(step_emb) + x = x.transpose(-1, -2) + if features is not None: + features = features.transpose(-1, -2) + skips = [] + for layer in self.residual_blocks: + x, skip = layer(x, t, features) + skips.append(skip) + + skip = torch.stack(skips).sum(0) + skip = skip.transpose(-1, -2) + out = self.out_linear(skip) + if self.init_skip: + out = out + input + return out \ No newline at end of file diff --git a/src/models/tsdiff_s4/s4.py b/src/models/tsdiff_s4/s4.py new file mode 100644 index 0000000..7dbc27b --- /dev/null +++ b/src/models/tsdiff_s4/s4.py @@ -0,0 +1,1836 @@ +"""Standalone version of Structured (Sequence) State Space (S4) model.""" + +import logging +from functools import partial +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch_lightning.utilities import rank_zero_only +from einops import rearrange, repeat +import opt_einsum as oe + +contract = oe.contract +contract_expression = oe.contract_expression + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ): + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger + + +log = get_logger(__name__) + +""" Cauchy and Vandermonde kernels """ + +try: # Try CUDA extension + from extensions.cauchy.cauchy import cauchy_mult + + has_cauchy_extension = True +except ImportError: + # log.warning( + # "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%" + # ) + has_cauchy_extension = False + +try: # Try pykeops + from pykeops.torch import Genred + + has_pykeops = True + log.info("Pykeops installation found.") + + def _broadcast_dims(*tensors): + max_dim = max([len(tensor.shape) for tensor in tensors]) + tensors = [ + tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape) + for tensor in tensors + ] + return tensors + + def cauchy_conj(v, z, w): + """Pykeops version""" + expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))" + expr_denom = "ComplexMult(z-w, z-Conj(w))" + + cauchy_mult = Genred( + f"ComplexDivide({expr_num}, {expr_denom})", + [ + "v = Vj(2)", + "z = Vi(2)", + "w = Vj(2)", + ], + reduction_op="Sum", + axis=1, + ) + + v, z, w = _broadcast_dims(v, z, w) + v = _c2r(v) + z = _c2r(z) + w = _c2r(w) + + r = 2 * cauchy_mult(v, z, w, backend="GPU") + return _r2c(r) + + def log_vandermonde(v, x, L): + expr = "ComplexMult(v, ComplexExp(ComplexMult(x, l)))" + vandermonde_mult = Genred( + expr, + [ + "v = Vj(2)", + "x = Vj(2)", + "l = Vi(2)", + ], + reduction_op="Sum", + axis=1, + ) + + l = torch.arange(L).to(x) + v, x, l = _broadcast_dims(v, x, l) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(v, x, l, backend="GPU") + return 2 * _r2c(r).real + + def log_vandermonde_transpose(u, v, x, L): + """ + u: ... H L + v: ... H N + x: ... H N + Returns: ... H N + + V = Vandermonde(a, L) : (H N L) + contract_L(V * u * v) + """ + expr = "ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))" + vandermonde_mult = Genred( + expr, + [ + "u = Vj(2)", + "v = Vi(2)", + "x = Vi(2)", + "l = Vj(2)", + ], + reduction_op="Sum", + axis=1, + ) + + l = torch.arange(L).to(x) + u, v, x, l = _broadcast_dims(u, v, x, l) + u = _c2r(u) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(u, v, x, l, backend="GPU") + return _r2c(r) + +except ImportError: + has_pykeops = False + if not has_cauchy_extension: + log.warning( + "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency." + ) + + def cauchy_naive(v, z, w): + """ + v, w: (..., N) + z: (..., L) + returns: (..., L) + """ + cauchy_matrix = v.unsqueeze(-1) / ( + z.unsqueeze(-2) - w.unsqueeze(-1) + ) # (... N L) + return torch.sum(cauchy_matrix, dim=-2) + + # Vandermonde functions + log.warning( + "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." + ) + + def log_vandermonde(v, x, L): + """ + v: (..., N) + x: (..., N) + returns: (..., L) \sum v x^l + """ + vandermonde_matrix = torch.exp( + x.unsqueeze(-1) * torch.arange(L).to(x) + ) # (... N L) + vandermonde_prod = contract( + "... n, ... n l -> ... l", v, vandermonde_matrix + ) # (... L) + return 2 * vandermonde_prod.real + + def log_vandermonde_transpose(u, v, x, L): + vandermonde_matrix = torch.exp( + x.unsqueeze(-1) * torch.arange(L).to(x) + ) # (... N L) + vandermonde_prod = contract( + "... l, ... n, ... n l -> ... n", + u.to(x), + v.to(x), + vandermonde_matrix, + ) # (... L) + return vandermonde_prod + + +def _conj(x): + return torch.cat([x, x.conj()], dim=-1) + + +_c2r = torch.view_as_real +_r2c = torch.view_as_complex +if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): + + def _resolve_conj(x): + return x.conj().resolve_conj() + +else: + + def _resolve_conj(x): + return x.conj() + + +""" Simple nn.Module components """ + + +def Activation(activation=None, dim=-1): + if activation in [None, "id", "identity", "linear"]: + return nn.Identity() + elif activation == "tanh": + return nn.Tanh() + elif activation == "relu": + return nn.ReLU() + elif activation == "gelu": + return nn.GELU() + elif activation in ["swish", "silu"]: + return nn.SiLU() + elif activation == "glu": + return nn.GLU(dim=dim) + elif activation == "sigmoid": + return nn.Sigmoid() + else: + raise NotImplementedError( + "hidden activation '{}' is not implemented".format(activation) + ) + + +def LinearActivation( + d_input, + d_output, + bias=True, + transposed=False, + activation=None, + activate=False, # Apply activation as part of this module + **kwargs, +): + """Returns a linear nn.Module with control over axes order, initialization, and activation""" + + # Construct core module + linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear + if activation == "glu": + d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + if activate and activation is not None: + activation = Activation(activation, dim=-2 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + + +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True, transposed=True): + """ + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError( + "dropout probability has to be in [0, 1), " + "but got {}".format(p) + ) + self.p = p + self.tie = tie + self.transposed = transposed + self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p) + + def forward(self, X): + """X: (batch, dim, lengths...)""" + if self.training: + if not self.transposed: + X = rearrange(X, "b d ... -> b ... d") + mask_shape = ( + X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape + ) + mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p + X = X * mask * (1.0 / (1 - self.p)) + if not self.transposed: + X = rearrange(X, "b ... d -> b d ...") + return X + return X + + +""" Misc functional utilities """ + + +def power(L, A, v=None): + """Compute A^L and the scan sum_i A^i v_i + + A: (..., N, N) + v: (..., N, L) + """ + + I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + + powers = [A] + l = 1 + while True: + if L % 2 == 1: + I = powers[-1] @ I + L //= 2 + if L == 0: + break + l *= 2 + powers.append(powers[-1] @ powers[-1]) + + if v is None: + return I + + # Invariants: + # powers[-1] := A^l + # l := largest po2 at most L + + # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A + # We do this reverse divide-and-conquer for efficiency reasons: + # 1) it involves fewer padding steps for non-po2 L + # 2) it involves more contiguous arrays + + # Take care of edge case for non-po2 arrays + # Note that this initial step is a no-op for the case of power of 2 (l == L) + k = v.size(-1) - l + v_ = powers.pop() @ v[..., l:] + v = v[..., :l] + v[..., :k] = v[..., :k] + v_ + + # Handle reduction for power of 2 + while v.size(-1) > 1: + v = rearrange(v, "... (z l) -> ... z l", z=2) + v = v[..., 0, :] + powers.pop() @ v[..., 1, :] + return I, v.squeeze(-1) + + +""" HiPPO utilities """ + + +def transition(measure, N): + """A, B transition matrices for different measures""" + # Legendre (translated) + if measure == "legt": + Q = np.arange(N, dtype=np.float64) + R = (2 * Q + 1) ** 0.5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :] + B = R[:, None] + A = -A + + # Halve again for timescale correctness + A *= 0.5 + B *= 0.5 + # Legendre (scaled) + elif measure == "legs": + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = ( + B.copy() + ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure == "legsd": + # Essentially equivalent to S4D-LegS + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = ( + B.copy() + ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + A += 0.5 * B * B[None, :, 0] + B = B / 2.0 + elif measure in ["fourier_diag", "foud"]: + # Essentially equivalent to S4D-Lin + freqs = np.arange(N // 2) + d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1] + A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1)) + A = A - 0.5 * np.eye(N) + B = np.zeros(N) + B[0::2] = 2**0.5 + B[0] = 1 + B = B[:, None] + elif measure in ["fourier", "fout"]: + freqs = np.arange(N // 2) + d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**0.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - B[:, None] * B[None, :] + B = B[:, None] + else: + raise NotImplementedError + + return A, B + + +def rank_correction(measure, N, rank=1, dtype=torch.float): + """Return low-rank matrix L such that A + L is normal""" + + if measure == "legs": + assert rank >= 1 + P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze( + 0 + ) # (1 N) + elif measure == "legt": + assert rank >= 2 + P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N) + P0 = P.clone() + P0[0::2] = 0.0 + P1 = P.clone() + P1[1::2] = 0.0 + P = torch.stack([P0, P1], dim=0) # (2 N) + P *= 2 ** ( + -0.5 + ) # Halve the rank correct just like the original matrix was halved + elif measure in ["fourier", "fout"]: + P = torch.zeros(N) + P[0::2] = 2**0.5 + P[0] = 1 + P = P.unsqueeze(0) + elif measure in ["fourier_diag", "foud", "legsd"]: + P = torch.zeros(1, N, dtype=dtype) + else: + raise NotImplementedError + + d = P.size(0) + if rank > d: + P = torch.cat( + [P, torch.zeros(rank - d, N, dtype=dtype)], dim=0 + ) # (rank N) + return P + + +def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): + """Return w, p, q, V, B such that + (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V + i.e. A = V[w - p q^*]V^*, B = V B + """ + assert dtype == torch.float or dtype == torch.double + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble + + A, B = transition(measure, N) + A = torch.as_tensor(A, dtype=dtype) # (N, N) + B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) + + P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) + AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3) + + # We require AP to be nearly skew-symmetric + _A = AP + AP.transpose(-1, -2) + if ( + err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N + ) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): + print("WARNING: HiPPO matrix not skew symmetric", err) + + # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately + # Imaginary part can use eigh instead of eig + w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) + + # Diagonalize in double precision + if diagonalize_precision: + AP = AP.to(torch.double) + w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N) + if diagonalize_precision: + w_im, V = w_im.to(cdtype), V.to(cdtype) + w = w_re + 1j * w_im + # Check: V w V^{-1} = A + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + # Only keep half of each conjugate pair + _, idx = torch.sort(w.imag) + w_sorted = w[idx] + V_sorted = V[:, idx] + + # There is an edge case when eigenvalues can be 0, which requires some machinery to handle + # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) + V = V_sorted[:, : N // 2] + w = w_sorted[: N // 2] + assert ( + w[-2].abs() > 1e-4 + ), "Only 1 zero eigenvalue allowed in diagonal part of A" + if w[-1].abs() < 1e-4: + V[:, -1] = 0.0 + V[0, -1] = 2**-0.5 + V[1, -1] = 2**-0.5 * 1j + + _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) + if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5: + print( + "Warning: Diagonalization of A matrix not numerically precise - error", + err, + ) + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + V_inv = V.conj().transpose(-1, -2) + + B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B + P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P + + return w, P, B, V + + +def dplr( + scaling, + N, + rank=1, + H=1, + dtype=torch.float, + real_scale=1.0, + imag_scale=1.0, + random_real=False, + random_imag=False, + normalize=False, + diagonal=True, + random_B=False, +): + assert dtype == torch.float or dtype == torch.double + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + + pi = torch.tensor(math.pi) + if random_real: + real_part = torch.rand(H, N // 2) + else: + real_part = 0.5 * torch.ones(H, N // 2) + if random_imag: + imag_part = N // 2 * torch.rand(H, N // 2) + else: + imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H) + + real_part = real_scale * real_part + if scaling == "random": + imag_part = torch.randn(H, N // 2) + elif scaling == "real": + imag_part = 0 * imag_part + real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H) + elif scaling in ["linear", "lin"]: + imag_part = pi * imag_part + elif scaling in [ + "inverse", + "inv", + ]: # Based on asymptotics of the default HiPPO matrix + imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1) + elif scaling in ["inverse2", "inv2"]: + imag_part = 1 / pi * N * (N / (1 + imag_part) - 1) + elif scaling in ["quadratic", "quad"]: + imag_part = 1 / pi * (1 + 2 * imag_part) ** 2 + elif scaling in ["legs", "hippo"]: + w, _, _, _ = nplr("legsd", N) + imag_part = w.imag + + else: + raise NotImplementedError + imag_part = imag_scale * imag_part + w = -real_part + 1j * imag_part + + # Initialize B + if random_B: + B = torch.randn(H, N // 2, dtype=dtype) + else: + B = torch.ones(H, N // 2, dtype=dtype) + + if normalize: + norm = ( + -B / w + ) # (H, N) # Result if you integrate the kernel with constant 1 function + zeta = 2 * torch.sum( + torch.abs(norm) ** 2, dim=-1, keepdim=True + ) # Variance with a random C vector + B = B / zeta**0.5 + + P = torch.randn(rank, H, N // 2, dtype=dtype) + if diagonal: + P = P * 0.0 + V = torch.eye(N, dtype=dtype)[:: N // 2] # Only used in testing + V = repeat(V, "n m -> h n m", h=H) + + return w, P, B, V + + +def ssm(measure, N, R, H, **ssm_args): + """Dispatcher to create single SSM initialization + + N: state size + R: rank (for DPLR parameterization) + H: number of independent SSM copies + """ + + if measure == "dplr": + w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) + elif measure.startswith("diag"): + args = measure.split("-") + assert args[0] == "diag" and len(args) > 1 + scaling = args[1] + w, P, B, V = dplr( + scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args + ) + else: + w, P, B, V = nplr(measure, N, R, **ssm_args) + w = repeat(w, "n -> s n", s=H) + P = repeat(P, "r n -> r s n", s=H) + B = repeat(B, "n -> s n", s=H) + V = repeat(V, "n m -> s n m", s=H) + return w, P, B, V + + +combinations = { + "hippo": ["legs", "fourier"], + "diag": ["diag-inv", "diag-lin"], + "all": ["legs", "fourier", "diag-inv", "diag-lin"], +} + + +def combination(measures, N, R, S, **ssm_args): + if isinstance(measures, str): + measures = ( + combinations[measures] if measures in combinations else [measures] + ) + + assert ( + S % len(measures) == 0 + ), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" + w, P, B, V = zip( + *[ + ssm(measure, N, R, S // len(measures), **ssm_args) + for measure in measures + ] + ) + w = torch.cat(w, dim=0) # (S N) + P = torch.cat(P, dim=1) # (R S N) + B = torch.cat(B, dim=0) # (S N) + V = torch.cat(V, dim=0) # (S N N) + return w, P, B, V + + +class OptimModule(nn.Module): + """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters""" + + def register(self, name, tensor, lr=None): + """Register a tensor with a configurable learning rate and 0 weight decay""" + + if lr == 0.0: + self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) + + optim = {"weight_decay": 0.0} + if lr is not None: + optim["lr"] = lr + setattr(getattr(self, name), "_optim", optim) + + +class SSKernelNPLR(OptimModule): + """Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)""" + + @torch.no_grad() + def _setup_C(self, L): + """Construct C~ from C + + Two modes are supported: go directly to length L if self.L is 1, or length is doubled + """ + + if self.L.item() == 0: + if self.verbose: + log.info(f"S4: Initializing kernel to length {L}") + double_length = False + elif L > self.L.item(): # 2*int(self.L) == L: + if self.verbose: + log.info( + f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}" + ) + double_length = True + L = self.L.item() # Convenience for the math below + else: + return + + C = _r2c(self.C) + dA, _ = self._setup_state() + dA_L = power(L, dA) + # Multiply C by I - dA_L + C_ = _conj(C) + prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) + if double_length: + prod = -prod # Multiply by I + dA_L instead + C_ = C_ - prod + C_ = C_[..., : self.N] # Take conjugate pairs again + self.C.copy_(_c2r(C_)) + + self.L = ( + 2 * self.L if double_length else self.L + L + ) # Preserve type/device + + def _omega(self, L, dtype, device, cache=True): + """Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform + This should be called everytime the internal length self.L changes""" + + # Use cached if available + if ( + cache + and hasattr(self, "omega") + and self.omega.size(-1) == L // 2 + 1 + ): + return self.omega, self.z + + omega = torch.tensor( + np.exp(-2j * np.pi / (L)), dtype=dtype, device=device + ) # \omega_{2L} + omega = omega ** torch.arange(0, L // 2 + 1, device=device) + z = 2 * (1 - omega) / (1 + omega) + + # Cache if necessary + if cache: + self.omega = omega + self.z = z + return omega, z + + def __init__( + self, + w, + P, + B, + C, + log_dt, + L=None, # starting/maximum length of kernel + lr=None, + verbose=False, + keops=False, + real_type="exp", # ['none' | 'exp' | 'relu' | sigmoid'] + real_tolerance=1e-3, + bandlimit=None, + ): + """ + L: Maximum length; this module computes an SSM kernel of length L + A is represented by diag(w) - PP^* + w: (S, N) diagonal part + P: (R, S, N) low-rank part + + B: (S, N) + C: (C, H, N) + dt: (H) timescale per feature + lr: [dict | float | None] hook to set lr of special parameters (A, B, dt) + + Dimensions: + N (or d_state): state size + H (or d_model): total SSM copies + S (or n_ssm): number of trainable copies of (A, B, dt); must divide H + R (or rank): rank of low-rank part + C (or channels): system is 1-dim to C-dim + + The forward pass of this Module returns a tensor of shape (C, H, L) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + super().__init__() + self.verbose = verbose + self.keops = keops + self.bandlimit = bandlimit + self.real_type = real_type + self.real_tolerance = real_tolerance + + # Rank of low-rank correction + self.rank = P.shape[-3] + assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = w.size(-1) + + # Check different SSM inits + assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm + assert self.H % w.size(0) == 0 + self.n_ssm = w.size(0) + self.repeat = self.H // w.size( + 0 + ) # Each trainable SSM needs to be duplicated this many times + + # Broadcast everything to correct shapes + C = C.expand( + torch.broadcast_shapes(C.shape, (1, self.H, self.N)) + ) # (C, H, N) + B = B.unsqueeze(0) # (1, 1, N) + + # Register parameters + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + if lr is None or isinstance(lr, float): + lr_dict = {} + else: + lr_dict, lr = lr, None + self.register("log_dt", log_dt, lr_dict.get("dt", lr)) + self.register("B", _c2r(B), lr_dict.get("B", lr)) + self.register("P", _c2r(P), lr_dict.get("A", lr)) + self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr)) + self.register("w_imag", w.imag, lr_dict.get("A", lr)) + + self.l_max = L + self.register_buffer("L", torch.tensor(0)) # Internal length + + def _w_init(self, w_real): + w_real = torch.clamp(w_real, max=-self.real_tolerance) + if self.real_type == "none": + return -w_real + elif self.real_type == "exp": + return torch.log( + -w_real + ) # Some of the HiPPO methods have real part 0 + elif self.real_type == "relu": + return -w_real + elif self.real_type == "sigmoid": + return torch.logit(-w_real) + elif self.real_type == "softplus": + return torch.log(torch.exp(-w_real) - 1) + else: + raise NotImplementedError + + def _w(self): + # Get the internal w (diagonal) parameter + if self.real_type == "none": + w_real = -self.inv_w_real + elif self.real_type == "exp": + w_real = -torch.exp(self.inv_w_real) + elif self.real_type == "relu": + w_real = -F.relu(self.inv_w_real) + elif self.real_type == "sigmoid": + w_real = -F.sigmoid(self.inv_w_real) + elif self.real_type == "softplus": + w_real = -F.softplus(self.inv_w_real) + else: + raise NotImplementedError + w = w_real + 1j * self.w_imag + return w + + def forward(self, state=None, rate=1.0, L=None): + """ + state: (B, H, N) initial state + rate: sampling rate factor + L: target length + + returns: + (C, H, L) convolution kernel (generally C=1) + (B, H, L) output from initial state + """ + + # Initialize C~ if necessary (done in forward pass so it's on the correct device) + if self.L.item() == 0 and self.l_max is not None and self.l_max > 0: + self._setup_C(self.l_max) + + # Handle sampling rate logic + # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate + if L is None: + L = round(self.L.item() / rate) + + # Increase the internal length if needed + continuous_L = round(rate * L) + while continuous_L > self.L.item(): + self._setup_C(continuous_L) + discrete_L = round(self.L.item() / rate) + + dt = torch.exp(self.log_dt) * rate + B = _r2c(self.B) + C = _r2c(self.C) + P = _r2c(self.P) + Q = P.conj() + w = self._w() # (n_ssm, N) + + # Address bandlimiting + if self.bandlimit is not None: + freqs = w.imag.abs() / (2 * math.pi) # (H, N) + freqs = dt[:, None] / rate * freqs # (H, N) + mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) + C = C * mask + + # Get FFT nodes of right length + omega, z = self._omega( + discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0) + ) + + # Broadcast parameters to same hidden features H + B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat) + P = repeat(P, "r t n -> r (v t) n", v=self.repeat) + Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat) + w = repeat(w, "t n -> (v t) n", v=self.repeat) + + # Augment B + if state is not None: + # Have to "unbilinear" the state to put it into the same "type" as B + # Compute 1/dt * (I + dt/2 A) @ state + + # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way + s = _conj(state) if state.size(-1) == self.N else state # (B H N) + sA = s * _conj(w) - contract( # (B H N) + "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P) + ) + s = s / dt.unsqueeze(-1) + sA / 2 + s = s[..., : self.N] + + B = torch.cat([s, B], dim=-3) # (B+1, H, N) + + # Incorporate dt into A + w = w * dt.unsqueeze(-1) # (H N) + + # Stack B and p, C and q for convenient batching + B = torch.cat([B, P], dim=-3) # (B+1+R, H, N) + C = torch.cat([C, Q], dim=-3) # (C+R, H, N) + + # Incorporate B and C batch dimensions + v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N) + + # Calculate resolvent at omega + if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops: + r = cauchy_mult(v, z, w, symmetric=True) + elif has_pykeops: + r = cauchy_conj(v, z, w) + else: + r = cauchy_naive(v, z, w) + r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L) + + # Low-rank Woodbury correction + if self.rank == 1: + k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / ( + 1 + r[-1:, -1:, :, :] + ) + elif self.rank == 2: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank :, :, :] + r10 = r[-self.rank :, : -self.rank, :, :] + r11 = r[-self.rank :, -self.rank :, :, :] + det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[ + :1, 1:, :, : + ] * r11[1:, :1, :, :] + s = ( + r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] + - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] + - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] + ) + s = s / det + k_f = r00 - s + else: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank :, :, :] + r10 = r[-self.rank :, : -self.rank, :, :] + r11 = r[-self.rank :, -self.rank :, :, :] + r11 = rearrange(r11, "a b h n -> h n a b") + r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) + r11 = rearrange(r11, "h n a b -> a b h n") + k_f = r00 - torch.einsum( + "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10 + ) + + # Final correction for the bilinear transform + k_f = k_f * 2 / (1 + omega) + + # Move from frequency to coefficients + k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L) + + # # Truncate to target length + k = k[..., :L] + + if state is not None: + k_state = k[:-1, :, :, :] # (B, C, H, L) + else: + k_state = None + k_B = k[-1, :, :, :] # (C H L) + + return k_B, k_state + + @torch.no_grad() + def _setup_linear(self): + """Create parameters that allow fast linear stepping of state""" + w = self._w() + B = _r2c(self.B) # (H N) + P = _r2c(self.P) + Q = P.conj() + + # Repeat w shape properly + B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat) + P = repeat(P, "r t n -> r (v t) n", v=self.repeat) + Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat) + w = repeat(w, "t n -> (v t) n", v=self.repeat) + + # Prepare Linear stepping + dt = torch.exp(self.log_dt) + D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) + R = ( + torch.eye(self.rank, dtype=w.dtype, device=w.device) + + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real + ) # (H R R) + Q_D = rearrange(Q * D, "r h n -> h r n") + try: + R = torch.linalg.solve(R, Q_D) # (H R N) + except Exception: + R = torch.tensor( + np.linalg.solve( + R.to(Q_D).contiguous().detach().cpu(), + Q_D.contiguous().detach().cpu(), + ) + ).to(Q_D) + R = rearrange(R, "h r n -> r h n") + + self.step_params = { + "D": D, # (H N) + "R": R, # (R H N) + "P": P, # (R H N) + "Q": Q, # (R H N) + "B": B, # (1 H N) + "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) + } + + def _step_state_linear(self, u=None, state=None): + """ + Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. + + Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster + + u: (H) input + state: (H, N/2) state with conjugate pairs + Optionally, the state can have last dimension N + Returns: same shape as state + """ + C = _r2c(self.C) # View used for dtype/device + + if u is None: # Special case used to find dA + u = torch.zeros(self.H, dtype=C.dtype, device=C.device) + if state is None: # Special case used to find dB + state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) + + step_params = self.step_params.copy() + if ( + state.size(-1) == self.N + ): # Only store half of the conjugate pairs; should be true by default + # There should be a slightly faster way using conjugate symmetry + def contract_fn(p, x, y): + return contract( + "r h n, r h m, ... h m -> ... h n", + _conj(p), + _conj(x), + _conj(y), + )[ + ..., : self.N + ] # inner outer product + + else: + assert state.size(-1) == 2 * self.N + step_params = {k: _conj(v) for k, v in step_params.items()} + + # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping + def contract_fn(p, x, y): + return contract( + "r h n, r h m, ... h m -> ... h n", p, x, y + ) # inner outer product + + D = step_params["D"] # (H N) + E = step_params["E"] # (H N) + R = step_params["R"] # (R H N) + P = step_params["P"] # (R H N) + Q = step_params["Q"] # (R H N) + B = step_params["B"] # (1 H N) + + new_state = E * state - contract_fn(P, Q, state) # (B H N) + new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) + new_state = D * (new_state - contract_fn(P, R, new_state)) + + return new_state + + def _setup_state(self): + """Construct dA and dB for discretized state equation""" + + # Construct dA and dB by using the stepping + self._setup_linear() + C = _r2c( + self.C + ) # Just returns a view that we use for finding dtype/device + + state = torch.eye( + 2 * self.N, dtype=C.dtype, device=C.device + ).unsqueeze( + -2 + ) # (N 1 N) + dA = self._step_state_linear(state=state) + dA = rearrange(dA, "n h m -> h m n") + + u = C.new_ones(self.H) + dB = self._step_state_linear(u=u) + dB = _conj(dB) + dB = rearrange(dB, "1 h n -> h n") # (H N) + return dA, dB + + def _step_state(self, u, state): + """Must be called after self.default_state() is used to construct an initial state!""" + next_state = self.state_contraction( + self.dA, state + ) + self.input_contraction(self.dB, u) + return next_state + + def _setup_step(self, mode="dense"): + """Set up dA, dB, dC discretized parameters for stepping""" + self.dA, self.dB = self._setup_state() + + # Calculate original C + C = _conj(_r2c(self.C)) # (H C N) + if self.L.item() == 0: + dC = C + else: + # self.C represents C_tilde + dA_L = power(self.L.item(), self.dA) + I = torch.eye(self.dA.size(-1)).to(dA_L) + + dC = torch.linalg.solve( + I - dA_L.transpose(-1, -2), + C.unsqueeze(-1), + ).squeeze(-1) + self.dC = dC + + # Do special preprocessing for different step modes + + self._step_mode = mode + if mode == "linear": + # Linear case: special step function for the state, we need to handle output + # use conjugate symmetry by default, which affects the output projection + self.dC = 2 * self.dC[:, :, : self.N] + elif mode == "diagonal": + # Eigendecomposition of the A matrix + L, V = torch.linalg.eig(self.dA) + V_inv = torch.linalg.inv(V) + # Check that the eigendedecomposition is correct + if self.verbose: + print( + "Diagonalization error:", + torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA), + ) + + # Change the parameterization to diagonalize + self.dA = L + self.dB = contract("h n m, h m -> h n", V_inv, self.dB) + self.dC = contract("h n m, c h n -> c h m", V, self.dC) + + elif mode == "dense": + pass + else: + raise NotImplementedError( + "NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}" + ) + + def default_state(self, *batch_shape): + C = _r2c(self.C) + N = C.size(-1) + H = C.size(-2) + + # Cache the tensor contractions we will later do, for efficiency + # These are put in this function because they depend on the batch size + step_mode = getattr( + self, "_step_mode", "dense" + ) # Used in default_state, which is called without _setup_step() in forward_state() + if step_mode != "linear": + N *= 2 + + if step_mode == "diagonal": + self.state_contraction = contract_expression( + "h n, ... h n -> ... h n", + (H, N), + batch_shape + (H, N), + ) + else: + # Dense (quadratic) case: expand all terms + self.state_contraction = contract_expression( + "h m n, ... h n -> ... h m", + (H, N, N), + batch_shape + (H, N), + ) + + self.input_contraction = contract_expression( + "h n, ... h -> ... h n", + (H, N), # self.dB.shape + batch_shape + (H,), + ) + + self.output_contraction = contract_expression( + "c h n, ... h n -> ... c h", + (C.shape[0], H, N), # self.dC.shape + batch_shape + (H, N), + ) + + state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + """Must have called self._setup_step() and created state with self.default_state() before calling this""" + + if self._step_mode == "linear": + new_state = self._step_state_linear(u, state) + else: + new_state = self._step_state(u, state) + y = self.output_contraction(self.dC, new_state) + return y.real, new_state + + +class SSKernelDiag(OptimModule): + """Version using (complex) diagonal state matrix (S4D)""" + + def __init__( + self, + A, + B, + C, + log_dt, + L=None, + disc="bilinear", + real_type="exp", + lr=None, + bandlimit=None, + ): + super().__init__() + self.L = L + self.disc = disc + self.bandlimit = bandlimit + self.real_type = real_type + + # Rank of low-rank correction + assert A.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = A.size(-1) + assert A.size(-2) == B.size(-2) # Number of independent SSMs trained + assert self.H % A.size(-2) == 0 + self.n_ssm = A.size(-2) + self.repeat = self.H // A.size(0) + + self.channels = C.shape[0] + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + + # Register parameters + if lr is None or isinstance(lr, float): + lr_dict = {} + else: + lr_dict, lr = lr, None + + self.register("log_dt", log_dt, lr_dict.get("dt", lr)) + self.register("B", _c2r(B), lr_dict.get("B", lr)) + self.register("inv_A_real", self._A_init(A.real), lr_dict.get("A", lr)) + self.register("A_imag", A.imag, lr_dict.get("A", lr)) + + def _A_init(self, A_real): + A_real = torch.clamp(A_real, max=-1e-4) + if self.real_type == "none": + return -A_real + elif self.real_type == "exp": + return torch.log( + -A_real + ) # Some of the HiPPO methods have real part 0 + elif self.real_type == "relu": + return -A_real + elif self.real_type == "sigmoid": + return torch.logit(-A_real) + elif self.real_type == "softplus": + return torch.log(torch.exp(-A_real) - 1) + else: + raise NotImplementedError + + def _A(self): + # Get the internal A (diagonal) parameter + if self.real_type == "none": + A_real = -self.inv_A_real + elif self.real_type == "exp": + A_real = -torch.exp(self.inv_A_real) + elif self.real_type == "relu": + # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it + A_real = -F.relu(self.inv_A_real) - 1e-4 + elif self.real_type == "sigmoid": + A_real = -F.sigmoid(self.inv_A_real) + elif self.real_type == "softplus": + A_real = -F.softplus(self.inv_A_real) + else: + raise NotImplementedError + A = A_real + 1j * self.A_imag + return A + + def forward(self, L, state=None, rate=1.0, u=None): + """ + state: (B, H, N) initial state + rate: sampling rate factor + L: target length + + returns: + (C, H, L) convolution kernel (generally C=1) + (B, H, L) output from initial state + """ + + dt = torch.exp(self.log_dt) * rate # (H) + C = _r2c(self.C) # (C H N) + A = self._A() # (H N) + + B = _r2c(self.B) + B = repeat(B, "t n -> 1 (v t) n", v=self.repeat) + + if self.bandlimit is not None: + freqs = dt[:, None] / rate * A.imag.abs() / (2 * math.pi) # (H, N) + mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) + C = C * mask + + # Incorporate dt into A + A = repeat(A, "t n -> (v t) n", v=self.repeat) + dtA = A * dt.unsqueeze(-1) # (H N) + + # Augment B with state + if state is not None: + s = state / dt.unsqueeze(-1) + if self.disc == "bilinear": + s = s * (1.0 + dtA / 2) + elif self.disc == "zoh": + s = s * dtA * dtA.exp() / (dtA.exp() - 1.0) + B = torch.cat([s, B], dim=-3) # (1+B H N) + + C = (B[:, None, :, :] * C).view(-1, self.H, self.N) + if self.disc == "zoh": + # Power up + C = C * (torch.exp(dtA) - 1.0) / A + K = log_vandermonde(C, dtA, L) # (H L) + elif self.disc == "bilinear": + C = ( + C * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) + ) # or * dtA / A + dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) + K = log_vandermonde(C, dA.log(), L) + elif self.disc == "dss": + # Implementation from DSS meant for case when real eigenvalues can be positive + P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] + A_gt_0 = A.real > 0 # [N] + if A_gt_0.any(): + with torch.no_grad(): + P_max = dtA * (A_gt_0 * (L - 1)) # [H N] + P = P - P_max.unsqueeze(-1) # [H N L] + S = P.exp() # [H N L] + + dtA_neg = dtA * (1 - 2 * A_gt_0) # [H N] + num = dtA_neg.exp() - 1 # [H N] + den = (dtA_neg * L).exp() - 1 # [H N] + + # Inline reciprocal function for DSS logic + x = den * A + x_conj = _resolve_conj(x) + r = x_conj / (x * x_conj + 1e-7) + + C = C * num * r # [C H N] + K = contract("chn,hnl->chl", C, S).float() + else: + assert False, f"{self.disc} not supported" + + K = K.view(-1, self.channels, self.H, L) # (1+B C H L) + if state is not None: + K_state = K[:-1, :, :, :] # (B C H L) + else: + K_state = None + K = K[-1, :, :, :] # (C H L) + return K, K_state + + def _setup_step(self): + # These methods are organized like this to be compatible with the NPLR kernel interface + dt = torch.exp(self.log_dt) # (H) + B = _r2c(self.B) # (H N) + C = _r2c(self.C) # (C H N) + self.dC = C + A = self._A() # (H N) + + A = repeat(A, "t n -> (v t) n", v=self.repeat) + B = repeat(B, "t n -> (v t) n", v=self.repeat) + + # Incorporate dt into A + dtA = A * dt.unsqueeze(-1) # (H N) + if self.disc == "zoh": + self.dA = torch.exp(dtA) # (H N) + self.dB = B * (torch.exp(dtA) - 1.0) / A # (C H N) + elif self.disc == "bilinear": + self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) + self.dB = ( + B * (1.0 - dtA / 2).reciprocal() * dt.unsqueeze(-1) + ) # or * dtA / A + + def default_state(self, *batch_shape): + C = _r2c(self.C) + state = torch.zeros( + *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device + ) + return state + + def step(self, u, state): + next_state = contract( + "h n, b h n -> b h n", self.dA, state + ) + contract("h n, b h -> b h n", self.dB, u) + y = contract("c h n, b h n -> b c h", self.dC, next_state) + return 2 * y.real, next_state + + def forward_state(self, u, state): + self._setup_step() + AL = self.dA ** u.size(-1) + u = u.flip(-1).to(self.dA).contiguous() # (B H L) + v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) + next_state = AL * state + v + return next_state + + +class SSKernel(nn.Module): + """Wrapper around SSKernel parameterizations. + + The SSKernel is expected to support the interface + forward() + default_state() + _setup_step() + step() + """ + + def __init__( + self, + H, + N=64, + L=None, + measure="legs", + rank=1, + channels=1, + dt_min=0.001, + dt_max=0.1, + deterministic=False, + lr=None, + mode="nplr", + n_ssm=None, + verbose=False, + measure_args={}, + **kernel_args, + ): + """State Space Kernel which computes the convolution kernel $\\bar{K}$ + + H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. + N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. + L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. + measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) + rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" + channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead + dt_min, dt_max: min and max values for the step size dt (\Delta) + mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing + n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H + lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. + """ + super().__init__() + self.N = N + self.H = H + dtype, cdtype = torch.float, torch.cfloat + self.channels = channels + self.n_ssm = n_ssm if n_ssm is not None else H + self.mode = mode + self.verbose = verbose + self.kernel_args = kernel_args + + # Generate dt + if deterministic: + log_dt = torch.exp( + torch.linspace(math.log(dt_min), math.log(dt_max), H) + ) + else: + log_dt = torch.rand(self.H, dtype=dtype) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + # Compute the preprocessed representation + w, P, B, V = combination( + measure, self.N, rank, self.n_ssm, **measure_args + ) + + # Broadcast C to have H channels + if deterministic: + C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype) + C[:, :, :1] = 1.0 + C = contract( + "hmn, chn -> chm", V.conj().transpose(-1, -2), C + ) # V^* C + C = ( + repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2)) + .clone() + .contiguous() + ) + else: + C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) + + # Broadcast other parameters to have n_ssm copies + assert ( + self.n_ssm % B.size(-2) == 0 + and self.n_ssm % P.size(-2) == 0 + and self.n_ssm % w.size(-2) == 0 + ) + # Broadcast tensors to n_ssm copies + # These will be the parameters, so make sure tensors are materialized and contiguous + B = ( + repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2)) + .clone() + .contiguous() + ) + P = ( + repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2)) + .clone() + .contiguous() + ) + w = ( + repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2)) + .clone() + .contiguous() + ) + + if mode == "nplr": + self.kernel = SSKernelNPLR( + w, + P, + B, + C, + log_dt, + L=L, + lr=lr, + verbose=verbose, + **kernel_args, + ) + elif mode == "diag": + if not measure.startswith("diag"): + log.warning( + "Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv." + ) + C = C * repeat(B, "t n -> (v t) n", v=H // self.n_ssm) + self.kernel = SSKernelDiag( + w, + B, + C, + log_dt, + L=L, + lr=lr, + **kernel_args, + ) + else: + raise NotImplementedError(f"{mode=} is not valid") + + def forward(self, state=None, L=None, rate=1.0): + return self.kernel(state=state, L=L, rate=rate) + + @torch.no_grad() + def forward_state(self, u, state): + """Forward the state through a sequence, i.e. computes the state after passing chunk through SSM + + state: (B, H, N) + u: (B, H, L) + + Returns: (B, H, N) + """ + + if hasattr(self.kernel, "forward_state"): + return self.kernel.forward_state(u, state) + + dA, dB = self.kernel._setup_state() # Construct dA, dB matrices + # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) + + conj = state.size(-1) != dA.size(-1) + if conj: + state = _conj(state) + + v = contract( + "h n, b h l -> b h n l", dB, u.flip(-1) + ) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) + AL, v = power(u.size(-1), dA, v) + next_state = contract("h m n, b h n -> b h m", AL, state) + next_state = next_state + v + + if conj: + next_state = next_state[..., : next_state.size(-1) // 2] + return next_state + + def _setup_step(self, **kwargs): + # This method is intended to be private so that setting up an S4 module with + # ``` + # if hasattr(module, 'setup_step'): module.setup_step() + # ``` + # will not trigger this method multiple times + self.kernel._setup_step(**kwargs) + + def step(self, u, state, **kwargs): + y, state = self.kernel.step(u, state, **kwargs) + return y, state + + def default_state(self, *args, **kwargs): + return self.kernel.default_state(*args, **kwargs) + + +class S4(nn.Module): + def __init__( + self, + d_model, + d_state=64, + l_max=None, + channels=1, + bidirectional=False, + # Arguments for position-wise feedforward components + activation="gelu", + postact="glu", + hyper_act=None, + dropout=0.0, + tie_dropout=False, + bottleneck=None, + gate=None, + transposed=True, + verbose=False, + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel + channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models + bidirectional: if True, convolution kernel will be two-sided + + Position-wise feedforward components: + -------------------- + activation: activation in between SS and FF + postact: activation after FF + hyper_act: use a "hypernetwork" multiplication (experimental) + dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d + + Other arguments: + -------------------- + transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] + gate: add gated activation (GSS) + bottleneck: reduce SSM dimension (GSS) + + See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + + super().__init__() + if verbose: + log.info( + f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})" + ) + + self.d_model = d_model + self.H = d_model + self.N = d_state + self.L = l_max + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + + self.gate = gate + self.bottleneck = bottleneck + + if bottleneck is not None: + self.H = self.H // bottleneck + self.input_linear = LinearActivation( + self.d_model, + self.H, + transposed=self.transposed, + activation=activation, + activate=True, + ) + + if gate is not None: + self.input_gate = LinearActivation( + self.d_model, + self.d_model * gate, + transposed=self.transposed, + activation=activation, + activate=True, + ) + self.output_gate = LinearActivation( + self.d_model * gate, + self.d_model, + transposed=self.transposed, + activation=None, + activate=False, + ) + + # optional multiplicative modulation GLU-style + # https://arxiv.org/abs/2002.05202 + self.hyper = hyper_act is not None + if self.hyper: + channels *= 2 + self.hyper_activation = Activation(hyper_act) + + self.D = nn.Parameter(torch.randn(channels, self.H)) + + if self.bidirectional: + channels *= 2 + + # SSM Kernel + self.kernel = SSKernel( + self.H, + N=self.N, + L=self.L, + channels=channels, + verbose=verbose, + **kernel_args, + ) + + # Pointwise + self.activation = Activation(activation) + dropout_fn = DropoutNd if tie_dropout else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + # position-wise output transform to mix features + self.output_linear = LinearActivation( + self.H * self.channels, + self.d_model * (1 if self.gate is None else self.gate), + transposed=self.transposed, + activation=postact, + activate=True, + ) + + def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): + """ + u: (B H L) if self.transposed else (B L H) + state: (H N) never needed unless you know what you're doing + + Returns: same shape as u + """ + if not self.transposed: + u = u.transpose(-1, -2) + L = u.size(-1) + + # Mask out padding tokens + if isinstance(lengths, int): + if lengths != L: + lengths = torch.tensor( + lengths, dtype=torch.long, device=u.device + ) + else: + lengths = None + if lengths is not None: + assert ( + isinstance(lengths, torch.Tensor) + and lengths.ndim == 1 + and lengths.size(0) in [1, u.size(0)] + ) + mask = torch.where( + torch.arange(L, device=lengths.device) + < lengths[:, None, None], + 1.0, + 0.0, + ) + u = u * mask + + if self.gate is not None: + v = self.input_gate(u) + if self.bottleneck is not None: + u = self.input_linear(u) + + # Compute SS Kernel + L_kernel = L if self.L is None else min(L, round(self.L / rate)) + k, k_state = self.kernel( + L=L_kernel, rate=rate, state=state + ) # (C H L) (B C H L) + + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) + k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) + k_f = torch.fft.rfft(k, n=L_kernel + L) # (C H L) + u_f = torch.fft.rfft(u, n=L_kernel + L) # (B H L) + y_f = contract("bhl,chl->bchl", u_f, k_f) + y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L] # (B C H L) + + # Compute D term in state space equation - essentially a skip connection + y = y + contract("bhl,ch->bchl", u, self.D) + + # Compute state update + if state is not None: + assert ( + not self.bidirectional + ), "Bidirectional not supported with state forwarding" + y = y + k_state # + next_state = self.kernel.forward_state(u, state) + else: + next_state = None + + # Optional hyper-network multiplication + if self.hyper: + y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2) + y = self.hyper_activation(yh) * y + + # Reshape to flatten channels + y = rearrange(y, "... c h l -> ... (c h) l") + + y = self.dropout(self.activation(y)) + + if not self.transposed: + y = y.transpose(-1, -2) + + y = self.output_linear(y) + + if self.gate is not None: + y = self.output_gate(y * v) + + return y, next_state + + def setup_step(self, **kwargs): + self.kernel._setup_step(**kwargs) + + def step(self, u, state): + """Step one time step as a recurrent model. Intended to be used during validation. + + u: (B H) + state: (B H N) + Returns: output (B H), state (B H N) + """ + assert not self.training + + y, next_state = self.kernel.step(u, state) # (B C H) + y = y + u.unsqueeze(-2) * self.D + y = rearrange(y, "b c h -> b (c h)") + y = self.activation(y) + if self.transposed: + y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) + else: + y = self.output_linear(y) + return y, next_state + + def default_state(self, *batch_shape, device=None): + # kernel is not a SequenceModule so it doesn't need to adhere to same interface + # the kernel will know the device of its own parameters + return self.kernel.default_state(*batch_shape) + + @property + def d_output(self): + return self.d_model \ No newline at end of file diff --git a/src/notebooks/diffusion-training.ipynb b/src/notebooks/diffusion-training.ipynb index c6dbfb2..18ef1cd 100644 --- a/src/notebooks/diffusion-training.ipynb +++ b/src/notebooks/diffusion-training.ipynb @@ -2,20 +2,41 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import sys\n", "sys.path.append('../..')\n", - "import torch" + "import torch\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "from src.data import DataProcessor, DataConfig\n", "from src.trainers.quantile_trainer import AutoRegressiveQuantileTrainer, NonAutoRegressiveQuantileRegression\n", @@ -44,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -62,8 +83,8 @@ "\n", "data_config.NOMINAL_NET_POSITION = True\n", "\n", - "data_processor = DataProcessor(data_config, path=\"../../\")\n", - "data_processor.set_batch_size(1024)\n", + "data_processor = DataProcessor(data_config, path=\"../../\", lstm=True)\n", + "data_processor.set_batch_size(128)\n", "data_processor.set_full_day_skip(True)" ] }, @@ -222,6 +243,165 @@ "sample_diffusion(new_model, 1, inputs)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Trying out BackboneModel using S4 state space model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[KeOps] Compiling cuda jit compiler engine ... \n", + "[KeOps] Warning : There were warnings or errors compiling formula :\n", + "/usr/bin/ld: warning: /opt/conda/lib/libstdc++.so: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n", + "/usr/bin/ld: warning: /opt/conda/lib/libstdc++.so: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n", + "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n", + "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n", + "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n", + "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n", + "\n", + "OK\n", + "[pyKeOps] Compiling nvrtc binder for python ... \n", + "[KeOps] Warning : There were warnings or errors compiling formula :\n", + "/usr/bin/ld: warning: /opt/conda/lib/libstdc++.so: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n", + "/usr/bin/ld: warning: /opt/conda/lib/libstdc++.so: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n", + "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n", + "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n", + "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001\n", + "/usr/bin/ld: warning: /opt/conda/lib/libgcc_s.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002\n", + "\n", + "OK\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append('../..')\n", + "import torch\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "from src.models.tsdiff_s4.backbones import BackboneModel\n", + "from src.trainers.diffusion_trainer import DiffusionTrainer\n", + "\n", + "backbone = BackboneModel(\n", + " input_dim=1,\n", + " hidden_dim=512,\n", + " output_dim=1,\n", + " step_emb=128,\n", + " num_residual_blocks=3,\n", + " num_features=2\n", + ")\n", + "backbone = backbone.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[KeOps] Generating code for formula Sum_Reduction(ComplexMult(Real2Complex(1/ComplexSquareAbs(ComplexMult(Var(1,2,0)-Var(2,2,1),Var(1,2,0)-Conj(Var(2,2,1))))),ComplexMult(Var(1,2,0)*ComplexReal(Var(0,2,1))-Real2Complex(Sum(Var(0,2,1)*Var(2,2,1))),Conj(ComplexMult(Var(1,2,0)-Var(2,2,1),Var(1,2,0)-Conj(Var(2,2,1)))))),0) ... " + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." + ] + } + ], + "source": [ + "# now lets find out what the input shape of the featues and input must be\n", + "\n", + "# input: (B, L, C)\n", + "# features: (B, L, F)\n", + "# time: (B, 1)\n", + "\n", + "# output: (B, L, C)? \n", + "\n", + "input = torch.randn(2, 96, 1).to(\"cuda\")\n", + "features = torch.randn(2, 96, 2).to(\"cuda\")\n", + "times = torch.randn(2).to(\"cuda\")\n", + "\n", + "backbone(input, times, features).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'nvrtc'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 13\u001b[0m\n\u001b[1;32m 10\u001b[0m times \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 11\u001b[0m features \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m96\u001b[39m, \u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 13\u001b[0m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/backbones.py:164\u001b[0m, in \u001b[0;36mBackboneModel.forward\u001b[0;34m(self, input, t, features)\u001b[0m\n\u001b[1;32m 162\u001b[0m skips \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresidual_blocks:\n\u001b[0;32m--> 164\u001b[0m x, skip \u001b[38;5;241m=\u001b[39m \u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 165\u001b[0m skips\u001b[38;5;241m.\u001b[39mappend(skip)\n\u001b[1;32m 167\u001b[0m skip \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack(skips)\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m0\u001b[39m)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/backbones.py:96\u001b[0m, in \u001b[0;36mS4Block.forward\u001b[0;34m(self, x, t, features)\u001b[0m\n\u001b[1;32m 94\u001b[0m t \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_linear(t)[:, \u001b[38;5;28;01mNone\u001b[39;00m, :]\u001b[38;5;241m.\u001b[39mrepeat(\u001b[38;5;241m1\u001b[39m, x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m], \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 95\u001b[0m t \u001b[38;5;241m=\u001b[39m t\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 96\u001b[0m out, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43ms4block\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m features \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 98\u001b[0m out \u001b[38;5;241m=\u001b[39m out \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeature_encoder(features)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/backbones.py:56\u001b[0m, in \u001b[0;36mS4Layer.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 54\u001b[0m z \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm(z\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m))\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 55\u001b[0m \u001b[38;5;66;03m# Apply layer: we ignore the state input and output for training\u001b[39;00m\n\u001b[0;32m---> 56\u001b[0m z, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mz\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;66;03m# Dropout on the output of the layer\u001b[39;00m\n\u001b[1;32m 58\u001b[0m z \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout(z)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/s4.py:1761\u001b[0m, in \u001b[0;36mS4.forward\u001b[0;34m(self, u, state, rate, lengths, **kwargs)\u001b[0m\n\u001b[1;32m 1759\u001b[0m \u001b[38;5;66;03m# Compute SS Kernel\u001b[39;00m\n\u001b[1;32m 1760\u001b[0m L_kernel \u001b[38;5;241m=\u001b[39m L \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mL \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mmin\u001b[39m(L, \u001b[38;5;28mround\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mL \u001b[38;5;241m/\u001b[39m rate))\n\u001b[0;32m-> 1761\u001b[0m k, k_state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1762\u001b[0m \u001b[43m \u001b[49m\u001b[43mL\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mL_kernel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstate\u001b[49m\n\u001b[1;32m 1763\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (C H L) (B C H L)\u001b[39;00m\n\u001b[1;32m 1765\u001b[0m \u001b[38;5;66;03m# Convolution\u001b[39;00m\n\u001b[1;32m 1766\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional:\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/s4.py:1549\u001b[0m, in \u001b[0;36mSSKernel.forward\u001b[0;34m(self, state, L, rate)\u001b[0m\n\u001b[1;32m 1548\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, state\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, L\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1.0\u001b[39m):\n\u001b[0;32m-> 1549\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mL\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mL\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrate\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/s4.py:925\u001b[0m, in \u001b[0;36mSSKernelNPLR.forward\u001b[0;34m(self, state, rate, L)\u001b[0m\n\u001b[1;32m 923\u001b[0m r \u001b[38;5;241m=\u001b[39m cauchy_mult(v, z, w, symmetric\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 924\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m has_pykeops:\n\u001b[0;32m--> 925\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43mcauchy_conj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 926\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 927\u001b[0m r \u001b[38;5;241m=\u001b[39m cauchy_naive(v, z, w)\n", + "File \u001b[0;32m/workspaces/Thesis/src/notebooks/../../src/models/tsdiff_s4/s4.py:89\u001b[0m, in \u001b[0;36mcauchy_conj\u001b[0;34m(v, z, w)\u001b[0m\n\u001b[1;32m 86\u001b[0m z \u001b[38;5;241m=\u001b[39m _c2r(z)\n\u001b[1;32m 87\u001b[0m w \u001b[38;5;241m=\u001b[39m _c2r(w)\n\u001b[0;32m---> 89\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[43mcauchy_mult\u001b[49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mGPU\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _r2c(r)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py:688\u001b[0m, in \u001b[0;36mGenred.__call__\u001b[0;34m(self, backend, device_id, ranges, out, *args)\u001b[0m\n\u001b[1;32m 686\u001b[0m params\u001b[38;5;241m.\u001b[39mny \u001b[38;5;241m=\u001b[39m ny\n\u001b[1;32m 687\u001b[0m params\u001b[38;5;241m.\u001b[39mout \u001b[38;5;241m=\u001b[39m out\n\u001b[0;32m--> 688\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mGenredAutograd_fun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 690\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m postprocess(out, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreduction_op, nout, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mopt_arg, dtype)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py:384\u001b[0m, in \u001b[0;36mGenredAutograd_fun\u001b[0;34m(*inputs)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mGenredAutograd_fun\u001b[39m(\u001b[38;5;241m*\u001b[39minputs):\n\u001b[0;32m--> 384\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mGenredAutograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py:506\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m 504\u001b[0m \u001b[38;5;66;03m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m 505\u001b[0m args \u001b[38;5;241m=\u001b[39m _functorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 506\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 508\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39msetup_context \u001b[38;5;241m==\u001b[39m _SingleLevelFunction\u001b[38;5;241m.\u001b[39msetup_context:\n\u001b[1;32m 509\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 510\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 511\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 512\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstaticmethod. For more details, please see \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 513\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhttps://pytorch.org/docs/master/notes/extending.func.html\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py:295\u001b[0m, in \u001b[0;36mGenredAutograd.forward\u001b[0;34m(*inputs)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39minputs):\n\u001b[0;32m--> 295\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mGenredAutograd_base\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py:91\u001b[0m, in \u001b[0;36mGenredAutograd_base._forward\u001b[0;34m(params, *args)\u001b[0m\n\u001b[1;32m 85\u001b[0m device_id, device_args \u001b[38;5;241m=\u001b[39m set_device(\n\u001b[1;32m 86\u001b[0m tagCPUGPU, tagHostDevice, params\u001b[38;5;241m.\u001b[39mdevice_id_request, \u001b[38;5;241m*\u001b[39margs\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpykeops\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mkeops_io\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m keops_binder\n\u001b[0;32m---> 91\u001b[0m myconv \u001b[38;5;241m=\u001b[39m \u001b[43mkeops_binder\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnvrtc\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtagCPUGPU\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcpp\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m(\n\u001b[1;32m 92\u001b[0m tagCPUGPU,\n\u001b[1;32m 93\u001b[0m tag1D2D,\n\u001b[1;32m 94\u001b[0m tagHostDevice,\n\u001b[1;32m 95\u001b[0m use_ranges,\n\u001b[1;32m 96\u001b[0m device_id,\n\u001b[1;32m 97\u001b[0m params\u001b[38;5;241m.\u001b[39mformula,\n\u001b[1;32m 98\u001b[0m params\u001b[38;5;241m.\u001b[39maliases,\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28mlen\u001b[39m(args),\n\u001b[1;32m 100\u001b[0m params\u001b[38;5;241m.\u001b[39mdtype,\n\u001b[1;32m 101\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 102\u001b[0m params\u001b[38;5;241m.\u001b[39moptional_flags,\n\u001b[1;32m 103\u001b[0m )\u001b[38;5;241m.\u001b[39mimport_module()\n\u001b[1;32m 105\u001b[0m \u001b[38;5;66;03m# N.B.: KeOps C++ expects contiguous data arrays\u001b[39;00m\n\u001b[1;32m 106\u001b[0m test_contig \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mall\u001b[39m(arg\u001b[38;5;241m.\u001b[39mis_contiguous() \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args)\n", + "\u001b[0;31mKeyError\u001b[0m: 'nvrtc'" + ] + } + ], + "source": [ + "# inputDim = data_processor.get_input_size()\n", + "learningRate = 0.0001\n", + "epochs=150\n", + "\n", + "#### Model ####\n", + "model = BackboneModel(1, 512, output_dim=1, step_emb=64, num_residual_blocks=4, num_features=2)\n", + "model.to(\"cuda\")\n", + "\n", + "inputs = torch.randn(2, 96, 1).to(\"cuda\")\n", + "times = torch.tensor([0]*2).to(\"cuda\")\n", + "features = torch.randn(2, 96, 2).to(\"cuda\")\n", + "\n", + "model(inputs, times, features).shape\n", + "\n", + "#### Trainer ####\n", + "# trainer = DiffusionTrainer(model, data_processor, \"cuda\")\n", + "# trainer.train(epochs, learningRate, None)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -246,7 +426,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/src/trainers/diffusion_trainer.py b/src/trainers/diffusion_trainer.py index 6f35ae3..85de501 100644 --- a/src/trainers/diffusion_trainer.py +++ b/src/trainers/diffusion_trainer.py @@ -132,6 +132,8 @@ class DiffusionTrainer: t = self.sample_timesteps(time_series.shape[0]).to(self.device) x_t, noise = self.noise_time_series(time_series, t) + x_t = x_t.unsqueeze(-1) + print(x_t.shape, t.shape, base_pattern.shape) predicted_noise = self.model(x_t, t, base_pattern) loss = criterion(predicted_noise, noise) diff --git a/test.py b/test.py new file mode 100644 index 0000000..b72c153 --- /dev/null +++ b/test.py @@ -0,0 +1,2 @@ +import pykeops +pykeops.test_numpy_bindings() \ No newline at end of file