Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT scripting is broken #246

Open
jramapuram opened this issue Mar 22, 2022 · 11 comments · Fixed by #252
Open

JIT scripting is broken #246

jramapuram opened this issue Mar 22, 2022 · 11 comments · Fixed by #252
Assignees

Comments

@jramapuram
Copy link

jramapuram commented Mar 22, 2022

🐛 Bug

JIT scripting xformers (running commit 357545a) breaks with the following error:

xformers/components/attention/attention_mask.py", line 128
    def __add__(self, other):
        assert isinstance(other, type(self))
                                 ~~~~~~~~~ <--- HERE
        return AttentionMask(self.values + other.values, is_causal=False)
'AttentionMask.__add__' is being compiled since it was called from '__torch__.xformers.components.attention.attention_mask.AttentionMask'

To Reproduce

Blocks config:

reversible: false
block_type: "encoder"
num_layers: 12
dim_model: 768
layer_norm_style: "pre"

multi_head_config:
  num_heads: 12
  residual_dropout: 0.0
  use_rotary_embeddings: false

  attention:
    name: "scaled_dot_product"
    dropout: 0.0
    causal: false

feedforward_config:
  name: "MLP"
  dropout: 0.0
  activation: "gelu"
  hidden_layer_multiplier: 4

Python code to repro:

import yaml
import torch

from xformers.factory import xFormer, xFormerConfig

with open(xformer_config_file, "rb") as fileptr:  # above config
    model_config = yaml.load(fileptr, Loader=yaml.FullLoader)

torch.jit.script(xFormer.from_config(xFormerConfig([model_config])))
@blefaudeux
Copy link
Contributor

ah, torch script is annoyingly fragile.. This part is being completely rewritten by @fmassa, but in the meantime probably that we can remove the assert, more of a failsafe really

@blefaudeux
Copy link
Contributor

Thanks for the issue and repro steps @jramapuram, super helpful !

@jramapuram
Copy link
Author

Indeed; got to love when you have to be verbose about stuff like this:

w, h = tensor.shape[-2:]  # not jit scriptable
w, h = tensor.shape[-2], tensor.shape[-1]  # jit scriptable.

I mean I get why, but doesn't hurt any less :)

@jramapuram
Copy link
Author

Thanks @blefaudeux !

@jramapuram
Copy link
Author

Not sure this is resolved, sorry! 😬

RuntimeError: Error inferring type for mask: None: 
builtin cannot be used as a value:
  File "xformers/components/attention/attention_mask.py", line 128
    def __add__(self, other):
        return AttentionMask(self.values + other.values, is_causal=False)
                                           ~~~~~~~~~~~~ <--- HERE
'AttentionMask.__add__' is being compiled since it was called from '__torch__.xformers.components.attention.attention_mask.AttentionMask'

@blefaudeux
Copy link
Contributor

blefaudeux commented Mar 25, 2022

ok sorry about that, I'll write this down as a unit test, I should have done that from the beginning..

@blefaudeux blefaudeux self-assigned this Mar 25, 2022
@blefaudeux blefaudeux reopened this Mar 25, 2022
@erip
Copy link
Contributor

erip commented Mar 25, 2022

Note that this was discussed previously in #168.

@blefaudeux
Copy link
Contributor

blefaudeux commented Mar 25, 2022

https://github.com/facebookresearch/xformers/tree/jit_with_test is doing the attention mask part (suggested by @erip), but torchscript dies on a lot of the flexible constructs which are not easy to get rid of when keeping things intercompatible (even **kwargs is a no go for instance, there are a lot of these in the wrappers). I had forgotten about that initially but I think that's right, each of the xformers components can be mostly made torchscriptable (except for the newer dispatch bits being worked on by @fmassa), but the programmatic construct cannot easily do that, and torchscript is on the way out anyway. Thoughts ?

@fmassa
Copy link
Contributor

fmassa commented Apr 1, 2022

I would stick with my initial comment from #168 (comment) that it might be preferable to stay away from torchscript support. Happy to reconsider this decision though, but I think the days of torchscript might be counted in favor of other approaches that work directly on Python.

@erip
Copy link
Contributor

erip commented Apr 1, 2022

Only slightly related, but is there something we (non-meta'ers) can read about the sunsetting of torchscript?

@jramapuram
Copy link
Author

@fmassa : correct me if I'm wrong here but JIT scripting != module packaging / deployment? I get that most folks use jit-scripting for this use case, but aren't there model optimizations (eg: inlining for loops, etc) that also take place with jit-scripting that aren't touched with torch.package?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants