Skip to content

Commit

Permalink
Optimize 3-bit packing
Browse files Browse the repository at this point in the history
Differential Revision: D64010666

Pull Request resolved: #1029
  • Loading branch information
metascroy authored and jainapurva committed Oct 15, 2024
1 parent 101d731 commit fa6d9c3
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <cassert>
Expand Down
2 changes: 2 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
vget_high_u8(shifted0),
vget_low_u8(shifted1),
vget_high_u8(shifted1));
break;
case 3:
uint8_t buffer3[32];
vst1q_u8(buffer3, shifted0);
Expand Down Expand Up @@ -185,6 +186,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
shifted0_low, shifted0_high, shifted1_low, shifted1_high, packed);
shifted0 = vcombine_u8(shifted0_low, shifted0_high);
shifted1 = vcombine_u8(shifted1_low, shifted1_high);
break;
case 3:
uint8_t buffer3[32];
torchao::bitpacking::internal::unpack_8_uint3_values(buffer3, packed);
Expand Down
Loading

0 comments on commit fa6d9c3

Please sign in to comment.