Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/val_sample_weight error for models inherited from RegressionModel #2626

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

[Full Changelog](https://github.com/unit8co/darts/compare/0.32.0...master)

- Fix the bug in [#2579 ](https://github.com/unit8co/darts/issues/2579) that causes an error when `val_sample_weight` is set in the CatBoost and XGBoost models.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Fix the bug in [#2579 ](https://github.com/unit8co/darts/issues/2579) that causes an error when `val_sample_weight` is set in the CatBoost and XGBoost models.
- Fix a bug in `RegressionModel` when `val_sample_weight` is used with a single timeseries. [#2626](https://github.com/unit8co/darts/pull/2626) by [Kylin Schmidt](https://github.com/kylinschmidt).


### For users of the library:

**Improved**
Expand Down Expand Up @@ -1440,7 +1442,7 @@ ts: TimeSeries = AirPassengers().load()
```python
# Assuming a multivariate TimeSeries named series with 3 columns or variables.
# To apply fn to columns with names '0' and '2':

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert this change and the others below?

#old syntax
series.map(fn, cols=['0', '2']) # returned a time series with 3 columns
#new syntax
Expand All @@ -1452,13 +1454,13 @@ ts: TimeSeries = AirPassengers().load()
```python
#old syntax
fillna(series, fill=0)

#new syntax
fill_missing_values(series, fill=0)

#old syntax
auto_fillna(series, **interpolate_kwargs)

#new syntax
fill_missing_values(series, fill='auto', **interpolate_kwargs)
fill_missing_values(series, **interpolate_kwargs) # fill='auto' by default
Expand Down Expand Up @@ -1496,13 +1498,13 @@ ts: TimeSeries = AirPassengers().load()
```python
# old syntax:
backtest_forecasting(forecasting_model, *args, **kwargs)

# new syntax:
forecasting_model.backtest(*args, **kwargs)

# old syntax:
backtest_regression(regression_model, *args, **kwargs)

# new syntax:
regression_model.backtest(*args, **kwargs)
```
Expand All @@ -1511,13 +1513,13 @@ ts: TimeSeries = AirPassengers().load()
```python
# old syntax:
multivariate_model.fit(multivariate_series, target_indices=[0, 1])

# new syntax:
multivariate_model.fit(multivariate_series, multivariate_series[["0", "1"]])

# old syntax:
univariate_model.fit(multivariate_series, component_index=2)

# new syntax:
univariate_model.fit(multivariate_series["2"])
```
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def _add_val_set_to_kwargs(
val_weights = val_weights or None
else:
val_sets = [(val_samples, val_labels)]
val_weights = val_weight
val_weights = [val_weight]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks neat!


val_set_name, val_weight_name = self.val_set_params
return dict(kwargs, **{val_set_name: val_sets, val_weight_name: val_weights})
Expand Down
Loading