Skip to content

Commit

Permalink
Add low-memory attention (still needs to be incorporated)
Browse files Browse the repository at this point in the history
  • Loading branch information
gahdritz committed Dec 17, 2021
1 parent c4d9f57 commit 9670958
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 10 deletions.
178 changes: 176 additions & 2 deletions openfold/model/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import math
from typing import Optional, Callable, List
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np

import torch
Expand All @@ -24,6 +24,7 @@
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
_chunk_slice,
)


Expand Down Expand Up @@ -217,7 +218,7 @@ def __init__(
self.c_hidden * self.no_heads, self.c_q, init="final"
)

if self.gating is not None:
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
Expand Down Expand Up @@ -370,3 +371,176 @@ def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
m = self.linear_o(o)

return m


@torch.jit.script
def _lma(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
q_chunk_size: int,
kv_chunk_size: int
):
no_q, no_kv = q.shape[-3], k.shape[-3]

# [*, Q, H, C_hidden]
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
big_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases
]

maxes = []
weights = []
values = []
for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in big_bias_chunks
]

a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk
)

for b in small_bias_chunks:
a += b

a = a.transpose(-2, -3)

max_a = torch.max(a, dim=-1, keepdim=True)[0].detach()
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)

maxes.append(max_a.squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v)

chunk_max = torch.stack(maxes, dim=-3)
chunk_weights = torch.stack(weights, dim=-3)
chunk_values = torch.stack(values, dim=-4)

global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs

all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)

q_chunk_out = all_values / all_weights

o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out

return o


class LowMemoryAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super().__init__()

self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating

self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(
self.c_hidden * self.no_heads, self.c_q, init="final"
)

if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)

self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)

def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_chunk_size: int,
kv_chunk_size: int,
biases: Optional[List[torch.Tensor]] = None,
):
if(biases is None):
biases = []
else:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]

# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)

# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))

q = q / math.sqrt(q.shape[-1])

o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)

if self.gating:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g

# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)

# [*, Q, C_q]
o = self.linear_o(o)

return o
11 changes: 6 additions & 5 deletions openfold/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _fetch_dims(tree):
return shapes


@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
Expand All @@ -135,6 +136,8 @@ def _flat_idx_to_idx(

return tuple(reversed(idx))


@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
Expand Down Expand Up @@ -252,18 +255,19 @@ def lower():
return [tuple(s) for s in slices]


@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
):
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the reshape call, which can be
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
Expand All @@ -281,7 +285,6 @@ def _chunk_slice(
batch_dims,
)

#
sliced_tensors = [t[s] for s in slices]

return torch.cat(
Expand Down Expand Up @@ -352,7 +355,6 @@ def _prep_inputs(t):

i = 0
out = None

for _ in range(no_chunks):
# Chunk the input
if(not low_mem):
Expand Down Expand Up @@ -382,7 +384,6 @@ def _prep_inputs(t):
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:

def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_unit_tests.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

#CUDA_VISIBLE_DEVICES="5"
CUDA_VISIBLE_DEVICES="0"

python3 -m unittest "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
2 changes: 0 additions & 2 deletions tests/compare_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4,"

import importlib
import pkgutil
import sys
Expand Down
70 changes: 70 additions & 0 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest

from openfold.model.primitives import (
Attention,
LowMemoryAttention,
)
from tests.config import consts


class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2**12
no_heads = 4

q = torch.rand(batch_size, n, c_hidden).cuda()
k = torch.rand(batch_size, n, c_hidden).cuda()
v = torch.rand(batch_size, n, c_hidden).cuda()

bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]

gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)

lma = LowMemoryAttention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()

with torch.no_grad():
for n, p in lma.named_parameters():
attrs = n.split('.')
param = a
for attr in attrs:
param = getattr(param, attr)
param.copy_(p)

for m in [lma, a]:
m.linear_g.weight.copy_(gating_fill)
m.linear_o.weight.copy_(o_fill)

with torch.no_grad():
l = lma(q, k, v, 1024, 4096, biases=bias)
real = a(q, k, v, biases=bias)

self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9670958

Please sign in to comment.