diff --git a/CHANGELOG.md b/CHANGELOG.md index d7df8c0969..7d95e8d0ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,6 +113,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Updated the Ray Tune Hyperparameter Optimization example in the [user guide](https://unit8co.github.io/darts/userguide/hyperparameter_optimization.html) to work with the latest `ray` versions (`>=2.31.0`). [#2459](https://github.com/unit8co/darts/pull/2459) by [He Weilin](https://github.com/cnhwl). - Indicate that `multi_models=False` induce a lags shift for each step in `output_chunk_length` in `RegressionModel` and `LinearRegressionModel`. [#2511](https://github.com/unit8co/darts/pull/2511) by [Antoine Madrona](https://github.com/madtoinou). - Added reference to `timeseries_generation.datetime_attribute_timeseries` in `TimeSeries.add_datetime_attribute` (0-indexing of encoding is enforced). [#2511](https://github.com/unit8co/darts/pull/2511) by [Antoine Madrona](https://github.com/madtoinou). + - Added `project_first_layer` keyword to `TSMixerModel` (default False), allowing the option of projecting to `output_chunk_length` at the end, rather than the beginning, of the model. This could improve performance when past covariates are more important the future covariates. **Fixed** diff --git a/darts/models/forecasting/tsmixer_model.py b/darts/models/forecasting/tsmixer_model.py index 56100ed4c6..027939bfc3 100644 --- a/darts/models/forecasting/tsmixer_model.py +++ b/darts/models/forecasting/tsmixer_model.py @@ -267,6 +267,7 @@ def __init__( super().__init__() mixing_input = input_dim + if static_cov_dim != 0: self.feature_mixing_static = _FeatureMixing( sequence_length=sequence_length, @@ -328,6 +329,7 @@ def __init__( dropout: float, norm_type: Union[str, nn.Module], normalize_before: bool, + project_first_layer: bool = True, **kwargs, ) -> None: """ @@ -362,6 +364,11 @@ def __init__( Type of normalization to use. normalize_before Whether to apply normalization before or after mixing. + project_first_layer + Whether to project to the output time dimension at the first layer (default), + or at the end of the module. False is recommended if there are + no future covariates, while True is recommended if there are + important future covariates. """ super().__init__(**kwargs) self.input_dim = input_dim @@ -369,11 +376,16 @@ def __init__( self.future_cov_dim = future_cov_dim self.static_cov_dim = static_cov_dim self.nr_params = nr_params + self.project_first_layer = project_first_layer + + self.sequence_length = ( + self.output_chunk_length if project_first_layer else self.input_chunk_length + ) if activation not in ACTIVATIONS: raise_log( ValueError( - f"Invalid `activation={activation}`. Must be on of {ACTIVATIONS}." + f"Invalid `activation={activation}`. Must be one of {ACTIVATIONS}." ), logger=logger, ) @@ -383,7 +395,7 @@ def __init__( if norm_type not in NORMS: raise_log( ValueError( - f"Invalid `norm_type={norm_type}`. Must be on of {NORMS}." + f"Invalid `norm_type={norm_type}`. Must be one of {NORMS}." ), logger=logger, ) @@ -402,16 +414,25 @@ def __init__( "normalize_before": normalize_before, } + # Projects from the input time dimension to the output time dimension self.fc_hist = nn.Linear(self.input_chunk_length, self.output_chunk_length) + + # Projects from the output time dimension to the input time dimension + # (if we are keeping the input time dimension/project_first_layer=False) + if not self.project_first_layer: + self.fc_future = nn.Linear( + self.output_chunk_length, self.input_chunk_length + ) + self.feature_mixing_hist = _FeatureMixing( - sequence_length=self.output_chunk_length, + sequence_length=self.sequence_length, input_dim=input_dim + past_cov_dim + future_cov_dim, output_dim=hidden_size, **mixer_params, ) if future_cov_dim: self.feature_mixing_future = _FeatureMixing( - sequence_length=self.output_chunk_length, + sequence_length=self.sequence_length, input_dim=future_cov_dim, output_dim=hidden_size, **mixer_params, @@ -419,7 +440,7 @@ def __init__( else: self.feature_mixing_future = None self.conditional_mixer = self._build_mixer( - prediction_length=self.output_chunk_length, + sequence_length=self.sequence_length, num_blocks=num_blocks, hidden_size=hidden_size, future_cov_dim=future_cov_dim, @@ -430,7 +451,7 @@ def __init__( @staticmethod def _build_mixer( - prediction_length: int, + sequence_length: int, num_blocks: int, hidden_size: int, future_cov_dim: int, @@ -448,7 +469,7 @@ def _build_mixer( layer = _ConditionalMixerLayer( input_dim=input_dim_block, output_dim=hidden_size, - sequence_length=prediction_length, + sequence_length=sequence_length, static_cov_dim=static_cov_dim, **kwargs, ) @@ -480,6 +501,7 @@ def forward( # B: batch size # L: input chunk length # T: output chunk length + # SL: Residual block time dimension (T if project_first_layer, L otherwise) # C: target components # P: past cov features # F: future cov features @@ -491,31 +513,62 @@ def forward( # `x`: (B, L, H), `x_future`: (B, T, F), `x_static`: (B, C or 1, S) x, x_future, x_static = x_in - # swap feature and time dimensions (B, L, H) -> (B, H, L) - x = _time_to_feature(x) - # linear transformations to horizon (B, H, L) -> (B, H, T) - x = self.fc_hist(x) - # (B, H, T) -> (B, T, H) - x = _time_to_feature(x) - - # feature mixing for historical features (B, T, H) -> (B, T, H_S) + # If project_first_layer, decoder style model with residual blocks in output time dimension + # (B, L, H) -> (B, SL, H) + if self.project_first_layer: + # swap feature and time dimensions (B, L, H) -> (B, H, L) + x = _time_to_feature(x) + # linear transformations to SL (T in this case) + # (B, H, L) -> (B, H, SL) + x = self.fc_hist(x) + # Transpose back + # (B, H, T) -> (B, T, H) + x = _time_to_feature(x) + + # Otherwise, encoder-style model with residual blocks in input time dimension + # In the original paper this was not implimented for future covariates, + # but rather than ignoring them or raising an error we remap them to the input time dimension. + # Suboptimal but may be useful in some cases. + elif self.future_cov_dim: + # swap feature and time dimensions (B, L, F) -> (B, F, L) + x_future = _time_to_feature(x_future) + # linear transformations to SL (L in this case) + # (B, F, T) -> (B, F, SL) + x_future = self.fc_future(x_future) + # Transpose back (B, L, F) -> (B, F, L) + x_future = _time_to_feature(x_future) + + # feature mixing for historical features (B, SL, H) -> (B, SL, H_S) x = self.feature_mixing_hist(x) if self.future_cov_dim: - # feature mixing for future features (B, T, F) -> (B, T, H_S) + # feature mixing for future features (B, SL, F) -> (B, SL, H_S) x_future = self.feature_mixing_future(x_future) - # (B, T, H_S) + (B, T, H_S) -> (B, T, 2*H_S) + # (B, SL, H_S) + (B, SL, H_S) -> (B, T, 2*H_S) x = torch.cat([x, x_future], dim=-1) if self.static_cov_dim: # (B, C, S) -> (B, 1, C * S) x_static = x_static.reshape(x_static.shape[0], 1, -1) - # repeat to match horizon (B, 1, C * S) -> (B, T, C * S) - x_static = x_static.repeat(1, self.output_chunk_length, 1) + # repeat to match time dim: (B, 1, C * S) -> (B, SL, C * S) + x_static = x_static.repeat(1, self.sequence_length, 1) for mixing_layer in self.conditional_mixer: - # conditional mixer layers with static covariates (B, T, 2 * H_S), (B, T, C * S) -> (B, T, H_S) + # conditional mixer layers with static covariates (B, SL, 2 * H_S), (B, SL, C * S) -> (B, SL, H_S) x = mixing_layer(x, x_static=x_static) + # If we are in the input time dimension, we need to project to the output time dimension. + # The original paper did not a fc_out layer (as hidden_size == output_dim) + # (so we needed to decide where to put it) + # We put the projection first as it as while both operations may be very compressive, + # we felt it more likely that output_dim << hidden_size than output_chunk_length << input_chunk_length. + if not self.project_first_layer: + # (B, SL, H_S) -> (B, H_S, SL) + x = _time_to_feature(x) + # (B, H_S, SL) -> (B, H_S, T) + x = self.fc_hist(x) + # (B, H_S, T) -> (B, T, H_S) + x = _time_to_feature(x) + # linear transformation to generate the forecast (B, T, H_S) -> (B, T, C * N_P) x = self.fc_out(x) # (B, T, C * N_P) -> (B, T, C, N_P) @@ -537,6 +590,7 @@ def __init__( norm_type: Union[str, nn.Module] = "LayerNorm", normalize_before: bool = False, use_static_covariates: bool = True, + project_first_layer: bool = True, **kwargs, ) -> None: """Time-Series Mixer (TSMixer): An All-MLP Architecture for Time Series. @@ -591,6 +645,10 @@ def __init__( `"LayerNormNoBias", "LayerNorm", "TimeBatchNorm2d"`. Otherwise, must be a custom `nn.Module`. normalize_before Whether to apply layer normalization before or after mixer layer. + project_first_layer + Whether to project to the output time dimension at the first layer (default), or at the end of the module. + Projecting last is recommended if there are no future covariates, while projecting first is recommended if + there are important future covariates. use_static_covariates Whether the model should use static covariate information in case the input `series` passed to ``fit()`` contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce @@ -774,6 +832,7 @@ def encode_year(idx): self.normalize_before = normalize_before self.norm_type = norm_type self.hidden_size = hidden_size + self.project_first_layer = project_first_layer self._considers_static_covariates = use_static_covariates def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Module: @@ -825,6 +884,7 @@ def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Modu dropout=self.dropout, norm_type=self.norm_type, normalize_before=self.normalize_before, + project_first_layer=self.project_first_layer, **self.pl_module_params, ) diff --git a/darts/tests/models/forecasting/test_tsmixer.py b/darts/tests/models/forecasting/test_tsmixer.py index 5eb7f80d57..779347f21f 100644 --- a/darts/tests/models/forecasting/test_tsmixer.py +++ b/darts/tests/models/forecasting/test_tsmixer.py @@ -362,3 +362,25 @@ def test_time_batch_norm_2d_gradients(self): output.mean().backward() assert input_tensor.grad is not None + + @pytest.mark.parametrize("project_first_layer", [True, False]) + def test_project_first(self, project_first_layer): + ts = tg.sine_timeseries(length=36, freq="h") + input_len = 12 + output_len = 6 + + model = TSMixerModel( + input_chunk_length=input_len, + output_chunk_length=output_len, + n_epochs=1, + project_first_layer=project_first_layer, + # Cover case of projecting future covs back to input dims + add_encoders={"cyclic": {"future": "hour"}}, + **tfm_kwargs, + ) + model.fit(ts) + + if project_first_layer: + assert model.model.sequence_length == output_len + else: + assert model.model.sequence_length == input_len