Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CustomDevice] add blas_axpby api for gradient_accumulator #44584

Merged
merged 1 commit into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions paddle/fluid/imperative/gradient_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif

namespace paddle {
namespace imperative {
Expand Down Expand Up @@ -189,10 +192,19 @@ class TensorAddFunctor
place));
}
void operator()(const platform::CustomPlace& place) const {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::CustomDeviceContext* ctx =
dynamic_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
phi::stream::Stream stream(place, ctx->stream());
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
device->BlasAXPBY<T>(stream, static_cast<size_t>(numel_), 1., x_, 1., y_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
#endif
}

private:
Expand Down Expand Up @@ -351,15 +363,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
return;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type),
place));
}
#endif

#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(place)) {
if (data_type == framework::DataTypeTrait<float>::DataType()) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ if(WITH_CUSTOM_DEVICE)
cc_test(
custom_device_test
SRCS custom/custom_device_test.cc
DEPS phi_backends phi_device_context)
DEPS phi_backends phi_device_context gradient_accumulator)
cc_test(
capi_test
SRCS custom/capi_test.cc
Expand Down
46 changes: 46 additions & 0 deletions paddle/phi/backends/custom/custom_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/data_type.h"

#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_guard.h"
Expand Down Expand Up @@ -608,6 +610,27 @@ class CustomDevice : public DeviceInterface {
#undef return_result
}

C_DataType ToCDatatType(paddle::experimental::DataType data_type) {
#define return_result(in, ret) \
case in: \
return C_DataType::ret
switch (data_type) {
return_result(paddle::experimental::DataType::FLOAT64, FLOAT64);
return_result(paddle::experimental::DataType::FLOAT32, FLOAT32);
return_result(paddle::experimental::DataType::FLOAT16, FLOAT16);
return_result(paddle::experimental::DataType::INT64, INT64);
return_result(paddle::experimental::DataType::INT32, INT32);
return_result(paddle::experimental::DataType::INT16, INT16);
return_result(paddle::experimental::DataType::INT8, INT8);
default: {
PADDLE_THROW(phi::errors::Unavailable(
"DataType is not supported on %s.", Type()));
return C_DataType::UNDEFINED;
}
}
#undef return_result
}

void CCLGetUniqueId(ccl::CCLRootId* unique_id) override {
CHECK_PTR(pimpl_->xccl_get_unique_id_size);
CHECK_PTR(pimpl_->xccl_get_unique_id);
Expand Down Expand Up @@ -771,6 +794,27 @@ class CustomDevice : public DeviceInterface {
reinterpret_cast<C_Stream>(stream.raw_stream())));
}

void BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y) override {
CHECK_PTR(pimpl_->blas_axpby);
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->blas_axpby(device,
reinterpret_cast<C_Stream>(stream.raw_stream()),
ToCDatatType(dtype),
numel,
alpha,
x,
beta,
y));
}

private:
inline int PlaceToIdNoCheck(const Place& place) {
int dev_id = place.GetDeviceId();
Expand Down Expand Up @@ -877,6 +921,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(xccl_group_end, false);
CHECK_INTERFACE(xccl_send, false);
CHECK_INTERFACE(xccl_recv, false);

CHECK_INTERFACE(blas_axpby, false);
return true;
#undef CHECK_INTERFACE
}
Expand Down
48 changes: 48 additions & 0 deletions paddle/phi/backends/custom/custom_device_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/custom/fake_cpu_device.h"
#include "paddle/phi/backends/device_manager.h"
Expand Down Expand Up @@ -237,6 +239,51 @@ void TestCustomCCL(const paddle::platform::Place& place) {
stream);
}

void TestBlasAPI(const paddle::platform::Place& place) {
std::cout << "TestBlasAPI on " << place << std::endl;
if (paddle::platform::is_custom_place(place) == false) {
return;
}
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
phi::stream::Stream stream(place, nullptr);
device->BlasAXPBY<float>(stream, 0, 1., nullptr, 1., nullptr);

paddle::framework::Variable var1;
paddle::framework::Variable var2;
std::vector<float> src_data(10, 1.0);
std::vector<float> dst_data(10, 0.0);
std::vector<float> result;
paddle::platform::CPUPlace src_place;
for (unsigned int i = 0; i < 10; i++) {
result.emplace_back(src_data[i] + dst_data[i]);
}

std::vector<int64_t> dims = {2, 5};
auto* src = var1.GetMutable<paddle::framework::LoDTensor>();
auto* dst = var2.GetMutable<paddle::framework::LoDTensor>();
src->Resize(phi::make_ddim(dims));
dst->Resize(phi::make_ddim(dims));
auto* src_mutable = src->mutable_data<float>(place);
auto* dst_mutable = dst->mutable_data<float>(place);

paddle::memory::Copy(place,
src_mutable,
src_place,
src_data.data(),
sizeof(float) * src_data.size());

paddle::memory::Copy(place,
dst_mutable,
src_place,
dst_data.data(),
sizeof(float) * dst_data.size());

paddle::imperative::TensorAdd<paddle::framework::Variable>(var1, &var2);
paddle::framework::LoDTensor rlt;
paddle::platform::CPUPlace rlt_place;
paddle::framework::TensorCopySync(*dst, rlt_place, &rlt);
}

