Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert num_buckets == self.num_buckets error #220

Open
gudrb opened this issue Jan 28, 2024 · 6 comments
Open

assert num_buckets == self.num_buckets error #220

gudrb opened this issue Jan 28, 2024 · 6 comments

Comments

@gudrb
Copy link

gudrb commented Jan 28, 2024

I am trying to use the mini_deit_tiny_patch16_224 with finetuning another subtask having different sequence size of 18 (num of patches) with dimension 192.
when operate under code
for blk in self.blocks:
x = blk(x)
i get the error from irpe.py file's in line 574 code "assert num_buckets == self.num_buckets"
num_buckets is 50 but self.num_buckets is 49.
Do u know why this problem happens and how can i fix it?

@wkcn
Copy link
Contributor

wkcn commented Jan 29, 2024

Hi @gudrb , thanks for your attention to our work!

Does the class token exist in the fine-tuned model?

If the class token exists (use_cls_token=True), please set skip=1 in https://github.com/microsoft/Cream/blob/main/MiniViT/Mini-DeiT/mini_deit_models.py#L16. It means that the class token is skipped to compute relative positional encoding.

If not, namely (use_cls_token=False), skip should be set to 0.

@gudrb
Copy link
Author

gudrb commented Jan 30, 2024

Thank you for answering,

I am not using class token, but still i tried to use tried to use skip=1 option, and it gives the key error when i load the pretrained model
self.v = create_model(
'mini_deit_tiny_patch16_224',
pretrained=False,
num_classes=1000,
drop_rate=0,
drop_path_rate=0.1,
drop_block_rate=None)
checkpoint = torch.load('./checkpoints/mini_deit_tiny_patch16_224.pth', map_location='cpu')
self.v.load_state_dict(checkpoint['model'])

I tried to use random variable and observed the blk fuction such as

    x = torch.randn((2, 196, 192), device=x.device)
    
    for blk in self.v.blocks:
        x = blk(x)
    x = self.v.norm(x)

and i found when i change the second dimension of variable x to another value such as 196 -> N (not 196), then i get the error
(num_buckes is 50, and self.num_buckets is 49)
File "/data/hyounggyu/Mini-DeiT/irpe.py", line 573, in _get_rp_bucket
assert num_buckets == self.num_buckets

Is it possible to use a pretrained model that utilizes IRPE for a different sequence length, such as a varying number of patches?

@wkcn
Copy link
Contributor

wkcn commented Jan 30, 2024

@gudrb

Is it possible to use a pretrained model that utilizes IRPE for a different sequence length, such as a varying number of patches?

Yes. You need to pass the two arguments width and height for rpe_k, rpe_q and rpe_v.

Example: https://github.com/microsoft/Cream/blob/main/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_attention_function.py#L330

iRPE is a 2D relative position encoding. If height and width are not specified, they are set to the square root of the sequence length. It leads to wrong number of buckets.

https://github.com/microsoft/Cream/blob/main/MiniViT/Mini-DeiT/irpe.py#L553

@gudrb
Copy link
Author

gudrb commented Jan 30, 2024

Now, it is working. I modified the code for the MiniAttention class from (

) to manually process my patch sequence (9 x 2) from a spectrogram image.

    # image relative position on keys
    if self.rpe_k is not None:
        attn += self.rpe_k(q,9,2)

I hope this is the correct way to utilize the MiniAttention class when fine-tuning the task with a different sequence length.

Thank you.

@gudrb
Copy link
Author

gudrb commented Feb 7, 2024

Do I need to crop or interpolate pretrained relative positional encoding parameters when the sequence length is changed?

When I use the pretrained Mini-DeiT with positional encodings (both absolute and relative), in the case of absolute positional encoding, if the modified sequence length is shorter or longer than 14, I employ cropping and interpolation, respectively.

            # get the positional embedding from deit model,  reshape it to original 2D shape.
            new_pos_embed = self.v.pos_embed.detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
            # cut (from middle) or interpolate the second dimension of the positional embedding
            if t_dim <= self.oringal_hw:
                new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
            else:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bicubic')
            # cut (from middle) or interpolate the first dimension of the positional embedding
            if f_dim <= self.oringal_hw:
                new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
            else:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bicubic')
            # flatten the positional embedding
            new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)

@wkcn
Copy link
Contributor

wkcn commented Feb 7, 2024

@gudrb No. You don't. Relative position encoding can be adapted with the longer sequence.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants