Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In convit.py file, where does ConVit come from, really? #9

Open
dinhanhx opened this issue Mar 31, 2022 · 15 comments
Open

In convit.py file, where does ConVit come from, really? #9

dinhanhx opened this issue Mar 31, 2022 · 15 comments

Comments

@dinhanhx
Copy link

"""
This ConViT is ViT with two-dimensional convolutional MSA, NOT [1]!
[1] d'Ascoli, Stéphane, et al. "Convit: Improving vision transformers with soft convolutional inductive biases."
arXiv preprint arXiv:2103.10697 (2021).
"""

You said it's not the same with ConVit by d'Ascoli, Stéphane, et al. Then where does this ConVit come from? I ask because if I reuse this code, I want to know whom I should cite.

@xxxnell
Copy link
Owner

xxxnell commented Mar 31, 2022

Hi,

I was inspired by "Convolutional Self-Attention Networks" [2], and implemented the two-dimensional ConViT model for vision tasks from scratch. Yang et al. [2] mainly proposed one-dimensional convolutional transformers for natural language tasks. As far as I know, no official implementation of [2] is provided.

I will add the reference [2] to convit.py. Thank you for your feedback!

[2] Baosong Yang, Longyue Wang, Derek F Wong, Lidia S Chao, and Zhaopeng Tu. "Convolutional self-attention networks". NAACL, 2019.

@dinhanhx
Copy link
Author

@xxxnell Uhm so what are the differences between these two attention mechanism?

class Attention2d(nn.Module):
def __init__(self, dim_in, dim_out=None, *,
heads=8, dim_head=64, dropout=0.0, k=1):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
dim_out = dim_in if dim_out is None else dim_out
self.to_q = nn.Conv2d(dim_in, inner_dim * 1, 1, bias=False)
self.to_kv = nn.Conv2d(dim_in, inner_dim * 2, k, stride=k, bias=False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim_out, 1),
nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
)
def forward(self, x, mask=None):
b, n, _, y = x.shape
qkv = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h=self.heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = dots + mask if mask is not None else dots
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', y=y)
out = self.to_out(out)
return out, attn

and

class ConvAttention2d(nn.Module):
def __init__(self, dim_in, dim_out=None, *,
heads=8, dim_head=64, dropout=0.0, k=1,
kernel_size=1, dilation=1, padding=0, stride=1):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.conv_args = {
"kernel_size": kernel_size,
"dilation": dilation,
"padding": padding,
"stride": stride
}
inner_dim = dim_head * heads
dim_out = dim_in if dim_out is None else dim_out
self.to_q = nn.Conv2d(dim_in, inner_dim * 1, 1, bias=False)
self.to_kv = nn.Conv2d(dim_in, inner_dim * 2, k, stride=k, bias=False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim_out, 1),
nn.Dropout(dropout)
)
def forward(self, x, mask=None):
b, n, _, y = x.shape
q, kv = self.to_q(x), self.to_kv(x).chunk(2, dim=1)
q = rearrange(q, 'b (h d) x y -> b h (x y) d', h=self.heads)
q = repeat(q, 'b h n d -> b h n w d', w=self.conv_args["kernel_size"] ** 2)
k, v = map(lambda t: F.unfold(t, **self.conv_args), kv)
k, v = map(lambda t: rearrange(t, 'b (h d w) n -> b h n w d', h=self.heads, d=q.shape[-1]), (k, v))
dots = einsum('b h n w d, b h n w d -> b h n w', q, k) * self.scale
dots = dots + mask if mask is not None else dots
attn = dots.softmax(dim=-1)
out = einsum('b h n w, b h n w d -> b h n d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', y=y)
out = self.to_out(out)
return out, attn
def extra_repr(self):
return ", ".join(["%s=%s" % (k, v) for k, v in self.conv_args.items()])

@xxxnell
Copy link
Owner

xxxnell commented Mar 31, 2022

