Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TimeMixer as an imputation model #499

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@
from .grud import GRUD
from .stemgnn import StemGNN
from .imputeformer import ImputeFormer
from .timemixer import TimeMixer

# naive imputation methods
from .locf import LOCF
@@ -75,6 +76,7 @@
"GRUD",
"StemGNN",
"ImputeFormer",
"TimeMixer",
# naive imputation methods
"LOCF",
"Mean",
24 changes: 24 additions & 0 deletions pypots/imputation/timemixer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package including the modules of TimeMixer.
Refer to the paper
`Shiyu Wang, Haixu Wu, Xiaoming Shi, Tengge Hu, Huakun Luo, Lintao Ma, James Y. Zhang, and Jun Zhou.
"TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting".
In ICLR 2024.
<https://openreview.net/pdf?id=7oLshfEIC2>`_
Notes
-----
This implementation is inspired by the official one https://github.com/kwuking/TimeMixer
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import TimeMixer

__all__ = [
"TimeMixer",
]
83 changes: 83 additions & 0 deletions pypots/imputation/timemixer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.functional import (
nonstationary_norm,
nonstationary_denorm,
)
from ...nn.modules.timemixer import BackboneTimeMixer
from ...utils.metrics import calc_mse


class _TimeMixer(nn.Module):
def __init__(
self,
n_layers,
n_steps,
n_features,
d_model,
d_ffn,
dropout,
top_k,
channel_independence,
decomp_method,
moving_avg,
downsampling_layers,
downsampling_window,
apply_nonstationary_norm: bool = False,
):
super().__init__()

self.apply_nonstationary_norm = apply_nonstationary_norm

self.model = BackboneTimeMixer(
task_name="imputation",
n_steps=n_steps,
n_features=n_features,
n_pred_steps=None,
n_pred_features=n_features,
n_layers=n_layers,
d_model=d_model,
d_ffn=d_ffn,
dropout=dropout,
channel_independence=channel_independence,
decomp_method=decomp_method,
top_k=top_k,
moving_avg=moving_avg,
downsampling_layers=downsampling_layers,
downsampling_window=downsampling_window,
downsampling_method="avg",
use_future_temporal_feature=False,
)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, missing_mask)

# TimesMixer processing
dec_out = self.model.imputation(X, None)

if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
dec_out = nonstationary_denorm(dec_out, means, stdev)

imputed_data = missing_mask * X + (1 - missing_mask) * dec_out
results = {
"imputed_data": imputed_data,
}

if training:
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/timemixer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for the imputation model TimeMixer.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForTimeMixer(DatasetForSAITS):
"""Actually TimeMixer uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
343 changes: 343 additions & 0 deletions pypots/imputation/timemixer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
"""
The implementation of TimeMixer for the partially-observed time-series imputation task.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union, Optional

import numpy as np
import torch
from torch.utils.data import DataLoader

from .core import _TimeMixer
from .data import DatasetForTimeMixer
from ..base import BaseNNImputer
from ...data.checking import key_in_data_set
from ...data.dataset import BaseDataset
from ...optim.adam import Adam
from ...optim.base import Optimizer


class TimeMixer(BaseNNImputer):
"""The PyTorch implementation of the TimeMixer model.
TimeMixer is originally proposed by Wang et al. in :cite:`wang2024timemixer`.
Parameters
----------
n_steps :
The number of time steps in the time-series data sample.
n_features :
The number of features in the time-series data sample.
n_layers :
The number of layers in the TimeMixer model.
d_model :
The dimension of the model.
d_ffn :
The dimension of the feed-forward network.
top_k :
The number of top-k amplitude values to be selected to obtain the most significant frequencies.
dropout :
The dropout rate for the model.
channel_independence :
Whether to use channel independence in the model.
decomp_method :
The decomposition method for the model. It has to be one of ['moving_avg', 'dft_decomp'].
moving_avg :
The window size for moving average decomposition.
downsampling_layers :
The number of downsampling layers in the model.
downsampling_window :
The window size for downsampling.
apply_nonstationary_norm :
Whether to apply non-stationary normalization to the input data for TimeMixer.
Please refer to :cite:`liu2022nonstationary` for details about non-stationary normalization,
which is not the idea of the original TimeMixer paper. Hence, we make it optional and default not to use here.
batch_size :
The batch size for training and evaluating the model.
epochs :
The number of epochs for training the model.
patience :
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.
optimizer :
The optimizer for model training.
If not given, will use a default Adam optimizer.
num_workers :
The number of subprocesses to use for data loading.
`0` means data loading will be in the main process, i.e. there won't be subprocesses.
device :
The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
saving_path :
The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
training into a tensorboard file). Will not save if not given.
model_saving_strategy :
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
No model will be saved when it is set as None.
The "best" strategy will only automatically save the best model after the training finished.
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
The "all" strategy will save every model after each epoch training.
verbose :
Whether to print out the training logs during the training process.
"""

