From cd85b520e6d430b6c18dd55ede3f1193623e9c6b Mon Sep 17 00:00:00 2001 From: Faran Ahmad Date: Fri, 17 Jan 2025 09:50:31 -0800 Subject: [PATCH] Support INT4 Dequant onto GPU for Seq INT TBE look up (#3584) Summary: Seq INT4 -> INT4 STBE look up is supported in the diff stack: https://www.internalfb.com/diff/D61305978 . This diff supports: 1. The dequanitzation of INT4 -> INT4 STBE look up onto Cuda for all float types 2. Extends the dequantization of INT4 > INT4 STBE look up onto CPU for BF16 The main gap is to handle the dequant for the case when scale bias for INT4 quantized tensor is in the front. While for CPU, just need to add the dequantization for BF16 based on dtype. This will enable us to reduce the network overhead to remote embedding server as well as D2H data transfer from onto GPU host. Differential Revision: D68187234 --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 3 +- .../quantize_fused_nbit_rowwise.cu | 58 ++++++++++------ .../src/quantize_ops/quantize_ops_cpu.cpp | 67 +++++++++++++++---- .../src/quantize_ops/quantize_ops_meta.cpp | 3 +- fbgemm_gpu/test/tbe/inference/common.py | 6 +- .../tbe/inference/failures_dict_fast.json | 11 ++- .../test/tbe/inference/nbit_forward_test.py | 62 ++++++++++++++--- include/fbgemm/QuantUtils.h | 5 +- src/QuantUtils.cc | 26 +++++-- 9 files changed, 186 insertions(+), 55 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 3a22f1c2f1..6fac57c716 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -432,7 +432,8 @@ at::Tensor fusednbitrowwise_to_half_cpu( at::Tensor fusednbitrowwise_to_float_or_half_cpu( const at::Tensor& input, const int64_t bit_rate, - const int64_t output_dtype); + const int64_t output_dtype, + const bool scale_bias_last); at::Tensor quantize_mx_cuda( const at::Tensor& input, diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu index f29f867ced..e30b9f8e1c 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu @@ -74,7 +74,7 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel( } // Fused 4/2-bit rowwise -> FP32/FP16 kernel -template +template __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( const int bit_rate, const std::uint8_t* input, @@ -83,7 +83,6 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( output_t* const output) { const int num_elem_per_byte = 8 / bit_rate; const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte; - int row = (int)blockIdx.y * blockDim.y + threadIdx.y; const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; const int row_incre = blockDim.y * gridDim.y; @@ -92,9 +91,14 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( const std::uint8_t* input_row = input + row * ncols; const __half* input_row_scale_bias = reinterpret_cast( input_row + - (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); + (!scale_bias_last + ? 0 + : (output_columns + num_elem_per_byte - 1) / num_elem_per_byte)); float scale = __half2float(input_row_scale_bias[0]); float bias = __half2float(input_row_scale_bias[1]); + if constexpr (!scale_bias_last) { + input_row += 2 * sizeof(__half); + } output_t* output_row = output + row * output_columns; std::uint8_t quantized = input_row[col / num_elem_per_byte]; @@ -215,7 +219,8 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu( template Tensor _fusednbitrowwise_to_float_gpu_t( const Tensor& input, - const int64_t bit_rate) { + const int64_t bit_rate, + const bool scale_bias_last) { TENSOR_ON_CUDA_GPU(input); TENSOR_NDIM_EQUALS(input, 2); CUDA_DEVICE_GUARD(input); @@ -245,7 +250,9 @@ Tensor _fusednbitrowwise_to_float_gpu_t( {nrows, output_columns}, // 2 = sizeof(bfloat16) input.options().dtype(at::kBFloat16)); } else { - TORCH_CHECK(false, "Unsupported output dtype"); + TORCH_CHECK( + false, + "Unsupported output dtype within _fusednbitrowwise_to_float_gpu_t"); } if (nrows == 0 || output_columns == 0) { @@ -260,18 +267,25 @@ Tensor _fusednbitrowwise_to_float_gpu_t( const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); const dim3 gridDim(gridDim_x, gridDim_y); +#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \ + _fusednbitrowwise_to_float_cuda_kernel \ + <<>>( \ + bit_rate, \ + input.data_ptr(), \ + nrows, \ + ncols, \ + output.data_ptr()) + FBGEMM_DISPATCH_FLOATING_TYPES( output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] { - _fusednbitrowwise_to_float_cuda_kernel - <<>>( - bit_rate, - input.data_ptr(), - nrows, - ncols, - output.data_ptr()); + if (scale_bias_last) { + DEQUANT_LAUNCH_NBIT(true); + } else { + DEQUANT_LAUNCH_NBIT(false); + } C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - +#undef DEQUANT_LAUNCH_NBIT return output; } @@ -286,7 +300,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t( DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu( const at::Tensor& input, const int64_t bit_rate) { - return _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + return _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, true /* scale_bias_last */); } /// @ingroup quantize-ops-cuda @@ -301,7 +316,8 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu( DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu( const at::Tensor& input, const int64_t bit_rate) { - return _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + return _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, true /* scale_bias_last */); } /// @ingroup quantize-ops-cuda @@ -321,19 +337,23 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu( DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu( const at::Tensor& input, const int64_t bit_rate, - const int64_t output_dtype) { + const int64_t output_dtype, + const bool scale_bias_last) { Tensor output; SparseType output_sparse_dtype = static_cast(output_dtype); switch (output_sparse_dtype) { case SparseType::FP32: - output = _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); break; case SparseType::FP16: - output = _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); break; case SparseType::BF16: - output = _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); break; default: TORCH_CHECK(false); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp index f10fd25d6a..b70e82b32b 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp @@ -150,7 +150,9 @@ Tensor _fusednbitrowwise_to_float_cpu( return output; } -Tensor _fusednbitrowwise_sbfront_to_float_cpu( +// Both float16 and bfloat16 are of same type uint16_t +template +Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu( const Tensor& input, const int64_t bit_rate) { TENSOR_ON_CPU(input); @@ -165,15 +167,36 @@ Tensor _fusednbitrowwise_sbfront_to_float_cpu( (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte; Tensor output; - output = at::empty( - {nrows, output_columns}, // 4 = sizeof(float) - input.options().dtype(at::kFloat)); + if (std::is_same::value) { + output = at::empty( + {nrows, output_columns}, // 4 = sizeof(float) + input.options().dtype(at::kFloat)); + } else if (std::is_same::value) { + output = at::empty( + {nrows, output_columns}, // 2 = sizeof(half) + input.options().dtype(at::kHalf)); + } else if (std::is_same::value) { + output = at::empty( + {nrows, output_columns}, // 2 = sizeof(half) + input.options().dtype(at::kBFloat16)); + } else { + TORCH_CHECK( + false, + "Unsupported output dtype for _fusednbitrowwise_sbfront_to_float_or_half_cpu"); + } - float* output_data = static_cast( + using output_ty = std::conditional_t< + std::is_same::value, + float, + fbgemm::float16>; + output_ty* output_data = static_cast( output.data_ptr()); // output.data_ptr(); -> Yields // unresolved data_ptr symbol. - fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( + constexpr bool is_float16_bf16 = std::is_same::value; + fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef< + output_ty, + is_float16_bf16>( bit_rate, input.data_ptr(), nrows, @@ -311,7 +334,7 @@ Tensor fusednbitrowwise_to_float_cpu( /// @ingroup quantize-data-cpu /// @brief Dequantize int4/int2 rows with scale and bias stored in the front -/// into float32. +/// into float32/float15/BFloat16. /// @param input Tensor of int4/int2 rows with scale and bias stored in the /// front. /// @param bit_rate Bit rate of each element. Should be 4 or 2. @@ -323,8 +346,25 @@ Tensor fusednbitrowwise_to_float_cpu( /// purpose because its kernel is reference implementation and not optimized. Tensor fusednbitrowwise_sbfront_to_float_cpu( const Tensor& input, - const int64_t bit_rate) { - return _fusednbitrowwise_sbfront_to_float_cpu(input, bit_rate); + const int64_t bit_rate, + const int64_t output_dtype) { + SparseType output_sparse_dtype = static_cast(output_dtype); + switch (output_sparse_dtype) { + case SparseType::FP32: + return _fusednbitrowwise_sbfront_to_float_or_half_cpu( + input, bit_rate); + break; + case SparseType::FP16: + return _fusednbitrowwise_sbfront_to_float_or_half_cpu( + input, bit_rate); + break; + case SparseType::BF16: + return _fusednbitrowwise_sbfront_to_float_or_half_cpu( + input, bit_rate); + break; + default: + TORCH_CHECK(false); + } } /// @ingroup quantize-data-cpu @@ -340,7 +380,8 @@ Tensor fusednbitrowwise_to_half_cpu( Tensor fusednbitrowwise_to_float_or_half_cpu( const Tensor& input, const int64_t bit_rate, - const int64_t output_dtype) { + const int64_t output_dtype, + [[maybe_unused]] const bool scale_bias_last) { Tensor output; SparseType output_sparse_dtype = static_cast(output_dtype); @@ -520,11 +561,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor"); m.def( - "FusedNBitRowwiseQuantizedSBHalfFrontToFloat(Tensor input, int bit_rate) -> Tensor"); + "FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf(Tensor input, int bit_rate, int output_dtype) -> Tensor"); m.def( "FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor"); m.def( - "FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0) -> Tensor"); + "FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0, bool scale_bias_last=True) -> Tensor"); m.def( "FloatToHFP8Quantized(Tensor input, int ebits, int exponent_bias, float max_pos) -> Tensor"); m.def( @@ -542,7 +583,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { TORCH_LIBRARY_IMPL(fbgemm, QuantizedCPU, m) { DISPATCH_TO_QUANTIZED_CPU( - "FusedNBitRowwiseQuantizedSBHalfFrontToFloat", + "FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf", fbgemm_gpu::fusednbitrowwise_sbfront_to_float_cpu); } diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp index f66affea0e..8fc3e2c0cb 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp @@ -72,7 +72,8 @@ Tensor FloatToFP8RowwiseQuantized_meta(const Tensor& input, bool forward) { Tensor fusednbitrowwise_to_float_or_half_meta( const Tensor& input, const int64_t bit_rate, - const int64_t output_dtype) { + const int64_t output_dtype, + [[maybe_unused]] const bool scale_bias_last) { const at::SymIntArrayRef input_sizes = input.sym_sizes(); const at::SymInt nrows = input_sizes[0]; // Here we want the number of bytes in a row diff --git a/fbgemm_gpu/test/tbe/inference/common.py b/fbgemm_gpu/test/tbe/inference/common.py index b50b1d03f1..8c04441e47 100644 --- a/fbgemm_gpu/test/tbe/inference/common.py +++ b/fbgemm_gpu/test/tbe/inference/common.py @@ -351,8 +351,10 @@ def execute_nbit_forward_( # noqa C901 f = torch.cat(fs, dim=0).view(-1, D) if fc2.dtype == torch.quint4x2: - fc2_float = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloat( - fc2.cpu(), bit_rate=4 + fc2_float = ( + torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf( + fc2.cpu(), bit_rate=4, output_dtype=0 + ) ) else: fc2_float = fc2.float() diff --git a/fbgemm_gpu/test/tbe/inference/failures_dict_fast.json b/fbgemm_gpu/test/tbe/inference/failures_dict_fast.json index 231eb24f7e..930c0c2a0d 100644 --- a/fbgemm_gpu/test/tbe/inference/failures_dict_fast.json +++ b/fbgemm_gpu/test/tbe/inference/failures_dict_fast.json @@ -7,7 +7,8 @@ "fbgemm::FloatToHFP8Quantized": {}, "fbgemm::Fused8BitRowwiseQuantizedToFloat": {}, "fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf": {}, - "fbgemm::FusedNBitRowwiseQuantizedSBHalfFrontToFloat": {}, + "fbgemm::FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf": {}, + "fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf": {}, "fbgemm::HFP8QuantizedToFloat": {}, "fbgemm::asynchronous_complete_cumsum": {}, "fbgemm::bounds_check_indices": {}, @@ -44,9 +45,17 @@ "comment": "", "status": "xsuccess" }, + "NBitFowardTest.test_faketensor__test_nbit_forward_cpu_gpu_dequantize_parity": { + "comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests", + "status": "xfail" + }, "NBitFowardTest.test_faketensor__test_nbit_forward_cpu_seq_int4": { "comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests", "status": "xfail" + }, + "NBitFowardTest.test_schema__test_nbit_forward_cpu_gpu_dequantize_parity": { + "comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests", + "status": "xfail" } }, "fbgemm::int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index 5da2cf680a..efa5934898 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -64,6 +64,18 @@ "test_faketensor__test_nbit_forward_cpu_gpu_dequantize_parity": [ unittest.skip("Operator not implemented for Meta tensors"), ], + "test_schema__test_nbit_forward_cpu_gpu_dequantize_parity": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_autograd_registration__test_nbit_forward_cpu_gpu_dequantize_parity": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_aot_dispatch_static__test_nbit_forward_cpu_gpu_dequantize_parity": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_aot_dispatch_dynamic__test_nbit_forward_cpu_gpu_dequantize_parity": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], "test_faketensor__test_nbit_forward_cpu_seq_int4": { unittest.skip( "Operator outputs int4 tensors which do not support opcheck tests" @@ -755,10 +767,11 @@ def test_nbit_forward_cpu_seq_int4( nbit_weights_ty=st.sampled_from( [ SparseType.INT8, + SparseType.INT4, ] ), pooling_mode=st.sampled_from([PoolingMode.NONE]), - output_dtype=st.sampled_from([SparseType.BF16, SparseType.FP16]), + output_dtype=st.sampled_from([SparseType.FP16, SparseType.BF16]), D=st.sampled_from([32, 256, 384, 512, 1024]), B=st.integers(min_value=8, max_value=32), T=st.integers(min_value=10, max_value=20), @@ -827,13 +840,21 @@ def test_nbit_forward_cpu_gpu_dequantize_parity( (weights, scale_shift) = split_weights[t] (ref_weights, ref_scale_shift) = ref_split_weights[t] self.assertEqual(weights.size(), ref_weights.size()) - element_size = SparseType.INT8.bit_rate() / 8.0 + element_size = ( + SparseType.INT8.bit_rate() + if nbit_weights_ty == SparseType.INT8 + else SparseType.INT4.bit_rate() + ) / 8.0 rand_tensor = torch.rand( ref_weights.shape[0], int(ref_weights.shape[1] / element_size) ) rand_weights, rand_scale_shift = quantize_embs( rand_tensor, - SparseType.INT8, + ( + SparseType.INT8 + if nbit_weights_ty == SparseType.INT8 + else SparseType.INT4 + ), ) ref_weights.copy_(rand_weights) weights.copy_(ref_weights) @@ -861,14 +882,35 @@ def test_nbit_forward_cpu_gpu_dequantize_parity( quant_cc_output = quant_cc(indices.int(), offsets.int()) dequant_cc_output = dequant_cc(indices.int(), offsets.int()) cuda_device = torch.device("cuda") - dequant_output_from_quant_cc = ( - torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatOrHalf( - quant_cc_output.to(cuda_device), - output_dtype.as_int(), - quant_padding_float_type=False, - scale_bias_last=False, + if nbit_weights_ty == SparseType.INT8: + dequant_output_from_quant_cc = ( + torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatOrHalf( + quant_cc_output.to(cuda_device), + output_dtype.as_int(), + quant_padding_float_type=False, + scale_bias_last=False, + ) ) - ) + elif nbit_weights_ty == SparseType.INT4: + tensor_gpu = torch.zeros( + (quant_cc_output.shape[0], int((quant_cc_output.shape[1] + 1) / 2)), + dtype=torch.uint8, + device=cuda_device, + ) + tensor_gpu.untyped_storage().copy_(quant_cc_output.untyped_storage()) + dequant_output_from_quant_cc = ( + torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( + tensor_gpu, + bit_rate=4, + output_dtype=output_dtype.as_int(), + scale_bias_last=False, + ) + ) + else: + raise ValueError( + "Unsupported nbit_weights_ty in test_nbit_forward_cpu_gpu_dequantize_parity" + ) + torch.testing.assert_close( dequant_cc_output.cpu(), dequant_output_from_quant_cc.cpu(), diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 86d22595fe..137d5b8156 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -300,7 +300,8 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( const uint8_t* input, size_t input_rows, int input_columns, - OutputType* output); + OutputType* output, + bool scale_bias_last = true); /** * Convert float or half inputs to rowwise quantized (8-bit) outputs. @@ -360,7 +361,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( * Same as FusedNBitRowwiseQuantizedSBHalfToFloat but unoptimized. * This should not be called directly except in testing. */ -template +template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( int bit_rate, const uint8_t* input, diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index e6a53253f0..23972fa573 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -723,7 +723,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( } } -template +template void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( int bit_rate, const uint8_t* input, @@ -733,7 +733,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( bool scale_bias_last) { static_assert( std::is_same() || std::is_same(), - "Only float and float16 types are allowed."); + "Only float, float16 or bfloat16 types are allowed."); int num_elem_per_byte = 8 / bit_rate; const int64_t output_columns = static_cast(input_columns - 2 * sizeof(float16)) * @@ -760,7 +760,11 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( if (std::is_same()) { output_row[col] = output_value; } else { - output_row[col] = cpu_float2half_rn(output_value); + if constexpr (is_float16_bf16) { + output_row[col] = cpu_float2bfloat16(output_value); + } else { + output_row[col] = cpu_float2half_rn(output_value); + } } } } @@ -772,7 +776,8 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( const uint8_t* input, size_t input_rows, int input_columns, - OutputType* output) { + OutputType* output, + [[maybe_unused]] bool scale_bias_last) { if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 switch (bit_rate) { @@ -857,7 +862,15 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( int input_columns, \ std::uint8_t* output); \ template FBGEMM_API void \ - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( \ + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( \ + int bit_rate, \ + const uint8_t* input, \ + size_t input_rows, \ + int input_columns, \ + type* output, \ + bool scale_bias_last); \ + template FBGEMM_API void \ + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( \ int bit_rate, \ const uint8_t* input, \ size_t input_rows, \ @@ -869,7 +882,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( const uint8_t* input, \ size_t input_rows, \ int input_columns, \ - type* output); \ + type* output, \ + bool scale_bias_last); \ template FBGEMM_API void \ FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( \ const type* input, \