Skip to content

Commit

Permalink
Optimize 3-bit packing
Browse files Browse the repository at this point in the history
Summary:
Optimizes 3-bit packing as outlined here: T199311618

Before change:
----------------------------------------------------------------------------------
benchmark_pack_uint_values<3>/128/8           47.0 ns         46.4 ns     15106555
benchmark_pack_uint_values<3>/128/64          6.94 ns         6.90 ns    101226284
benchmark_pack_uint_values<3>/128/128         3.27 ns         3.24 ns    215022716
benchmark_unpack_uint_values<3>/128/8         22.0 ns         21.9 ns     32585572
benchmark_unpack_uint_values<3>/128/64        6.02 ns         5.98 ns    116910230
benchmark_unpack_uint_values<3>/128/128       2.74 ns         2.73 ns    257088291

After change:
----------------------------------------------------------------------------------
benchmark_pack_uint_values<3>/128/8           19.5 ns         19.5 ns     36050883
benchmark_pack_uint_values<3>/128/64          3.90 ns         3.87 ns    181151919
benchmark_pack_uint_values<3>/128/128         1.57 ns         1.57 ns    447247194
benchmark_unpack_uint_values<3>/128/8         20.5 ns         20.4 ns     34490914
benchmark_unpack_uint_values<3>/128/64        3.19 ns         3.11 ns    228019714
benchmark_unpack_uint_values<3>/128/128       1.71 ns         1.70 ns    408587338

Unpacking perf for 128 values is 1.60x faster (2.74/1.71).

Differential Revision: D64010666
  • Loading branch information
