From 665825ad207a900fb8b4a1eb579a609ceb551369 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 9 Nov 2024 16:58:51 +0530 Subject: [PATCH 01/15] initial commit --- .../models/xLSTMTime/__init__.py | 0 .../models/xLSTMTime/mLSTM/__init__.py | 0 .../models/xLSTMTime/mLSTM/cell.py | 105 +++++++++++ .../models/xLSTMTime/mLSTM/layer.py | 85 +++++++++ .../models/xLSTMTime/mLSTM/network.py | 27 +++ .../models/xLSTMTime/sLSTM/__init__.py | 0 .../models/xLSTMTime/sLSTM/cell.py | 92 +++++++++ .../models/xLSTMTime/sLSTM/layer.py | 82 +++++++++ .../models/xLSTMTime/sLSTM/network.py | 38 ++++ .../models/xLSTMTime/xLSTMTime.py | 174 ++++++++++++++++++ 10 files changed, 603 insertions(+) create mode 100644 pytorch_forecasting/models/xLSTMTime/__init__.py create mode 100644 pytorch_forecasting/models/xLSTMTime/mLSTM/__init__.py create mode 100644 pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py create mode 100644 pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py create mode 100644 pytorch_forecasting/models/xLSTMTime/mLSTM/network.py create mode 100644 pytorch_forecasting/models/xLSTMTime/sLSTM/__init__.py create mode 100644 pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py create mode 100644 pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py create mode 100644 pytorch_forecasting/models/xLSTMTime/sLSTM/network.py create mode 100644 pytorch_forecasting/models/xLSTMTime/xLSTMTime.py diff --git a/pytorch_forecasting/models/xLSTMTime/__init__.py b/pytorch_forecasting/models/xLSTMTime/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/__init__.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py new file mode 100644 index 000000000..bb7e55b4f --- /dev/null +++ b/pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import math + + +class mLSTMCell(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None): + super(mLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.layer_norm = layer_norm + + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.Wq = nn.Linear(input_size, hidden_size) + self.Wk = nn.Linear(input_size, hidden_size) + self.Wv = nn.Linear(input_size, hidden_size) + + self.Wi = nn.Linear(input_size, hidden_size) + self.Wf = nn.Linear(input_size, hidden_size) + self.Wo = nn.Linear(input_size, hidden_size) + + self.Wq.to(self.device) + self.Wk.to(self.device) + self.Wv.to(self.device) + self.Wi.to(self.device) + self.Wf.to(self.device) + self.Wo.to(self.device) + + self.dropout = nn.Dropout(dropout) + self.dropout.to(self.device) + + if layer_norm: + self.ln_q = nn.LayerNorm(hidden_size) + self.ln_k = nn.LayerNorm(hidden_size) + self.ln_v = nn.LayerNorm(hidden_size) + self.ln_i = nn.LayerNorm(hidden_size) + self.ln_f = nn.LayerNorm(hidden_size) + self.ln_o = nn.LayerNorm(hidden_size) + + self.ln_q.to(self.device) + self.ln_k.to(self.device) + self.ln_v.to(self.device) + self.ln_i.to(self.device) + self.ln_f.to(self.device) + self.ln_o.to(self.device) + + self.sigmoid = nn.Sigmoid() + self.tanh = nn.Tanh() + + def forward(self, x, h_prev, c_prev, n_prev): + + x = x.to(self.device) + h_prev = h_prev.to(self.device) + c_prev = c_prev.to(self.device) + n_prev = n_prev.to(self.device) + + + batch_size = x.size(0) + assert x.dim() == 2, f"Input should be 2D (batch_size, input_size), got {x.dim()}D" + assert h_prev.size() == (batch_size, self.hidden_size), f"h_prev shape mismatch: {h_prev.size()}" + assert c_prev.size() == (batch_size, self.hidden_size), f"c_prev shape mismatch: {c_prev.size()}" + assert n_prev.size() == (batch_size, self.hidden_size), f"n_prev shape mismatch: {n_prev.size()}" + + + x = self.dropout(x) + h_prev = self.dropout(h_prev) + + q = self.Wq(x) + k = self.Wk(x) / math.sqrt(self.hidden_size) + v = self.Wv(x) + + if self.layer_norm: + q = self.ln_q(q) + k = self.ln_k(k) + v = self.ln_v(v) + + i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x)) + f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x)) + o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x)) + + k_expanded = k.unsqueeze(-1) + v_expanded = v.unsqueeze(-2) + + kv_interaction = k_expanded @ v_expanded + + kv_sum = kv_interaction.sum(dim=1) + + c = f * c_prev + i * kv_sum + n = f * n_prev + i * k + + epsilon = 1e-8 + normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon) + h = o * self.tanh(c * normalized_n) + + return h, c, n + + def init_hidden(self, batch_size): + """ + Initialize hidden, cell, and normalization states. + """ + shape = (batch_size, self.hidden_size) + return (torch.zeros(shape, device=self.device), + torch.zeros(shape, device=self.device), + torch.zeros(shape, device=self.device)) \ No newline at end of file diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py new file mode 100644 index 000000000..cef6dcbcc --- /dev/null +++ b/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +from .cell import mLSTMCell + + +class mLSTMLayer(nn.Module): + def __init__(self, input_size, hidden_size, num_layers, dropout=0.2, layer_norm=True, residual_conn=True, + device=None): + super(mLSTMLayer, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.layer_norm = layer_norm + self.residual_conn = residual_conn + self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.dropout = nn.Dropout(dropout).to(self.device) + + self.cells = nn.ModuleList([ + mLSTMCell(input_size if i == 0 else hidden_size, hidden_size, dropout, layer_norm, self.device) + for i in range(num_layers) + ]) + + def init_hidden(self, batch_size): + """ + Initialize hidden, cell, and normalization states for all layers. + """ + hidden_states, cell_states, norm_states = zip(*[ + self.cells[i].init_hidden(batch_size) for i in range(self.num_layers) + ]) + + + return ( + torch.stack(hidden_states).to(self.device), + torch.stack(cell_states).to(self.device), + torch.stack(norm_states).to(self.device) + ) + + def forward(self, x, h=None, c=None, n=None): + """ + Forward pass for the mLSTM layer. + """ + + x = x.to(self.device).transpose(0, 1) + batch_size, seq_len, _ = x.size() + + + if h is None or c is None or n is None: + h, c, n = self.init_hidden(batch_size) + + + outputs = [] + + for t in range(seq_len): + layer_input = x[:, t, :] + next_hidden_states = [] + next_cell_states = [] + next_norm_states = [] + + for i, cell in enumerate(self.cells): + + h_i, c_i, n_i = cell(layer_input, h[i], c[i], n[i]) + + if self.residual_conn and i > 0: + h_i = h_i + layer_input + + layer_input = h_i + + next_hidden_states.append(h_i) + next_cell_states.append(c_i) + next_norm_states.append(n_i) + + h = torch.stack(next_hidden_states).to(self.device) + c = torch.stack(next_cell_states).to(self.device) + n = torch.stack(next_norm_states).to(self.device) + + outputs.append(h[-1]) + + + output = torch.stack(outputs, dim=1) + + + output = output.transpose(0, 1) + + return output, (h, c, n) diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py new file mode 100644 index 000000000..fc732b210 --- /dev/null +++ b/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py @@ -0,0 +1,27 @@ +import torch.nn as nn +import torch +from .layer import mLSTMLayer + +class mLSTMNetwork(nn.Module): + def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0, use_layer_norm=True, + use_residual=True, device=None): + super(mLSTMNetwork, self).__init__() + self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.mlstm_layer = mLSTMLayer(input_size, hidden_size, num_layers, dropout, use_layer_norm, use_residual, + self.device) + self.fc = nn.Linear(hidden_size, output_size) + + def forward(self, x, h=None, c=None, n=None): + """ + Forward pass through the mLSTM network. + """ + output, (h, c, n) = self.mlstm_layer(x, h, c, n) + + output = self.fc(output[-1]) + + return output, (h, c, n) + + def init_hidden(self, batch_size): + """Initialize hidden, cell, and normalization states.""" + return self.mlstm_layer.init_hidden(batch_size) diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/__init__.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py new file mode 100644 index 000000000..44360866a --- /dev/null +++ b/pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn +import math + +class sLSTMCell(nn.Module): + """Stabilized LSTM Cell""" + def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, device=None): + super(sLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.use_layer_norm = use_layer_norm + self.eps = 1e-6 + + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.input_weights = nn.Linear(input_size, 4 * hidden_size).to(self.device) + self.hidden_weights = nn.Linear(hidden_size, 4 * hidden_size).to(self.device) + + if use_layer_norm: + self.ln_cell = nn.LayerNorm(hidden_size).to(self.device) + self.ln_hidden = nn.LayerNorm(hidden_size).to(self.device) + self.ln_input = nn.LayerNorm(4 * hidden_size).to(self.device) + self.ln_hidden_update = nn.LayerNorm(4 * hidden_size).to(self.device) + + self.dropout_layer = nn.Dropout(dropout).to(self.device) + + + self.reset_parameters() + + + self.grad_clip = 5.0 + + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + + self.to(self.device) + + def reset_parameters(self): + """Initialize parameters using Xavier/Glorot initialization""" + std = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + weight.data.uniform_(-std, std) + + def normalized_exp_gate(self, pre_gate): + """Compute normalized exponential gate activation""" + centered = pre_gate - torch.mean(pre_gate, dim=1, keepdim=True) + exp_val = torch.exp(torch.clamp(centered, min=-5.0, max=5.0)) + normalizer = torch.sum(exp_val, dim=1, keepdim=True) + self.eps + return exp_val / normalizer + + def forward(self, x, h_prev, c_prev): + """Forward pass with stabilized exponential gating""" + x = x.to(self.device) + h_prev = h_prev.to(self.device) + c_prev = c_prev.to(self.device) + + x = self.dropout_layer(x) + h_prev = self.dropout_layer(h_prev) + + gates_x = self.input_weights(x) + gates_h = self.hidden_weights(h_prev) + + if self.use_layer_norm: + gates_x = self.ln_input(gates_x) + gates_h = self.ln_hidden_update(gates_h) + + gates = gates_x + gates_h + i, f, g, o = gates.chunk(4, dim=1) + + i = self.normalized_exp_gate(i) + f = self.normalized_exp_gate(f) + gate_sum = i + f + i = i / (gate_sum + self.eps) + f = f / (gate_sum + self.eps) + + c_tilde = self.tanh(g) + c = f * c_prev + i * c_tilde + if self.use_layer_norm: + c = self.ln_cell(c) + + o = self.sigmoid(o) + c_out = self.tanh(c) + if self.use_layer_norm: + c_out = self.ln_hidden(c_out) + h = o * c_out + + return h, c + + def init_hidden(self, batch_size): + return (torch.zeros(batch_size, self.hidden_size, device=self.device), + torch.zeros(batch_size, self.hidden_size, device=self.device)) \ No newline at end of file diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py new file mode 100644 index 000000000..612df40da --- /dev/null +++ b/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +from .cell import sLSTMCell + +class sLSTMLayer(nn.Module): + """ + Enhanced sLSTM Layer that supports multiple sLSTM cells across timesteps and residual connections. + """ + + def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, use_layer_norm=True, use_residual=True, device=None): + super(sLSTMLayer, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.dropout = dropout + self.use_layer_norm = use_layer_norm + self.use_residual = use_residual + self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.cells = nn.ModuleList([sLSTMCell( + input_size if layer == 0 else hidden_size, + hidden_size, + dropout=dropout, + use_layer_norm=use_layer_norm, + device=self.device + ) for layer in range(num_layers)]) + + if self.use_residual: + self.res_proj = nn.ModuleList([nn.Linear(hidden_size, hidden_size, bias=False).to(self.device) for _ in range(num_layers)]) + + if self.use_layer_norm: + self.layer_norm_layers = nn.ModuleList([nn.LayerNorm(hidden_size).to(self.device) for _ in range(num_layers)]) + + def forward(self, x, h=None, c=None): + """ + Forward pass through the sLSTM layer for each time step in sequence. + Args: + x: input tensor (seq_len, batch_size, input_size) + h: initial hidden states (num_layers, batch_size, hidden_size) + c: initial cell states (num_layers, batch_size, hidden_size) + Returns: + output: tensor of hidden states (seq_len, batch_size, hidden_size) + (h, c): final hidden and cell states + """ + seq_len, batch_size, _ = x.size() + + if h is None or c is None: + h, c = self.init_hidden(batch_size) + + x = x.to(self.device) + h = [hi.to(self.device) for hi in h] + c = [ci.to(self.device) for ci in c] + + outputs = [] + + for t in range(seq_len): + input_t = x[t] + for layer in range(self.num_layers): + h[layer], c[layer] = self.cells[layer](input_t, h[layer], c[layer]) + + if self.use_residual: + h[layer] = h[layer] + self.res_proj[layer](input_t) + + if self.use_layer_norm: + h[layer] = self.layer_norm_layers[layer](h[layer]) + + input_t = h[layer] + outputs.append(h[-1]) + + output = torch.stack(outputs) + + h = [hi.detach() for hi in h] + c = [ci.detach() for ci in c] + + return output, (h, c) + + def init_hidden(self, batch_size): + """Initialize hidden and cell states for each layer.""" + return ([torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)], + [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)]) + + diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py new file mode 100644 index 000000000..d8fb59b30 --- /dev/null +++ b/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py @@ -0,0 +1,38 @@ +import torch.nn as nn +import torch +from .layer import sLSTMLayer + +class sLSTMNetwork(nn.Module): + """ + Stabilized LSTM Network with multiple sLSTM layers. + """ + def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0, use_layer_norm=True, device=None): + super(sLSTMNetwork, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.output_size = output_size + self.dropout = dropout + self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.slstm_layer = sLSTMLayer(input_size, hidden_size, num_layers, dropout, use_layer_norm, self.device) + self.fc = nn.Linear(hidden_size, output_size).to(self.device) + + def forward(self, x, h=None, c=None): + """ + Forward pass through the sLSTM network. + Args: + x: input tensor (seq_len, batch_size, input_size) + h: initial hidden states (num_layers, batch_size, hidden_size) + c: initial cell states (num_layers, batch_size, hidden_size) + Returns: + output: tensor of output predictions (seq_len, batch_size, output_size) + (h, c): final hidden and cell states + """ + output, (h, c) = self.slstm_layer(x, h, c) + output = self.fc(output[-1]) + return output, (h, c) + + def init_hidden(self, batch_size): + """Initialize hidden and cell states for the entire network.""" + return self.slstm_layer.init_hidden(batch_size) \ No newline at end of file diff --git a/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py b/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py new file mode 100644 index 000000000..86ef45fb8 --- /dev/null +++ b/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +from typing import Optional, Tuple, Union, Literal +from mLSTM.network import mLSTMNetwork +from sLSTM.network import sLSTMNetwork + + +class SeriesDecomposition(nn.Module): + """Implements series decomposition using learnable moving averages.""" + + def __init__(self, kernel_size: int): + super(SeriesDecomposition, self).__init__() + self.kernel_size = kernel_size + self.padding = kernel_size // 2 + self.avg_pool = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=self.padding) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decomposes input series into trend and seasonal components. + + Args: + x: Input tensor of shape (batch_size, seq_len, n_features) + + Returns: + Tuple of (trend_component, seasonal_component) + """ + + batch_size, seq_len, n_features = x.shape + x_reshaped = x.reshape(batch_size * n_features, 1, seq_len) + + + trend = self.avg_pool(x_reshaped) + + + trend = trend.reshape(batch_size, seq_len, n_features) + seasonal = x - trend + + return trend, seasonal + + +class xLSTMTime(nn.Module): + """ + Implementation of xLSTMTime architecture for time series forecasting. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + xlstm_type: Literal['slstm', 'mlstm'], + num_layers: int = 1, + decomposition_kernel: int = 25, + dropout: float = 0.1, + device: Optional[torch.device] = None + ): + """ + Initialize xLSTMTime model. + + Args: + input_size: Number of input features + hidden_size: Size of hidden layers + output_size: Number of output features + xlstm_type: Type of LSTM to use ('slstm' or 'mlstm') + num_layers: Number of LSTM layers + decomposition_kernel: Kernel size for series decomposition + dropout: Dropout rate + device: Torch device to use + """ + super(xLSTMTime, self).__init__() + + if xlstm_type not in ['slstm', 'mlstm']: + raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") + + self.xlstm_type = xlstm_type + self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + + self.decomposition = SeriesDecomposition(decomposition_kernel) + self.input_linear = nn.Linear(input_size * 2, hidden_size) + + self.batch_norm = nn.BatchNorm1d(hidden_size) + + if xlstm_type == 'mlstm': + self.lstm = mLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + device=self.device + ) + else: # slstm + self.lstm = sLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + device=self.device + ) + self.output_linear = nn.Linear(hidden_size, output_size) + + self.instance_norm = nn.InstanceNorm1d(output_size) + + def forward( + self, + x: torch.Tensor, + hidden_states: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]] = None + ) -> Tuple[torch.Tensor, Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + """ + Forward pass through the network. + + Args: + x: Input tensor of shape (batch_size, seq_len, input_size) + hidden_states: Initial hidden states for LSTM + + Returns: + Tuple of (output, hidden_states) + """ + batch_size, seq_len, _ = x.shape + + trend, seasonal = self.decomposition(x) + + x = torch.cat([trend, seasonal], dim=-1) + + x = self.input_linear(x) + + # Reshape for batch norm + x = x.transpose(1, 2) + x = self.batch_norm(x) + x = x.transpose(1, 2) + + if hidden_states is None: + hidden_states = self.lstm.init_hidden(batch_size) + + x = x.transpose(0, 1) + output, hidden_states = self.lstm(x, *hidden_states) + + if isinstance(output, tuple): + output = output[0] + + + if output.dim() == 2: + output = output.unsqueeze(0) + + + output = self.output_linear(output) + + output = output.transpose(1, 2) + output = self.instance_norm(output) + output = output.transpose(1, 2) + + return output, hidden_states + + def predict( + self, + x: torch.Tensor, + hidden_states: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]] = None + ) -> torch.Tensor: + """ + Make predictions using the model. + + Args: + x: Input tensor + hidden_states: Optional initial hidden states + + Returns: + Predictions tensor + """ + output, _ = self.forward(x, hidden_states) + return output \ No newline at end of file From 5e57d347f3584fb82d2b276b0b6b0a8fd8d40f6a Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 9 Nov 2024 17:11:20 +0530 Subject: [PATCH 02/15] linting --- .../models/xLSTMTime/mLSTM/cell.py | 12 ++--- .../models/xLSTMTime/mLSTM/layer.py | 30 +++++------ .../models/xLSTMTime/mLSTM/network.py | 21 ++++++-- .../models/xLSTMTime/sLSTM/cell.py | 12 +++-- .../models/xLSTMTime/sLSTM/layer.py | 44 +++++++++------ .../models/xLSTMTime/sLSTM/network.py | 6 ++- .../models/xLSTMTime/xLSTMTime.py | 53 +++++++++---------- 7 files changed, 100 insertions(+), 78 deletions(-) diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py index bb7e55b4f..f0178f4af 100644 --- a/pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py +++ b/pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py @@ -10,7 +10,7 @@ def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True, device self.hidden_size = hidden_size self.layer_norm = layer_norm - self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") self.Wq = nn.Linear(input_size, hidden_size) self.Wk = nn.Linear(input_size, hidden_size) @@ -55,14 +55,12 @@ def forward(self, x, h_prev, c_prev, n_prev): c_prev = c_prev.to(self.device) n_prev = n_prev.to(self.device) - batch_size = x.size(0) assert x.dim() == 2, f"Input should be 2D (batch_size, input_size), got {x.dim()}D" assert h_prev.size() == (batch_size, self.hidden_size), f"h_prev shape mismatch: {h_prev.size()}" assert c_prev.size() == (batch_size, self.hidden_size), f"c_prev shape mismatch: {c_prev.size()}" assert n_prev.size() == (batch_size, self.hidden_size), f"n_prev shape mismatch: {n_prev.size()}" - x = self.dropout(x) h_prev = self.dropout(h_prev) @@ -100,6 +98,8 @@ def init_hidden(self, batch_size): Initialize hidden, cell, and normalization states. """ shape = (batch_size, self.hidden_size) - return (torch.zeros(shape, device=self.device), - torch.zeros(shape, device=self.device), - torch.zeros(shape, device=self.device)) \ No newline at end of file + return ( + torch.zeros(shape, device=self.device), + torch.zeros(shape, device=self.device), + torch.zeros(shape, device=self.device), + ) diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py index cef6dcbcc..7b04ece76 100644 --- a/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py +++ b/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py @@ -4,36 +4,38 @@ class mLSTMLayer(nn.Module): - def __init__(self, input_size, hidden_size, num_layers, dropout=0.2, layer_norm=True, residual_conn=True, - device=None): + def __init__( + self, input_size, hidden_size, num_layers, dropout=0.2, layer_norm=True, residual_conn=True, device=None + ): super(mLSTMLayer, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.layer_norm = layer_norm self.residual_conn = residual_conn - self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dropout = nn.Dropout(dropout).to(self.device) - self.cells = nn.ModuleList([ - mLSTMCell(input_size if i == 0 else hidden_size, hidden_size, dropout, layer_norm, self.device) - for i in range(num_layers) - ]) + self.cells = nn.ModuleList( + [ + mLSTMCell(input_size if i == 0 else hidden_size, hidden_size, dropout, layer_norm, self.device) + for i in range(num_layers) + ] + ) def init_hidden(self, batch_size): """ Initialize hidden, cell, and normalization states for all layers. """ - hidden_states, cell_states, norm_states = zip(*[ - self.cells[i].init_hidden(batch_size) for i in range(self.num_layers) - ]) - + hidden_states, cell_states, norm_states = zip( + *[self.cells[i].init_hidden(batch_size) for i in range(self.num_layers)] + ) return ( torch.stack(hidden_states).to(self.device), torch.stack(cell_states).to(self.device), - torch.stack(norm_states).to(self.device) + torch.stack(norm_states).to(self.device), ) def forward(self, x, h=None, c=None, n=None): @@ -44,11 +46,9 @@ def forward(self, x, h=None, c=None, n=None): x = x.to(self.device).transpose(0, 1) batch_size, seq_len, _ = x.size() - if h is None or c is None or n is None: h, c, n = self.init_hidden(batch_size) - outputs = [] for t in range(seq_len): @@ -76,10 +76,8 @@ def forward(self, x, h=None, c=None, n=None): outputs.append(h[-1]) - output = torch.stack(outputs, dim=1) - output = output.transpose(0, 1) return output, (h, c, n) diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py index fc732b210..65a9414d3 100644 --- a/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py +++ b/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py @@ -2,14 +2,25 @@ import torch from .layer import mLSTMLayer + class mLSTMNetwork(nn.Module): - def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0, use_layer_norm=True, - use_residual=True, device=None): + def __init__( + self, + input_size, + hidden_size, + num_layers, + output_size, + dropout=0.0, + use_layer_norm=True, + use_residual=True, + device=None, + ): super(mLSTMNetwork, self).__init__() - self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.mlstm_layer = mLSTMLayer(input_size, hidden_size, num_layers, dropout, use_layer_norm, use_residual, - self.device) + self.mlstm_layer = mLSTMLayer( + input_size, hidden_size, num_layers, dropout, use_layer_norm, use_residual, self.device + ) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x, h=None, c=None, n=None): diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py index 44360866a..6ff0fdf60 100644 --- a/pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py +++ b/pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py @@ -2,8 +2,10 @@ import torch.nn as nn import math + class sLSTMCell(nn.Module): """Stabilized LSTM Cell""" + def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, device=None): super(sLSTMCell, self).__init__() self.input_size = input_size @@ -12,7 +14,7 @@ def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, de self.use_layer_norm = use_layer_norm self.eps = 1e-6 - self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") self.input_weights = nn.Linear(input_size, 4 * hidden_size).to(self.device) self.hidden_weights = nn.Linear(hidden_size, 4 * hidden_size).to(self.device) @@ -25,10 +27,8 @@ def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, de self.dropout_layer = nn.Dropout(dropout).to(self.device) - self.reset_parameters() - self.grad_clip = 5.0 self.tanh = nn.Tanh() @@ -88,5 +88,7 @@ def forward(self, x, h_prev, c_prev): return h, c def init_hidden(self, batch_size): - return (torch.zeros(batch_size, self.hidden_size, device=self.device), - torch.zeros(batch_size, self.hidden_size, device=self.device)) \ No newline at end of file + return ( + torch.zeros(batch_size, self.hidden_size, device=self.device), + torch.zeros(batch_size, self.hidden_size, device=self.device), + ) diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py index 612df40da..2d48d31e3 100644 --- a/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py +++ b/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py @@ -2,12 +2,15 @@ import torch.nn as nn from .cell import sLSTMCell + class sLSTMLayer(nn.Module): """ Enhanced sLSTM Layer that supports multiple sLSTM cells across timesteps and residual connections. """ - def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, use_layer_norm=True, use_residual=True, device=None): + def __init__( + self, input_size, hidden_size, num_layers=1, dropout=0.0, use_layer_norm=True, use_residual=True, device=None + ): super(sLSTMLayer, self).__init__() self.input_size = input_size self.hidden_size = hidden_size @@ -15,21 +18,30 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, use_layer self.dropout = dropout self.use_layer_norm = use_layer_norm self.use_residual = use_residual - self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - self.cells = nn.ModuleList([sLSTMCell( - input_size if layer == 0 else hidden_size, - hidden_size, - dropout=dropout, - use_layer_norm=use_layer_norm, - device=self.device - ) for layer in range(num_layers)]) + self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.cells = nn.ModuleList( + [ + sLSTMCell( + input_size if layer == 0 else hidden_size, + hidden_size, + dropout=dropout, + use_layer_norm=use_layer_norm, + device=self.device, + ) + for layer in range(num_layers) + ] + ) if self.use_residual: - self.res_proj = nn.ModuleList([nn.Linear(hidden_size, hidden_size, bias=False).to(self.device) for _ in range(num_layers)]) + self.res_proj = nn.ModuleList( + [nn.Linear(hidden_size, hidden_size, bias=False).to(self.device) for _ in range(num_layers)] + ) if self.use_layer_norm: - self.layer_norm_layers = nn.ModuleList([nn.LayerNorm(hidden_size).to(self.device) for _ in range(num_layers)]) + self.layer_norm_layers = nn.ModuleList( + [nn.LayerNorm(hidden_size).to(self.device) for _ in range(num_layers)] + ) def forward(self, x, h=None, c=None): """ @@ -76,7 +88,7 @@ def forward(self, x, h=None, c=None): def init_hidden(self, batch_size): """Initialize hidden and cell states for each layer.""" - return ([torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)], - [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)]) - - + return ( + [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)], + [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)], + ) diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py index d8fb59b30..b2b3a3904 100644 --- a/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py +++ b/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py @@ -2,10 +2,12 @@ import torch from .layer import sLSTMLayer + class sLSTMNetwork(nn.Module): """ Stabilized LSTM Network with multiple sLSTM layers. """ + def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0, use_layer_norm=True, device=None): super(sLSTMNetwork, self).__init__() self.input_size = input_size @@ -13,7 +15,7 @@ def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0 self.num_layers = num_layers self.output_size = output_size self.dropout = dropout - self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") self.slstm_layer = sLSTMLayer(input_size, hidden_size, num_layers, dropout, use_layer_norm, self.device) self.fc = nn.Linear(hidden_size, output_size).to(self.device) @@ -35,4 +37,4 @@ def forward(self, x, h=None, c=None): def init_hidden(self, batch_size): """Initialize hidden and cell states for the entire network.""" - return self.slstm_layer.init_hidden(batch_size) \ No newline at end of file + return self.slstm_layer.init_hidden(batch_size) diff --git a/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py b/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py index 86ef45fb8..984272a75 100644 --- a/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py +++ b/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py @@ -28,10 +28,8 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, n_features = x.shape x_reshaped = x.reshape(batch_size * n_features, 1, seq_len) - trend = self.avg_pool(x_reshaped) - trend = trend.reshape(batch_size, seq_len, n_features) seasonal = x - trend @@ -44,15 +42,15 @@ class xLSTMTime(nn.Module): """ def __init__( - self, - input_size: int, - hidden_size: int, - output_size: int, - xlstm_type: Literal['slstm', 'mlstm'], - num_layers: int = 1, - decomposition_kernel: int = 25, - dropout: float = 0.1, - device: Optional[torch.device] = None + self, + input_size: int, + hidden_size: int, + output_size: int, + xlstm_type: Literal["slstm", "mlstm"], + num_layers: int = 1, + decomposition_kernel: int = 25, + dropout: float = 0.1, + device: Optional[torch.device] = None, ): """ Initialize xLSTMTime model. @@ -69,26 +67,25 @@ def __init__( """ super(xLSTMTime, self).__init__() - if xlstm_type not in ['slstm', 'mlstm']: + if xlstm_type not in ["slstm", "mlstm"]: raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") self.xlstm_type = xlstm_type - self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.decomposition = SeriesDecomposition(decomposition_kernel) self.input_linear = nn.Linear(input_size * 2, hidden_size) self.batch_norm = nn.BatchNorm1d(hidden_size) - if xlstm_type == 'mlstm': + if xlstm_type == "mlstm": self.lstm = mLSTMNetwork( input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, output_size=hidden_size, dropout=dropout, - device=self.device + device=self.device, ) else: # slstm self.lstm = sLSTMNetwork( @@ -97,17 +94,18 @@ def __init__( num_layers=num_layers, output_size=hidden_size, dropout=dropout, - device=self.device + device=self.device, ) self.output_linear = nn.Linear(hidden_size, output_size) self.instance_norm = nn.InstanceNorm1d(output_size) def forward( - self, - x: torch.Tensor, - hidden_states: Optional[ - Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]] = None + self, + x: torch.Tensor, + hidden_states: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = None, ) -> Tuple[torch.Tensor, Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """ Forward pass through the network. @@ -141,11 +139,9 @@ def forward( if isinstance(output, tuple): output = output[0] - if output.dim() == 2: output = output.unsqueeze(0) - output = self.output_linear(output) output = output.transpose(1, 2) @@ -155,10 +151,11 @@ def forward( return output, hidden_states def predict( - self, - x: torch.Tensor, - hidden_states: Optional[ - Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]] = None + self, + x: torch.Tensor, + hidden_states: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = None, ) -> torch.Tensor: """ Make predictions using the model. @@ -171,4 +168,4 @@ def predict( Predictions tensor """ output, _ = self.forward(x, hidden_states) - return output \ No newline at end of file + return output From e498848d5a2c6d2df94120af1a6499eda8f9200f Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 10 Nov 2024 01:14:27 +0530 Subject: [PATCH 03/15] adding some tests and a little in debug in `sLSTM` structure --- .../models/xLSTMTime/mLSTM/layer.py | 2 +- .../models/xLSTMTime/mLSTM/network.py | 2 +- .../models/xLSTMTime/sLSTM/layer.py | 24 +- .../models/xLSTMTime/sLSTM/network.py | 4 +- .../models/xLSTMTime/xLSTMTime.py | 4 +- tests/test_models/test_xlstmtime.py | 278 ++++++++++++++++++ 6 files changed, 299 insertions(+), 15 deletions(-) create mode 100644 tests/test_models/test_xlstmtime.py diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py index 7b04ece76..7ae1e0027 100644 --- a/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py +++ b/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from .cell import mLSTMCell +from pytorch_forecasting.models.xLSTMTime.mLSTM.cell import mLSTMCell class mLSTMLayer(nn.Module): diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py b/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py index 65a9414d3..86024fe5b 100644 --- a/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py +++ b/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py @@ -1,6 +1,6 @@ import torch.nn as nn import torch -from .layer import mLSTMLayer +from pytorch_forecasting.models.xLSTMTime.mLSTM.layer import mLSTMLayer class mLSTMNetwork(nn.Module): diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py index 2d48d31e3..cddff40b0 100644 --- a/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py +++ b/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from .cell import sLSTMCell +from pytorch_forecasting.models.xLSTMTime.sLSTM.cell import sLSTMCell class sLSTMLayer(nn.Module): @@ -20,6 +20,10 @@ def __init__( self.use_residual = use_residual self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.input_projection = None + if self.use_residual and input_size != hidden_size: + self.input_projection = nn.Linear(input_size, hidden_size, bias=False).to(self.device) + self.cells = nn.ModuleList( [ sLSTMCell( @@ -33,11 +37,6 @@ def __init__( ] ) - if self.use_residual: - self.res_proj = nn.ModuleList( - [nn.Linear(hidden_size, hidden_size, bias=False).to(self.device) for _ in range(num_layers)] - ) - if self.use_layer_norm: self.layer_norm_layers = nn.ModuleList( [nn.LayerNorm(hidden_size).to(self.device) for _ in range(num_layers)] @@ -67,16 +66,23 @@ def forward(self, x, h=None, c=None): for t in range(seq_len): input_t = x[t] + layer_input = input_t + for layer in range(self.num_layers): - h[layer], c[layer] = self.cells[layer](input_t, h[layer], c[layer]) + h[layer], c[layer] = self.cells[layer](layer_input, h[layer], c[layer]) if self.use_residual: - h[layer] = h[layer] + self.res_proj[layer](input_t) + if layer == 0 and self.input_projection is not None: + residual = self.input_projection(layer_input) + else: + residual = layer_input if layer_input.size(-1) == self.hidden_size else 0 + h[layer] = h[layer] + residual if self.use_layer_norm: h[layer] = self.layer_norm_layers[layer](h[layer]) - input_t = h[layer] + layer_input = h[layer] + outputs.append(h[-1]) output = torch.stack(outputs) diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py index b2b3a3904..2075c314c 100644 --- a/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py +++ b/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py @@ -1,6 +1,6 @@ import torch.nn as nn import torch -from .layer import sLSTMLayer +from pytorch_forecasting.models.xLSTMTime.sLSTM.layer import sLSTMLayer class sLSTMNetwork(nn.Module): @@ -17,7 +17,7 @@ def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0 self.dropout = dropout self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.slstm_layer = sLSTMLayer(input_size, hidden_size, num_layers, dropout, use_layer_norm, self.device) + self.slstm_layer = sLSTMLayer(input_size, hidden_size, num_layers, dropout, use_layer_norm, device=self.device) self.fc = nn.Linear(hidden_size, output_size).to(self.device) def forward(self, x, h=None, c=None): diff --git a/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py b/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py index 984272a75..f5f13fe1b 100644 --- a/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py +++ b/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn from typing import Optional, Tuple, Union, Literal -from mLSTM.network import mLSTMNetwork -from sLSTM.network import sLSTMNetwork +from pytorch_forecasting.models.xLSTMTime.mLSTM.network import mLSTMNetwork +from pytorch_forecasting.models.xLSTMTime.sLSTM.network import sLSTMNetwork class SeriesDecomposition(nn.Module): diff --git a/tests/test_models/test_xlstmtime.py b/tests/test_models/test_xlstmtime.py new file mode 100644 index 000000000..7c6e0e01b --- /dev/null +++ b/tests/test_models/test_xlstmtime.py @@ -0,0 +1,278 @@ +import pytest +import torch +import torch.nn as nn +from typing import Tuple, List + +from sympy.stats.sampling.sample_numpy import numpy + +from pytorch_forecasting.models.xLSTMTime.mLSTM.cell import mLSTMCell +from pytorch_forecasting.models.xLSTMTime.mLSTM.layer import mLSTMLayer +from pytorch_forecasting.models.xLSTMTime.mLSTM.network import mLSTMNetwork +from pytorch_forecasting.models.xLSTMTime.sLSTM.cell import sLSTMCell +from pytorch_forecasting.models.xLSTMTime.sLSTM.layer import sLSTMLayer +from pytorch_forecasting.models.xLSTMTime.sLSTM.network import sLSTMNetwork +from pytorch_forecasting.models.xLSTMTime.xLSTMTime import xLSTMTime, SeriesDecomposition + + +# Fixtures for common test parameters +@pytest.fixture +def device(): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.fixture +def batch_size(): + return 32 + + +@pytest.fixture +def seq_length(): + return 24 + + +@pytest.fixture +def input_size(): + return 10 + + +@pytest.fixture +def hidden_size(): + return 64 + + +@pytest.fixture +def output_size(): + return 5 + + +@pytest.fixture +def sample_input(batch_size, seq_length, input_size, device): + return torch.randn(batch_size, seq_length, input_size).to(device) + + +# Test Series Decomposition +class TestSeriesDecomposition: + def test_initialization(self): + kernel_size = 25 + decomp = SeriesDecomposition(kernel_size) + assert decomp.kernel_size == kernel_size + assert decomp.padding == kernel_size // 2 + + def test_forward_shape(self, sample_input): + kernel_size = 25 + decomp = SeriesDecomposition(kernel_size) + trend, seasonal = decomp(sample_input) + + assert trend.shape == sample_input.shape + assert seasonal.shape == sample_input.shape + + def test_decomposition_sum(self, sample_input): + kernel_size = 25 + decomp = SeriesDecomposition(kernel_size) + trend, seasonal = decomp(sample_input) + + # Check if trend + seasonal approximately equals input + torch.testing.assert_close(trend + seasonal, sample_input, rtol=1e-4, atol=1e-4) + + +# Test mLSTM Components +class TestMLSTM: + def test_cell_initialization(self, input_size, hidden_size, device): + cell = mLSTMCell(input_size, hidden_size, device=device) + assert cell.input_size == input_size + assert cell.hidden_size == hidden_size + + def test_cell_forward(self, batch_size, input_size, hidden_size, device): + cell = mLSTMCell(input_size, hidden_size, device=device) + x = torch.randn(1, batch_size, input_size).to(device) # Add sequence dimension + h_prev = torch.randn(batch_size, hidden_size).to(device) + c_prev = torch.randn(batch_size, hidden_size).to(device) + n_prev = torch.randn(batch_size, hidden_size).to(device) + + h, c, n = cell(x[0], h_prev, c_prev, n_prev) # Use first timestep + assert h.shape == (batch_size, hidden_size) + assert c.shape == (batch_size, hidden_size) + assert n.shape == (batch_size, hidden_size) + + def test_layer_initialization(self, input_size, hidden_size, device): + layer = mLSTMLayer(input_size, hidden_size, num_layers=2, device=device) + assert layer.input_size == input_size + assert layer.hidden_size == hidden_size + assert len(layer.cells) == 2 + + def test_network_forward(self, sample_input, input_size, hidden_size, output_size, device): + network = mLSTMNetwork(input_size, hidden_size, num_layers=2, output_size=output_size, device=device) + # Transpose input to seq_len, batch_size, input_size + x = sample_input.transpose(0, 1) + output, (h, c, n) = network(x) + # Transpose output back to batch_size, seq_len, output_size + output = output.transpose(0, 1) + assert output.shape == (output_size, sample_input.shape[0]) + + +# Test sLSTM Components +class TestSLSTM: + def test_cell_initialization(self, input_size, hidden_size, device): + cell = sLSTMCell(input_size, hidden_size, device=device) + assert cell.input_size == input_size + assert cell.hidden_size == hidden_size + + def test_cell_forward(self, batch_size, input_size, hidden_size, device): + cell = sLSTMCell(input_size, hidden_size, device=device) + x = torch.randn(1, batch_size, input_size).to(device) # Add sequence dimension + h_prev = torch.randn(batch_size, hidden_size).to(device) + c_prev = torch.randn(batch_size, hidden_size).to(device) + + h, c = cell(x[0], h_prev, c_prev) # Use first timestep + assert h.shape == (batch_size, hidden_size) + assert c.shape == (batch_size, hidden_size) + + def test_layer_initialization(self, input_size, hidden_size, device): + layer = sLSTMLayer(input_size, hidden_size, num_layers=2, device=device) + assert layer.input_size == input_size + assert layer.hidden_size == hidden_size + assert len(layer.cells) == 2 + + def test_network_forward(self, sample_input, input_size, hidden_size, output_size, device): + network = sLSTMNetwork(input_size, hidden_size, num_layers=2, output_size=output_size, device=device) + # Transpose input to seq_len, batch_size, input_size + x = sample_input.transpose(0, 1) + output, (h, c) = network(x) + # Transpose output back to batch_size, seq_len, output_size + output = output.transpose(0, 1) + assert output.shape == (output_size, sample_input.shape[0]) + + +# Test xLSTMTime +class TestXLSTMTime: + @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) + def test_initialization(self, input_size, hidden_size, output_size, xlstm_type, device): + model = xLSTMTime( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + xlstm_type=xlstm_type, + device=device, + ) + assert isinstance(model.decomposition, SeriesDecomposition) + assert isinstance(model.input_linear, nn.Linear) + assert isinstance(model.output_linear, nn.Linear) + + @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) + def test_forward(self, sample_input, input_size, hidden_size, output_size, xlstm_type, device): + model = xLSTMTime( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + xlstm_type=xlstm_type, + device=device, + ) + output, hidden_states = model(sample_input) + # Check output shape is batch_size, seq_len, output_size + assert output.shape == (1, sample_input.shape[0], output_size) + + if xlstm_type == "mlstm": + assert len(hidden_states) == 3 # h, c, n for mLSTM + h, c, n = hidden_states + assert h.shape == (1, sample_input.shape[0], hidden_size) + assert c.shape == (1, sample_input.shape[0], hidden_size) + assert n.shape == (1, sample_input.shape[0], hidden_size) + else: + assert len(hidden_states) == 2 # h, c for sLSTM + h, c = hidden_states + assert torch.stack(h).shape == (1, sample_input.shape[0], hidden_size) + assert torch.stack(c).shape == (1, sample_input.shape[0], hidden_size) + + @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) + def test_predict(self, sample_input, input_size, hidden_size, output_size, xlstm_type, device): + model = xLSTMTime( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + xlstm_type=xlstm_type, + device=device, + ) + predictions = model.predict(sample_input) + assert predictions.shape == (1, sample_input.shape[0], output_size) + + def test_invalid_xlstm_type(self, input_size, hidden_size, output_size, device): + with pytest.raises(ValueError, match="xlstm_type must be either 'slstm' or 'mlstm'"): + xLSTMTime( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + xlstm_type="invalid_type", + device=device, + ) + + +# Test edge cases and error handling +class TestEdgeCases: + @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) + def test_single_sequence_length(self, batch_size, input_size, hidden_size, output_size, xlstm_type, device): + model = xLSTMTime( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + xlstm_type=xlstm_type, + device=device, + ) + single_step = torch.randn(batch_size, 1, input_size).to(device) + output, hidden_states = model(single_step) + assert output.shape == (1, batch_size, output_size) + + if xlstm_type == "mlstm": + h, c, n = hidden_states + assert h.shape == (1, batch_size, hidden_size) + assert c.shape == (1, batch_size, hidden_size) + assert n.shape == (1, batch_size, hidden_size) + else: # slstm + h, c = hidden_states + assert torch.stack(h).shape == (1, batch_size, hidden_size) + assert torch.stack(c).shape == (1, batch_size, hidden_size) + + @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) + def test_input_nan_handling(self, batch_size, input_size, hidden_size, output_size, xlstm_type, device): + """Test model behavior with NaN inputs""" + model = xLSTMTime( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + xlstm_type=xlstm_type, + device=device, + ) + + # Create input with some NaN values + nan_input = torch.randn(batch_size, 24, input_size).to(device) + nan_input[0, 0, 0] = float("nan") # Insert a NaN value + + try: + output, _ = model(nan_input) + # If we reach here, check if output contains NaN + assert torch.isnan(output).any(), "Expected NaN in output with NaN input" + except Exception as e: + # Model should either propagate NaN or raise an exception + assert isinstance(e, (RuntimeError, ValueError)), "Expected RuntimeError or ValueError with NaN input" + + @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) + def test_numerical_stability(self, batch_size, input_size, hidden_size, output_size, xlstm_type, device): + """Test model behavior with extreme input values""" + model = xLSTMTime( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + xlstm_type=xlstm_type, + device=device, + ) + + # Test with very large values + large_input = torch.full((batch_size, 24, input_size), 1e10).to(device) + output_large, _ = model(large_input) + assert not torch.isnan(output_large).any(), "NaN in output with large input values" + assert not torch.isinf(output_large).any(), "Inf in output with large input values" + + # Test with very small values + small_input = torch.full((batch_size, 24, input_size), 1e-10).to(device) + output_small, _ = model(small_input) + assert not torch.isnan(output_small).any(), "NaN in output with small input values" + assert not torch.isinf(output_small).any(), "Inf in output with small input values" From 38e4c9c935e7f0f8dd40e5fed1253b051876147c Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 12 Dec 2024 21:56:13 +0530 Subject: [PATCH 04/15] new baseclass implementation --- .../models/xLSTMTime/sLSTM/__init__.py | 0 .../models/xLSTMTime/xLSTMTime.py | 171 ------------------ .../models/x_lstm_time/__init__.py | 134 ++++++++++++++ .../m_lstm}/__init__.py | 0 .../mLSTM => x_lstm_time/m_lstm}/cell.py | 0 .../mLSTM => x_lstm_time/m_lstm}/layer.py | 4 +- .../mLSTM => x_lstm_time/m_lstm}/network.py | 4 +- .../mLSTM => x_lstm_time/s_lstm}/__init__.py | 0 .../sLSTM => x_lstm_time/s_lstm}/cell.py | 0 .../sLSTM => x_lstm_time/s_lstm}/layer.py | 6 +- .../sLSTM => x_lstm_time/s_lstm}/network.py | 6 +- .../models/x_lstm_time/series_decomp.py | 31 ++++ tests/test_models/test_xlstmtime.py | 28 ++- 13 files changed, 188 insertions(+), 196 deletions(-) delete mode 100644 pytorch_forecasting/models/xLSTMTime/sLSTM/__init__.py delete mode 100644 pytorch_forecasting/models/xLSTMTime/xLSTMTime.py create mode 100644 pytorch_forecasting/models/x_lstm_time/__init__.py rename pytorch_forecasting/models/{xLSTMTime => x_lstm_time/m_lstm}/__init__.py (100%) rename pytorch_forecasting/models/{xLSTMTime/mLSTM => x_lstm_time/m_lstm}/cell.py (100%) rename pytorch_forecasting/models/{xLSTMTime/mLSTM => x_lstm_time/m_lstm}/layer.py (95%) rename pytorch_forecasting/models/{xLSTMTime/mLSTM => x_lstm_time/m_lstm}/network.py (88%) rename pytorch_forecasting/models/{xLSTMTime/mLSTM => x_lstm_time/s_lstm}/__init__.py (100%) rename pytorch_forecasting/models/{xLSTMTime/sLSTM => x_lstm_time/s_lstm}/cell.py (100%) rename pytorch_forecasting/models/{xLSTMTime/sLSTM => x_lstm_time/s_lstm}/layer.py (92%) rename pytorch_forecasting/models/{xLSTMTime/sLSTM => x_lstm_time/s_lstm}/network.py (88%) create mode 100644 pytorch_forecasting/models/x_lstm_time/series_decomp.py diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/__init__.py b/pytorch_forecasting/models/xLSTMTime/sLSTM/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py b/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py deleted file mode 100644 index f5f13fe1b..000000000 --- a/pytorch_forecasting/models/xLSTMTime/xLSTMTime.py +++ /dev/null @@ -1,171 +0,0 @@ -import torch -import torch.nn as nn -from typing import Optional, Tuple, Union, Literal -from pytorch_forecasting.models.xLSTMTime.mLSTM.network import mLSTMNetwork -from pytorch_forecasting.models.xLSTMTime.sLSTM.network import sLSTMNetwork - - -class SeriesDecomposition(nn.Module): - """Implements series decomposition using learnable moving averages.""" - - def __init__(self, kernel_size: int): - super(SeriesDecomposition, self).__init__() - self.kernel_size = kernel_size - self.padding = kernel_size // 2 - self.avg_pool = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=self.padding) - - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Decomposes input series into trend and seasonal components. - - Args: - x: Input tensor of shape (batch_size, seq_len, n_features) - - Returns: - Tuple of (trend_component, seasonal_component) - """ - - batch_size, seq_len, n_features = x.shape - x_reshaped = x.reshape(batch_size * n_features, 1, seq_len) - - trend = self.avg_pool(x_reshaped) - - trend = trend.reshape(batch_size, seq_len, n_features) - seasonal = x - trend - - return trend, seasonal - - -class xLSTMTime(nn.Module): - """ - Implementation of xLSTMTime architecture for time series forecasting. - """ - - def __init__( - self, - input_size: int, - hidden_size: int, - output_size: int, - xlstm_type: Literal["slstm", "mlstm"], - num_layers: int = 1, - decomposition_kernel: int = 25, - dropout: float = 0.1, - device: Optional[torch.device] = None, - ): - """ - Initialize xLSTMTime model. - - Args: - input_size: Number of input features - hidden_size: Size of hidden layers - output_size: Number of output features - xlstm_type: Type of LSTM to use ('slstm' or 'mlstm') - num_layers: Number of LSTM layers - decomposition_kernel: Kernel size for series decomposition - dropout: Dropout rate - device: Torch device to use - """ - super(xLSTMTime, self).__init__() - - if xlstm_type not in ["slstm", "mlstm"]: - raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") - - self.xlstm_type = xlstm_type - self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.decomposition = SeriesDecomposition(decomposition_kernel) - self.input_linear = nn.Linear(input_size * 2, hidden_size) - - self.batch_norm = nn.BatchNorm1d(hidden_size) - - if xlstm_type == "mlstm": - self.lstm = mLSTMNetwork( - input_size=hidden_size, - hidden_size=hidden_size, - num_layers=num_layers, - output_size=hidden_size, - dropout=dropout, - device=self.device, - ) - else: # slstm - self.lstm = sLSTMNetwork( - input_size=hidden_size, - hidden_size=hidden_size, - num_layers=num_layers, - output_size=hidden_size, - dropout=dropout, - device=self.device, - ) - self.output_linear = nn.Linear(hidden_size, output_size) - - self.instance_norm = nn.InstanceNorm1d(output_size) - - def forward( - self, - x: torch.Tensor, - hidden_states: Optional[ - Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = None, - ) -> Tuple[torch.Tensor, Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: - """ - Forward pass through the network. - - Args: - x: Input tensor of shape (batch_size, seq_len, input_size) - hidden_states: Initial hidden states for LSTM - - Returns: - Tuple of (output, hidden_states) - """ - batch_size, seq_len, _ = x.shape - - trend, seasonal = self.decomposition(x) - - x = torch.cat([trend, seasonal], dim=-1) - - x = self.input_linear(x) - - # Reshape for batch norm - x = x.transpose(1, 2) - x = self.batch_norm(x) - x = x.transpose(1, 2) - - if hidden_states is None: - hidden_states = self.lstm.init_hidden(batch_size) - - x = x.transpose(0, 1) - output, hidden_states = self.lstm(x, *hidden_states) - - if isinstance(output, tuple): - output = output[0] - - if output.dim() == 2: - output = output.unsqueeze(0) - - output = self.output_linear(output) - - output = output.transpose(1, 2) - output = self.instance_norm(output) - output = output.transpose(1, 2) - - return output, hidden_states - - def predict( - self, - x: torch.Tensor, - hidden_states: Optional[ - Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = None, - ) -> torch.Tensor: - """ - Make predictions using the model. - - Args: - x: Input tensor - hidden_states: Optional initial hidden states - - Returns: - Predictions tensor - """ - output, _ = self.forward(x, hidden_states) - return output diff --git a/pytorch_forecasting/models/x_lstm_time/__init__.py b/pytorch_forecasting/models/x_lstm_time/__init__.py new file mode 100644 index 000000000..392213ebb --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/__init__.py @@ -0,0 +1,134 @@ +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel +from copy import copy +from typing import Optional, Tuple, Union, Literal, Dict +from pytorch_forecasting.metrics import SMAPE, Metric +import torch +from torch import nn +from pytorch_forecasting.models.x_lstm_time.s_lstm.network import sLSTMNetwork +from pytorch_forecasting.models.x_lstm_time.m_lstm.network import mLSTMNetwork +from pytorch_forecasting.models.x_lstm_time.series_decomp import SeriesDecomposition + + +class xLSTMTime(AutoRegressiveBaseModel): + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + xlstm_type: Literal["slstm", "mlstm"] = "slstm", + num_layers: int = 1, + decomposition_kernel: int = 25, + input_projection_size: Optional[int] = None, + dropout: float = 0.1, + loss: Metric = SMAPE(), + device: Optional[torch.device] = None, + **kwargs, + ): + + if device is None: + device = "cpu" + if "target" in kwargs: + del kwargs["target"] + if "target_lags" in kwargs: + del kwargs["target_lags"] + self.save_hyperparameters() + super().__init__(loss=loss, **kwargs) + + if xlstm_type not in ["slstm", "mlstm"]: + raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") + + self.xlstm_type = xlstm_type + self._device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(self._device) + + self.decomposition = SeriesDecomposition(decomposition_kernel) + self.batch_norm = nn.BatchNorm1d(hidden_size) + + self.input_projection_size = input_projection_size or hidden_size + self.input_linear = None + + if xlstm_type == "mlstm": + self.lstm = mLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + device=self.device, + ) + else: # slstm + self.lstm = sLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + device=self.device, + ) + + self.output_linear = nn.Linear(hidden_size, output_size) + self.instance_norm = nn.InstanceNorm1d(output_size) + + def forward( + self, + x: Dict[str, torch.Tensor], + hidden_states: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = None, + ) -> Dict[str, torch.Tensor]: + encoder_cont = x["encoder_cont"] + batch_size, seq_len, n_features = encoder_cont.shape + + trend, seasonal = self.decomposition(encoder_cont) + + x = torch.cat([trend, seasonal], dim=-1) + concatenated_features = x.shape[-1] + + if self.input_linear is None: + self.input_linear = nn.Linear(concatenated_features, self.input_projection_size).to(self.device) + + x = self.input_linear(x) + + x = x.transpose(1, 2) + x = self.batch_norm(x) + x = x.transpose(1, 2) + + if hidden_states is None: + hidden_states = self.lstm.init_hidden(batch_size) + + x = x.transpose(0, 1) + output, hidden_states = self.lstm(x, *hidden_states) + + if isinstance(output, tuple): + output = output[0] + + if output.dim() == 2: + output = output.unsqueeze(0) + + output = self.output_linear(output) + + output = output.transpose(1, 2) + output = self.instance_norm(output) + output = output.transpose(1, 2) + + output = output[0, ..., : self.hparams.output_size] + return self.to_network_output(prediction=output) + + def predict( + self, + x: torch.Tensor, + hidden_states: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = None, + ) -> torch.Tensor: + + output, _ = self.forward(x, hidden_states) + return output + + @classmethod + def from_dataset(cls, dataset, **kwargs): + new_kwargs = copy(kwargs) + new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, SMAPE())) + + return super().from_dataset(dataset, **kwargs) diff --git a/pytorch_forecasting/models/xLSTMTime/__init__.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/__init__.py similarity index 100% rename from pytorch_forecasting/models/xLSTMTime/__init__.py rename to pytorch_forecasting/models/x_lstm_time/m_lstm/__init__.py diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py similarity index 100% rename from pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py rename to pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py similarity index 95% rename from pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py rename to pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py index 7ae1e0027..d3fe98ded 100644 --- a/pytorch_forecasting/models/xLSTMTime/mLSTM/layer.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from pytorch_forecasting.models.xLSTMTime.mLSTM.cell import mLSTMCell +from pytorch_forecasting.models.x_lstm_time.m_lstm.cell import mLSTMCell class mLSTMLayer(nn.Module): @@ -40,7 +40,7 @@ def init_hidden(self, batch_size): def forward(self, x, h=None, c=None, n=None): """ - Forward pass for the mLSTM layer. + Forward pass for the m_lstm layer. """ x = x.to(self.device).transpose(0, 1) diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py similarity index 88% rename from pytorch_forecasting/models/xLSTMTime/mLSTM/network.py rename to pytorch_forecasting/models/x_lstm_time/m_lstm/network.py index 86024fe5b..bf89192d4 100644 --- a/pytorch_forecasting/models/xLSTMTime/mLSTM/network.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py @@ -1,6 +1,6 @@ import torch.nn as nn import torch -from pytorch_forecasting.models.xLSTMTime.mLSTM.layer import mLSTMLayer +from pytorch_forecasting.models.x_lstm_time.m_lstm.layer import mLSTMLayer class mLSTMNetwork(nn.Module): @@ -25,7 +25,7 @@ def __init__( def forward(self, x, h=None, c=None, n=None): """ - Forward pass through the mLSTM network. + Forward pass through the m_lstm network. """ output, (h, c, n) = self.mlstm_layer(x, h, c, n) diff --git a/pytorch_forecasting/models/xLSTMTime/mLSTM/__init__.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/__init__.py similarity index 100% rename from pytorch_forecasting/models/xLSTMTime/mLSTM/__init__.py rename to pytorch_forecasting/models/x_lstm_time/s_lstm/__init__.py diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py similarity index 100% rename from pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py rename to pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py similarity index 92% rename from pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py rename to pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py index cddff40b0..305b3328a 100644 --- a/pytorch_forecasting/models/xLSTMTime/sLSTM/layer.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from pytorch_forecasting.models.xLSTMTime.sLSTM.cell import sLSTMCell +from pytorch_forecasting.models.x_lstm_time.s_lstm.cell import sLSTMCell class sLSTMLayer(nn.Module): """ - Enhanced sLSTM Layer that supports multiple sLSTM cells across timesteps and residual connections. + Enhanced s_lstm Layer that supports multiple s_lstm cells across timesteps and residual connections. """ def __init__( @@ -44,7 +44,7 @@ def __init__( def forward(self, x, h=None, c=None): """ - Forward pass through the sLSTM layer for each time step in sequence. + Forward pass through the s_lstm layer for each time step in sequence. Args: x: input tensor (seq_len, batch_size, input_size) h: initial hidden states (num_layers, batch_size, hidden_size) diff --git a/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py similarity index 88% rename from pytorch_forecasting/models/xLSTMTime/sLSTM/network.py rename to pytorch_forecasting/models/x_lstm_time/s_lstm/network.py index 2075c314c..1d46f5cff 100644 --- a/pytorch_forecasting/models/xLSTMTime/sLSTM/network.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py @@ -1,11 +1,11 @@ import torch.nn as nn import torch -from pytorch_forecasting.models.xLSTMTime.sLSTM.layer import sLSTMLayer +from pytorch_forecasting.models.x_lstm_time.s_lstm.layer import sLSTMLayer class sLSTMNetwork(nn.Module): """ - Stabilized LSTM Network with multiple sLSTM layers. + Stabilized LSTM Network with multiple s_lstm layers. """ def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0, use_layer_norm=True, device=None): @@ -22,7 +22,7 @@ def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0 def forward(self, x, h=None, c=None): """ - Forward pass through the sLSTM network. + Forward pass through the s_lstm network. Args: x: input tensor (seq_len, batch_size, input_size) h: initial hidden states (num_layers, batch_size, hidden_size) diff --git a/pytorch_forecasting/models/x_lstm_time/series_decomp.py b/pytorch_forecasting/models/x_lstm_time/series_decomp.py new file mode 100644 index 000000000..4268482dd --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/series_decomp.py @@ -0,0 +1,31 @@ +from typing import Tuple +import torch +from torch import nn + + +class SeriesDecomposition(nn.Module): + """Implements series decomposition using learnable moving averages.""" + + def __init__(self, kernel_size: int): + super(SeriesDecomposition, self).__init__() + self.kernel_size = kernel_size + self.padding = kernel_size // 2 + self.avg_pool = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=self.padding) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decomposes input series into trend and seasonal components. + + Args: + x: Input tensor of shape (batch_size, seq_len, n_features) + + Returns: + Tuple of (trend_component, seasonal_component) + """ + batch_size, seq_len, n_features = x.shape + x_reshaped = x.reshape(batch_size * n_features, 1, seq_len) + trend = self.avg_pool(x_reshaped) + trend = trend.reshape(batch_size, seq_len, n_features) + seasonal = x - trend + + return trend, seasonal diff --git a/tests/test_models/test_xlstmtime.py b/tests/test_models/test_xlstmtime.py index 7c6e0e01b..b611c3a4c 100644 --- a/tests/test_models/test_xlstmtime.py +++ b/tests/test_models/test_xlstmtime.py @@ -1,17 +1,15 @@ import pytest import torch import torch.nn as nn -from typing import Tuple, List -from sympy.stats.sampling.sample_numpy import numpy - -from pytorch_forecasting.models.xLSTMTime.mLSTM.cell import mLSTMCell -from pytorch_forecasting.models.xLSTMTime.mLSTM.layer import mLSTMLayer -from pytorch_forecasting.models.xLSTMTime.mLSTM.network import mLSTMNetwork -from pytorch_forecasting.models.xLSTMTime.sLSTM.cell import sLSTMCell -from pytorch_forecasting.models.xLSTMTime.sLSTM.layer import sLSTMLayer -from pytorch_forecasting.models.xLSTMTime.sLSTM.network import sLSTMNetwork -from pytorch_forecasting.models.xLSTMTime.xLSTMTime import xLSTMTime, SeriesDecomposition +from pytorch_forecasting.models.x_lstm_time.m_lstm.cell import mLSTMCell +from pytorch_forecasting.models.x_lstm_time.m_lstm.layer import mLSTMLayer +from pytorch_forecasting.models.x_lstm_time.m_lstm.network import mLSTMNetwork +from pytorch_forecasting.models.x_lstm_time.s_lstm.cell import sLSTMCell +from pytorch_forecasting.models.x_lstm_time.s_lstm.layer import sLSTMLayer +from pytorch_forecasting.models.x_lstm_time.s_lstm.network import sLSTMNetwork +from pytorch_forecasting.models.x_lstm_time.series_decomp import SeriesDecomposition +from pytorch_forecasting.models.x_lstm_time import xLSTMTime # Fixtures for common test parameters @@ -75,7 +73,7 @@ def test_decomposition_sum(self, sample_input): torch.testing.assert_close(trend + seasonal, sample_input, rtol=1e-4, atol=1e-4) -# Test mLSTM Components +# Test m_lstm Components class TestMLSTM: def test_cell_initialization(self, input_size, hidden_size, device): cell = mLSTMCell(input_size, hidden_size, device=device) @@ -110,7 +108,7 @@ def test_network_forward(self, sample_input, input_size, hidden_size, output_siz assert output.shape == (output_size, sample_input.shape[0]) -# Test sLSTM Components +# Test s_lstm Components class TestSLSTM: def test_cell_initialization(self, input_size, hidden_size, device): cell = sLSTMCell(input_size, hidden_size, device=device) @@ -143,7 +141,7 @@ def test_network_forward(self, sample_input, input_size, hidden_size, output_siz assert output.shape == (output_size, sample_input.shape[0]) -# Test xLSTMTime +# Test x_lstm_time class TestXLSTMTime: @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) def test_initialization(self, input_size, hidden_size, output_size, xlstm_type, device): @@ -172,13 +170,13 @@ def test_forward(self, sample_input, input_size, hidden_size, output_size, xlstm assert output.shape == (1, sample_input.shape[0], output_size) if xlstm_type == "mlstm": - assert len(hidden_states) == 3 # h, c, n for mLSTM + assert len(hidden_states) == 3 # h, c, n for m_lstm h, c, n = hidden_states assert h.shape == (1, sample_input.shape[0], hidden_size) assert c.shape == (1, sample_input.shape[0], hidden_size) assert n.shape == (1, sample_input.shape[0], hidden_size) else: - assert len(hidden_states) == 2 # h, c for sLSTM + assert len(hidden_states) == 2 # h, c for s_lstm h, c = hidden_states assert torch.stack(h).shape == (1, sample_input.shape[0], hidden_size) assert torch.stack(c).shape == (1, sample_input.shape[0], hidden_size) From a72c8c62c6b47660985e736419b0d3412d75f88c Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 13 Dec 2024 23:18:36 +0530 Subject: [PATCH 05/15] Update __init__.py --- pytorch_forecasting/models/x_lstm_time/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_forecasting/models/x_lstm_time/__init__.py b/pytorch_forecasting/models/x_lstm_time/__init__.py index 392213ebb..b9de84cb9 100644 --- a/pytorch_forecasting/models/x_lstm_time/__init__.py +++ b/pytorch_forecasting/models/x_lstm_time/__init__.py @@ -26,8 +26,6 @@ def __init__( **kwargs, ): - if device is None: - device = "cpu" if "target" in kwargs: del kwargs["target"] if "target_lags" in kwargs: From b3b3e55e837ae24b67ac7d4956596b864eba2555 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 23 Dec 2024 16:14:30 +0530 Subject: [PATCH 06/15] little debug in `predict` method --- .../models/x_lstm_time/__init__.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_forecasting/models/x_lstm_time/__init__.py b/pytorch_forecasting/models/x_lstm_time/__init__.py index b9de84cb9..674b5e76c 100644 --- a/pytorch_forecasting/models/x_lstm_time/__init__.py +++ b/pytorch_forecasting/models/x_lstm_time/__init__.py @@ -114,15 +114,16 @@ def forward( return self.to_network_output(prediction=output) def predict( - self, - x: torch.Tensor, - hidden_states: Optional[ - Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = None, + self, + x: Dict[str, torch.Tensor], + hidden_states: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = None, ) -> torch.Tensor: - output, _ = self.forward(x, hidden_states) - return output + network_output = self.forward(x, hidden_states) + prediction = network_output["prediction"] + return prediction @classmethod def from_dataset(cls, dataset, **kwargs): From 87f4ff46d67576e917a7714390808b9cdfd9b70a Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Tue, 24 Dec 2024 18:53:35 +0530 Subject: [PATCH 07/15] trying the baseclass predict function and removing the test files --- .../models/x_lstm_time/__init__.py | 12 - tests/test_models/test_xlstmtime.py | 276 ------------------ 2 files changed, 288 deletions(-) delete mode 100644 tests/test_models/test_xlstmtime.py diff --git a/pytorch_forecasting/models/x_lstm_time/__init__.py b/pytorch_forecasting/models/x_lstm_time/__init__.py index 674b5e76c..5c65f36b1 100644 --- a/pytorch_forecasting/models/x_lstm_time/__init__.py +++ b/pytorch_forecasting/models/x_lstm_time/__init__.py @@ -113,18 +113,6 @@ def forward( output = output[0, ..., : self.hparams.output_size] return self.to_network_output(prediction=output) - def predict( - self, - x: Dict[str, torch.Tensor], - hidden_states: Optional[ - Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = None, - ) -> torch.Tensor: - - network_output = self.forward(x, hidden_states) - prediction = network_output["prediction"] - return prediction - @classmethod def from_dataset(cls, dataset, **kwargs): new_kwargs = copy(kwargs) diff --git a/tests/test_models/test_xlstmtime.py b/tests/test_models/test_xlstmtime.py deleted file mode 100644 index b611c3a4c..000000000 --- a/tests/test_models/test_xlstmtime.py +++ /dev/null @@ -1,276 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -from pytorch_forecasting.models.x_lstm_time.m_lstm.cell import mLSTMCell -from pytorch_forecasting.models.x_lstm_time.m_lstm.layer import mLSTMLayer -from pytorch_forecasting.models.x_lstm_time.m_lstm.network import mLSTMNetwork -from pytorch_forecasting.models.x_lstm_time.s_lstm.cell import sLSTMCell -from pytorch_forecasting.models.x_lstm_time.s_lstm.layer import sLSTMLayer -from pytorch_forecasting.models.x_lstm_time.s_lstm.network import sLSTMNetwork -from pytorch_forecasting.models.x_lstm_time.series_decomp import SeriesDecomposition -from pytorch_forecasting.models.x_lstm_time import xLSTMTime - - -# Fixtures for common test parameters -@pytest.fixture -def device(): - return torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -@pytest.fixture -def batch_size(): - return 32 - - -@pytest.fixture -def seq_length(): - return 24 - - -@pytest.fixture -def input_size(): - return 10 - - -@pytest.fixture -def hidden_size(): - return 64 - - -@pytest.fixture -def output_size(): - return 5 - - -@pytest.fixture -def sample_input(batch_size, seq_length, input_size, device): - return torch.randn(batch_size, seq_length, input_size).to(device) - - -# Test Series Decomposition -class TestSeriesDecomposition: - def test_initialization(self): - kernel_size = 25 - decomp = SeriesDecomposition(kernel_size) - assert decomp.kernel_size == kernel_size - assert decomp.padding == kernel_size // 2 - - def test_forward_shape(self, sample_input): - kernel_size = 25 - decomp = SeriesDecomposition(kernel_size) - trend, seasonal = decomp(sample_input) - - assert trend.shape == sample_input.shape - assert seasonal.shape == sample_input.shape - - def test_decomposition_sum(self, sample_input): - kernel_size = 25 - decomp = SeriesDecomposition(kernel_size) - trend, seasonal = decomp(sample_input) - - # Check if trend + seasonal approximately equals input - torch.testing.assert_close(trend + seasonal, sample_input, rtol=1e-4, atol=1e-4) - - -# Test m_lstm Components -class TestMLSTM: - def test_cell_initialization(self, input_size, hidden_size, device): - cell = mLSTMCell(input_size, hidden_size, device=device) - assert cell.input_size == input_size - assert cell.hidden_size == hidden_size - - def test_cell_forward(self, batch_size, input_size, hidden_size, device): - cell = mLSTMCell(input_size, hidden_size, device=device) - x = torch.randn(1, batch_size, input_size).to(device) # Add sequence dimension - h_prev = torch.randn(batch_size, hidden_size).to(device) - c_prev = torch.randn(batch_size, hidden_size).to(device) - n_prev = torch.randn(batch_size, hidden_size).to(device) - - h, c, n = cell(x[0], h_prev, c_prev, n_prev) # Use first timestep - assert h.shape == (batch_size, hidden_size) - assert c.shape == (batch_size, hidden_size) - assert n.shape == (batch_size, hidden_size) - - def test_layer_initialization(self, input_size, hidden_size, device): - layer = mLSTMLayer(input_size, hidden_size, num_layers=2, device=device) - assert layer.input_size == input_size - assert layer.hidden_size == hidden_size - assert len(layer.cells) == 2 - - def test_network_forward(self, sample_input, input_size, hidden_size, output_size, device): - network = mLSTMNetwork(input_size, hidden_size, num_layers=2, output_size=output_size, device=device) - # Transpose input to seq_len, batch_size, input_size - x = sample_input.transpose(0, 1) - output, (h, c, n) = network(x) - # Transpose output back to batch_size, seq_len, output_size - output = output.transpose(0, 1) - assert output.shape == (output_size, sample_input.shape[0]) - - -# Test s_lstm Components -class TestSLSTM: - def test_cell_initialization(self, input_size, hidden_size, device): - cell = sLSTMCell(input_size, hidden_size, device=device) - assert cell.input_size == input_size - assert cell.hidden_size == hidden_size - - def test_cell_forward(self, batch_size, input_size, hidden_size, device): - cell = sLSTMCell(input_size, hidden_size, device=device) - x = torch.randn(1, batch_size, input_size).to(device) # Add sequence dimension - h_prev = torch.randn(batch_size, hidden_size).to(device) - c_prev = torch.randn(batch_size, hidden_size).to(device) - - h, c = cell(x[0], h_prev, c_prev) # Use first timestep - assert h.shape == (batch_size, hidden_size) - assert c.shape == (batch_size, hidden_size) - - def test_layer_initialization(self, input_size, hidden_size, device): - layer = sLSTMLayer(input_size, hidden_size, num_layers=2, device=device) - assert layer.input_size == input_size - assert layer.hidden_size == hidden_size - assert len(layer.cells) == 2 - - def test_network_forward(self, sample_input, input_size, hidden_size, output_size, device): - network = sLSTMNetwork(input_size, hidden_size, num_layers=2, output_size=output_size, device=device) - # Transpose input to seq_len, batch_size, input_size - x = sample_input.transpose(0, 1) - output, (h, c) = network(x) - # Transpose output back to batch_size, seq_len, output_size - output = output.transpose(0, 1) - assert output.shape == (output_size, sample_input.shape[0]) - - -# Test x_lstm_time -class TestXLSTMTime: - @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) - def test_initialization(self, input_size, hidden_size, output_size, xlstm_type, device): - model = xLSTMTime( - input_size=input_size, - hidden_size=hidden_size, - output_size=output_size, - xlstm_type=xlstm_type, - device=device, - ) - assert isinstance(model.decomposition, SeriesDecomposition) - assert isinstance(model.input_linear, nn.Linear) - assert isinstance(model.output_linear, nn.Linear) - - @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) - def test_forward(self, sample_input, input_size, hidden_size, output_size, xlstm_type, device): - model = xLSTMTime( - input_size=input_size, - hidden_size=hidden_size, - output_size=output_size, - xlstm_type=xlstm_type, - device=device, - ) - output, hidden_states = model(sample_input) - # Check output shape is batch_size, seq_len, output_size - assert output.shape == (1, sample_input.shape[0], output_size) - - if xlstm_type == "mlstm": - assert len(hidden_states) == 3 # h, c, n for m_lstm - h, c, n = hidden_states - assert h.shape == (1, sample_input.shape[0], hidden_size) - assert c.shape == (1, sample_input.shape[0], hidden_size) - assert n.shape == (1, sample_input.shape[0], hidden_size) - else: - assert len(hidden_states) == 2 # h, c for s_lstm - h, c = hidden_states - assert torch.stack(h).shape == (1, sample_input.shape[0], hidden_size) - assert torch.stack(c).shape == (1, sample_input.shape[0], hidden_size) - - @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) - def test_predict(self, sample_input, input_size, hidden_size, output_size, xlstm_type, device): - model = xLSTMTime( - input_size=input_size, - hidden_size=hidden_size, - output_size=output_size, - xlstm_type=xlstm_type, - device=device, - ) - predictions = model.predict(sample_input) - assert predictions.shape == (1, sample_input.shape[0], output_size) - - def test_invalid_xlstm_type(self, input_size, hidden_size, output_size, device): - with pytest.raises(ValueError, match="xlstm_type must be either 'slstm' or 'mlstm'"): - xLSTMTime( - input_size=input_size, - hidden_size=hidden_size, - output_size=output_size, - xlstm_type="invalid_type", - device=device, - ) - - -# Test edge cases and error handling -class TestEdgeCases: - @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) - def test_single_sequence_length(self, batch_size, input_size, hidden_size, output_size, xlstm_type, device): - model = xLSTMTime( - input_size=input_size, - hidden_size=hidden_size, - output_size=output_size, - xlstm_type=xlstm_type, - device=device, - ) - single_step = torch.randn(batch_size, 1, input_size).to(device) - output, hidden_states = model(single_step) - assert output.shape == (1, batch_size, output_size) - - if xlstm_type == "mlstm": - h, c, n = hidden_states - assert h.shape == (1, batch_size, hidden_size) - assert c.shape == (1, batch_size, hidden_size) - assert n.shape == (1, batch_size, hidden_size) - else: # slstm - h, c = hidden_states - assert torch.stack(h).shape == (1, batch_size, hidden_size) - assert torch.stack(c).shape == (1, batch_size, hidden_size) - - @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) - def test_input_nan_handling(self, batch_size, input_size, hidden_size, output_size, xlstm_type, device): - """Test model behavior with NaN inputs""" - model = xLSTMTime( - input_size=input_size, - hidden_size=hidden_size, - output_size=output_size, - xlstm_type=xlstm_type, - device=device, - ) - - # Create input with some NaN values - nan_input = torch.randn(batch_size, 24, input_size).to(device) - nan_input[0, 0, 0] = float("nan") # Insert a NaN value - - try: - output, _ = model(nan_input) - # If we reach here, check if output contains NaN - assert torch.isnan(output).any(), "Expected NaN in output with NaN input" - except Exception as e: - # Model should either propagate NaN or raise an exception - assert isinstance(e, (RuntimeError, ValueError)), "Expected RuntimeError or ValueError with NaN input" - - @pytest.mark.parametrize("xlstm_type", ["mlstm", "slstm"]) - def test_numerical_stability(self, batch_size, input_size, hidden_size, output_size, xlstm_type, device): - """Test model behavior with extreme input values""" - model = xLSTMTime( - input_size=input_size, - hidden_size=hidden_size, - output_size=output_size, - xlstm_type=xlstm_type, - device=device, - ) - - # Test with very large values - large_input = torch.full((batch_size, 24, input_size), 1e10).to(device) - output_large, _ = model(large_input) - assert not torch.isnan(output_large).any(), "NaN in output with large input values" - assert not torch.isinf(output_large).any(), "Inf in output with large input values" - - # Test with very small values - small_input = torch.full((batch_size, 24, input_size), 1e-10).to(device) - output_small, _ = model(small_input) - assert not torch.isnan(output_small).any(), "NaN in output with small input values" - assert not torch.isinf(output_small).any(), "Inf in output with small input values" From a6b2da98bd6778bca9b6e814a8242aaad73b6942 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 6 Jan 2025 17:43:08 +0530 Subject: [PATCH 08/15] refactor `__init__.py` --- .../models/x_lstm_time/__init__.py | 122 +-------------- .../models/x_lstm_time/series_decomp.py | 31 ---- .../models/x_lstm_time/x_lstm.py | 148 ++++++++++++++++++ 3 files changed, 150 insertions(+), 151 deletions(-) delete mode 100644 pytorch_forecasting/models/x_lstm_time/series_decomp.py create mode 100644 pytorch_forecasting/models/x_lstm_time/x_lstm.py diff --git a/pytorch_forecasting/models/x_lstm_time/__init__.py b/pytorch_forecasting/models/x_lstm_time/__init__.py index 5c65f36b1..7ebcbd7cc 100644 --- a/pytorch_forecasting/models/x_lstm_time/__init__.py +++ b/pytorch_forecasting/models/x_lstm_time/__init__.py @@ -1,121 +1,3 @@ -from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel -from copy import copy -from typing import Optional, Tuple, Union, Literal, Dict -from pytorch_forecasting.metrics import SMAPE, Metric -import torch -from torch import nn -from pytorch_forecasting.models.x_lstm_time.s_lstm.network import sLSTMNetwork -from pytorch_forecasting.models.x_lstm_time.m_lstm.network import mLSTMNetwork -from pytorch_forecasting.models.x_lstm_time.series_decomp import SeriesDecomposition +from pytorch_forecasting.models.x_lstm_time.x_lstm import xLSTMTime - -class xLSTMTime(AutoRegressiveBaseModel): - - def __init__( - self, - input_size: int, - hidden_size: int, - output_size: int, - xlstm_type: Literal["slstm", "mlstm"] = "slstm", - num_layers: int = 1, - decomposition_kernel: int = 25, - input_projection_size: Optional[int] = None, - dropout: float = 0.1, - loss: Metric = SMAPE(), - device: Optional[torch.device] = None, - **kwargs, - ): - - if "target" in kwargs: - del kwargs["target"] - if "target_lags" in kwargs: - del kwargs["target_lags"] - self.save_hyperparameters() - super().__init__(loss=loss, **kwargs) - - if xlstm_type not in ["slstm", "mlstm"]: - raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") - - self.xlstm_type = xlstm_type - self._device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.to(self._device) - - self.decomposition = SeriesDecomposition(decomposition_kernel) - self.batch_norm = nn.BatchNorm1d(hidden_size) - - self.input_projection_size = input_projection_size or hidden_size - self.input_linear = None - - if xlstm_type == "mlstm": - self.lstm = mLSTMNetwork( - input_size=hidden_size, - hidden_size=hidden_size, - num_layers=num_layers, - output_size=hidden_size, - dropout=dropout, - device=self.device, - ) - else: # slstm - self.lstm = sLSTMNetwork( - input_size=hidden_size, - hidden_size=hidden_size, - num_layers=num_layers, - output_size=hidden_size, - dropout=dropout, - device=self.device, - ) - - self.output_linear = nn.Linear(hidden_size, output_size) - self.instance_norm = nn.InstanceNorm1d(output_size) - - def forward( - self, - x: Dict[str, torch.Tensor], - hidden_states: Optional[ - Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = None, - ) -> Dict[str, torch.Tensor]: - encoder_cont = x["encoder_cont"] - batch_size, seq_len, n_features = encoder_cont.shape - - trend, seasonal = self.decomposition(encoder_cont) - - x = torch.cat([trend, seasonal], dim=-1) - concatenated_features = x.shape[-1] - - if self.input_linear is None: - self.input_linear = nn.Linear(concatenated_features, self.input_projection_size).to(self.device) - - x = self.input_linear(x) - - x = x.transpose(1, 2) - x = self.batch_norm(x) - x = x.transpose(1, 2) - - if hidden_states is None: - hidden_states = self.lstm.init_hidden(batch_size) - - x = x.transpose(0, 1) - output, hidden_states = self.lstm(x, *hidden_states) - - if isinstance(output, tuple): - output = output[0] - - if output.dim() == 2: - output = output.unsqueeze(0) - - output = self.output_linear(output) - - output = output.transpose(1, 2) - output = self.instance_norm(output) - output = output.transpose(1, 2) - - output = output[0, ..., : self.hparams.output_size] - return self.to_network_output(prediction=output) - - @classmethod - def from_dataset(cls, dataset, **kwargs): - new_kwargs = copy(kwargs) - new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, SMAPE())) - - return super().from_dataset(dataset, **kwargs) +__all__ = ["xLSTMTime"] diff --git a/pytorch_forecasting/models/x_lstm_time/series_decomp.py b/pytorch_forecasting/models/x_lstm_time/series_decomp.py deleted file mode 100644 index 4268482dd..000000000 --- a/pytorch_forecasting/models/x_lstm_time/series_decomp.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Tuple -import torch -from torch import nn - - -class SeriesDecomposition(nn.Module): - """Implements series decomposition using learnable moving averages.""" - - def __init__(self, kernel_size: int): - super(SeriesDecomposition, self).__init__() - self.kernel_size = kernel_size - self.padding = kernel_size // 2 - self.avg_pool = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=self.padding) - - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Decomposes input series into trend and seasonal components. - - Args: - x: Input tensor of shape (batch_size, seq_len, n_features) - - Returns: - Tuple of (trend_component, seasonal_component) - """ - batch_size, seq_len, n_features = x.shape - x_reshaped = x.reshape(batch_size * n_features, 1, seq_len) - trend = self.avg_pool(x_reshaped) - trend = trend.reshape(batch_size, seq_len, n_features) - seasonal = x - trend - - return trend, seasonal diff --git a/pytorch_forecasting/models/x_lstm_time/x_lstm.py b/pytorch_forecasting/models/x_lstm_time/x_lstm.py new file mode 100644 index 000000000..5056306ba --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/x_lstm.py @@ -0,0 +1,148 @@ +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel +from copy import copy +from typing import Optional, Tuple, Union, Literal, Dict +from pytorch_forecasting.metrics import SMAPE, Metric +import torch +from torch import nn +from pytorch_forecasting.models.x_lstm_time.s_lstm.network import sLSTMNetwork +from pytorch_forecasting.models.x_lstm_time.m_lstm.network import mLSTMNetwork + + +class SeriesDecomposition(nn.Module): + """Implements series decomposition using learnable moving averages.""" + + def __init__(self, kernel_size: int): + super(SeriesDecomposition, self).__init__() + self.kernel_size = kernel_size + self.padding = kernel_size // 2 + self.avg_pool = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=self.padding) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decomposes input series into trend and seasonal components. + + Args: + x: Input tensor of shape (batch_size, seq_len, n_features) + + Returns: + Tuple of (trend_component, seasonal_component) + """ + batch_size, seq_len, n_features = x.shape + x_reshaped = x.reshape(batch_size * n_features, 1, seq_len) + trend = self.avg_pool(x_reshaped) + trend = trend.reshape(batch_size, seq_len, n_features) + seasonal = x - trend + + return trend, seasonal + + +class xLSTMTime(AutoRegressiveBaseModel): + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + xlstm_type: Literal["slstm", "mlstm"] = "slstm", + num_layers: int = 1, + decomposition_kernel: int = 25, + input_projection_size: Optional[int] = None, + dropout: float = 0.1, + loss: Metric = SMAPE(), + device: Optional[torch.device] = None, + **kwargs, + ): + + if "target" in kwargs: + del kwargs["target"] + if "target_lags" in kwargs: + del kwargs["target_lags"] + self.save_hyperparameters() + super().__init__(loss=loss, **kwargs) + + if xlstm_type not in ["slstm", "mlstm"]: + raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") + + self.xlstm_type = xlstm_type + self._device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(self._device) + + self.decomposition = SeriesDecomposition(decomposition_kernel) + self.batch_norm = nn.BatchNorm1d(hidden_size) + + self.input_projection_size = input_projection_size or hidden_size + self.input_linear = None + + if xlstm_type == "mlstm": + self.lstm = mLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + device=self.device, + ) + else: # slstm + self.lstm = sLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + device=self.device, + ) + + self.output_linear = nn.Linear(hidden_size, output_size) + self.instance_norm = nn.InstanceNorm1d(output_size) + + def forward( + self, + x: Dict[str, torch.Tensor], + hidden_states: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = None, + ) -> Dict[str, torch.Tensor]: + encoder_cont = x["encoder_cont"] + batch_size, seq_len, n_features = encoder_cont.shape + + trend, seasonal = self.decomposition(encoder_cont) + + x = torch.cat([trend, seasonal], dim=-1) + concatenated_features = x.shape[-1] + + if self.input_linear is None: + self.input_linear = nn.Linear(concatenated_features, self.input_projection_size).to(self.device) + + x = self.input_linear(x) + + x = x.transpose(1, 2) + x = self.batch_norm(x) + x = x.transpose(1, 2) + + if hidden_states is None: + hidden_states = self.lstm.init_hidden(batch_size) + + x = x.transpose(0, 1) + output, hidden_states = self.lstm(x, *hidden_states) + + if isinstance(output, tuple): + output = output[0] + + if output.dim() == 2: + output = output.unsqueeze(0) + + output = self.output_linear(output) + + output = output.transpose(1, 2) + output = self.instance_norm(output) + output = output.transpose(1, 2) + + output = output[0, ..., : self.hparams.output_size] + return self.to_network_output(prediction=output) + + @classmethod + def from_dataset(cls, dataset, **kwargs): + new_kwargs = copy(kwargs) + new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, SMAPE())) + + return super().from_dataset(dataset, **kwargs) From f67509a5f607fa6794a650ccdb4e5239cc0b2078 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 6 Jan 2025 18:10:18 +0530 Subject: [PATCH 09/15] linting --- .../models/x_lstm_time/m_lstm/cell.py | 32 +++++++++++++---- .../models/x_lstm_time/m_lstm/layer.py | 22 ++++++++++-- .../models/x_lstm_time/m_lstm/network.py | 15 ++++++-- .../models/x_lstm_time/s_lstm/cell.py | 13 +++++-- .../models/x_lstm_time/s_lstm/layer.py | 36 +++++++++++++++---- .../models/x_lstm_time/s_lstm/network.py | 29 ++++++++++++--- .../models/x_lstm_time/x_lstm.py | 31 +++++++++++----- 7 files changed, 143 insertions(+), 35 deletions(-) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py index f0178f4af..2dfa476d9 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py @@ -1,16 +1,23 @@ +import math + import torch import torch.nn as nn -import math class mLSTMCell(nn.Module): - def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None): + def __init__( + self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None + ): super(mLSTMCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.layer_norm = layer_norm - self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = ( + device + if device is not None + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) self.Wq = nn.Linear(input_size, hidden_size) self.Wk = nn.Linear(input_size, hidden_size) @@ -56,10 +63,21 @@ def forward(self, x, h_prev, c_prev, n_prev): n_prev = n_prev.to(self.device) batch_size = x.size(0) - assert x.dim() == 2, f"Input should be 2D (batch_size, input_size), got {x.dim()}D" - assert h_prev.size() == (batch_size, self.hidden_size), f"h_prev shape mismatch: {h_prev.size()}" - assert c_prev.size() == (batch_size, self.hidden_size), f"c_prev shape mismatch: {c_prev.size()}" - assert n_prev.size() == (batch_size, self.hidden_size), f"n_prev shape mismatch: {n_prev.size()}" + assert ( + x.dim() == 2 + ), f"Input should be 2D (batch_size, input_size), got {x.dim()}D" + assert h_prev.size() == ( + batch_size, + self.hidden_size, + ), f"h_prev shape mismatch: {h_prev.size()}" + assert c_prev.size() == ( + batch_size, + self.hidden_size, + ), f"c_prev shape mismatch: {c_prev.size()}" + assert n_prev.size() == ( + batch_size, + self.hidden_size, + ), f"n_prev shape mismatch: {n_prev.size()}" x = self.dropout(x) h_prev = self.dropout(h_prev) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py index d3fe98ded..edb6df6ef 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py @@ -1,11 +1,19 @@ import torch import torch.nn as nn + from pytorch_forecasting.models.x_lstm_time.m_lstm.cell import mLSTMCell class mLSTMLayer(nn.Module): def __init__( - self, input_size, hidden_size, num_layers, dropout=0.2, layer_norm=True, residual_conn=True, device=None + self, + input_size, + hidden_size, + num_layers, + dropout=0.2, + layer_norm=True, + residual_conn=True, + device=None, ): super(mLSTMLayer, self).__init__() self.input_size = input_size @@ -13,13 +21,21 @@ def __init__( self.num_layers = num_layers self.layer_norm = layer_norm self.residual_conn = residual_conn - self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) self.dropout = nn.Dropout(dropout).to(self.device) self.cells = nn.ModuleList( [ - mLSTMCell(input_size if i == 0 else hidden_size, hidden_size, dropout, layer_norm, self.device) + mLSTMCell( + input_size if i == 0 else hidden_size, + hidden_size, + dropout, + layer_norm, + self.device, + ) for i in range(num_layers) ] ) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py index bf89192d4..d6c989b3d 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py @@ -1,5 +1,6 @@ -import torch.nn as nn import torch +import torch.nn as nn + from pytorch_forecasting.models.x_lstm_time.m_lstm.layer import mLSTMLayer @@ -16,10 +17,18 @@ def __init__( device=None, ): super(mLSTMNetwork, self).__init__() - self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) self.mlstm_layer = mLSTMLayer( - input_size, hidden_size, num_layers, dropout, use_layer_norm, use_residual, self.device + input_size, + hidden_size, + num_layers, + dropout, + use_layer_norm, + use_residual, + self.device, ) self.fc = nn.Linear(hidden_size, output_size) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py index 6ff0fdf60..e4c63545e 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py @@ -1,12 +1,15 @@ +import math + import torch import torch.nn as nn -import math class sLSTMCell(nn.Module): """Stabilized LSTM Cell""" - def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, device=None): + def __init__( + self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, device=None + ): super(sLSTMCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size @@ -14,7 +17,11 @@ def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, de self.use_layer_norm = use_layer_norm self.eps = 1e-6 - self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = ( + device + if device is not None + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) self.input_weights = nn.Linear(input_size, 4 * hidden_size).to(self.device) self.hidden_weights = nn.Linear(hidden_size, 4 * hidden_size).to(self.device) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py index 305b3328a..a0393223e 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn + from pytorch_forecasting.models.x_lstm_time.s_lstm.cell import sLSTMCell @@ -9,7 +10,14 @@ class sLSTMLayer(nn.Module): """ def __init__( - self, input_size, hidden_size, num_layers=1, dropout=0.0, use_layer_norm=True, use_residual=True, device=None + self, + input_size, + hidden_size, + num_layers=1, + dropout=0.0, + use_layer_norm=True, + use_residual=True, + device=None, ): super(sLSTMLayer, self).__init__() self.input_size = input_size @@ -18,11 +26,17 @@ def __init__( self.dropout = dropout self.use_layer_norm = use_layer_norm self.use_residual = use_residual - self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = ( + device + if device + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) self.input_projection = None if self.use_residual and input_size != hidden_size: - self.input_projection = nn.Linear(input_size, hidden_size, bias=False).to(self.device) + self.input_projection = nn.Linear(input_size, hidden_size, bias=False).to( + self.device + ) self.cells = nn.ModuleList( [ @@ -75,7 +89,11 @@ def forward(self, x, h=None, c=None): if layer == 0 and self.input_projection is not None: residual = self.input_projection(layer_input) else: - residual = layer_input if layer_input.size(-1) == self.hidden_size else 0 + residual = ( + layer_input + if (layer_input.size(-1) == self.hidden_size) + else 0 + ) h[layer] = h[layer] + residual if self.use_layer_norm: @@ -95,6 +113,12 @@ def forward(self, x, h=None, c=None): def init_hidden(self, batch_size): """Initialize hidden and cell states for each layer.""" return ( - [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)], - [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)], + [ + torch.zeros(batch_size, self.hidden_size, device=self.device) + for _ in range(self.num_layers) + ], + [ + torch.zeros(batch_size, self.hidden_size, device=self.device) + for _ in range(self.num_layers) + ], ) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py index 1d46f5cff..d5846a65d 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py @@ -1,5 +1,6 @@ -import torch.nn as nn import torch +import torch.nn as nn + from pytorch_forecasting.models.x_lstm_time.s_lstm.layer import sLSTMLayer @@ -8,16 +9,36 @@ class sLSTMNetwork(nn.Module): Stabilized LSTM Network with multiple s_lstm layers. """ - def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0, use_layer_norm=True, device=None): + def __init__( + self, + input_size, + hidden_size, + num_layers, + output_size, + dropout=0.0, + use_layer_norm=True, + device=None, + ): super(sLSTMNetwork, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.output_size = output_size self.dropout = dropout - self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = ( + device + if device + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) - self.slstm_layer = sLSTMLayer(input_size, hidden_size, num_layers, dropout, use_layer_norm, device=self.device) + self.slstm_layer = sLSTMLayer( + input_size, + hidden_size, + num_layers, + dropout, + use_layer_norm, + device=self.device, + ) self.fc = nn.Linear(hidden_size, output_size).to(self.device) def forward(self, x, h=None, c=None): diff --git a/pytorch_forecasting/models/x_lstm_time/x_lstm.py b/pytorch_forecasting/models/x_lstm_time/x_lstm.py index 5056306ba..e88af3920 100644 --- a/pytorch_forecasting/models/x_lstm_time/x_lstm.py +++ b/pytorch_forecasting/models/x_lstm_time/x_lstm.py @@ -1,11 +1,13 @@ -from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel from copy import copy -from typing import Optional, Tuple, Union, Literal, Dict -from pytorch_forecasting.metrics import SMAPE, Metric +from typing import Dict, Literal, Optional, Tuple, Union + import torch from torch import nn -from pytorch_forecasting.models.x_lstm_time.s_lstm.network import sLSTMNetwork + +from pytorch_forecasting.metrics import SMAPE, Metric +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel from pytorch_forecasting.models.x_lstm_time.m_lstm.network import mLSTMNetwork +from pytorch_forecasting.models.x_lstm_time.s_lstm.network import sLSTMNetwork class SeriesDecomposition(nn.Module): @@ -15,7 +17,9 @@ def __init__(self, kernel_size: int): super(SeriesDecomposition, self).__init__() self.kernel_size = kernel_size self.padding = kernel_size // 2 - self.avg_pool = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=self.padding) + self.avg_pool = nn.AvgPool1d( + kernel_size=kernel_size, stride=1, padding=self.padding + ) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -64,7 +68,9 @@ def __init__( raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") self.xlstm_type = xlstm_type - self._device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) self.to(self._device) self.decomposition = SeriesDecomposition(decomposition_kernel) @@ -99,7 +105,10 @@ def forward( self, x: Dict[str, torch.Tensor], hidden_states: Optional[ - Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ] ] = None, ) -> Dict[str, torch.Tensor]: encoder_cont = x["encoder_cont"] @@ -111,7 +120,9 @@ def forward( concatenated_features = x.shape[-1] if self.input_linear is None: - self.input_linear = nn.Linear(concatenated_features, self.input_projection_size).to(self.device) + self.input_linear = nn.Linear( + concatenated_features, self.input_projection_size + ).to(self.device) x = self.input_linear(x) @@ -143,6 +154,8 @@ def forward( @classmethod def from_dataset(cls, dataset, **kwargs): new_kwargs = copy(kwargs) - new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, SMAPE())) + new_kwargs.update( + cls.deduce_default_output_parameters(dataset, kwargs, SMAPE()) + ) return super().from_dataset(dataset, **kwargs) From 46a9e74bd401a4b9a206e256695db4e97068202d Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 6 Jan 2025 18:12:25 +0530 Subject: [PATCH 10/15] Update layer.py --- pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py index a0393223e..d32217916 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py @@ -6,7 +6,7 @@ class sLSTMLayer(nn.Module): """ - Enhanced s_lstm Layer that supports multiple s_lstm cells across timesteps and residual connections. + Enhanced s_lstm Layer that supports multiple s_lstm cells. """ def __init__( From 7e7d91561c83e647543cc688d750d03825f7c041 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 6 Jan 2025 19:24:51 +0530 Subject: [PATCH 11/15] docs --- .../models/x_lstm_time/m_lstm/cell.py | 63 +++++++++++++++++++ .../models/x_lstm_time/m_lstm/layer.py | 59 ++++++++++++++++- .../models/x_lstm_time/m_lstm/network.py | 60 +++++++++++++++++- .../models/x_lstm_time/s_lstm/cell.py | 63 ++++++++++++++++++- .../models/x_lstm_time/s_lstm/layer.py | 61 ++++++++++++++---- .../models/x_lstm_time/s_lstm/network.py | 57 ++++++++++++++--- .../models/x_lstm_time/x_lstm.py | 11 ++-- 7 files changed, 341 insertions(+), 33 deletions(-) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py index 2dfa476d9..587cc5c36 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py @@ -5,6 +5,46 @@ class mLSTMCell(nn.Module): + """Implements the Matrix Long Short-Term Memory (mLSTM) Cell. + + Implements the mLSTM algorithm as described in the paper: + (https://arxiv.org/pdf/2407.10240). + + Parameters + ---------- + input_size : int + Size of the input feature vector. + hidden_size : int + Number of hidden units in the LSTM cell. + dropout : float, optional + Dropout rate applied to inputs and hidden states, by default 0.2. + layer_norm : bool, optional + If True, apply Layer Normalization to gates and interactions, by default True. + device : torch.device, optional + Device for computation (CPU or CUDA), by default uses GPU if available. + + + Attributes + ---------- + Wq : nn.Linear + Linear layer for computing the query vector. + Wk : nn.Linear + Linear layer for computing the key vector. + Wv : nn.Linear + Linear layer for computing the value vector. + Wi : nn.Linear + Linear layer for the input gate. + Wf : nn.Linear + Linear layer for the forget gate. + Wo : nn.Linear + Linear layer for the output gate. + dropout : nn.Dropout + Dropout regularization layer. + ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm + Optional layer normalization layers for respective computations. + device : torch.device + Device used for computation. + """ def __init__( self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None ): @@ -56,6 +96,29 @@ def __init__( self.tanh = nn.Tanh() def forward(self, x, h_prev, c_prev, n_prev): + """Compute the next hidden, cell, and normalized states in the mLSTM cell. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, input_size). + h_prev : torch.Tensor + Previous hidden state of shape (batch_size, hidden_size). + c_prev : torch.Tensor + Previous cell state of shape (batch_size, hidden_size). + n_prev : torch.Tensor + Previous normalized state of shape (batch_size, hidden_size). + + Returns + ------- + tuple of torch.Tensor: + h : torch.Tensor + Current hidden state of shape (batch_size, hidden_size). + c : torch.Tensor + Current cell state of shape (batch_size, hidden_size). + n : torch.Tensor + Current normalized state of shape (batch_size, hidden_size). + """ x = x.to(self.device) h_prev = h_prev.to(self.device) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py index edb6df6ef..9f80f6d70 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py @@ -5,6 +5,37 @@ class mLSTMLayer(nn.Module): + """Implements a mLSTM (Matrix LSTM) layer. + + This class stacks multiple mLSTM cells to form a deep recurrent layer. + It supports residual connections, layer normalization, and dropout. + + Parameters + ---------- + input_size : int + The number of features in the input. + hidden_size : int + The number of features in the hidden state. + num_layers : int + The number of mLSTM layers to stack. + dropout : float, optional + Dropout probability applied to the inputs and intermediate layers, + by default 0.2. + layer_norm : bool, optional + Whether to use layer normalization in each mLSTM cell, by default True. + residual_conn : bool, optional + Whether to enable residual connections between layers, by default True. + device : torch.device, optional + The device to run the computations on + + Attributes + ---------- + cells : nn.ModuleList + A list containing all mLSTM cells in the layer. + dropout : nn.Dropout + Dropout layer applied between layers. + + """ def __init__( self, input_size, @@ -55,8 +86,32 @@ def init_hidden(self, batch_size): ) def forward(self, x, h=None, c=None, n=None): - """ - Forward pass for the m_lstm layer. + """Forward pass through the mLSTM layer. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, seq_len, input_size). + h : torch.Tensor, optional + Initial hidden states for all layers, shape (num_layers, batch_size, hidden_size). + If None, initialized to zeros, by default None. + c : torch.Tensor, optional + Initial cell states for all layers, shape (num_layers, batch_size, hidden_size). + If None, initialized to zeros, by default None. + n : torch.Tensor, optional + Initial normalized states for all layers, shape (num_layers, batch_size, hidden_size). + If None, initialized to zeros, by default None. + + Returns + ------- + tuple + output : torch.Tensor + Final output tensor from the last layer, shape (batch_size, seq_len, hidden_size). + (h, c, n) : tuple of torch.Tensor + Final hidden, cell, and normalized states for all layers: + - h : torch.Tensor, shape (num_layers, batch_size, hidden_size). + - c : torch.Tensor, shape (num_layers, batch_size, hidden_size). + - n : torch.Tensor, shape (num_layers, batch_size, hidden_size). """ x = x.to(self.device).transpose(0, 1) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py index d6c989b3d..2c4f53d74 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py @@ -5,6 +5,38 @@ class mLSTMNetwork(nn.Module): + """Implements the mLSTM Network, a complete model based on stacked mLSTM layers. + + This network combines stacked mLSTM layers and a fully connected output layer. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each mLSTM layer. + num_layers : int + Number of mLSTM layers to stack. + output_size : int + Number of features in the output. + dropout : float, optional + Dropout probability for the mLSTM layers, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization in the mLSTM layers, by default True. + use_residual : bool, optional + Whether to use residual connections in the mLSTM layers, by default True. + device : torch.device, optional + Device to run the computations on + + Attributes + ---------- + mlstm_layer : mLSTMLayer + Stacked mLSTM layers used for processing input sequences. + fc : nn.Linear + Fully connected layer to generate final output. + + + """ def __init__( self, input_size, @@ -33,8 +65,32 @@ def __init__( self.fc = nn.Linear(hidden_size, output_size) def forward(self, x, h=None, c=None, n=None): - """ - Forward pass through the m_lstm network. + """Forward pass through the mLSTM Network. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, seq_len, input_size). + h : torch.Tensor, optional + Initial hidden states for all layers, shape (num_layers, batch_size, hidden_size). + If None, initialized to zeros, by default None. + c : torch.Tensor, optional + Initial cell states for all layers, shape (num_layers, batch_size, hidden_size). + If None, initialized to zeros, by default None. + n : torch.Tensor, optional + Initial normalized states for all layers, shape (num_layers, batch_size, hidden_size). + If None, initialized to zeros, by default None. + + Returns + ------- + tuple + output : torch.Tensor + Final output tensor from the fully connected layer, shape (batch_size, output_size). + (h, c, n) : tuple of torch.Tensor + Final hidden, cell, and normalized states for all layers: + - h : torch.Tensor, shape (num_layers, batch_size, hidden_size). + - c : torch.Tensor, shape (num_layers, batch_size, hidden_size). + - n : torch.Tensor, shape (num_layers, batch_size, hidden_size). """ output, (h, c, n) = self.mlstm_layer(x, h, c, n) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py index e4c63545e..b42aec5fd 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py @@ -5,7 +5,49 @@ class sLSTMCell(nn.Module): - """Stabilized LSTM Cell""" + """Implements the stabilized LSTM cell + + Implements the sLSTM algorithm as described in the paper: + (https://arxiv.org/pdf/2407.10240). + + Parameters + ---------- + input_size : int + Number of input features for the cell. + hidden_size : int + Number of features in the hidden state of the cell. + dropout : float, optional + Dropout probability for the cell's input and hidden state, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization for the cell's internal computations, by default True. + device : torch.device, optional + The device to run the computations on + + Attributes + ---------- + input_weights : nn.Linear + Linear layer for processing input features into gate computations. + hidden_weights : nn.Linear + Linear layer for processing hidden state features into gate computations. + ln_cell : nn.LayerNorm + Layer normalization for the cell state, applied if use_layer_norm is True. + ln_hidden : nn.LayerNorm + Layer normalization for the output hidden state, applied if use_layer_norm is True. + ln_input : nn.LayerNorm + Layer normalization for input gates, applied if use_layer_norm is True. + ln_hidden_update : nn.LayerNorm + Layer normalization for hidden state gates, applied if use_layer_norm is True. + dropout_layer : nn.Dropout + Dropout layer applied to inputs and hidden states. + grad_clip : float + Gradient clipping threshold to improve training stability. + eps : float + Small constant for numerical stability in calculations. + tanh : nn.Tanh + Tanh activation function. + sigmoid : nn.Sigmoid + Sigmoid activation function. + """ def __init__( self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, device=None @@ -57,7 +99,24 @@ def normalized_exp_gate(self, pre_gate): return exp_val / normalizer def forward(self, x, h_prev, c_prev): - """Forward pass with stabilized exponential gating""" + """Forward pass with stabilized exponential gating. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h_prev : torch.Tensor + Previous hidden state tensor. + c_prev : torch.Tensor + Previous cell state tensor. + + Returns + ------- + h : torch.Tensor + Updated hidden state tensor. + c : torch.Tensor + Updated cell state tensor. + """ x = x.to(self.device) h_prev = h_prev.to(self.device) c_prev = c_prev.to(self.device) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py index d32217916..0c842966d 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py @@ -5,8 +5,37 @@ class sLSTMLayer(nn.Module): - """ - Enhanced s_lstm Layer that supports multiple s_lstm cells. + """Implements the sLSTM Layer, which consists of multiple stacked sLSTM cells. + + This layer is designed for sequence modeling tasks, supporting multiple layers + with optional residual connections and layer normalization. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each sLSTM cell. + num_layers : int, optional + Number of stacked sLSTM layers, by default 1. + dropout : float, optional + Dropout probability for the input of each sLSTM cell, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization for each sLSTM cell, by default True. + use_residual : bool, optional + Whether to use residual connections in each sLSTM layer, by default True. + device : torch.device, optional + The device to run the computations on + + Attributes + ---------- + cells : nn.ModuleList + List of sLSTMCell objects, one for each layer. + input_projection : nn.Linear or None + Linear layer for projecting input to match hidden state size, + used when residual connections are enabled. + layer_norm_layers : nn.ModuleList + List of LayerNorm layers, one for each sLSTM layer (if use_layer_norm is True). """ def __init__( @@ -57,15 +86,25 @@ def __init__( ) def forward(self, x, h=None, c=None): - """ - Forward pass through the s_lstm layer for each time step in sequence. - Args: - x: input tensor (seq_len, batch_size, input_size) - h: initial hidden states (num_layers, batch_size, hidden_size) - c: initial cell states (num_layers, batch_size, hidden_size) - Returns: - output: tensor of hidden states (seq_len, batch_size, hidden_size) - (h, c): final hidden and cell states + """Forward pass through the sLSTM Layer. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : list of torch.Tensor, optional + Initial hidden states for each layer. + If None, hidden states are initialized to zeros. + c : list of torch.Tensor, optional + Initial cell states for each layer. + If None, cell states are initialized to zeros. + + Returns + ------- + output : torch.Tensor + Tensor containing hidden states for each time step. + (h, c) : tuple of lists + Final hidden and cell states for each layer. """ seq_len, batch_size, _ = x.size() diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py index d5846a65d..60a59b1e8 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py @@ -5,8 +5,34 @@ class sLSTMNetwork(nn.Module): - """ - Stabilized LSTM Network with multiple s_lstm layers. + """ Implements the Stabilized LSTM Network with multiple sLSTM layers. + + This network combines sLSTM layers with a fully connected output layer for + prediction. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each sLSTM layer. + num_layers : int + Number of stacked sLSTM layers in the network. + output_size : int + Number of features in the output prediction. + dropout : float, optional + Dropout probability for the input of each sLSTM layer, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization in each sLSTM layer, by default True. + device : torch.device, optional + Device to run the computations on + + Attributes + ---------- + slstm_layer : sLSTMLayer + Stacked sLSTM layers used for processing input sequences. + fc : nn.Linear + Fully connected layer to generate the final output predictions. """ def __init__( @@ -43,14 +69,25 @@ def __init__( def forward(self, x, h=None, c=None): """ - Forward pass through the s_lstm network. - Args: - x: input tensor (seq_len, batch_size, input_size) - h: initial hidden states (num_layers, batch_size, hidden_size) - c: initial cell states (num_layers, batch_size, hidden_size) - Returns: - output: tensor of output predictions (seq_len, batch_size, output_size) - (h, c): final hidden and cell states + Forward pass through the sLSTM network. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : list of torch.Tensor, optional + Initial hidden states for each layer. + If None, hidden states are initialized to zeros. + c : list of torch.Tensor, optional + Initial cell states for each layer. + If None, cell states are initialized to zeros. + + Returns + ------- + output : torch.Tensor + Tensor containing the final output predictions. + (h, c) : tuple of lists + Final hidden and cell states for each layer. """ output, (h, c) = self.slstm_layer(x, h, c) output = self.fc(output[-1]) diff --git a/pytorch_forecasting/models/x_lstm_time/x_lstm.py b/pytorch_forecasting/models/x_lstm_time/x_lstm.py index e88af3920..e8444b3c2 100644 --- a/pytorch_forecasting/models/x_lstm_time/x_lstm.py +++ b/pytorch_forecasting/models/x_lstm_time/x_lstm.py @@ -77,7 +77,11 @@ def __init__( self.batch_norm = nn.BatchNorm1d(hidden_size) self.input_projection_size = input_projection_size or hidden_size - self.input_linear = None + + self.input_linear = nn.Linear( + input_size * 2, + self.input_projection_size + ).to(self.device) if xlstm_type == "mlstm": self.lstm = mLSTMNetwork( @@ -117,12 +121,7 @@ def forward( trend, seasonal = self.decomposition(encoder_cont) x = torch.cat([trend, seasonal], dim=-1) - concatenated_features = x.shape[-1] - if self.input_linear is None: - self.input_linear = nn.Linear( - concatenated_features, self.input_projection_size - ).to(self.device) x = self.input_linear(x) From 31cd4de383b0fb96c597abc75e9a3bfa251a3574 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 6 Jan 2025 19:34:23 +0530 Subject: [PATCH 12/15] linting --- .../models/x_lstm_time/m_lstm/cell.py | 15 ++++++++------- .../models/x_lstm_time/m_lstm/layer.py | 17 +++++++++-------- .../models/x_lstm_time/m_lstm/network.py | 17 +++++++++-------- .../models/x_lstm_time/s_lstm/cell.py | 6 ++++-- .../models/x_lstm_time/s_lstm/network.py | 2 +- .../models/x_lstm_time/x_lstm.py | 8 +++----- 6 files changed, 34 insertions(+), 31 deletions(-) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py index 587cc5c36..311292fd2 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py @@ -45,6 +45,7 @@ class mLSTMCell(nn.Module): device : torch.device Device used for computation. """ + def __init__( self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None ): @@ -101,23 +102,23 @@ def forward(self, x, h_prev, c_prev, n_prev): Parameters ---------- x : torch.Tensor - Input tensor of shape (batch_size, input_size). + The number of features in the input. h_prev : torch.Tensor - Previous hidden state of shape (batch_size, hidden_size). + Previous hidden state c_prev : torch.Tensor - Previous cell state of shape (batch_size, hidden_size). + Previous cell state n_prev : torch.Tensor - Previous normalized state of shape (batch_size, hidden_size). + Previous normalized state Returns ------- tuple of torch.Tensor: h : torch.Tensor - Current hidden state of shape (batch_size, hidden_size). + Current hidden state c : torch.Tensor - Current cell state of shape (batch_size, hidden_size). + Current cell state n : torch.Tensor - Current normalized state of shape (batch_size, hidden_size). + Current normalized state """ x = x.to(self.device) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py index 9f80f6d70..728506a3e 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py @@ -36,6 +36,7 @@ class mLSTMLayer(nn.Module): Dropout layer applied between layers. """ + def __init__( self, input_size, @@ -91,27 +92,27 @@ def forward(self, x, h=None, c=None, n=None): Parameters ---------- x : torch.Tensor - Input tensor of shape (batch_size, seq_len, input_size). + The number of features in the input. h : torch.Tensor, optional - Initial hidden states for all layers, shape (num_layers, batch_size, hidden_size). + Initial hidden states for all layers If None, initialized to zeros, by default None. c : torch.Tensor, optional - Initial cell states for all layers, shape (num_layers, batch_size, hidden_size). + Initial cell states for all layers If None, initialized to zeros, by default None. n : torch.Tensor, optional - Initial normalized states for all layers, shape (num_layers, batch_size, hidden_size). + Initial normalized states for all layers If None, initialized to zeros, by default None. Returns ------- tuple output : torch.Tensor - Final output tensor from the last layer, shape (batch_size, seq_len, hidden_size). + Final output tensor from the last layer (h, c, n) : tuple of torch.Tensor Final hidden, cell, and normalized states for all layers: - - h : torch.Tensor, shape (num_layers, batch_size, hidden_size). - - c : torch.Tensor, shape (num_layers, batch_size, hidden_size). - - n : torch.Tensor, shape (num_layers, batch_size, hidden_size). + - h : torch.Tensor + - c : torch.Tensor + - n : torch.Tensor """ x = x.to(self.device).transpose(0, 1) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py index 2c4f53d74..89e46450b 100644 --- a/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py @@ -37,6 +37,7 @@ class mLSTMNetwork(nn.Module): """ + def __init__( self, input_size, @@ -70,27 +71,27 @@ def forward(self, x, h=None, c=None, n=None): Parameters ---------- x : torch.Tensor - Input tensor of shape (batch_size, seq_len, input_size). + The number of features in the input. h : torch.Tensor, optional - Initial hidden states for all layers, shape (num_layers, batch_size, hidden_size). + Initial hidden states for all layers. If None, initialized to zeros, by default None. c : torch.Tensor, optional - Initial cell states for all layers, shape (num_layers, batch_size, hidden_size). + Initial cell states for all layers. If None, initialized to zeros, by default None. n : torch.Tensor, optional - Initial normalized states for all layers, shape (num_layers, batch_size, hidden_size). + Initial normalized states for all layers. If None, initialized to zeros, by default None. Returns ------- tuple output : torch.Tensor - Final output tensor from the fully connected layer, shape (batch_size, output_size). + Final output tensor from the fully connected layer. (h, c, n) : tuple of torch.Tensor Final hidden, cell, and normalized states for all layers: - - h : torch.Tensor, shape (num_layers, batch_size, hidden_size). - - c : torch.Tensor, shape (num_layers, batch_size, hidden_size). - - n : torch.Tensor, shape (num_layers, batch_size, hidden_size). + - h : torch.Tensor + - c : torch.Tensor + - n : torch.Tensor """ output, (h, c, n) = self.mlstm_layer(x, h, c, n) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py index b42aec5fd..c4380356f 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py @@ -19,7 +19,8 @@ class sLSTMCell(nn.Module): dropout : float, optional Dropout probability for the cell's input and hidden state, by default 0.0. use_layer_norm : bool, optional - Whether to use layer normalization for the cell's internal computations, by default True. + Whether to use layer normalization for the cell's internal computations, + by default True. device : torch.device, optional The device to run the computations on @@ -32,7 +33,8 @@ class sLSTMCell(nn.Module): ln_cell : nn.LayerNorm Layer normalization for the cell state, applied if use_layer_norm is True. ln_hidden : nn.LayerNorm - Layer normalization for the output hidden state, applied if use_layer_norm is True. + Layer normalization for the output hidden state, + applied if use_layer_norm is True. ln_input : nn.LayerNorm Layer normalization for input gates, applied if use_layer_norm is True. ln_hidden_update : nn.LayerNorm diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py index 60a59b1e8..5f94023c5 100644 --- a/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py @@ -5,7 +5,7 @@ class sLSTMNetwork(nn.Module): - """ Implements the Stabilized LSTM Network with multiple sLSTM layers. + """Implements the Stabilized LSTM Network with multiple sLSTM layers. This network combines sLSTM layers with a fully connected output layer for prediction. diff --git a/pytorch_forecasting/models/x_lstm_time/x_lstm.py b/pytorch_forecasting/models/x_lstm_time/x_lstm.py index e8444b3c2..eb02c7306 100644 --- a/pytorch_forecasting/models/x_lstm_time/x_lstm.py +++ b/pytorch_forecasting/models/x_lstm_time/x_lstm.py @@ -78,10 +78,9 @@ def __init__( self.input_projection_size = input_projection_size or hidden_size - self.input_linear = nn.Linear( - input_size * 2, - self.input_projection_size - ).to(self.device) + self.input_linear = nn.Linear(input_size * 2, self.input_projection_size).to( + self.device + ) if xlstm_type == "mlstm": self.lstm = mLSTMNetwork( @@ -122,7 +121,6 @@ def forward( x = torch.cat([trend, seasonal], dim=-1) - x = self.input_linear(x) x = x.transpose(1, 2) From c72bff908d403ddbe3b30362cf7c096804afb919 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 6 Jan 2025 21:57:52 +0530 Subject: [PATCH 13/15] Update __init__.py --- pytorch_forecasting/models/x_lstm_time/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_forecasting/models/x_lstm_time/__init__.py b/pytorch_forecasting/models/x_lstm_time/__init__.py index 7ebcbd7cc..0b082eeae 100644 --- a/pytorch_forecasting/models/x_lstm_time/__init__.py +++ b/pytorch_forecasting/models/x_lstm_time/__init__.py @@ -1,3 +1,4 @@ +"""xLSTMTime implementation for forecasting""" from pytorch_forecasting.models.x_lstm_time.x_lstm import xLSTMTime __all__ = ["xLSTMTime"] From 62e97ae5ea549d5659c30f76e229a3d6a8381b3e Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 6 Jan 2025 21:59:49 +0530 Subject: [PATCH 14/15] Update __init__.py --- pytorch_forecasting/models/x_lstm_time/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/x_lstm_time/__init__.py b/pytorch_forecasting/models/x_lstm_time/__init__.py index 0b082eeae..fb5bc9892 100644 --- a/pytorch_forecasting/models/x_lstm_time/__init__.py +++ b/pytorch_forecasting/models/x_lstm_time/__init__.py @@ -1,4 +1,5 @@ -"""xLSTMTime implementation for forecasting""" +"""xLSTMTime implementation for forecasting.""" + from pytorch_forecasting.models.x_lstm_time.x_lstm import xLSTMTime __all__ = ["xLSTMTime"] From acb23e747dde0a3ddef2bfd26f8842d042650076 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Tue, 21 Jan 2025 23:26:21 +0530 Subject: [PATCH 15/15] Adding tests --- tests/test_models/test_x_lstm.py | 112 +++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 tests/test_models/test_x_lstm.py diff --git a/tests/test_models/test_x_lstm.py b/tests/test_models/test_x_lstm.py new file mode 100644 index 000000000..1392f7b9e --- /dev/null +++ b/tests/test_models/test_x_lstm.py @@ -0,0 +1,112 @@ +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +import pytest + +from pytorch_forecasting.metrics import SMAPE +from pytorch_forecasting.models.x_lstm_time.x_lstm import xLSTMTime + + +def _integration( + dataloaders_fixed_window_without_covariates, tmp_path, xlstm_type="slstm", **kwargs +): + + train_dataloader = dataloaders_fixed_window_without_covariates["train"] + val_dataloader = dataloaders_fixed_window_without_covariates["val"] + test_dataloader = dataloaders_fixed_window_without_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + ) + + model_kwargs = { + "input_size": 1, + "output_size": 1, + "hidden_size": 32, + "xlstm_type": xlstm_type, + "learning_rate": 0.01, + "loss": SMAPE(), + } + + model_kwargs.update(kwargs) + + net = xLSTMTime.from_dataset(train_dataloader.dataset, **model_kwargs) + + try: + + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + + net = xLSTMTime.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"xlstm_type": "mlstm"}, + {"num_layers": 2}, + {"xlstm_type": "slstm", "input_projection_size": 32}, + { + "xlstm_type": "mlstm", + "decomposition_kernel": 13, + "dropout": 0.2, + }, + ], +) +def test_integration(dataloaders_fixed_window_without_covariates, tmp_path, kwargs): + _integration(dataloaders_fixed_window_without_covariates, tmp_path, **kwargs) + + +@pytest.fixture(scope="session") +def model(dataloaders_fixed_window_without_covariates): + dataset = dataloaders_fixed_window_without_covariates["train"].dataset + net = xLSTMTime.from_dataset( + dataset, + input_size=1, + hidden_size=32, + output_size=1, + xlstm_type="slstm", + learning_rate=0.01, + loss=SMAPE(), + ) + return net