-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor THCNumerics and add common math functions for at::Half (#10301)
Summary: **Summary**: This PR is a followup of mruberry's pytorch/pytorch#9318. It tries to achieve the following: - Specializing std common math functions for `at::Half` type. - Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`. - Update `THCNumerics.cuh` with new usage and comments to demonstrate the best practice for developers and hence, making way for its deprecation. - Remove legacy/redundant code path. - Remove unused CUDA HALF macros (see separate PR pytorch/pytorch#10147) **Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed: - All arithmetic can now be done in ATen using binary cuda kernel or CUDA tensor pointwise apply (check pytorch/pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float. - Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h` - Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call. - Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for `at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/HIP#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP). Here are some reference PRs that was handy in refactoring TH into ATen: - pytorch/pytorch#6786 - pytorch/pytorch#5475 - pytorch/pytorch#9401 - pytorch/pytorch#8689 - pytorch/pytorch#8919 Pull Request resolved: pytorch/pytorch#10301 Differential Revision: D9204758 Pulled By: soumith fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
- Loading branch information
1 parent
2bdc0f8
commit 7a19aac
Showing
12 changed files
with
339 additions
and
383 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#pragma once | ||
|
||
#include <cuda.h> | ||
#include <limits.h> | ||
|
||
// NumericLimits.cuh is a holder for numeric limits definitions of commonly used | ||
// types. This header is very specific to ROCm HIP and may be removed in the future. | ||
// This header is derived from the legacy THCNumerics.cuh. | ||
|
||
namespace at{ | ||
|
||
template <typename T> | ||
struct numeric_limits { | ||
}; | ||
|
||
// WARNING: the following at::numeric_limits definitions are there only to support | ||
// HIP compilation for the moment. Use std::numeric_limits if you are not | ||
// compiling for ROCm. | ||
// from @colesbury: "The functions on numeric_limits aren't marked with | ||
// __device__ which is why they don't work with ROCm. CUDA allows them | ||
// because they're constexpr." | ||
template <> | ||
struct numeric_limits<uint8_t> { | ||
static inline __host__ __device__ uint8_t lowest() { return 0; } | ||
static inline __host__ __device__ uint8_t max() { return UINT8_MAX; } | ||
}; | ||
|
||
template <> | ||
struct numeric_limits<int8_t> { | ||
static inline __host__ __device__ int8_t lowest() { return INT8_MIN; } | ||
static inline __host__ __device__ int8_t max() { return INT8_MAX; } | ||
}; | ||
|
||
template <> | ||
struct numeric_limits<int16_t> { | ||
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } | ||
static inline __host__ __device__ int16_t max() { return INT16_MAX; } | ||
}; | ||
|
||
template <> | ||
struct numeric_limits<int32_t> { | ||
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } | ||
static inline __host__ __device__ int32_t max() { return INT32_MAX; } | ||
}; | ||
|
||
template <> | ||
struct numeric_limits<int64_t> { | ||
#ifdef _MSC_VER | ||
static inline __host__ __device__ int64_t lowest() { return _I64_MIN; } | ||
static inline __host__ __device__ int64_t max() { return _I64_MAX; } | ||
#else | ||
static inline __host__ __device__ int64_t lowest() { return INT64_MIN; } | ||
static inline __host__ __device__ int64_t max() { return INT64_MAX; } | ||
#endif | ||
}; | ||
|
||
template <> | ||
struct numeric_limits<at::Half> { | ||
static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits); } | ||
static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits); } | ||
}; | ||
|
||
template <> | ||
struct numeric_limits<float> { | ||
static inline __host__ __device__ float lowest() { return -FLT_MAX; } | ||
static inline __host__ __device__ float max() { return FLT_MAX; } | ||
}; | ||
|
||
template <> | ||
struct numeric_limits<double> { | ||
static inline __host__ __device__ double lowest() { return -DBL_MAX; } | ||
static inline __host__ __device__ double max() { return DBL_MAX; } | ||
}; | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#define CATCH_CONFIG_MAIN | ||
#include "catch.hpp" | ||
|
||
#include "ATen/ATen.h" | ||
#include "ATen/cuda/NumericLimits.cuh" | ||
#include "cuda.h" | ||
#include "cuda_fp16.h" | ||
#include "cuda_runtime.h" | ||
|
||
#include <assert.h> | ||
|
||
using namespace at; | ||
|
||
__device__ void test(){ | ||
|
||
// test half construction and implicit conversions in device | ||
assert(Half(3) == Half(3.0f)); | ||
assert(static_cast<Half>(3.0f) == Half(3.0f)); | ||
// there is no float <=> __half implicit conversion | ||
assert(static_cast<Half>(3.0f) == 3.0f); | ||
|
||
__half a = __float2half(3.0f); | ||
__half b = __float2half(2.0f); | ||
__half c = a - Half(b); | ||
assert(static_cast<Half>(c) == Half(1.0)); | ||
|
||
// asserting if the functions used on | ||
// half types give almost equivalent results when using | ||
// functions on double. | ||
// The purpose of these asserts are to test the device side | ||
// half API for the common mathematical functions. | ||
// Note: When calling std math functions from device, don't | ||
// use the std namespace, but just "::" so that the function | ||
// gets resolved from nvcc math_functions.hpp | ||
|
||
float threshold = 0.00001; | ||
assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold); | ||
assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold); | ||
assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold); | ||
assert(::abs(::log10(Half(1000.0)) - ::log10(1000.0f)) <= threshold); | ||
assert(::abs(::log1p(Half(0.0)) - ::log1p(0.0f)) <= threshold); | ||
assert(::abs(::log2(Half(1000.0)) - ::log2(1000.0f)) <= threshold); | ||
assert(::abs(::expm1(Half(1.0)) - ::expm1(1.0f)) <= threshold); | ||
assert(::abs(::cos(Half(0.0)) - ::cos(0.0f)) <= threshold); | ||
assert(::abs(::sin(Half(0.0)) - ::sin(0.0f)) <= threshold); | ||
assert(::abs(::sqrt(Half(100.0)) - ::sqrt(100.0f)) <= threshold); | ||
assert(::abs(::ceil(Half(2.4)) - ::ceil(2.4f)) <= threshold); | ||
assert(::abs(::floor(Half(2.7)) - ::floor(2.7f)) <= threshold); | ||
assert(::abs(::trunc(Half(2.7)) - ::trunc(2.7f)) <= threshold); | ||
assert(::abs(::acos(Half(-1.0)) - ::acos(-1.0f)) <= threshold); | ||
assert(::abs(::cosh(Half(1.0)) - ::cosh(1.0f)) <= threshold); | ||
assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold); | ||
assert(::abs(::asin(Half(1.0)) - ::asin(1.0f)) <= threshold); | ||
assert(::abs(::sinh(Half(1.0)) - ::sinh(1.0f)) <= threshold); | ||
assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold); | ||
assert(::abs(::tan(Half(0.0)) - ::tan(0.0f)) <= threshold); | ||
assert(::abs(::atan(Half(1.0)) - ::atan(1.0f)) <= threshold); | ||
assert(::abs(::tanh(Half(1.0)) - ::tanh(1.0f)) <= threshold); | ||
assert(::abs(::erf(Half(10.0)) - ::erf(10.0f)) <= threshold); | ||
assert(::abs(::erfc(Half(10.0)) - ::erfc(10.0f)) <= threshold); | ||
assert(::abs(::abs(Half(-3.0)) - ::abs(-3.0f)) <= threshold); | ||
assert(::abs(::round(Half(2.3)) - ::round(2.3f)) <= threshold); | ||
assert(::abs(::pow(Half(2.0), Half(10.0)) - ::pow(2.0f, 10.0f)) <= threshold); | ||
assert(::abs(::atan2(Half(7.0), Half(0.0)) - ::atan2(7.0f, 0.0f)) <= threshold); | ||
// note: can't use namespace on isnan and isinf in device code | ||
#ifdef _MSC_VER | ||
// Windows requires this explicit conversion. The reason is unclear | ||
// related issue with clang: https://reviews.llvm.org/D37906 | ||
assert(::abs(::isnan((float)Half(0.0)) - ::isnan(0.0f)) <= threshold); | ||
assert(::abs(::isinf((float)Half(0.0)) - ::isinf(0.0f)) <= threshold); | ||
#else | ||
assert(::abs(::isnan(Half(0.0)) - ::isnan(0.0f)) <= threshold); | ||
assert(::abs(::isinf(Half(0.0)) - ::isinf(0.0f)) <= threshold); | ||
#endif | ||
} | ||
|
||
__global__ void kernel(){ | ||
test(); | ||
} | ||
|
||
void launch_function(){ | ||
kernel<<<1,1>>>(); | ||
} | ||
|
||
TEST_CASE( "half common math functions tests in device", "[cuda]" ) { | ||
launch_function(); | ||
cudaError_t err = cudaDeviceSynchronize(); | ||
REQUIRE(err == cudaSuccess); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.