Skip to content

Commit

Permalink
minor lint, needs fixing and better handling of the smaller case -who…
Browse files Browse the repository at this point in the history
…le line in kernel-
  • Loading branch information
blefaudeux committed Dec 22, 2021
1 parent a12b351 commit fc02cc8
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions xformers/triton/k_mem_efficient_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta, Inc. and its affiliates. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -143,7 +143,7 @@ def grid(META):

# Epilogue
if tiles_n > 1:
# There were tilse over the N dimension,
# There were tiles over the N dimension,
# so the weights were not correct in real time.

# Let's fix that:
Expand All @@ -162,6 +162,7 @@ def grid(META):
else:
weights = weights_n

# TODO: do this in the kernel if it owns the whole line
qkv = out / weights.unsqueeze(-1)
qkvs.append(qkv)

Expand Down

0 comments on commit fc02cc8

Please sign in to comment.