-
Notifications
You must be signed in to change notification settings - Fork 904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto batch size for torch model #2318
base: master
Are you sure you want to change the base?
Auto batch size for torch model #2318
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2318 +/- ##
==========================================
+ Coverage 93.75% 94.02% +0.26%
==========================================
Files 138 138
Lines 14352 14152 -200
==========================================
- Hits 13456 13306 -150
+ Misses 896 846 -50 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for implementing this @BohdanBilonoh. It looks already pretty good 🚀
I added some suggestions on how we could simplify things a bit, and also to support "predict" method.
epochs: int = 0, | ||
max_samples_per_ts: Optional[int] = None, | ||
num_loader_workers: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are not required I guess
epochs: int = 0, | |
max_samples_per_ts: Optional[int] = None, | |
num_loader_workers: int = 0, |
epochs: int = 0, | ||
max_samples_per_ts: Optional[int] = None, | ||
num_loader_workers: int = 0, | ||
method: Literal["fit", "validate", "test", "predict"] = "fit", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"test" is not supported for darts.
"predict" would require a datamodule for prediction
in my opinion, we should only support "fit" and "predict", since we use the same batch size for train and val.
batch_arg_name | ||
The name of the argument to scale in the model. Defaults to 'batch_size'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not required
batch_arg_name | |
The name of the argument to scale in the model. Defaults to 'batch_size'. |
self.batch_size = batch_size | ||
|
||
def train_dataloader(self): | ||
return DataLoader( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we use this also in _setup_for_train
and predict_from_dataset
, it would be good to have this logic in a private method for example _build_dataloader()
, that takes as input a dataset, and returns the dataloader according to "train", "val" and "predict"
def scale_batch_size( | ||
self, | ||
series: Union[TimeSeries, Sequence[TimeSeries]], | ||
val_series: Union[TimeSeries, Sequence[TimeSeries]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove all val_*
arguments, since we use the same batch size for training and evaluation.
For the val_dataloader
in the DataModule
, we can just use series
, past_covariates
, future_covariates
for the input dataset
@@ -1373,6 +1373,21 @@ def test_lr_find(self): | |||
) | |||
assert scores["worst"] > scores["suggested"] | |||
|
|||
@pytest.mark.slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not slow
@pytest.mark.slow |
@@ -1373,6 +1373,21 @@ def test_lr_find(self): | |||
) | |||
assert scores["worst"] > scores["suggested"] | |||
|
|||
@pytest.mark.slow | |||
def test_scale_batch_size(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test this for method="fit" and "predict"
CHANGELOG.md
Outdated
@@ -89,6 +89,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co | |||
- Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`. | |||
- Improvements to `DataTransformer`: [#2267](https://github.com/unit8co/darts/pull/2267) by [Alicja Krzeminska-Sciga](https://github.com/alicjakrzeminska). | |||
- `InvertibleDataTransformer` now supports parallelized inverse transformation for `series` being a list of lists of `TimeSeries` (`Sequence[Sequence[TimeSeries]]`). This `series` type represents for example the output from `historical_forecasts()` when using multiple series. | |||
- New method `TorchForecastingModel.scale_batch_size()` that helps to find batch size automatically. [#2318](https://github.com/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- New method `TorchForecastingModel.scale_batch_size()` that helps to find batch size automatically. [#2318](https://github.com/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh) | |
- Improvements to `TorchForecastingModel`: | |
- New method `TorchForecastingModel.scale_batch_size()` to find the maximum batch size for fit and predict before memory would run out. [#2318](https://github.com/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh) |
@dennisbader could you please help with |
eecfd0f
to
ae7a128
Compare
Checklist before merging this PR:
Summary
Auto batch size finding for
TorchForecastingModel
. This is just a wrapper forlightning.pytorch.tuner.tuning.Tuner.scale_batch_size