-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'gumbel_api' of github.com:PureNatural/Paddle into gumbe…
…l_api
- Loading branch information
Showing
189 changed files
with
6,282 additions
and
1,961 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
set(PADDLE_RPC_SRCS python_rpc_handler.cc rpc_agent.cc) | ||
|
||
set_source_files_properties( | ||
python_rpc_handler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) | ||
set_source_files_properties(rpc_agent.cc PROPERTIES COMPILE_FLAGS | ||
${DISTRIBUTE_COMPILE_FLAGS}) | ||
|
||
set(PADDLE_RPC_DEPS brpc protobuf glog pybind) | ||
proto_library(paddle_rpc_proto SRCS rpc.proto) | ||
cc_library( | ||
paddle_rpc | ||
SRCS ${PADDLE_RPC_SRCS} | ||
DEPS ${PADDLE_RPC_DEPS} paddle_rpc_proto) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
#include <cassert> | ||
#include <future> | ||
#include <string> | ||
|
||
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/fluid/platform/macros.h" | ||
|
||
namespace py = pybind11; | ||
namespace paddle { | ||
namespace distributed { | ||
class FutureWrapper { | ||
public: | ||
FutureWrapper() {} | ||
explicit FutureWrapper(std::future<std::string> fut) : fut_(std::move(fut)) {} | ||
py::object wait() { | ||
// GIL must be released, otherwise fut_.get() blocking will cause the | ||
// service to fail to process RPC requests, leading to deadlock | ||
PADDLE_ENFORCE_EQ( | ||
PyGILState_Check(), | ||
false, | ||
platform::errors::Fatal( | ||
"GIL must be released before fut.wait(), otherwise fut_.get() " | ||
"blocking will cause the service to fail to " | ||
"process RPC requests, leading to deadlock")); | ||
auto s = fut_.get(); | ||
py::gil_scoped_acquire ag; | ||
std::shared_ptr<PythonRpcHandler> python_handler = | ||
PythonRpcHandler::GetInstance(); | ||
py::object obj = python_handler->Deserialize(py::bytes(s)); | ||
return obj; | ||
} | ||
|
||
private: | ||
DISABLE_COPY_AND_ASSIGN(FutureWrapper); | ||
std::future<std::string> fut_; | ||
}; | ||
} // namespace distributed | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
constexpr auto kInternalModule = "paddle.distributed.rpc.internal"; | ||
|
||
py::object getFunction(const py::object& module, const char* name) { | ||
py::object fn = module.attr(name); | ||
return fn; | ||
} | ||
|
||
PythonRpcHandler::PythonRpcHandler() { | ||
py::gil_scoped_acquire ag; | ||
// import python module | ||
py::object rpc_internal = py::module::import(kInternalModule); | ||
py_run_function_ = getFunction(rpc_internal, "_run_py_func"); | ||
py_serialize_ = getFunction(rpc_internal, "_serialize"); | ||
py_deserialize_ = getFunction(rpc_internal, "_deserialize"); | ||
} | ||
|
||
py::object PythonRpcHandler::RunPythonFunc(const py::object& python_func) { | ||
py::gil_scoped_acquire ag; | ||
return py_run_function_(python_func); | ||
} | ||
|
||
std::string PythonRpcHandler::Serialize(const py::object& obj) { | ||
py::gil_scoped_acquire ag; | ||
py::object res = py_serialize_(obj); | ||
return res.cast<std::string>(); | ||
} | ||
|
||
py::object PythonRpcHandler::Deserialize(const std::string& obj) { | ||
py::gil_scoped_acquire ag; | ||
return py_deserialize_(py::bytes(obj)); | ||
} | ||
|
||
std::shared_ptr<PythonRpcHandler> PythonRpcHandler::python_rpc_handler_ = | ||
nullptr; | ||
std::mutex PythonRpcHandler::lock_; | ||
|
||
std::shared_ptr<PythonRpcHandler> PythonRpcHandler::GetInstance() { | ||
if (python_rpc_handler_ == nullptr) { | ||
std::lock_guard<std::mutex> guard(lock_); | ||
if (python_rpc_handler_ == nullptr) { | ||
python_rpc_handler_ = std::make_shared<PythonRpcHandler>(); | ||
return python_rpc_handler_; | ||
} | ||
} | ||
return python_rpc_handler_; | ||
} | ||
|
||
} // namespace distributed | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
#include <memory> | ||
#include <mutex> | ||
#include <string> | ||
|
||
#include "paddle/fluid/platform/macros.h" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
|
||
class PYBIND11_EXPORT PythonRpcHandler { | ||
public: | ||
PythonRpcHandler(); | ||
~PythonRpcHandler() = default; | ||
static std::shared_ptr<PythonRpcHandler> GetInstance(); | ||
// Run a pickled Python function and return the result py::object | ||
py::object RunPythonFunc(const py::object& python_func); | ||
|
||
// Serialized a py::object into a string | ||
std::string Serialize(const py::object& obj); | ||
|
||
// Deserialize a string into a py::object | ||
py::object Deserialize(const std::string& obj); | ||
|
||
private: | ||
DISABLE_COPY_AND_ASSIGN(PythonRpcHandler); | ||
|
||
static std::shared_ptr<PythonRpcHandler> python_rpc_handler_; | ||
// Ref to `paddle.distributed.rpc.internal.run_py_func`. | ||
py::object py_run_function_; | ||
|
||
// Ref to `paddle.distributed.rpc.internal.serialize`. | ||
py::object py_serialize_; | ||
|
||
// Ref to `paddle.distributed.rpc.internal.deserialize`. | ||
py::object py_deserialize_; | ||
|
||
// Lock to protect initialization. | ||
static std::mutex lock_; | ||
}; | ||
|
||
} // namespace distributed | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
|
||
syntax="proto2"; | ||
package paddle.distributed; | ||
|
||
option cc_generic_services = true; | ||
option cc_enable_arenas = true; | ||
|
||
message RpcRequest { | ||
required bytes message = 1; | ||
}; | ||
|
||
message RpcResponse { | ||
required bytes message = 1; | ||
}; | ||
|
||
service RpcBaseService { | ||
rpc Send(RpcRequest) returns (RpcResponse); | ||
rpc InvokeRpc(RpcRequest) returns (RpcResponse); | ||
}; |
Oops, something went wrong.