-
Notifications
You must be signed in to change notification settings - Fork 904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX Support #553
Comments
Thanks for the suggestion @falkmatt91 - Indeed it would be nice. We are also very open to receive pull requests, if you want to contribute something along these lines. I think most of the work here should revolve around the |
Hi! This is a very interesting request indeed. But I am afraid I have not the required experience to contribute as it will be my first time. If anyone is interested in a collaboration I could join however! By the way I have developed some darts models for a project which will use ONNX as the interoperability standard. Is there any way currently to deploy those models as ONNX? (even if it is not the most efficient one). @falknerdominik is it possible to store darts models as pkl files and then convert them to ONNX? And are they going to be fully functional? Does ONNX accept darts.TimeSeries as (example) input? Thanks in advance! |
@hrzn do you have a starting point for this? Might be interested in implementing part of it. @pelekhs ONNX models are saved as a simple plain text file which just allows you to do minimal inference (train data, ... is normally not saved, can not be retrained as far as a i know). The nice thing is it can be easily deployed and the onnx foundation provides docker images which can load any onnx file and output an inference. Inputs/Outputs have to be described manually (Althrough we can probably specify output automatically). |
After answering to a comment on gitter, noticed that some features implemented recently could be reused for this feature:
I am bumping this feature is the backlog, it should become a low hanging fruit. |
Hello, |
I started working on a PR to close this issue but encountering some difficulties when loading the ONNX model and trying to run inference. If anyone with experience in ONNX (or who managed to make it work for darts models) wants to pick it up, go ahead! |
Hello, has there been any work done on ONNX export functionality for the darts models (the pytorch-lightning ones)? If not - does someone know if there is a work-around for exporting NBeats/TCN to ONNX format? |
Hi @JoonasHL, Made a bit of progress but other features got priority over this one. Still on the roadmap, I need to find some time for it (or happy to let someone else take over this). I wrote a workaround in #1521, but the preparation for inference once in ONNX format is not straightforward yet. Since |
Thanks @madtoinou. Finally got it to work for my use case. I also encountered some issues when running inference with the prepared onnx model. Reading in the onnxruntime forum, seems like there is limited support for double type. What i did was that i cast the model and the input data to float32 and using your example in #1521, i got it to work. sample code: model = NBEATSModel( ... )
model.fit( ... )
model_for_onnx = model.model
`# Cast model dtype to float32`
`model_for_onnx = model_for_onnx.to(torch.float32)`
dim_component = 2
(
past_target,
past_covariates,
future_past_covariates,
static_covariates,
) = [torch.Tensor(x).unsqueeze(0) if x is not None else None for x in model.train_sample]
n_past_covs = (
past_covariates.shape[dim_component] if past_covariates is not None else 0
)
input_past = torch.cat(
[ds for ds in [past_target, past_covariates] if ds is not None],
dim=dim_component,
)
`input_sample = [input_past.float(), static_covariates.float()]`
model_for_onnx.to_onnx("test_export.onnx", input_sample=input_sample)
import onnxruntime
import numpy as np
onnx_model_path = "test_export.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_model_path)
input_np_float = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).reshape(1, 30, 1).float().numpy()
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
result = onnx_session.run([output_name], {input_name: input_np_float})
print("Output shape:", result[0].shape)
print("Predictions:", result[0]) |
Is your feature request related to a current problem? Please describe.
Using ONNX on other pc / inside containers is a hassle and currently (as of reading the documentation) only supported by pickling. Pickling can be error prone and produces big files (with a lot of unneeded data) > 1MB for non baseline models.
Describe proposed solution
Support ONNX (like scikit) which allows to save and load the model.
Describe potential alternatives
Provide your own way to serialize / deserialize a model (probably not the best idea)
The text was updated successfully, but these errors were encountered: