Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example_input_array missing for converting to ONNX #1521

Closed
solalatus opened this issue Jan 27, 2023 · 2 comments
Closed

Example_input_array missing for converting to ONNX #1521

solalatus opened this issue Jan 27, 2023 · 2 comments
Labels
question Further information is requested

Comments

@solalatus
Copy link
Contributor

Describe the bug
Since the example_input_array attribute is not set in case of a checkpoint loaded NHiTS model, converting to ONNX with the to_onnx() function fails.

To Reproduce

loaded_model = NHiTS.load_from_checkpoint(...)
loaded_model.fit(...)
loaded_model.to_onnx("filename")

ValueError: Could not export to ONNX since neither `input_sample` nor `model.example_input_array` attribute is set.

Expected behavior
To output an ONNX file.

System (please complete the following information):

  • Python version: 3.8
  • darts version 0.23.1

Additional context
The lack of input_sample and model.example_input_array also makes it pretty difficult to use graph visualization tools to take a look at the network, since they typically require something to do a single forward pass with.

@solalatus solalatus added bug Something isn't working triage Issue waiting for triaging labels Jan 27, 2023
@madtoinou
Copy link
Collaborator

Hi,

Thank you for reporting this bug; it seems to be caused by the fact that darts is not using the Pytorch Lightning example_input_array attribute but train_sample, which is not directly suitable for model forward pass. We should probably process the train_sample tuple into a suitable input and store it in example_input_array.

In the meantime, It's possible to mimic the preprocessing performed by darts using the following snippet (modified from PLPastCovariatesModule._get_batch_prediction since NHits inherits from it):

dim_component = 2
(
    past_target,
    past_covariates,
    future_past_covariates,
    static_covariates,
) = [torch.Tensor(x).unsqueeze(0) if x is not None else None for x in loaded_model.train_sample]

n_past_covs = (
    past_covariates.shape[dim_component] if past_covariates is not None else 0
)

input_past = torch.cat(
    [ds for ds in [past_target, past_covariates] if ds is not None],
    dim=dim_component,
)

input_sample = [input_past.double(), static_covariates.double()]
loaded_model.model.to_onnx("filename", input_sample=input_sample)

@solalatus
Copy link
Contributor Author

solalatus commented Feb 6, 2023

Cool that you could take a look! Thanks!

Maybe if the storage is a problem because of "leaking data" or privacy reasons, some random, but well shaped data can be put in there I think, like here, so maybe doing it that way can help...

Should I close this issue or you plan to work on it?

@hrzn hrzn changed the title [BUG] example_input_array attribute not set in NHiTS models [BUG] example_input_array missing for converting to ONNX Feb 10, 2023
@hrzn hrzn added question Further information is requested and removed bug Something isn't working triage Issue waiting for triaging labels Feb 10, 2023
@hrzn hrzn changed the title [BUG] example_input_array missing for converting to ONNX Example_input_array missing for converting to ONNX Feb 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants