Skip to content

Commit

Permalink
[CustomDevice] add blas_axpby api for gradient_accumulator (#44584)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored Jul 26, 2022
1 parent 356ff43 commit 0d51fcf
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 13 deletions.
22 changes: 13 additions & 9 deletions paddle/fluid/imperative/gradient_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif

namespace paddle {
namespace imperative {
Expand Down Expand Up @@ -189,10 +192,19 @@ class TensorAddFunctor
place));
}
void operator()(const platform::CustomPlace& place) const {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::CustomDeviceContext* ctx =
dynamic_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
phi::stream::Stream stream(place, ctx->stream());
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
device->BlasAXPBY<T>(stream, static_cast<size_t>(numel_), 1., x_, 1., y_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
#endif
}

private:
Expand Down Expand Up @@ -351,15 +363,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
return;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type),
place));
}
#endif

#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(place)) {
if (data_type == framework::DataTypeTrait<float>::DataType()) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ if(WITH_CUSTOM_DEVICE)
cc_test(
custom_device_test
SRCS custom/custom_device_test.cc
DEPS phi_backends phi_device_context)
DEPS phi_backends phi_device_context gradient_accumulator)
cc_test(
capi_test
SRCS custom/capi_test.cc
Expand Down
46 changes: 46 additions & 0 deletions paddle/phi/backends/custom/custom_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/data_type.h"

#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_guard.h"
Expand Down Expand Up @@ -608,6 +610,27 @@ class CustomDevice : public DeviceInterface {
#undef return_result
}

C_DataType ToCDatatType(paddle::experimental::DataType data_type) {
#define return_result(in, ret) \
case in: \
return C_DataType::ret
switch (data_type) {
return_result(paddle::experimental::DataType::FLOAT64, FLOAT64);
return_result(paddle::experimental::DataType::FLOAT32, FLOAT32);
return_result(paddle::experimental::DataType::FLOAT16, FLOAT16);
return_result(paddle::experimental::DataType::INT64, INT64);
return_result(paddle::experimental::DataType::INT32, INT32);
return_result(paddle::experimental::DataType::INT16, INT16);
return_result(paddle::experimental::DataType::INT8, INT8);
default: {
PADDLE_THROW(phi::errors::Unavailable(
"DataType is not supported on %s.", Type()));
return C_DataType::UNDEFINED;
}
}
#undef return_result
}

void CCLGetUniqueId(ccl::CCLRootId* unique_id) override {
CHECK_PTR(pimpl_->xccl_get_unique_id_size);
CHECK_PTR(pimpl_->xccl_get_unique_id);
Expand Down Expand Up @@ -771,6 +794,27 @@ class CustomDevice : public DeviceInterface {
reinterpret_cast<C_Stream>(stream.raw_stream())));
}

void BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y) override {
CHECK_PTR(pimpl_->blas_axpby);
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->blas_axpby(device,
reinterpret_cast<C_Stream>(stream.raw_stream()),
ToCDatatType(dtype),
numel,
alpha,
x,
beta,
y));
}

private:
inline int PlaceToIdNoCheck(const Place& place) {
int dev_id = place.GetDeviceId();
Expand Down Expand Up @@ -877,6 +921,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(xccl_group_end, false);
CHECK_INTERFACE(xccl_send, false);
CHECK_INTERFACE(xccl_recv, false);

CHECK_INTERFACE(blas_axpby, false);
return true;
#undef CHECK_INTERFACE
}
Expand Down
48 changes: 48 additions & 0 deletions paddle/phi/backends/custom/custom_device_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/custom/fake_cpu_device.h"
#include "paddle/phi/backends/device_manager.h"
Expand Down Expand Up @@ -237,6 +239,51 @@ void TestCustomCCL(const paddle::platform::Place& place) {
stream);
}

void TestBlasAPI(const paddle::platform::Place& place) {
std::cout << "TestBlasAPI on " << place << std::endl;
if (paddle::platform::is_custom_place(place) == false) {
return;
}
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
phi::stream::Stream stream(place, nullptr);
device->BlasAXPBY<float>(stream, 0, 1., nullptr, 1., nullptr);

paddle::framework::Variable var1;
paddle::framework::Variable var2;
std::vector<float> src_data(10, 1.0);
std::vector<float> dst_data(10, 0.0);
std::vector<float> result;
paddle::platform::CPUPlace src_place;
for (unsigned int i = 0; i < 10; i++) {
result.emplace_back(src_data[i] + dst_data[i]);
}

std::vector<int64_t> dims = {2, 5};
auto* src = var1.GetMutable<paddle::framework::LoDTensor>();
auto* dst = var2.GetMutable<paddle::framework::LoDTensor>();
src->Resize(phi::make_ddim(dims));
dst->Resize(phi::make_ddim(dims));
auto* src_mutable = src->mutable_data<float>(place);
auto* dst_mutable = dst->mutable_data<float>(place);

paddle::memory::Copy(place,
src_mutable,
src_place,
src_data.data(),
sizeof(float) * src_data.size());

paddle::memory::Copy(place,
dst_mutable,
src_place,
dst_data.data(),
sizeof(float) * dst_data.size());

paddle::imperative::TensorAdd<paddle::framework::Variable>(var1, &var2);
paddle::framework::LoDTensor rlt;
paddle::platform::CPUPlace rlt_place;
paddle::framework::TensorCopySync(*dst, rlt_place, &rlt);
}

