Skip to content

Commit

Permalink
Merge branch 'gumbel_api' of github.com:PureNatural/Paddle into gumbe…
Browse files Browse the repository at this point in the history
…l_api
  • Loading branch information
PureNatural committed Oct 14, 2022
2 parents e7108f0 + 017f66c commit 069cadf
Show file tree
Hide file tree
Showing 189 changed files with 6,282 additions and 1,961 deletions.
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ This is an incomplete list of authors of [Paddle](https://github.com/PaddlePaddl
| xushaoyong | Shao-Yong Xu |
| Yancey1989 | Xu Yan |
| zhaopu7 | Pu Zhao |
| zhiqiu | Qiu-Liang Chen |
| zhouxiao-coder | Xiao Zhou |
| Zrachel | Rui-Qing Zhang |
| jeng1220 | Bai-Cheng(Ryan) Jeng (NVIDIA) |
Expand Down
4 changes: 1 addition & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ find_package(CUDA QUIET)
find_package(MKL CONFIG QUIET)
option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF)
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_MPI "Compile PaddlePaddle with MPI" OFF)
option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
option(WITH_XPU_KP "Compile PaddlePaddle with BAIDU XPU compiler " OFF)
Expand Down Expand Up @@ -485,9 +486,6 @@ if(WITH_DISTRIBUTE)
ON
CACHE STRING "Enable GLOO when compiling WITH_DISTRIBUTE=ON." FORCE)
endif()
set(WITH_MPI
ON
CACHE STRING "Enable MPI when compiling WITH_DISTRIBUTE=ON." FORCE)
if(WITH_ASCEND_CL AND NOT WITH_ARM_BRPC)
# disable WITH_PSCORE for NPU before include third_party
message(
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ set(DISTRIBUTE_COMPILE_FLAGS
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()

if(LINUX)
add_subdirectory(rpc)
endif()
add_subdirectory(common)
add_subdirectory(ps)
add_subdirectory(test)
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/distributed/rpc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
set(PADDLE_RPC_SRCS python_rpc_handler.cc rpc_agent.cc)

set_source_files_properties(
python_rpc_handler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(rpc_agent.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})

set(PADDLE_RPC_DEPS brpc protobuf glog pybind)
proto_library(paddle_rpc_proto SRCS rpc.proto)
cc_library(
paddle_rpc
SRCS ${PADDLE_RPC_SRCS}
DEPS ${PADDLE_RPC_DEPS} paddle_rpc_proto)
57 changes: 57 additions & 0 deletions paddle/fluid/distributed/rpc/future_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <pybind11/pybind11.h>

#include <cassert>
#include <future>
#include <string>

#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"

namespace py = pybind11;
namespace paddle {
namespace distributed {
class FutureWrapper {
public:
FutureWrapper() {}
explicit FutureWrapper(std::future<std::string> fut) : fut_(std::move(fut)) {}
py::object wait() {
// GIL must be released, otherwise fut_.get() blocking will cause the
// service to fail to process RPC requests, leading to deadlock
PADDLE_ENFORCE_EQ(
PyGILState_Check(),
false,
platform::errors::Fatal(
"GIL must be released before fut.wait(), otherwise fut_.get() "
"blocking will cause the service to fail to "
"process RPC requests, leading to deadlock"));
auto s = fut_.get();
py::gil_scoped_acquire ag;
std::shared_ptr<PythonRpcHandler> python_handler =
PythonRpcHandler::GetInstance();
py::object obj = python_handler->Deserialize(py::bytes(s));
return obj;
}

private:
DISABLE_COPY_AND_ASSIGN(FutureWrapper);
std::future<std::string> fut_;
};
} // namespace distributed
} // namespace paddle
67 changes: 67 additions & 0 deletions paddle/fluid/distributed/rpc/python_rpc_handler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"

namespace paddle {
namespace distributed {
constexpr auto kInternalModule = "paddle.distributed.rpc.internal";

py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name);
return fn;
}

PythonRpcHandler::PythonRpcHandler() {
py::gil_scoped_acquire ag;
// import python module
py::object rpc_internal = py::module::import(kInternalModule);
py_run_function_ = getFunction(rpc_internal, "_run_py_func");
py_serialize_ = getFunction(rpc_internal, "_serialize");
py_deserialize_ = getFunction(rpc_internal, "_deserialize");
}

py::object PythonRpcHandler::RunPythonFunc(const py::object& python_func) {
py::gil_scoped_acquire ag;
return py_run_function_(python_func);
}

std::string PythonRpcHandler::Serialize(const py::object& obj) {
py::gil_scoped_acquire ag;
py::object res = py_serialize_(obj);
return res.cast<std::string>();
}

py::object PythonRpcHandler::Deserialize(const std::string& obj) {
py::gil_scoped_acquire ag;
return py_deserialize_(py::bytes(obj));
}

std::shared_ptr<PythonRpcHandler> PythonRpcHandler::python_rpc_handler_ =
nullptr;
std::mutex PythonRpcHandler::lock_;

std::shared_ptr<PythonRpcHandler> PythonRpcHandler::GetInstance() {
if (python_rpc_handler_ == nullptr) {
std::lock_guard<std::mutex> guard(lock_);
if (python_rpc_handler_ == nullptr) {
python_rpc_handler_ = std::make_shared<PythonRpcHandler>();
return python_rpc_handler_;
}
}
return python_rpc_handler_;
}

} // namespace distributed
} // namespace paddle
62 changes: 62 additions & 0 deletions paddle/fluid/distributed/rpc/python_rpc_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <pybind11/pybind11.h>

#include <memory>
#include <mutex>
#include <string>

#include "paddle/fluid/platform/macros.h"

namespace py = pybind11;

namespace paddle {
namespace distributed {

class PYBIND11_EXPORT PythonRpcHandler {
public:
PythonRpcHandler();
~PythonRpcHandler() = default;
static std::shared_ptr<PythonRpcHandler> GetInstance();
// Run a pickled Python function and return the result py::object
py::object RunPythonFunc(const py::object& python_func);

// Serialized a py::object into a string
std::string Serialize(const py::object& obj);

// Deserialize a string into a py::object
py::object Deserialize(const std::string& obj);

private:
DISABLE_COPY_AND_ASSIGN(PythonRpcHandler);

static std::shared_ptr<PythonRpcHandler> python_rpc_handler_;
// Ref to `paddle.distributed.rpc.internal.run_py_func`.
py::object py_run_function_;

// Ref to `paddle.distributed.rpc.internal.serialize`.
py::object py_serialize_;

// Ref to `paddle.distributed.rpc.internal.deserialize`.
py::object py_deserialize_;

// Lock to protect initialization.
static std::mutex lock_;
};

} // namespace distributed
} // namespace paddle
33 changes: 33 additions & 0 deletions paddle/fluid/distributed/rpc/rpc.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


syntax="proto2";
package paddle.distributed;

option cc_generic_services = true;
option cc_enable_arenas = true;

message RpcRequest {
required bytes message = 1;
};

message RpcResponse {
required bytes message = 1;
};

service RpcBaseService {
rpc Send(RpcRequest) returns (RpcResponse);
rpc InvokeRpc(RpcRequest) returns (RpcResponse);
};
Loading

0 comments on commit 069cadf

Please sign in to comment.