TEST(CustomDevice, Tensor) {
InitDevice();
auto dev_types = phi::DeviceManager::GetAllDeviceTypes();
Expand All @@ -251,6 +298,7 @@ TEST(CustomDevice, Tensor) {
TestTensorShareDataWith(place);
TestTensorUtils(place);
TestCustomCCL(place);
TestBlasAPI(place);
}
}

Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/backends/custom/fake_cpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ C_Status XcclRecv(void *recv_buf,
return C_SUCCESS;
}

C_Status BlasAXPBY(const C_Device device,
C_Stream stream,
C_DataType dtype,
size_t numel,
float alpha,
void *x,
float beta,
void *y) {
return C_SUCCESS;
}

#define DEVICE_TYPE "FakeCPU"
#define SUB_DEVICE_TYPE "V100"

Expand Down Expand Up @@ -278,4 +289,6 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
params->interface->xccl_reduce_scatter = XcclReduceScatter;
params->interface->xccl_send = XcclSend;
params->interface->xccl_recv = XcclRecv;

params->interface->blas_axpby = BlasAXPBY;
}
12 changes: 12 additions & 0 deletions paddle/phi/backends/device_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ void DeviceInterface::CCLRecv(void* recvbuf,
INTERFACE_UNIMPLEMENT;
}

// blas
void DeviceInterface::BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y) {
INTERFACE_UNIMPLEMENT;
}

#undef INTERFACE_UNIMPLEMENT

} // namespace phi
10 changes: 10 additions & 0 deletions paddle/phi/backends/device_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ class DeviceInterface { // Driver / Runtime
const ccl::CCLComm& ccl_comm,
const stream::Stream& stream);

// blas
virtual void BlasAXPBY(size_t dev_id,
const stream::Stream& stream,
paddle::experimental::DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y);

private:
const std::string type_;
const uint8_t priority_;
Expand Down
14 changes: 13 additions & 1 deletion paddle/phi/backends/device_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,19 @@ struct C_DeviceInterface {
// other api //
///////////////

void* reserved_other_api[8];
/**
* @brief y = alpha * x + beta * y
*
*/
C_Status (*blas_axpby)(const C_Device device,
C_Stream stream,
C_DataType dtype,
size_t numel,
float alpha,
void* x,
float beta,
void* y);
void* reserved_other_api[7];
};

struct CustomRuntimeVersion {
Expand Down
75 changes: 75 additions & 0 deletions paddle/phi/backends/device_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/common/complex.h"

#if !defined(_WIN32)
#include <dirent.h>
Expand Down Expand Up @@ -135,6 +136,80 @@ void Device::MemorySet(void* ptr, uint8_t value, size_t size) {
impl_->MemorySet(dev_id_, ptr, value, size);
}

template <typename T>
void Device::BlasAXPBY(const stream::Stream& stream,
size_t numel,
float alpha,
const T* x,
float beta,
T* y) {
impl_->BlasAXPBY(dev_id_,
stream,
paddle::experimental::CppTypeToDataType<T>::Type(),
numel,
alpha,
reinterpret_cast<void*>(const_cast<T*>(x)),
beta,
reinterpret_cast<void*>(y));
}

template void Device::BlasAXPBY<paddle::float16>(const stream::Stream& stream,
size_t numel,
float alpha,
const paddle::float16* x,
float beta,
paddle::float16* y);
template void Device::BlasAXPBY<float>(const stream::Stream& stream,
size_t numel,
float alpha,
const float* x,
float beta,
float* y);
template void Device::BlasAXPBY<double>(const stream::Stream& stream,
size_t numel,
float alpha,
const double* x,
float beta,
double* y);
template void Device::BlasAXPBY<int8_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int8_t* x,
float beta,
int8_t* y);
template void Device::BlasAXPBY<int16_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int16_t* x,
float beta,
int16_t* y);
template void Device::BlasAXPBY<int32_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int32_t* x,
float beta,
int32_t* y);
template void Device::BlasAXPBY<int64_t>(const stream::Stream& stream,
size_t numel,
float alpha,
const int64_t* x,
float beta,
int64_t* y);
template void Device::BlasAXPBY<phi::dtype::complex<float>>(
const stream::Stream& stream,
size_t numel,
float alpha,
const phi::dtype::complex<float>* x,
float beta,
phi::dtype::complex<float>* y);
template void Device::BlasAXPBY<phi::dtype::complex<double>>(
const stream::Stream& stream,
size_t numel,
float alpha,
const phi::dtype::complex<double>* x,
float beta,
phi::dtype::complex<double>* y);

std::string Device::Type() { return impl_->Type(); }

static phi::RWLock _global_device_manager_rw_lock;
Expand Down
Loading