def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_model: int,
d_ffn: int,
top_k: int,
dropout: float = 0,
channel_independence: bool = False,
decomp_method: str = "moving_avg",
moving_avg: int = 5,
downsampling_layers: int = 3,
downsampling_window: int = 2,
apply_nonstationary_norm: bool = False,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
model_saving_strategy: Optional[str] = "best",
verbose: bool = True,
):
super().__init__(
batch_size,
epochs,
patience,
num_workers,
device,
saving_path,
model_saving_strategy,
verbose,
)

assert decomp_method in [
"moving_avg",
"dft_decomp",
], "decomp_method must be one of ['moving_avg', 'dft_decomp']."

self.n_steps = n_steps
self.n_features = n_features
# model hype-parameters
self.n_layers = n_layers
self.d_model = d_model
self.d_ffn = d_ffn
self.dropout = dropout
self.top_k = top_k
self.channel_independence = channel_independence
self.decomp_method = decomp_method
self.moving_avg = moving_avg
self.downsampling_layers = downsampling_layers
self.downsampling_window = downsampling_window
self.apply_nonstationary_norm = apply_nonstationary_norm

# set up the model
self.model = _TimeMixer(
self.n_layers,
self.n_steps,
self.n_features,
self.d_model,
self.d_ffn,
self.dropout,
self.top_k,
self.channel_independence,
self.decomp_method,
self.moving_avg,
self.downsampling_layers,
self.downsampling_window,
self.apply_nonstationary_norm,
)
self._send_model_to_given_device()
self._print_model_size()

# set up the optimizer
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: list) -> dict:
(
indices,
X,
missing_mask,
X_ori,
indicating_mask,
) = self._send_data_to_given_device(data)

inputs = {
"X": X,
"missing_mask": missing_mask,
"X_ori": X_ori,
"indicating_mask": indicating_mask,
}

return inputs

def _assemble_input_for_validating(self, data: list) -> dict:
return self._assemble_input_for_training(data)

def _assemble_input_for_testing(self, data: list) -> dict:
indices, X, missing_mask = self._send_data_to_given_device(data)

inputs = {
"X": X,
"missing_mask": missing_mask,
}

return inputs

def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "hdf5",
) -> None:
# Step 1: wrap the input data with classes Dataset and DataLoader
training_set = DatasetForTimeMixer(
train_set, return_X_ori=False, return_y=False, file_type=file_type
)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
val_loader = None
if val_set is not None:
if not key_in_data_set("X_ori", val_set):
raise ValueError("val_set must contain 'X_ori' for model validation.")
val_set = DatasetForTimeMixer(
val_set, return_X_ori=True, return_y=False, file_type=file_type
)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

# Step 2: train the model and freeze it
self._train_model(training_loader, val_loader)
self.model.load_state_dict(self.best_model_dict)
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)