TEST(CustomDevice, Tensor) {
InitDevice();
auto dev_types = phi::DeviceManager::GetAllDeviceTypes();
Expand All @@ -251,6 +298,7 @@ TEST(CustomDevice, Tensor) {
TestTensorShareDataWith(place);
TestTensorUtils(place);
TestCustomCCL(place);
TestBlasAPI(place);
}
}

Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/backends/custom/fake_cpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ C_Status XcclRecv(void *recv_buf,
return C_SUCCESS;
}

C_Status BlasAXPBY(const C_Device device,
C_Stream stream,
C_DataType dtype,
size_t numel,
float alpha,
void *x,
float beta,
void *y) {
return C_SUCCESS;
}

#define DEVICE_TYPE "FakeCPU"
#define SUB_DEVICE_TYPE "V100"

Expand Down Expand Up @@ -278,4 +289,6 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
params->interface->xccl_reduce_scatter = XcclReduceScatter;
params->interface->xccl_send = XcclSend;
params->interface->xccl_recv = XcclRecv;

params->interface->blas_axpby = BlasAXPBY;
}
12 changes: 12 additions & 0 deletions paddle/phi/backends/device_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ void DeviceInterface::CCLRecv(void* recvbuf,
INTERFACE_UNIMPLEMENT;
}

// blas
void DeviceInterface::BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y) {
INTERFACE_UNIMPLEMENT;
}

#undef INTERFACE_UNIMPLEMENT

} // namespace phi
10 changes: 10 additions & 0 deletions paddle/phi/backends/device_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ class DeviceInterface { // Driver / Runtime
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);

// blas
virtual void BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y);

private:
const std::string type_;
const uint8_t priority_;
Expand Down
14 changes: 13 additions & 1 deletion paddle/phi/backends/device_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,19 @@ struct C_DeviceInterface {
// other api //
///////////////

void* reserved_other_api[8];
/**
* @brief y = alpha * x + beta * y
*
*/
C_Status (*blas_axpby)(const C_Device device,
C_Stream stream,
C_DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y);
void* reserved_other_api[7];
};

struct CustomRuntimeVersion {
Expand Down
75 changes: 75 additions & 0 deletions paddle/phi/backends/device_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/common/complex.h"

#if !defined(_WIN32)
#include <dirent.h>
Expand Down Expand Up @@ -135,6 +136,80 @@ void Device::MemorySet(void* ptr, uint8_t value, size_t size) {
impl_->MemorySet(dev_id_, ptr, value, size);
}

template <typename T>
void Device::BlasAXPBY(const stream::Stream& stream,
size_t numel,
float alpha,
const T* x,
float beta,
T* y) {
impl_->BlasAXPBY(dev_id_,
stream,
paddle::experimental::CppTypeToDataType<T>::Type(),
numel,
alpha,
reinterpret_cast<void*>(const_cast<T*>(x)),
beta,
reinterpret_cast<void*>(y));
}

template void Device::BlasAXPBY<paddle::float16>(const stream::Stream& stream,
size_t numel,
float alpha,
const paddle::float16* x,
float beta,
paddle::float16* y);
template void Device::BlasAXPBY<float>(const stream::Stream& stream,
size_t numel,
float alpha,
const float* x,
float beta,
float* y);
template void Device::BlasAXPBY<double>(const stream::Stream& stream,
size_t numel,
float alpha,
const double* x,
float beta,
double* y);
template void Device::BlasAXPBY<int8_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int8_t* x,
float beta,
int8_t* y);
template void Device::BlasAXPBY<int16_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int16_t* x,
float beta,
int16_t* y);
template void Device::BlasAXPBY<int32_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int32_t* x,
float beta,
int32_t* y);
template void Device::BlasAXPBY<int64_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int64_t* x,
float beta,
int64_t* y);
template void Device::BlasAXPBY<phi::dtype::complex<float>>(
const stream::Stream& stream,
size_t numel,
float alpha,
const phi::dtype::complex<float>* x,
float beta,
phi::dtype::complex<float>* y);
template void Device::BlasAXPBY<phi::dtype::complex<double>>(
const stream::Stream& stream,
size_t numel,
float alpha,
const phi::dtype::complex<double>* x,
float beta,
phi::dtype::complex<double>* y);

std::string Device::Type() { return impl_->Type(); }

static phi::RWLock _global_device_manager_rw_lock;
Expand Down
Loading

0 comments on commit 0d51fcf

Please sign in to comment.