diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h b/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h index bf3d547e52..01ad64bd1f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h @@ -95,8 +95,10 @@ div_round_up(uint32_t a, uint32_t b) { return ((a + b - 1) / b); } -C10_HOST_DEVICE C10_ALWAYS_INLINE int32_t -unpadded_row_size_in_bytes(int32_t dim, fbgemm_gpu::SparseType weight_ty) { +C10_HOST_DEVICE C10_ALWAYS_INLINE int32_t unpadded_row_size_in_bytes( + int32_t dim, + fbgemm_gpu::SparseType weight_ty, + const int32_t scale_bias_bytes = 4) { if (weight_ty == fbgemm_gpu::SparseType::FP32) { return dim * 4; } @@ -107,13 +109,13 @@ unpadded_row_size_in_bytes(int32_t dim, fbgemm_gpu::SparseType weight_ty) { return dim; } if (weight_ty == fbgemm_gpu::SparseType::INT8) { - return dim + 4; + return dim + scale_bias_bytes; } if (weight_ty == fbgemm_gpu::SparseType::INT4) { - return dim / 2 + 4; + return dim / 2 + scale_bias_bytes; } if (weight_ty == fbgemm_gpu::SparseType::INT2) { - return dim / 4 + 4; + return dim / 4 + scale_bias_bytes; } return 0; } @@ -121,9 +123,10 @@ unpadded_row_size_in_bytes(int32_t dim, fbgemm_gpu::SparseType weight_ty) { C10_HOST_DEVICE C10_ALWAYS_INLINE int32_t padded_row_size_in_bytes( int32_t dim, fbgemm_gpu::SparseType weight_ty, - int32_t row_alignment) { - auto r = unpadded_row_size_in_bytes(dim, weight_ty); - return round_up(r, row_alignment); + const int32_t row_alignment, + const int32_t scale_bias_bytes = 4) { + auto r = unpadded_row_size_in_bytes(dim, weight_ty, scale_bias_bytes); + return static_cast(round_up(r, row_alignment)); } } // namespace nbit