Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dict key error in demo #49

Open
DmitriyPin opened this issue Jul 12, 2023 · 8 comments
Open

dict key error in demo #49

DmitriyPin opened this issue Jul 12, 2023 · 8 comments

Comments

@DmitriyPin
Copy link

Hello,
When I am running demo, I am getting the following dictionary key error. I am using PyTorch 2.0.1 and cuda 11.7. Is that specific version of PyTorch that I should use? Any ideas how I can resolve this?
Thank you

Traceback (most recent call last):
  File "test.py", line 87, in <module>
    pipe.vae.load_state_dict(new_state_dict)
  File "C:\Users\best4\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py", line 1672, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for AutoencoderKL:
        Missing key(s) in state_dict: "encoder.mid_block.attentions.0.to_q.weight", "encoder.mid_block.attentions.0.to_q.bias", "encoder.mid_block.attentions.0.to_k.weight", "encoder.mid_block.attentions.0.to_k.bias", "encoder.mid_block.attentions.0.to_v.weight", "encoder.mid_block.attentions.0.to_v.bias", "encoder.mid_block.attentions.0.to_out.0.weight", "encoder.mid_block.attentions.0.to_out.0.bias", "decoder.mid_block.attentions.0.to_q.weight", "decoder.mid_block.attentions.0.to_q.bias", "decoder.mid_block.attentions.0.to_k.weight", "decoder.mid_block.attentions.0.to_k.bias", "decoder.mid_block.attentions.0.to_v.weight", "decoder.mid_block.attentions.0.to_v.bias", "decoder.mid_block.attentions.0.to_out.0.weight", "decoder.mid_block.attentions.0.to_out.0.bias".
        Unexpected key(s) in state_dict: "encoder.mid_block.attentions.0.query.weight", "encoder.mid_block.attentions.0.query.bias", "encoder.mid_block.attentions.0.key.weight", "encoder.mid_block.attentions.0.key.bias", "encoder.mid_block.attentions.0.value.weight", "encoder.mid_block.attentions.0.value.bias", "encoder.mid_block.attentions.0.proj_attn.weight", "encoder.mid_block.attentions.0.proj_attn.bias", "decoder.mid_block.attentions.0.query.weight", "decoder.mid_block.attentions.0.query.bias", "decoder.mid_block.attentions.0.key.weight", "decoder.mid_block.attentions.0.key.bias", "decoder.mid_block.attentions.0.value.weight", "decoder.mid_block.attentions.0.value.bias", "decoder.mid_block.attentions.0.proj_attn.weight", "decoder.mid_block.attentions.0.proj_attn.bias".
@VoHoangAnh
Copy link

have you fixed it yet?
I have the same problem.

@DmitriyPin
Copy link
Author

no , I do not have a fix for this

@hughkhu
Copy link

hughkhu commented Jul 24, 2023

I have the same problem.

@VoHoangAnh
Copy link

I replaced keywords and it worked.

@LaiaTarres
Copy link

Same! I have replaced it like this and now it doesn't give me this error anymore:

        name = name.replace('query.', 'to_q.')
        name = name.replace('key.', 'to_k.')
        name = name.replace('value.', 'to_v.')
        name = name.replace('proj_attn.', 'to_out.')
        name = name.replace('.mid_block.attentions.0.to_out.', '.mid_block.attentions.0.to_out.0.')

@TrainColab
Copy link

@LaiaTarres hi Laia can u plz help me to find, where should replace this keywords (file name + ligne) . and thank you so much

@xavier111222
Copy link

@LaiaTarres Can you send any code that needs to be changed or replaced? I would love to be able to test it because it really amazes me [email protected] Thank you

@SaharaSheik
Copy link

I fixed this by modifying the line (this is a common issues so all you need to do is toc change the state dict names such that they match your expected state:

    for k, v in vae_state_dict.items():
        name1 = k.replace('module.', '')  #name = k[7:] if k[:7] == 'module' else k
        name2 = name1.replace('query', 'to_q')  #name = k[7:] if k[:7] == 'module' else k
        name3 = name2.replace('key', 'to_k')
        name4 = name3.replace('value', 'to_v')
        name = name4.replace('proj_attn', 'to_out.0')
        new_state_dict[name] = v
    pipe.vae.load_state_dict(new_state_dict)
    pipe.vae = pipe.vae.cuda()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants