Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to log/store xlstm based models? | issue with mlflow and pickle logging #59

Open
meteoDaniel opened this issue Nov 15, 2024 · 0 comments

Comments

@meteoDaniel
Copy link

meteoDaniel commented Nov 15, 2024

Hey there,

I really like your work on xLSTM and currently I am implementing the xLSTM-Mixer.

After fitting the model, I came across following issue:

src/engine/modules/modelling/pytorch_core.py:459: in _execution_function
    mlflow.pytorch.log_model(
/usr/local/lib/python3.10/site-packages/mlflow/pytorch/__init__.py:295: in log_model
    return Model.log(
/usr/local/lib/python3.10/site-packages/mlflow/models/model.py:725: in log
    flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
/usr/local/lib/python3.10/site-packages/mlflow/pytorch/__init__.py:479: in save_model
    torch.save(pytorch_model, model_path, pickle_module=pickle_module, **kwargs)
/usr/local/lib/python3.10/site-packages/torch/serialization.py:850: in save
    _save(
/usr/local/lib/python3.10/site-packages/torch/serialization.py:1088: in _save
    pickler.dump(obj)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <cloudpickle.cloudpickle.Pickler object at 0x7d589e98d600>
obj = PytorchxLSTMMixer(
  (validation_metric_function): RMSELoss(
    (mse): MSELoss()
  )
  (mlp_in): Sequential(
    (0):..._2_seq_var): Rearrange('batch var seq -> batch seq var')
  (Linear): Linear(in_features=2, out_features=1, bias=True)
)

    def dump(self, obj):
        try:
>           return super().dump(obj)
E           TypeError: cannot pickle 'slstm_HS4BS8NH2NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0.sLSTMFunc' object

/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py:1295: TypeError

As you can see I use the model in an mlflow environment. And mlflow tries to log the model with pickle. I think the cpp extension of the sLSTM Model is not able to pickle here.

Is there a way to use a 100% pytorch version of xLSTM?
Maybe you have some other ideas.

Best regards and thanks in advance.
Daniel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant