Skip to content

Commit

Permalink
allow for one to pass in images of any dimension, provided height and…
Browse files Browse the repository at this point in the history
… width is divisible by block size
  • Loading branch information
lucidrains committed Mar 24, 2021
1 parent 8c38d66 commit e153c4f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 18 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ from halonet_pytorch import HaloAttention

attn = HaloAttention(
dim = 512, # dimension of feature map
fmap_size = 32, # feature map height and width
block_size = 8, # neighborhood block size (feature map must be divisible by this)
halo_size = 4, # halo size (block receptive field)
dim_head = 64, # dimension of each head
Expand Down
25 changes: 9 additions & 16 deletions halonet_pytorch/halonet_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@ class RelPosEmb(nn.Module):
def __init__(
self,
block_size,
fmap_size,
rel_size,
dim_head
):
super().__init__()
fmap_size = pair(fmap_size)
height, width = fmap_size
height = width = rel_size
scale = dim_head ** -0.5

self.fmap_size = fmap_size
self.block_size = block_size
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
Expand All @@ -79,14 +77,12 @@ def __init__(
self,
*,
dim,
fmap_size,
block_size,
halo_size,
dim_head = 64,
heads = 8
):
super().__init__()
assert fmap_size % block_size == 0, 'feature map height or width must be divisible by block size'
assert halo_size > 0, 'halo size must be greater than 0'

self.dim = dim
Expand All @@ -100,24 +96,17 @@ def __init__(

self.rel_pos_emb = RelPosEmb(
block_size = block_size,
fmap_size = block_size + (halo_size * 2),
rel_size = block_size + (halo_size * 2),
dim_head = dim_head
)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

# prepare a mask for removing attention to padding, cached for performance

mask = torch.ones(1, 1, fmap_size, fmap_size)
mask = F.unfold(mask, kernel_size = block_size + (halo_size * 2), stride = block_size, padding = halo_size)
mask = repeat(mask, 'b j i -> (b i h) j', h = heads)
self.register_buffer('mask', mask == 0)

def forward(self, x):
b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.device
assert h == w, 'dimensions of fmap must be same on both sides, for now'
assert h % block == 0 and w % block == 0, 'fmap dimensions must be divisible by the block size'
assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'

# get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key values
Expand Down Expand Up @@ -150,7 +139,11 @@ def forward(self, x):

# mask out padding (in the paper, they claim to not need masks, but what about padding?)

mask = repeat(self.mask, 'h j -> (b h) () j', b = b)
mask = torch.ones(1, 1, h, w, device = device)
mask = F.unfold(mask, kernel_size = block + (halo * 2), stride = block, padding = halo)
mask = repeat(mask, '() j i -> (b i h) () j', b = b, h = heads)
mask = mask.bool()

max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(mask, max_neg_value)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'halonet-pytorch',
packages = find_packages(),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'HaloNet - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e153c4f

Please sign in to comment.