Skip to content

Commit

Permalink
Remove double type support from fbgemm_cuda_utils.cuh
Browse files Browse the repository at this point in the history
Summary: As title

Differential Revision: D53831943
  • Loading branch information
sryap authored and facebook-github-bot committed Feb 16, 2024
1 parent d00b5d5 commit feaf651
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 250 deletions.
239 changes: 1 addition & 238 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ struct Half4 {
}
};

// Customized 4-element vector data types (with element type Half, float, or
// double).
// Customized 4-element vector data types (with element type Half, or float).
template <typename T>
struct Vec4T {};

Expand All @@ -136,10 +135,6 @@ struct Vec4T<float> {
load(p);
}

DEVICE_INLINE Vec4T(const double* p) {
load(p);
}

DEVICE_INLINE Vec4T(const at::Half* p) {
load(p);
}
Expand All @@ -152,13 +147,6 @@ struct Vec4T<float> {
acc = *((const float4*)p);
}

DEVICE_INLINE void load(const double* p) {
acc.x = p[0];
acc.y = p[1];
acc.z = p[2];
acc.w = p[3];
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef USE_ROCM
union U {
Expand Down Expand Up @@ -239,13 +227,6 @@ struct Vec4T<float> {
p[3] = acc.w;
}

DEVICE_INLINE void store(double* p) const {
p[0] = acc.x;
p[1] = acc.y;
p[2] = acc.z;
p[3] = acc.w;
}

DEVICE_INLINE void store(uint8_t* p) const {
CUDA_KERNEL_ASSERT(false);
}
Expand Down Expand Up @@ -309,10 +290,6 @@ struct Vec4T<at::Half> {
load(p);
}

DEVICE_INLINE Vec4T(const double* p) {
load(p);
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef USE_ROCM
union U {
Expand Down Expand Up @@ -363,13 +340,6 @@ struct Vec4T<at::Half> {
acc = *((const float4*)p);
}

DEVICE_INLINE void load(const double* p) {
acc.x = p[0];
acc.y = p[1];
acc.z = p[2];
acc.w = p[3];
}

DEVICE_INLINE void load(const uint8_t* p) {
CUDA_KERNEL_ASSERT(false);
}
Expand Down Expand Up @@ -400,13 +370,6 @@ struct Vec4T<at::Half> {
*((float4*)p) = acc;
}

DEVICE_INLINE void store(double* p) const {
p[0] = acc.x;
p[1] = acc.y;
p[2] = acc.z;
p[3] = acc.w;
}

DEVICE_INLINE void store(uint8_t* p) const {
CUDA_KERNEL_ASSERT(false);
}
Expand Down Expand Up @@ -516,10 +479,6 @@ struct Vec4T<at::BFloat16> {
load(p);
}

DEVICE_INLINE Vec4T(const double* p) {
load(p);
}

DEVICE_INLINE void load(const at::BFloat16* p) {
acc.x = p[0];
acc.y = p[1];
Expand Down Expand Up @@ -570,13 +529,6 @@ struct Vec4T<at::BFloat16> {
acc = *((const float4*)p);
}

DEVICE_INLINE void load(const double* p) {
acc.x = p[0];
acc.y = p[1];
acc.z = p[2];
acc.w = p[3];
}

DEVICE_INLINE void load(const uint8_t* p) {
CUDA_KERNEL_ASSERT(false);
}
Expand Down Expand Up @@ -607,13 +559,6 @@ struct Vec4T<at::BFloat16> {
*((float4*)p) = acc;
}

DEVICE_INLINE void store(double* p) const {
p[0] = acc.x;
p[1] = acc.y;
p[2] = acc.z;
p[3] = acc.w;
}

DEVICE_INLINE void store(uint8_t* p) const {
CUDA_KERNEL_ASSERT(false);
}
Expand Down Expand Up @@ -681,164 +626,6 @@ struct Vec4T<at::BFloat16> {
}
};

template <>
struct Vec4T<double> {
double4 acc;
DEVICE_INLINE Vec4T() {
acc.x = 0;
acc.y = 0;
acc.z = 0;
acc.w = 0;
}

DEVICE_INLINE Vec4T(const at::Half* p) {
load(p);
}

DEVICE_INLINE Vec4T(const at::BFloat16* p) {
load(p);
}

DEVICE_INLINE Vec4T(const float* p) {
load(p);
}

DEVICE_INLINE Vec4T(const double* p) {
load(p);
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
} tmp_out;

// uint2 = 2 uints = 8 bytes
tmp_out.ui = *reinterpret_cast<uint2 const*>(p);

float2 a = __half22float2(tmp_out.h[0]);
float2 b = __half22float2(tmp_out.h[1]);

acc.x = a.x;
acc.y = a.y;
acc.z = b.x;
acc.w = b.y;
#else
Half4 out;
#if CUDA_VERSION >= 9000
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(__HALF2_TO_UI(out.a)), "=r"(__HALF2_TO_UI(out.b))
: "l"(p));
#else
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(out.a.x), "=r"(out.b.x)
: "l"(p));
#endif

float2 a = __half22float2(out.a);
float2 b = __half22float2(out.b);

acc.x = a.x;
acc.y = a.y;
acc.z = b.x;
acc.w = b.y;
#endif
}

DEVICE_INLINE void load(const at::BFloat16* p) {
acc.x = p[0];
acc.y = p[1];
acc.z = p[2];
acc.w = p[3];
}

DEVICE_INLINE void load(const float* p) {
acc.x = p[0];
acc.y = p[1];
acc.z = p[2];
acc.w = p[3];
}

DEVICE_INLINE void load(const uint8_t* p) {
CUDA_KERNEL_ASSERT(false);
}

DEVICE_INLINE void load(const double* p) {
acc = *((const double4*)p);
}

DEVICE_INLINE void store(double* p) const {
*((double4*)p) = acc;
}

DEVICE_INLINE void store(float* p) const {
float4* f4 = (float4*)p;
f4->x = acc.x;
f4->y = acc.y;
f4->z = acc.z;
f4->w = acc.w;
}

DEVICE_INLINE void store(at::Half* p) const {
float2 a;
a.x = acc.x;
a.y = acc.y;

float2 b;
b.x = acc.z;
b.y = acc.w;

Half4 out;
out.a = __float22half2_rn(a);
out.b = __float22half2_rn(b);
out.store(p);
}

DEVICE_INLINE void store(at::BFloat16* p) const {
p[0] = acc.x;
p[1] = acc.y;
p[2] = acc.z;
p[3] = acc.w;
}

DEVICE_INLINE static void copy(const double* src, double* dst) {
*((double4*)dst) = *((const double4*)src);
}

// this <- this + a * b
DEVICE_INLINE void fma_(const Vec4T<double>& a, const double b) {
acc.x = __fma_rn(a.acc.x, b, acc.x);
acc.y = __fma_rn(a.acc.y, b, acc.y);
acc.z = __fma_rn(a.acc.z, b, acc.z);
acc.w = __fma_rn(a.acc.w, b, acc.w);
}

// this <- this + a
DEVICE_INLINE void add_(const Vec4T<double>& a) {
acc.x += a.acc.x;
acc.y += a.acc.y;
acc.z += a.acc.z;
acc.w += a.acc.w;
}

// this <- this element-wise mul a
DEVICE_INLINE void element_wise_mul_(const Vec4T<double>& a) {
acc.x *= a.acc.x;
acc.y *= a.acc.y;
acc.z *= a.acc.z;
acc.w *= a.acc.w;
}

// this <- this * scale
DEVICE_INLINE void mul_(float scale) {
acc.x *= scale;
acc.y *= scale;
acc.z *= scale;
acc.w *= scale;
}
};

template <typename scalar_t>
DEVICE_INLINE Vec4T<scalar_t> vec4_acc(
const Vec4T<scalar_t>& lhs,
Expand Down Expand Up @@ -1202,14 +989,6 @@ DEVICE_INLINE void nearest_rounding_vector(
output[3] = lrintf((value.acc.w - qparams.y) * inv_scale);
}

template <>
DEVICE_INLINE void nearest_rounding_vector(
uint8_t* output,
const Vec4T<double>& value,
const float2 qparams) {
CUDA_KERNEL_ASSERT(false);
}

template <typename dst_t, typename src_t>
DEVICE_INLINE void quantize_store(
dst_t* output,
Expand Down Expand Up @@ -1561,14 +1340,6 @@ struct SharedMemory<float> {
}
};

template <>
struct SharedMemory<double> {
__device__ double* getPointer() {
extern __shared__ double s_double_t[];
return s_double_t;
}
};

template <>
struct SharedMemory<Vec4T<at::acc_type<float, true>>> {
__device__ Vec4T<at::acc_type<float, true>>* getPointer() {
Expand All @@ -1577,14 +1348,6 @@ struct SharedMemory<Vec4T<at::acc_type<float, true>>> {
}
};

template <>
struct SharedMemory<Vec4T<at::acc_type<double, true>>> {
__device__ Vec4T<at::acc_type<double, true>>* getPointer() {
extern __shared__ Vec4T<at::acc_type<double, true>> s_acc_double_vec_t[];
return s_acc_double_vec_t;
}
};

// Return if the address is aligned to the type (mainly for Vec4T).
template <class T>
DEVICE_INLINE bool is_aligned(const void* ptr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cuda_runtime.h>
#include "fbgemm_gpu/ops_utils.h"

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/layout_transform_ops.cuh"
#include "fbgemm_gpu/permute_pooled_embedding_ops.h"
Expand Down Expand Up @@ -105,12 +106,8 @@ Tensor permute_pooled_embs_gpu_impl(
std::min(static_cast<int32_t>(B), max_grid_dim_y),
(B + max_grid_dim_y - 1) / max_grid_dim_y);

AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
pooled_embs_contiguous.scalar_type(),
"permute_pooled_embeddings",
[&] {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
pooled_embs_contiguous.scalar_type(), "permute_pooled_embeddings", [&] {
permute_pooled_embs_kernel<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
pooled_embs_contiguous.data_ptr<scalar_t>(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <cuda.h>
#include <cuda_runtime.h>

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/layout_transform_ops.cuh"
#include "fbgemm_gpu/permute_pooled_embedding_ops_split.h"
Expand Down Expand Up @@ -104,12 +105,8 @@ Tensor permute_pooled_embs_split_gpu_impl(
std::min(static_cast<int32_t>(B), max_grid_dim_y),
(B + max_grid_dim_y - 1) / max_grid_dim_y);

AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
pooled_embs_contiguous.scalar_type(),
"permute_pooled_embeddings",
[&] {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
pooled_embs_contiguous.scalar_type(), "permute_pooled_embeddings", [&] {
permute_pooled_embs_kernel<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
pooled_embs_contiguous.data_ptr<scalar_t>(),
Expand Down

0 comments on commit feaf651

Please sign in to comment.