Skip to content

Commit

Permalink
Replace einops.rearrange with torch native
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored and ruchej committed Sep 30, 2024
1 parent 70e352f commit f399e16
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,19 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)

q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
def _reshape(t):
"""rearrange(t, 'b n (h d) -> b n h d', h=h).
Using torch native operations to avoid overhead as this function is
called frequently. (70 times/it for SDXL)
"""
b, n, _ = t.shape # Get the batch size (b) and sequence length (n)
d = t.shape[2] // h # Determine the depth per head
return t.reshape(b, n, h, d)

q = _reshape(q_in)
k = _reshape(k_in)
v = _reshape(v_in)

del q_in, k_in, v_in

dtype = q.dtype
Expand All @@ -497,7 +509,9 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):

out = out.to(dtype)

out = rearrange(out, 'b n h d -> b n (h d)', h=h)
# out = rearrange(out, 'b n h d -> b n (h d)', h=h)
b, n, h, d = out.shape
out = out.reshape(b, n, h * d)
return self.to_out(out)


Expand Down

0 comments on commit f399e16

Please sign in to comment.