diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 3a22f1c2f1..6fac57c716 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -432,7 +432,8 @@ at::Tensor fusednbitrowwise_to_half_cpu( at::Tensor fusednbitrowwise_to_float_or_half_cpu( const at::Tensor& input, const int64_t bit_rate, - const int64_t output_dtype); + const int64_t output_dtype, + const bool scale_bias_last); at::Tensor quantize_mx_cuda( const at::Tensor& input, diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu index f29f867ced..e30b9f8e1c 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu @@ -74,7 +74,7 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel( } // Fused 4/2-bit rowwise -> FP32/FP16 kernel -template +template __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( const int bit_rate, const std::uint8_t* input, @@ -83,7 +83,6 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( output_t* const output) { const int num_elem_per_byte = 8 / bit_rate; const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte; - int row = (int)blockIdx.y * blockDim.y + threadIdx.y; const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; const int row_incre = blockDim.y * gridDim.y; @@ -92,9 +91,14 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( const std::uint8_t* input_row = input + row * ncols; const __half* input_row_scale_bias = reinterpret_cast( input_row + - (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); + (!scale_bias_last + ? 0 + : (output_columns + num_elem_per_byte - 1) / num_elem_per_byte)); float scale = __half2float(input_row_scale_bias[0]); float bias = __half2float(input_row_scale_bias[1]); + if constexpr (!scale_bias_last) { + input_row += 2 * sizeof(__half); + } output_t* output_row = output + row * output_columns; std::uint8_t quantized = input_row[col / num_elem_per_byte]; @@ -215,7 +219,8 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu( template Tensor _fusednbitrowwise_to_float_gpu_t( const Tensor& input, - const int64_t bit_rate) { + const int64_t bit_rate, + const bool scale_bias_last) { TENSOR_ON_CUDA_GPU(input); TENSOR_NDIM_EQUALS(input, 2); CUDA_DEVICE_GUARD(input); @@ -245,7 +250,9 @@ Tensor _fusednbitrowwise_to_float_gpu_t( {nrows, output_columns}, // 2 = sizeof(bfloat16) input.options().dtype(at::kBFloat16)); } else { - TORCH_CHECK(false, "Unsupported output dtype"); + TORCH_CHECK( + false, + "Unsupported output dtype within _fusednbitrowwise_to_float_gpu_t"); } if (nrows == 0 || output_columns == 0) { @@ -260,18 +267,25 @@ Tensor _fusednbitrowwise_to_float_gpu_t( const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); const dim3 gridDim(gridDim_x, gridDim_y); +#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \ + _fusednbitrowwise_to_float_cuda_kernel \ + <<>>( \ + bit_rate, \ + input.data_ptr(), \ + nrows, \ + ncols, \ + output.data_ptr()) + FBGEMM_DISPATCH_FLOATING_TYPES( output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] { - _fusednbitrowwise_to_float_cuda_kernel - <<>>( - bit_rate, - input.data_ptr(), - nrows, - ncols, - output.data_ptr()); + if (scale_bias_last) { + DEQUANT_LAUNCH_NBIT(true); + } else { + DEQUANT_LAUNCH_NBIT(false); + } C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - +#undef DEQUANT_LAUNCH_NBIT return output; } @@ -286,7 +300,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t( DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu( const at::Tensor& input, const int64_t bit_rate) { - return _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + return _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, true /* scale_bias_last */); } /// @ingroup quantize-ops-cuda @@ -301,7 +316,8 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu( DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu( const at::Tensor& input, const int64_t bit_rate) { - return _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + return _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, true /* scale_bias_last */); } /// @ingroup quantize-ops-cuda @@ -321,19 +337,23 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu( DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu( const at::Tensor& input, const int64_t bit_rate, - const int64_t output_dtype) { + const int64_t output_dtype, + const bool scale_bias_last) { Tensor output; SparseType output_sparse_dtype = static_cast(output_dtype); switch (output_sparse_dtype) { case SparseType::FP32: - output = _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); break; case SparseType::FP16: - output = _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); break; case SparseType::BF16: - output = _fusednbitrowwise_to_float_gpu_t(input, bit_rate); + output = _fusednbitrowwise_to_float_gpu_t( + input, bit_rate, scale_bias_last); break; default: TORCH_CHECK(false); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp index f10fd25d6a..529df15f3f 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp @@ -150,7 +150,9 @@ Tensor _fusednbitrowwise_to_float_cpu( return output; } -Tensor _fusednbitrowwise_sbfront_to_float_cpu( +// Both float16 and bfloat16 are of same type uint16_t +template +Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu( const Tensor& input, const int64_t bit_rate) { TENSOR_ON_CPU(input); @@ -165,15 +167,37 @@ Tensor _fusednbitrowwise_sbfront_to_float_cpu( (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte; Tensor output; - output = at::empty( - {nrows, output_columns}, // 4 = sizeof(float) - input.options().dtype(at::kFloat)); + if (std::is_same::value) { + output = at::empty( + {nrows, output_columns}, // 4 = sizeof(float) + input.options().dtype(at::kFloat)); + } else if (std::is_same::value) { + output = at::empty( + {nrows, output_columns}, // 2 = sizeof(half) + input.options().dtype(at::kHalf)); + } else if (std::is_same::value) { + output = at::empty( + {nrows, output_columns}, // 2 = sizeof(half) + input.options().dtype(at::kBFloat16)); + } else { + TORCH_CHECK( + false, + "Unsupported output dtype for _fusednbitrowwise_sbfront_to_float_or_half_cpu"); + } - float* output_data = static_cast( + using output_ty = std::conditional_t< + std::is_same::value, + float, + fbgemm::float16>; + output_ty* output_data = static_cast( output.data_ptr()); // output.data_ptr(); -> Yields // unresolved data_ptr symbol. - fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( + constexpr bool is_uint16_t_of_type_bf16 = + std::is_same::value; + fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef< + output_ty, + is_uint16_t_of_type_bf16>( bit_rate, input.data_ptr(), nrows, @@ -311,7 +335,7 @@ Tensor fusednbitrowwise_to_float_cpu( /// @ingroup quantize-data-cpu /// @brief Dequantize int4/int2 rows with scale and bias stored in the front -/// into float32. +/// into float32/float16/Bfloat16. /// @param input Tensor of int4/int2 rows with scale and bias stored in the /// front. /// @param bit_rate Bit rate of each element. Should be 4 or 2. @@ -323,8 +347,25 @@ Tensor fusednbitrowwise_to_float_cpu( /// purpose because its kernel is reference implementation and not optimized. Tensor fusednbitrowwise_sbfront_to_float_cpu( const Tensor& input, - const int64_t bit_rate) { - return _fusednbitrowwise_sbfront_to_float_cpu(input, bit_rate); + const int64_t bit_rate, + const int64_t output_dtype) { + SparseType output_sparse_dtype = static_cast(output_dtype); + switch (output_sparse_dtype) { + case SparseType::FP32: + return _fusednbitrowwise_sbfront_to_float_or_half_cpu( + input, bit_rate); + break; + case SparseType::FP16: + return _fusednbitrowwise_sbfront_to_float_or_half_cpu( + input, bit_rate); + break; + case SparseType::BF16: + return _fusednbitrowwise_sbfront_to_float_or_half_cpu( + input, bit_rate); + break; + default: + TORCH_CHECK(false); + } } /// @ingroup quantize-data-cpu @@ -340,7 +381,8 @@ Tensor fusednbitrowwise_to_half_cpu( Tensor fusednbitrowwise_to_float_or_half_cpu( const Tensor& input, const int64_t bit_rate, - const int64_t output_dtype) { + const int64_t output_dtype, + [[maybe_unused]] const bool scale_bias_last) { Tensor output; SparseType output_sparse_dtype = static_cast(output_dtype); @@ -520,11 +562,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor"); m.def( - "FusedNBitRowwiseQuantizedSBHalfFrontToFloat(Tensor input, int bit_rate) -> Tensor"); + "FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf(Tensor input, int bit_rate, int output_dtype) -> Tensor"); m.def( "FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor"); m.def( - "FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0) -> Tensor"); + "FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0, bool scale_bias_last=True) -> Tensor"); m.def( "FloatToHFP8Quantized(Tensor input, int ebits, int exponent_bias, float max_pos) -> Tensor"); m.def( @@ -542,7 +584,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { TORCH_LIBRARY_IMPL(fbgemm, QuantizedCPU, m) { DISPATCH_TO_QUANTIZED_CPU( - "FusedNBitRowwiseQuantizedSBHalfFrontToFloat", + "FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf", fbgemm_gpu::fusednbitrowwise_sbfront_to_float_cpu); } diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp index f66affea0e..8fc3e2c0cb 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp @@ -72,7 +72,8 @@ Tensor FloatToFP8RowwiseQuantized_meta(const Tensor& input, bool forward) { Tensor fusednbitrowwise_to_float_or_half_meta( const Tensor& input, const int64_t bit_rate, - const int64_t output_dtype) { + const int64_t output_dtype, + [[maybe_unused]] const bool scale_bias_last) { const at::SymIntArrayRef input_sizes = input.sym_sizes(); const at::SymInt nrows = input_sizes[0]; // Here we want the number of bytes in a row diff --git a/fbgemm_gpu/test/tbe/inference/common.py b/fbgemm_gpu/test/tbe/inference/common.py index b50b1d03f1..8c04441e47 100644 --- a/fbgemm_gpu/test/tbe/inference/common.py +++ b/fbgemm_gpu/test/tbe/inference/common.py @@ -351,8 +351,10 @@ def execute_nbit_forward_( # noqa C901 f = torch.cat(fs, dim=0).view(-1, D) if fc2.dtype == torch.quint4x2: - fc2_float = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloat( - fc2.cpu(), bit_rate=4 + fc2_float = ( + torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf( + fc2.cpu(), bit_rate=4, output_dtype=0 + ) ) else: fc2_float = fc2.float() diff --git a/fbgemm_gpu/test/tbe/inference/failures_dict_fast.json b/fbgemm_gpu/test/tbe/inference/failures_dict_fast.json index 231eb24f7e..930c0c2a0d 100644 --- a/fbgemm_gpu/test/tbe/inference/failures_dict_fast.json +++ b/fbgemm_gpu/test/tbe/inference/failures_dict_fast.json @@ -7,7 +7,8 @@ "fbgemm::FloatToHFP8Quantized": {}, "fbgemm::Fused8BitRowwiseQuantizedToFloat": {}, "fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf": {}, - "fbgemm::FusedNBitRowwiseQuantizedSBHalfFrontToFloat": {}, + "fbgemm::FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf": {}, + "fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf": {}, "fbgemm::HFP8QuantizedToFloat": {}, "fbgemm::asynchronous_complete_cumsum": {}, "fbgemm::bounds_check_indices": {}, @@ -44,9 +45,17 @@ "comment": "", "status": "xsuccess" }, + "NBitFowardTest.test_faketensor__test_nbit_forward_cpu_gpu_dequantize_parity": { + "comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests", + "status": "xfail" + }, "NBitFowardTest.test_faketensor__test_nbit_forward_cpu_seq_int4": { "comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests", "status": "xfail" + }, + "NBitFowardTest.test_schema__test_nbit_forward_cpu_gpu_dequantize_parity": { + "comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests", + "status": "xfail" } }, "fbgemm::int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index 5da2cf680a..efa5934898 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -64,6 +64,18 @@ "test_faketensor__test_nbit_forward_cpu_gpu_dequantize_parity": [ unittest.skip("Operator not implemented for Meta tensors"), ], + "test_schema__test_nbit_forward_cpu_gpu_dequantize_parity": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_autograd_registration__test_nbit_forward_cpu_gpu_dequantize_parity": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_aot_dispatch_static__test_nbit_forward_cpu_gpu_dequantize_parity": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_aot_dispatch_dynamic__test_nbit_forward_cpu_gpu_dequantize_parity": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], "test_faketensor__test_nbit_forward_cpu_seq_int4": { unittest.skip( "Operator outputs int4 tensors which do not support opcheck tests" @@ -755,10 +767,11 @@ def test_nbit_forward_cpu_seq_int4( nbit_weights_ty=st.sampled_from( [ SparseType.INT8, + SparseType.INT4, ] ), pooling_mode=st.sampled_from([PoolingMode.NONE]), - output_dtype=st.sampled_from([SparseType.BF16, SparseType.FP16]), + output_dtype=st.sampled_from([SparseType.FP16, SparseType.BF16]), D=st.sampled_from([32, 256, 384, 512, 1024]), B=st.integers(min_value=8, max_value=32), T=st.integers(min_value=10, max_value=20), @@ -827,13 +840,21 @@ def test_nbit_forward_cpu_gpu_dequantize_parity( (weights, scale_shift) = split_weights[t] (ref_weights, ref_scale_shift) = ref_split_weights[t] self.assertEqual(weights.size(), ref_weights.size()) - element_size = SparseType.INT8.bit_rate() / 8.0 + element_size = ( + SparseType.INT8.bit_rate() + if nbit_weights_ty == SparseType.INT8 + else SparseType.INT4.bit_rate() + ) / 8.0 rand_tensor = torch.rand( ref_weights.shape[0], int(ref_weights.shape[1] / element_size) ) rand_weights, rand_scale_shift = quantize_embs( rand_tensor, - SparseType.INT8, + ( + SparseType.INT8 + if nbit_weights_ty == SparseType.INT8 + else SparseType.INT4 + ), ) ref_weights.copy_(rand_weights) weights.copy_(ref_weights) @@ -861,14 +882,35 @@ def test_nbit_forward_cpu_gpu_dequantize_parity( quant_cc_output = quant_cc(indices.int(), offsets.int()) dequant_cc_output = dequant_cc(indices.int(), offsets.int()) cuda_device = torch.device("cuda") - dequant_output_from_quant_cc = ( - torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatOrHalf( - quant_cc_output.to(cuda_device), - output_dtype.as_int(), - quant_padding_float_type=False, - scale_bias_last=False, + if nbit_weights_ty == SparseType.INT8: + dequant_output_from_quant_cc = ( + torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatOrHalf( + quant_cc_output.to(cuda_device), + output_dtype.as_int(), + quant_padding_float_type=False, + scale_bias_last=False, + ) ) - ) + elif nbit_weights_ty == SparseType.INT4: + tensor_gpu = torch.zeros( + (quant_cc_output.shape[0], int((quant_cc_output.shape[1] + 1) / 2)), + dtype=torch.uint8, + device=cuda_device, + ) + tensor_gpu.untyped_storage().copy_(quant_cc_output.untyped_storage()) + dequant_output_from_quant_cc = ( + torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( + tensor_gpu, + bit_rate=4, + output_dtype=output_dtype.as_int(), + scale_bias_last=False, + ) + ) + else: + raise ValueError( + "Unsupported nbit_weights_ty in test_nbit_forward_cpu_gpu_dequantize_parity" + ) + torch.testing.assert_close( dequant_cc_output.cpu(), dequant_output_from_quant_cc.cpu(), diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 86d22595fe..4e5a4fd2ef 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -300,7 +300,8 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( const uint8_t* input, size_t input_rows, int input_columns, - OutputType* output); + OutputType* output, + bool scale_bias_last = true); /** * Convert float or half inputs to rowwise quantized (8-bit) outputs. @@ -360,7 +361,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( * Same as FusedNBitRowwiseQuantizedSBHalfToFloat but unoptimized. * This should not be called directly except in testing. */ -template +template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( int bit_rate, const uint8_t* input, diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index e6a53253f0..555a5cfc2b 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -723,7 +723,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( } } -template +template void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( int bit_rate, const uint8_t* input, @@ -733,7 +733,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( bool scale_bias_last) { static_assert( std::is_same() || std::is_same(), - "Only float and float16 types are allowed."); + "Only float, float16 or bfloat16 types are allowed."); int num_elem_per_byte = 8 / bit_rate; const int64_t output_columns = static_cast(input_columns - 2 * sizeof(float16)) * @@ -760,7 +760,11 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( if (std::is_same()) { output_row[col] = output_value; } else { - output_row[col] = cpu_float2half_rn(output_value); + if constexpr (is_uint16_t_of_type_bf16) { + output_row[col] = cpu_float2bfloat16(output_value); + } else { + output_row[col] = cpu_float2half_rn(output_value); + } } } } @@ -772,7 +776,8 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( const uint8_t* input, size_t input_rows, int input_columns, - OutputType* output) { + OutputType* output, + [[maybe_unused]] bool scale_bias_last) { if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 switch (bit_rate) { @@ -857,7 +862,15 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( int input_columns, \ std::uint8_t* output); \ template FBGEMM_API void \ - FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( \ + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( \ + int bit_rate, \ + const uint8_t* input, \ + size_t input_rows, \ + int input_columns, \ + type* output, \ + bool scale_bias_last); \ + template FBGEMM_API void \ + FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( \ int bit_rate, \ const uint8_t* input, \ size_t input_rows, \ @@ -869,7 +882,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( const uint8_t* input, \ size_t input_rows, \ int input_columns, \ - type* output); \ + type* output, \ + bool scale_bias_last); \ template FBGEMM_API void \ FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( \ const type* input, \