def predict(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type :
The type of the given file if test_set is a path string.
Returns
-------
file_type :
The dictionary containing the clustering results and latent variables if necessary.
"""
# Step 1: wrap the input data with classes Dataset and DataLoader
self.model.eval() # set the model as eval status to freeze it.
test_set = BaseDataset(
test_set,
return_X_ori=False,
return_X_pred=False,
return_y=False,
file_type=file_type,
)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
imputation_collector = []

# Step 2: process the data with the model
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
imputation = torch.cat(imputation_collector).cpu().detach().numpy()
result_dict = {
"imputation": imputation,
}
return result_dict

def impute(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.
Parameters
----------
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples, sequence length (n_steps), n_features],
Imputed data.
"""

result_dict = self.predict(test_set, file_type=file_type)
return result_dict["imputation"]
23 changes: 23 additions & 0 deletions pypots/nn/modules/timemixer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
The package including the modules of TimeMixer.
Refer to the paper
`Shiyu Wang, Haixu Wu, Xiaoming Shi, Tengge Hu, Huakun Luo, Lintao Ma, James Y. Zhang, and Jun Zhou.
"TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting".
In ICLR 2024.
<https://openreview.net/pdf?id=7oLshfEIC2>`_
Notes
-----
This implementation is inspired by the official one https://github.com/kwuking/TimeMixer
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from .backbone import BackboneTimeMixer

__all__ = [
"BackboneTimeMixer",
]
392 changes: 392 additions & 0 deletions pypots/nn/modules/timemixer/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,392 @@
"""
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

from .layers import PastDecomposableMixing
from ..autoformer import SeriesDecompositionBlock
from ..revin import RevIN
from ..transformer.embedding import DataEmbedding


class BackboneTimeMixer(nn.Module):
def __init__(
self,
task_name,
n_steps,
n_features,
n_pred_steps,
n_pred_features,
n_layers,
d_model,
d_ffn,
dropout,
channel_independence,
decomp_method,
top_k,
moving_avg,
downsampling_layers,
downsampling_window,
downsampling_method,
use_future_temporal_feature,
embed="fixed",
freq="h",
n_classes=None,
):
super().__init__()
self.task_name = task_name
self.n_steps = n_steps
self.n_features = n_features
self.n_pred_steps = n_pred_steps
self.n_pred_features = n_pred_features
self.n_layers = n_layers
self.channel_independence = channel_independence
self.downsampling_window = downsampling_window
self.downsampling_layers = downsampling_layers
self.downsampling_method = downsampling_method
self.use_future_temporal_feature = use_future_temporal_feature

self.pdm_blocks = nn.ModuleList(
[
PastDecomposableMixing(
n_steps,
n_pred_steps,
d_model,
d_ffn,
dropout,
channel_independence,
decomp_method,
top_k,
moving_avg,
downsampling_layers,
downsampling_window,
)
for _ in range(n_layers)
]
)
self.preprocess = SeriesDecompositionBlock(moving_avg)

if self.channel_independence == 1:
self.enc_embedding = DataEmbedding(
1, d_model, embed, freq, dropout, with_pos=False
)
else:
self.enc_embedding = DataEmbedding(
n_features, d_model, embed, freq, dropout, with_pos=False
)

self.normalize_layers = torch.nn.ModuleList(
[RevIN(n_features) for _ in range(downsampling_layers + 1)]
)

if task_name == "long_term_forecast" or task_name == "short_term_forecast":
self.predict_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
n_steps // (downsampling_window**i),
n_pred_steps,
)
for i in range(downsampling_layers + 1)
]
)

if self.channel_independence == 1:
self.projection_layer = nn.Linear(d_model, 1, bias=True)
else:
self.projection_layer = nn.Linear(d_model, n_pred_features, bias=True)

self.out_res_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
n_steps // (downsampling_window**i),
n_steps // (downsampling_window**i),
)
for i in range(downsampling_layers + 1)
]
)

self.regression_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
n_steps // (downsampling_window**i),
n_pred_steps,
)
for i in range(downsampling_layers + 1)
]
)
if task_name == "imputation" or task_name == "anomaly_detection":
if self.channel_independence == 1:
self.projection_layer = nn.Linear(d_model, 1, bias=True)
else:
self.projection_layer = nn.Linear(d_model, n_pred_features, bias=True)
if task_name == "classification":
self.act = F.gelu
self.dropout = nn.Dropout(dropout)
self.projection = nn.Linear(d_model * n_steps, n_classes)

def out_projection(self, dec_out, i, out_res):
dec_out = self.projection_layer(dec_out)
out_res = out_res.permute(0, 2, 1)
out_res = self.out_res_layers[i](out_res)
out_res = self.regression_layers[i](out_res).permute(0, 2, 1)
dec_out = dec_out + out_res
return dec_out

def pre_enc(self, x_list):
if self.channel_independence == 1:
return x_list, None
else:
out1_list = []
out2_list = []
for x in x_list:
x_1, x_2 = self.preprocess(x)
out1_list.append(x_1)
out2_list.append(x_2)
return out1_list, out2_list

def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
if self.downsampling_method == "max":
down_pool = torch.nn.MaxPool1d(
self.downsampling_window, return_indices=False
)
elif self.downsampling_method == "avg":
down_pool = torch.nn.AvgPool1d(self.downsampling_window)
elif self.downsampling_method == "conv":
padding = 1 if torch.__version__ >= "1.5.0" else 2
down_pool = nn.Conv1d(
in_channels=self.enc_in,
out_channels=self.enc_in,
kernel_size=3,
padding=padding,
stride=self.downsampling_window,
padding_mode="circular",
bias=False,
)
else:
return x_enc, x_mark_enc
# B,T,C -> B,C,T
x_enc = x_enc.permute(0, 2, 1)

x_enc_ori = x_enc
x_mark_enc_mark_ori = x_mark_enc

x_enc_sampling_list = []
x_mark_sampling_list = []
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
x_mark_sampling_list.append(x_mark_enc)

for i in range(self.downsampling_layers):
x_enc_sampling = down_pool(x_enc_ori)

x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
x_enc_ori = x_enc_sampling

if x_mark_enc_mark_ori is not None:
x_mark_sampling_list.append(
x_mark_enc_mark_ori[:, :: self.downsampling_window, :]
)
x_mark_enc_mark_ori = x_mark_enc_mark_ori[
:, :: self.downsampling_window, :
]

x_enc = x_enc_sampling_list
if x_mark_enc_mark_ori is not None:
x_mark_enc = x_mark_sampling_list
else:
x_mark_enc = x_mark_enc

return x_enc, x_mark_enc

def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
if self.use_future_temporal_feature:
if self.channel_independence == 1:
B, T, N = x_enc.size()
x_mark_dec = x_mark_dec.repeat(N, 1, 1)
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
else:
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)

x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)

x_list = []
x_mark_list = []
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
B, T, N = x.size()
x = self.normalize_layers[i](x, "norm")
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_mark = x_mark.repeat(N, 1, 1)
x_list.append(x)
x_mark_list.append(x_mark)
else:
for i, x in zip(
range(len(x_enc)),
x_enc,
):
B, T, N = x.size()
x = self.normalize_layers[i](x, "norm")
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)

# embedding
enc_out_list = []
x_list = self.pre_enc(x_list)
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list):
enc_out = self.enc_embedding(x, x_mark) # [B,T,C]
enc_out_list.append(enc_out)
else:
for i, x in zip(range(len(x_list[0])), x_list[0]):
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)

# Past Decomposable Mixing as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)

# Future Multipredictor Mixing as decoder for future
dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list)

dec_out = torch.stack(dec_out_list, dim=-1).sum(-1)
dec_out = self.normalize_layers[0](dec_out, "denorm")
return dec_out

def future_multi_mixing(self, B, enc_out_list, x_list):
dec_out_list = []
if self.channel_independence == 1:
x_list = x_list[0]
for i, enc_out in zip(range(len(x_list)), enc_out_list):
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
0, 2, 1
) # align temporal dimension
if self.use_future_temporal_feature:
dec_out = dec_out + self.x_mark_dec
dec_out = self.projection_layer(dec_out)
else:
dec_out = self.projection_layer(dec_out)
dec_out = (
dec_out.reshape(B, self.c_out, self.n_pred_steps)
.permute(0, 2, 1)
.contiguous()
)
dec_out_list.append(dec_out)

else:
for i, enc_out, out_res in zip(
range(len(x_list[0])), enc_out_list, x_list[1]
):
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
0, 2, 1
) # align temporal dimension
dec_out = self.out_projection(dec_out, i, out_res)
dec_out_list.append(dec_out)

return dec_out_list

def classification(self, x_enc, x_mark_enc):
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
x_list = x_enc

# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)

# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)

enc_out = enc_out_list[0]
# Output
# the output transformer encoder/decoder embeddings don't include non-linearity
output = self.act(enc_out)
output = self.dropout(output)
# zero-out padding embeddings
output = output * x_mark_enc.unsqueeze(-1)
# (batch_size, n_stepsgth * d_model)
output = output.reshape(output.shape[0], -1)
output = self.projection(output) # (batch_size, num_classes)
return output

def anomaly_detection(self, x_enc):
B, T, N = x_enc.size()
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)

x_list = []

for i, x in zip(
range(len(x_enc)),
x_enc,
):
B, T, N = x.size()
x = self.normalize_layers[i](x, "norm")
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)

# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)

# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)

dec_out = self.projection_layer(enc_out_list[0])
dec_out = dec_out.reshape(B, self.c_out, -1).permute(0, 2, 1).contiguous()

dec_out = self.normalize_layers[0](dec_out, "denorm")
return dec_out

def imputation(self, x_enc, x_mark_enc):

B, T, N = x_enc.size()
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)

x_list = []
x_mark_list = []
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
B, T, N = x.size()
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
x_mark = x_mark.repeat(N, 1, 1)
x_mark_list.append(x_mark)
else:
for i, x in zip(
range(len(x_enc)),
x_enc,
):
B, T, N = x.size()
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)

# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)

# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.n_layers):
enc_out_list = self.pdm_blocks[i](enc_out_list)

dec_out = self.projection_layer(enc_out_list[0])
dec_out = (
dec_out.reshape(B, self.n_pred_features, -1).permute(0, 2, 1).contiguous()
)

return dec_out
222 changes: 222 additions & 0 deletions pypots/nn/modules/timemixer/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ..autoformer import SeriesDecompositionBlock


class DFT_series_decomp(nn.Module):
"""
Series decomposition block
"""

def __init__(self, top_k=5):
super().__init__()
self.top_k = top_k

def forward(self, x):
xf = torch.fft.rfft(x)
freq = abs(xf)
freq[0] = 0
top_k_freq, top_list = torch.topk(freq, self.top_k)
xf[freq <= top_k_freq.min()] = 0
x_season = torch.fft.irfft(xf)
x_trend = x - x_season
return x_season, x_trend


class MultiScaleSeasonMixing(nn.Module):
"""
Bottom-up mixing season pattern
"""

def __init__(
self,
n_steps,
downsampling_window,
downsampling_layers,
):
super().__init__()

self.downsampling_layers = torch.nn.ModuleList(
[
nn.Sequential(
torch.nn.Linear(
n_steps // (downsampling_window**i),
n_steps // (downsampling_window ** (i + 1)),
),
nn.GELU(),
torch.nn.Linear(
n_steps // (downsampling_window ** (i + 1)),
n_steps // (downsampling_window ** (i + 1)),
),
)
for i in range(downsampling_layers)
]
)

def forward(self, season_list):

# mixing high->low
out_high = season_list[0]
out_low = season_list[1]
out_season_list = [out_high.permute(0, 2, 1)]

for i in range(len(season_list) - 1):
out_low_res = self.downsampling_layers[i](out_high)
out_low = out_low + out_low_res
out_high = out_low
if i + 2 <= len(season_list) - 1:
out_low = season_list[i + 2]
out_season_list.append(out_high.permute(0, 2, 1))

return out_season_list


class MultiScaleTrendMixing(nn.Module):
"""
Top-down mixing trend pattern
"""

def __init__(
self,
n_steps,
downsampling_window,
downsampling_layers,
):
super().__init__()

self.up_sampling_layers = torch.nn.ModuleList(
[
nn.Sequential(
torch.nn.Linear(
n_steps // (downsampling_window ** (i + 1)),
n_steps // (downsampling_window**i),
),
nn.GELU(),
torch.nn.Linear(
n_steps // (downsampling_window**i),
n_steps // (downsampling_window**i),
),
)
for i in reversed(range(downsampling_layers))
]
)

def forward(self, trend_list):

# mixing low->high
trend_list_reverse = trend_list.copy()
trend_list_reverse.reverse()
out_low = trend_list_reverse[0]
out_high = trend_list_reverse[1]
out_trend_list = [out_low.permute(0, 2, 1)]

for i in range(len(trend_list_reverse) - 1):
out_high_res = self.up_sampling_layers[i](out_low)
out_high = out_high + out_high_res
out_low = out_high
if i + 2 <= len(trend_list_reverse) - 1:
out_high = trend_list_reverse[i + 2]
out_trend_list.append(out_low.permute(0, 2, 1))

out_trend_list.reverse()
return out_trend_list


class PastDecomposableMixing(nn.Module):
def __init__(
self,
n_steps,
n_pred_steps,
d_model,
d_ffn,
dropout,
channel_independence,
decomp_method,
top_k,
moving_avg,
downsampling_layers,
downsampling_window,
):
super().__init__()
self.n_steps = n_steps
self.n_pred_steps = n_pred_steps
self.downsampling_window = downsampling_window

self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.channel_independence = channel_independence

if decomp_method == "moving_avg":
self.decompsition = SeriesDecompositionBlock(moving_avg)
elif decomp_method == "dft_decomp":
self.decompsition = DFT_series_decomp(top_k)
else:
raise ValueError("decompsition is error")

if channel_independence == 0:
self.cross_layer = nn.Sequential(
nn.Linear(in_features=d_model, out_features=d_ffn),
nn.GELU(),
nn.Linear(in_features=d_ffn, out_features=d_model),
)

# Mixing season
self.mixing_multi_scale_season = MultiScaleSeasonMixing(
n_steps,
downsampling_window,
downsampling_layers,
)

# Mxing trend
self.mixing_multi_scale_trend = MultiScaleTrendMixing(
n_steps,
downsampling_window,
downsampling_layers,
)

self.out_cross_layer = nn.Sequential(
nn.Linear(in_features=d_model, out_features=d_ffn),
nn.GELU(),
nn.Linear(in_features=d_ffn, out_features=d_model),
)

def forward(self, x_list):
length_list = []
for x in x_list:
_, T, _ = x.size()
length_list.append(T)

# Decompose to obtain the season and trend
season_list = []
trend_list = []
for x in x_list:
season, trend = self.decompsition(x)
if self.channel_independence == 0:
season = self.cross_layer(season)
trend = self.cross_layer(trend)
season_list.append(season.permute(0, 2, 1))
trend_list.append(trend.permute(0, 2, 1))

# bottom-up season mixing
out_season_list = self.mixing_multi_scale_season(season_list)
# top-down trend mixing
out_trend_list = self.mixing_multi_scale_trend(trend_list)

out_list = []
for ori, out_season, out_trend, length in zip(
x_list, out_season_list, out_trend_list, length_list
):
out = out_season + out_trend
if self.channel_independence:
out = ori + self.out_cross_layer(out)
out_list.append(out[:, :length, :])

return out_list
128 changes: 128 additions & 0 deletions tests/imputation/timemixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Test cases for TimeMixer imputation model.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


import os.path
import unittest

import numpy as np
import pytest

from pypots.imputation import TimeMixer
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import calc_mse
from tests.global_test_config import (
DATA,
EPOCHS,
DEVICE,
TRAIN_SET,
VAL_SET,
TEST_SET,
GENERAL_H5_TRAIN_SET_PATH,
GENERAL_H5_VAL_SET_PATH,
GENERAL_H5_TEST_SET_PATH,
RESULT_SAVING_DIR_FOR_IMPUTATION,
check_tb_and_model_checkpoints_existence,
)


class TestTimeMixer(unittest.TestCase):
logger.info("Running tests for an imputation model TimeMixer...")

# set the log and model saving path
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "TimeMixer")
model_save_name = "saved_timemixer_model.pypots"

# initialize an Adam optimizer
optimizer = Adam(lr=0.001, weight_decay=1e-5)

# initialize a TimeMixer model
timemixer = TimeMixer(
DATA["n_steps"],
DATA["n_features"],
n_layers=2,
top_k=5,
d_model=512,
d_ffn=512,
dropout=0.1,
epochs=EPOCHS,
saving_path=saving_path,
optimizer=optimizer,
device=DEVICE,
)

@pytest.mark.xdist_group(name="imputation-timemixer")
def test_0_fit(self):
self.timemixer.fit(TRAIN_SET, VAL_SET)

@pytest.mark.xdist_group(name="imputation-timemixer")
def test_1_impute(self):
imputation_results = self.timemixer.predict(TEST_SET)
assert not np.isnan(
imputation_results["imputation"]
).any(), "Output still has missing values after running impute()."

test_MSE = calc_mse(
imputation_results["imputation"],
DATA["test_X_ori"],
DATA["test_X_indicating_mask"],
)
logger.info(f"TimeMixer test_MSE: {test_MSE}")

@pytest.mark.xdist_group(name="imputation-timemixer")
def test_2_parameters(self):
assert hasattr(self.timemixer, "model") and self.timemixer.model is not None

assert (
hasattr(self.timemixer, "optimizer")
and self.timemixer.optimizer is not None
)

assert hasattr(self.timemixer, "best_loss")
self.assertNotEqual(self.timemixer.best_loss, float("inf"))

assert (
hasattr(self.timemixer, "best_model_dict")
and self.timemixer.best_model_dict is not None
)

@pytest.mark.xdist_group(name="imputation-timemixer")
def test_3_saving_path(self):
# whether the root saving dir exists, which should be created by save_log_into_tb_file
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# check if the tensorboard file and model checkpoints exist
check_tb_and_model_checkpoints_existence(self.timemixer)

# save the trained model into file, and check if the path exists
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.timemixer.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
self.timemixer.load(saved_model_path)

@pytest.mark.xdist_group(name="imputation-timemixer")
def test_4_lazy_loading(self):
self.timemixer.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)
imputation_results = self.timemixer.predict(GENERAL_H5_TEST_SET_PATH)
assert not np.isnan(
imputation_results["imputation"]
).any(), "Output still has missing values after running impute()."

test_MSE = calc_mse(
imputation_results["imputation"],
DATA["test_X_ori"],
DATA["test_X_indicating_mask"],
)
logger.info(f"Lazy-loading TimeMixer test_MSE: {test_MSE}")


if __name__ == "__main__":
unittest.main()