metascroy authored and facebook-github-bot committed Oct 7, 2024
1 parent dec0313 commit 5387317
Show file tree
Hide file tree
Showing 4 changed files with 377 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ void pack_uint_values<3>(
int variant) {
constexpr int nbit = 3;
pack_uint_odd_bit_values(
torchao::bitpacking::internal::pack_8_uint3_values,
torchao::bitpacking::internal::vec_pack_64_uint3_values,
torchao::bitpacking::internal::vec_pack_128_uint3_values,
torchao::bitpacking::internal::pack_8_uint3_values_v2,
torchao::bitpacking::internal::vec_pack_64_uint3_values_v2,
torchao::bitpacking::internal::vec_pack_128_uint3_values_v2,
nbit,
packed,
unpacked,
Expand All @@ -373,9 +373,9 @@ void unpack_uint_values<3>(
int variant) {
constexpr int nbit = 3;
unpack_uint_odd_bit_values(
torchao::bitpacking::internal::unpack_8_uint3_values,
torchao::bitpacking::internal::vec_unpack_64_uint3_values,
torchao::bitpacking::internal::vec_unpack_128_uint3_values,
torchao::bitpacking::internal::unpack_8_uint3_values_v2,
torchao::bitpacking::internal::vec_unpack_64_uint3_values_v2,
torchao::bitpacking::internal::vec_unpack_128_uint3_values_v2,
nbit,
unpacked,
packed,
Expand Down
26 changes: 14 additions & 12 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
vget_high_u8(shifted0),
vget_low_u8(shifted1),
vget_high_u8(shifted1));
break;
case 3:
uint8_t buffer3[32];
vst1q_u8(buffer3, shifted0);
vst1q_u8(buffer3 + 16, shifted1);

torchao::bitpacking::internal::pack_8_uint3_values(packed, buffer3);
torchao::bitpacking::internal::pack_8_uint3_values(
torchao::bitpacking::internal::pack_8_uint3_values_v2(packed, buffer3);
torchao::bitpacking::internal::pack_8_uint3_values_v2(
packed + 3, buffer3 + 8);
torchao::bitpacking::internal::pack_8_uint3_values(
torchao::bitpacking::internal::pack_8_uint3_values_v2(
packed + 6, buffer3 + 16);
torchao::bitpacking::internal::pack_8_uint3_values(
torchao::bitpacking::internal::pack_8_uint3_values_v2(
packed + 9, buffer3 + 24);
break;
case 4:
Expand Down Expand Up @@ -185,14 +186,15 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
shifted0_low, shifted0_high, shifted1_low, shifted1_high, packed);
shifted0 = vcombine_u8(shifted0_low, shifted0_high);
shifted1 = vcombine_u8(shifted1_low, shifted1_high);
break;
case 3:
uint8_t buffer3[32];
torchao::bitpacking::internal::unpack_8_uint3_values(buffer3, packed);
torchao::bitpacking::internal::unpack_8_uint3_values(
torchao::bitpacking::internal::unpack_8_uint3_values_v2(buffer3, packed);
torchao::bitpacking::internal::unpack_8_uint3_values_v2(
buffer3 + 8, packed + 3);
torchao::bitpacking::internal::unpack_8_uint3_values(
torchao::bitpacking::internal::unpack_8_uint3_values_v2(
buffer3 + 16, packed + 6);
torchao::bitpacking::internal::unpack_8_uint3_values(
torchao::bitpacking::internal::unpack_8_uint3_values_v2(
buffer3 + 24, packed + 9);
shifted0 = vld1q_u8(buffer3);
shifted1 = vld1q_u8(buffer3 + 16);
Expand Down Expand Up @@ -258,7 +260,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
packed, shifted0, shifted1, shifted2, shifted3);
break;
case 3:
torchao::bitpacking::internal::vec_pack_64_uint3_values(
torchao::bitpacking::internal::vec_pack_64_uint3_values_v2(
packed, shifted0, shifted1, shifted2, shifted3);
break;
case 4:
Expand Down Expand Up @@ -309,7 +311,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
shifted0, shifted1, shifted2, shifted3, packed);
break;
case 3:
torchao::bitpacking::internal::vec_unpack_64_uint3_values(
torchao::bitpacking::internal::vec_unpack_64_uint3_values_v2(
shifted0, shifted1, shifted2, shifted3, packed);
break;
case 4:
Expand Down Expand Up @@ -387,7 +389,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
packed + 16, shifted4, shifted5, shifted6, shifted7);
break;
case 3:
torchao::bitpacking::internal::vec_pack_128_uint3_values(
torchao::bitpacking::internal::vec_pack_128_uint3_values_v2(
packed,
shifted0,
shifted1,
Expand Down Expand Up @@ -478,7 +480,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
shifted4, shifted5, shifted6, shifted7, packed + 16);
break;
case 3:
torchao::bitpacking::internal::vec_unpack_128_uint3_values(
torchao::bitpacking::internal::vec_unpack_128_uint3_values_v2(
shifted0,
shifted1,
shifted2,
Expand Down
247 changes: 247 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,33 @@ TORCHAO_ALWAYS_INLINE inline void pack_8_uint3_values(
((unpacked[6] & 3) << 4) | ((unpacked[7] & 3) << 6);
}

TORCHAO_ALWAYS_INLINE inline void pack_8_uint3_values_v2(
uint8_t* packed,
const uint8_t* unpacked) {
// Given 8 unpacked uint3 values: abc, def, ghi, jkl, mno, pqr, 012, 345
// this function packs them as:
// b0: 12|abc|def (bottom bits from 7th value, full bits from 1st/2nd
// value)
// b1: 45|ghi|jkl (bottom bits from 8th value, full bits from
// 3rd/4th value)
// b2: 03|mno|pqr (top bit from 7th/8th value, full bits
// from 5th/6th value)
// These are stored in packed as: b0, b1, b2
//
// Input is 8 bytes
// Output is 24 bits = 3 bytes

// b0
packed[0] = ((unpacked[6] & 3) << 6) | ((unpacked[0] & 7) << 3) | unpacked[1];

// b1
packed[1] = ((unpacked[7] & 3) << 6) | ((unpacked[2] & 7) << 3) | unpacked[3];

// b2
packed[2] = ((unpacked[6] & 4) << 5) | ((unpacked[7] & 4) << 4) |
((unpacked[4] & 7) << 3) | unpacked[5];
}

TORCHAO_ALWAYS_INLINE inline void unpack_8_uint3_values(
uint8_t* unpacked,
const uint8_t* packed) {
Expand All @@ -70,6 +97,31 @@ TORCHAO_ALWAYS_INLINE inline void unpack_8_uint3_values(
unpacked[7] = ((b2 & 128) >> 5) | ((b10_1 & 192) >> 6);
}

TORCHAO_ALWAYS_INLINE inline void unpack_8_uint3_values_v2(
uint8_t* unpacked,
const uint8_t* packed) {
// Unpacks data packed by pack_8_uint3_values_v2
//
// Input is 24 bits = 3 bytes
// Output is 8 bytes

uint8_t b0 = packed[0];
uint8_t b1 = packed[1];
uint8_t b2 = packed[2];

unpacked[0] = ((b0 >> 3) & 7);
unpacked[1] = b0 & 7;

unpacked[2] = ((b1 >> 3) & 7);
unpacked[3] = b1 & 7;

unpacked[4] = ((b2 >> 3) & 7);
unpacked[5] = b2 & 7;

unpacked[6] = (b0 >> 6) | ((b2 >> 5) & 4);
unpacked[7] = (b1 >> 6) | ((b2 >> 4) & 4);
}

TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint3_values(
uint8_t* packed,
const uint8x16_t& unpacked0,
Expand Down Expand Up @@ -135,6 +187,47 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint3_values(
vst1_u8(packed + 16, b10_1);
}

TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint3_values_v2(
uint8_t* packed,
const uint8x16_t& unpacked0, // 0, 1
const uint8x16_t& unpacked1, // 2, 3
const uint8x16_t& unpacked2, // 4, 5
const uint8x16_t& unpacked3) { // 6, 7
// This function is a vectorized version of pack_8_uint3_values_v2
// To understand it, please see pack_8_uint3_values_v2 first.
// Before each code section, there is a comment indicating the
// code in pack_8_uint3_values_v2 that is being vectorized
//
// Input is 64 bytes
// Output is 3*64= 192 bits = 24 bytes

uint8x8_t b;
// b0
// packed[0] = ((unpacked[6] & 3) << 6) | ((unpacked[0] & 7) << 3) |
// unpacked[1];
b = vshl_n_u8(vand_u8(vget_low_u8(unpacked3), vdup_n_u8(3)), 6);
b = vorr_u8(b, vshl_n_u8(vand_u8(vget_low_u8(unpacked0), vdup_n_u8(7)), 3));
b = vorr_u8(b, vget_high_u8(unpacked0));
vst1_u8(packed, b);

// b1
// packed[1] = ((unpacked[7] & 3) << 6) | ((unpacked[2] & 7) << 3) |
// unpacked[3];
b = vshl_n_u8(vand_u8(vget_high_u8(unpacked3), vdup_n_u8(3)), 6);
b = vorr_u8(b, vshl_n_u8(vand_u8(vget_low_u8(unpacked1), vdup_n_u8(7)), 3));
b = vorr_u8(b, vget_high_u8(unpacked1));
vst1_u8(packed + 8, b);

// b2
// packed[2] = ((unpacked[6] & 4) << 5) | ((unpacked[7] & 4) << 4) |
// ((unpacked[4] & 7) << 3) | unpacked[5];
b = vshl_n_u8(vand_u8(vget_low_u8(unpacked3), vdup_n_u8(4)), 5);
b = vorr_u8(b, vshl_n_u8(vand_u8(vget_high_u8(unpacked3), vdup_n_u8(4)), 4));
b = vorr_u8(b, vshl_n_u8(vand_u8(vget_low_u8(unpacked2), vdup_n_u8(7)), 3));
b = vorr_u8(b, vget_high_u8(unpacked2));
vst1_u8(packed + 16, b);
}

TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint3_values(
uint8x16_t& unpacked0,
uint8x16_t& unpacked1,
Expand Down Expand Up @@ -203,6 +296,62 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint3_values(
unpacked3 = vcombine_u8(unpacked_tmp0, unpacked_tmp1);
}

TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint3_values_v2(
uint8x16_t& unpacked0, // 0, 1
uint8x16_t& unpacked1, // 2, 3
uint8x16_t& unpacked2, // 4, 5
uint8x16_t& unpacked3, // 6, 7
const uint8_t* packed) {
// Unpacks data packed by pack_64_uint3_values_v2
//
// This function vectorizes vec_unpack_8_uint3_values_v2
// To understand it, please see vec_unpack_8_uint3_values_v2 first.
// Before each code section, there is a comment indicating the
// code in vec_unpack_8_uint3_values_v2 that is being vectorized

// Input is 3*64= 192 bits = 24 bytes
// Output is 64 bytes

uint8x8_t b0 = vld1_u8(packed);
uint8x8_t b1 = vld1_u8(packed + 8);
uint8x8_t b2 = vld1_u8(packed + 16);
uint8x8_t unpacked_tmp0;
uint8x8_t unpacked_tmp1;

// unpacked[0] = ((b0 >> 3) & 7);
uint8x8_t mask = vdup_n_u8(7);
unpacked_tmp0 = vand_u8(vshr_n_u8(b0, 3), mask);

// unpacked[1] = b0 & 7;
unpacked_tmp1 = vand_u8(b0, mask);
unpacked0 = vcombine_u8(unpacked_tmp0, unpacked_tmp1);

// unpacked[2] = ((b1 >> 3) & 7);
unpacked_tmp0 = vand_u8(vshr_n_u8(b1, 3), mask);

// unpacked[3] = b1 & 7;
unpacked_tmp1 = vand_u8(b1, mask);
unpacked1 = vcombine_u8(unpacked_tmp0, unpacked_tmp1);

// unpacked[4] = ((b2 >> 3) & 7);
unpacked_tmp0 = vand_u8(vshr_n_u8(b2, 3), mask);

// unpacked[5] = b2 & 7;
unpacked_tmp1 = vand_u8(b2, mask);
unpacked2 = vcombine_u8(unpacked_tmp0, unpacked_tmp1);

// unpacked[6] = (b0 >> 6) | ((b2 >> 5) & 4);
mask = vdup_n_u8(4);
unpacked_tmp0 = vshr_n_u8(b0, 6);
unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(vshr_n_u8(b2, 5), mask));

// unpacked[7] = (b1 >> 6) | ((b2 >> 4) & 4);
unpacked_tmp1 = vshr_n_u8(b1, 6);
unpacked_tmp1 = vorr_u8(unpacked_tmp1, vand_u8(vshr_n_u8(b2, 4), mask));

unpacked3 = vcombine_u8(unpacked_tmp0, unpacked_tmp1);
}

TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uint3_values(
uint8_t* packed,
const uint8x16_t& unpacked0,
Expand Down Expand Up @@ -266,6 +415,51 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uint3_values(
vst1q_u8(packed + 32, b10_1);
}

TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uint3_values_v2(
uint8_t* packed,
const uint8x16_t& unpacked0,
const uint8x16_t& unpacked1,
const uint8x16_t& unpacked2,
const uint8x16_t& unpacked3,
const uint8x16_t& unpacked4,
const uint8x16_t& unpacked5,
const uint8x16_t& unpacked6,
const uint8x16_t& unpacked7) {
// This function is a vectorized version of pack_8_uint3_values_v2
// To understand it, please see pack_8_uint3_values_v2 first.
// Before each code section, there is a comment indicating the
// code in pack_8_uint3_values_v2 that is being vectorized
//
// Input is 128 bytes
// Output is 3*128= 384 bits = 48 bytes

uint8x16_t b;
// b0
// packed[0] = ((unpacked[6] & 3) << 6) | ((unpacked[0] & 7) << 3) |
// unpacked[1];
b = vshlq_n_u8(vandq_u8(unpacked6, vdupq_n_u8(3)), 6);
b = vorrq_u8(b, vshlq_n_u8(vandq_u8(unpacked0, vdupq_n_u8(7)), 3));
b = vorrq_u8(b, unpacked1);
vst1q_u8(packed, b);

// b1
// packed[1] = ((unpacked[7] & 3) << 6) | ((unpacked[2] & 7) << 3) |
// unpacked[3];
b = vshlq_n_u8(vandq_u8(unpacked7, vdupq_n_u8(3)), 6);
b = vorrq_u8(b, vshlq_n_u8(vandq_u8(unpacked2, vdupq_n_u8(7)), 3));
b = vorrq_u8(b, unpacked3);
vst1q_u8(packed + 16, b);

// b2
// packed[2] = ((unpacked[6] & 4) << 5) | ((unpacked[7] & 4) << 4) |
// ((unpacked[4] & 7) << 3) | unpacked[5];
b = vshlq_n_u8(vandq_u8(unpacked6, vdupq_n_u8(4)), 5);
b = vorrq_u8(b, vshlq_n_u8(vandq_u8(unpacked7, vdupq_n_u8(4)), 4));
b = vorrq_u8(b, vshlq_n_u8(vandq_u8(unpacked4, vdupq_n_u8(7)), 3));
b = vorrq_u8(b, unpacked5);
vst1q_u8(packed + 32, b);
}

TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint3_values(
uint8x16_t& unpacked0,
uint8x16_t& unpacked1,
Expand Down Expand Up @@ -329,6 +523,59 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint3_values(
vorrq_u8(unpacked7, vshrq_n_u8(vandq_u8(b10_1, vdupq_n_u8(192)), 6));
}

TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint3_values_v2(
uint8x16_t& unpacked0,
uint8x16_t& unpacked1,
uint8x16_t& unpacked2,
uint8x16_t& unpacked3,
uint8x16_t& unpacked4,
uint8x16_t& unpacked5,
uint8x16_t& unpacked6,
uint8x16_t& unpacked7,
const uint8_t* packed) {
// Unpacks data packed by pack_128_uint3_values_v2
//
// This function vectorizes vec_unpack_8_uint3_values_v2
// To understand it, please see vec_unpack_8_uint3_values_v2 first.
// Before each code section, there is a comment indicating the
// code in vec_unpack_8_uint3_values_v2 that is being vectorized

// Input is 3*128 = 384 bits = 48 bytes
// Output is 128 bytes

uint8x16_t b0 = vld1q_u8(packed);
uint8x16_t b1 = vld1q_u8(packed + 16);
uint8x16_t b2 = vld1q_u8(packed + 32);

// unpacked[0] = ((b0 >> 3) & 7);
uint8x16_t mask = vdupq_n_u8(7);
unpacked0 = vandq_u8(vshrq_n_u8(b0, 3), mask);

// unpacked[1] = b0 & 7;
unpacked1 = vandq_u8(b0, mask);

// unpacked[2] = ((b1 >> 3) & 7);
unpacked2 = vandq_u8(vshrq_n_u8(b1, 3), mask);

// unpacked[3] = b1 & 7;
unpacked3 = vandq_u8(b1, mask);

// unpacked[4] = ((b2 >> 3) & 7);
unpacked4 = vandq_u8(vshrq_n_u8(b2, 3), mask);

// unpacked[5] = b2 & 7;
unpacked5 = vandq_u8(b2, mask);

// unpacked[6] = (b0 >> 6) | ((b2 >> 5) & 4);
mask = vdupq_n_u8(4);
unpacked6 = vshrq_n_u8(b0, 6);
unpacked6 = vorrq_u8(unpacked6, vandq_u8(vshrq_n_u8(b2, 5), mask));

// unpacked[7] = (b1 >> 6) | ((b2 >> 4) & 4);
unpacked7 = vshrq_n_u8(b1, 6);
unpacked7 = vorrq_u8(unpacked7, vandq_u8(vshrq_n_u8(b2, 4), mask));
}

} // namespace internal
} // namespace bitpacking
} // namespace torchao
Expand Down
Loading

0 comments on commit 5387317

Please sign in to comment.