Attention2d in models/attentions.py is traditional global self-attention. ConvAttention2d in models/convit.py is convolutional self-attention, and it is a kind of local self-attention. ConvAttention2d calculates self-attention only between tokens in convolutional receptive fields (e.g., 3x3) after unfolding the tokens like Conv2d.

@dinhanhx
Copy link
Author

I think I understand now. Just one more question, if I use Attention2D in models/attentions.py, I should cite your paper,right?

@xxxnell
Copy link
Owner

xxxnell commented Apr 1, 2022

Yes. I'd really appreciate it if you would cite my paper.

@dinhanhx
Copy link
Author

@xxxnell Quick question, which part of your publication mentioned Attention2D in models/attentions.py? From what I read, you only mentioned MSA from vanilla transformer.

@xxxnell
Copy link
Owner

xxxnell commented Apr 11, 2022

@dinhanhx Oh! Sorry for the confusion. Attention2d in models/attentions.py is almost identical to traditional MSA in vanilla ViT, so I think you should cite the original ViT paper. Please cite my paper only if you have used or modified my code and implementation directly.

@dinhanhx
Copy link
Author

@xxxnell well I found that your Attention2D in models/attentions.py is kinda similar to this one https://github.com/lucidrains/vit-pytorch/blob/c2aab05ebfb01b9740ba4dae5b515fce1140e97d/vit_pytorch/cvt.py#L70-L102 from CvT: Introducing Convolutions to Vision Transformers. From my understanding, the major difference is the number of CNN layers to project qkv.

@xxxnell
Copy link
Owner

xxxnell commented Apr 12, 2022

@dinhanhx Ah, I think now I understand what you pointed out! I initially used two Convs for qkv to improve the performance of AlterNet. So there was an experiment and discussion on stride k in self.to_kv in the first draft, but they were removed in the final revision for better readability. As a result, in the context of my paper, I didn't take advantage of the two Convs and the stride attribute for the sake of simplicity, and it also looks good to me to use one Conv instead of two Convs. In addition, since a lot of my implementations are based on https://github.com/lucidrains/vit-pytorch, I think it's also great to cite the original project to use the code.

@dinhanhx
Copy link
Author

dinhanhx commented Apr 12, 2022

@xxxnell It was confusing to me since there are few similar convolution attention mechanism like yours. I did have a hard time trying to differentiate them.

if I use AlterNet (theory), I cite your paper.
if I use AlterNet (code), I cite your paper, and the original project https://github.com/lucidrains/vit-pytorch.
if I only use Attention2D in models/attentions.py, I cite your paper, that CvT paper, and the original project.

Right?

@xxxnell
Copy link
Owner

xxxnell commented Apr 12, 2022

@dinhanhx Right. I think what you said is one of the best practices.

@dinhanhx
Copy link
Author

@dinhanhx Right. I think what you said is one of the best practices.

Thanks for supporting me!

@longyuewangdcu
Copy link

Thanks for your comments on our "Convolutional SANs" (https://arxiv.org/abs/1904.03107). We are also very happy to see this can inspire your work. The paper on analyzing Vision Transformers is really insightful and interesting.

@longyuewangdcu
Copy link

Hi,

I was inspired by "Convolutional Self-Attention Networks" [2], and implemented the two-dimensional ConViT model for vision tasks from scratch. Yang et al. [2] mainly proposed one-dimensional convolutional transformers for natural language tasks. As far as I know, no official implementation of [2] is provided.

I will add the reference [2] to convit.py. Thank you for your feedback!

[2] Baosong Yang, Longyue Wang, Derek F Wong, Lidia S Chao, and Zhaopeng Tu. "Convolutional self-attention networks". NAACL, 2019.

We have implemented various SANs including "Convolutional SANs" at:
https://github.com/baosongyang/Context-Aware-SAN/blob/main/layers/attention_conv.py.

@xxxnell
Copy link
Owner

xxxnell commented Apr 24, 2022

Hi @longyuewangdcu ,

Thank you for the great paper and your kind words. And sorry I missed that implementation. I starred the repository, and I'll take a closer look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants