Skip to content

Commit

Permalink
Improve type casters (#837)
Browse files Browse the repository at this point in the history
* Fixed docstring, added wait method to Device class

* Untangled input type overloads for functions to get default type category for queue/device

* Improve performance of default constructors for usm_memory and ums_ndarray
  • Loading branch information
oleksandr-pavlyk authored May 18, 2022
2 parents 50be964 + ba50f98 commit fbe078e
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 71 deletions.
207 changes: 151 additions & 56 deletions dpctl/apis/include/dpctl4pybind11.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,60 @@ namespace pybind11
namespace detail
{

#define DPCTL_TYPE_CASTER(type, py_name) \
protected: \
std::unique_ptr<type> value; \
\
public: \
static constexpr auto name = py_name; \
template < \
typename T_, \
::pybind11::detail::enable_if_t< \
std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
int> = 0> \
static ::pybind11::handle cast(T_ *src, \
::pybind11::return_value_policy policy, \
::pybind11::handle parent) \
{ \
if (!src) \
return ::pybind11::none().release(); \
if (policy == ::pybind11::return_value_policy::take_ownership) { \
auto h = cast(std::move(*src), policy, parent); \
delete src; \
return h; \
} \
return cast(*src, policy, parent); \
} \
operator type *() \
{ \
return value.get(); \
} /* NOLINT(bugprone-macro-parentheses) */ \
operator type &() \
{ \
return *value; \
} /* NOLINT(bugprone-macro-parentheses) */ \
operator type &&() && \
{ \
return std::move(*value); \
} /* NOLINT(bugprone-macro-parentheses) */ \
template <typename T_> \
using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>

/* This type caster associates ``sycl::queue`` C++ class with
* :class:`dpctl.SyclQueue` for the purposes of generation of
* Python bindings by pybind11.
*/
template <> struct type_caster<sycl::queue>
{
public:
PYBIND11_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));

bool load(handle src, bool)
{
PyObject *source = src.ptr();
if (PyObject_TypeCheck(source, &PySyclQueueType)) {
DPCTLSyclQueueRef QRef = SyclQueue_GetQueueRef(
reinterpret_cast<PySyclQueueObject *>(source));
sycl::queue *q = reinterpret_cast<sycl::queue *>(QRef);
value = *q;
value = std::make_unique<sycl::queue>(
*(reinterpret_cast<sycl::queue *>(QRef)));
return true;
}
else {
Expand All @@ -69,6 +106,8 @@ template <> struct type_caster<sycl::queue>
auto tmp = SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&src));
return handle(reinterpret_cast<PyObject *>(tmp));
}

DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
};

/* This type caster associates ``sycl::device`` C++ class with
Expand All @@ -78,20 +117,14 @@ template <> struct type_caster<sycl::queue>
template <> struct type_caster<sycl::device>
{
public:
PYBIND11_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));

bool load(handle src, bool)
{
PyObject *source = src.ptr();
if (PyObject_TypeCheck(source, &PySyclDeviceType)) {
DPCTLSyclDeviceRef DRef = SyclDevice_GetDeviceRef(
reinterpret_cast<PySyclDeviceObject *>(source));
sycl::device *d = reinterpret_cast<sycl::device *>(DRef);
value = *d;
return true;
}
else if (source == Py_None) {
value = sycl::device{};
value = std::make_unique<sycl::device>(
*(reinterpret_cast<sycl::device *>(DRef)));
return true;
}
else {
Expand All @@ -105,6 +138,8 @@ template <> struct type_caster<sycl::device>
auto tmp = SyclDevice_Make(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
return handle(reinterpret_cast<PyObject *>(tmp));
}

DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
};

/* This type caster associates ``sycl::context`` C++ class with
Expand All @@ -114,16 +149,14 @@ template <> struct type_caster<sycl::device>
template <> struct type_caster<sycl::context>
{
public:
PYBIND11_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));

bool load(handle src, bool)
{
PyObject *source = src.ptr();
if (PyObject_TypeCheck(source, &PySyclContextType)) {
DPCTLSyclContextRef CRef = SyclContext_GetContextRef(
reinterpret_cast<PySyclContextObject *>(source));
sycl::context *ctx = reinterpret_cast<sycl::context *>(CRef);
value = *ctx;
value = std::make_unique<sycl::context>(
*(reinterpret_cast<sycl::context *>(CRef)));
return true;
}
else {
Expand All @@ -138,6 +171,8 @@ template <> struct type_caster<sycl::context>
SyclContext_Make(reinterpret_cast<DPCTLSyclContextRef>(&src));
return handle(reinterpret_cast<PyObject *>(tmp));
}

DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
};

/* This type caster associates ``sycl::event`` C++ class with
Expand All @@ -147,16 +182,14 @@ template <> struct type_caster<sycl::context>
template <> struct type_caster<sycl::event>
{
public:
PYBIND11_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));

bool load(handle src, bool)
{
PyObject *source = src.ptr();
if (PyObject_TypeCheck(source, &PySyclEventType)) {
DPCTLSyclEventRef ERef = SyclEvent_GetEventRef(
reinterpret_cast<PySyclEventObject *>(source));
sycl::event *ev = reinterpret_cast<sycl::event *>(ERef);
value = *ev;
value = std::make_unique<sycl::event>(
*(reinterpret_cast<sycl::event *>(ERef)));
return true;
}
else {
Expand All @@ -170,12 +203,102 @@ template <> struct type_caster<sycl::event>
auto tmp = SyclEvent_Make(reinterpret_cast<DPCTLSyclEventRef>(&src));
return handle(reinterpret_cast<PyObject *>(tmp));
}

DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
};
} // namespace detail
} // namespace pybind11

namespace dpctl
{

namespace detail
{

struct dpctl_api
{
public:
static dpctl_api &get()
{
static dpctl_api api;
return api;
}

py::object sycl_queue_()
{
return *sycl_queue;
}
py::object default_usm_memory_()
{
return *default_usm_memory;
}
py::object default_usm_ndarray_()
{
return *default_usm_ndarray;
}
py::object as_usm_memory_()
{
return *as_usm_memory;
}

private:
struct Deleter
{
void operator()(py::object *p) const
{
bool guard = (Py_IsInitialized() && !_Py_IsFinalizing());

if (guard) {
delete p;
}
}
};

std::shared_ptr<py::object> sycl_queue;
std::shared_ptr<py::object> default_usm_memory;
std::shared_ptr<py::object> default_usm_ndarray;
std::shared_ptr<py::object> as_usm_memory;

dpctl_api() : sycl_queue{}, default_usm_memory{}, default_usm_ndarray{}
{
import_dpctl();

sycl::queue q_;
py::object py_sycl_queue = py::cast(q_);
sycl_queue = std::shared_ptr<py::object>(new py::object{py_sycl_queue},
Deleter{});

py::module_ mod_memory = py::module_::import("dpctl.memory");
py::object py_as_usm_memory = mod_memory.attr("as_usm_memory");
as_usm_memory = std::shared_ptr<py::object>(
new py::object{py_as_usm_memory}, Deleter{});

auto mem_kl = mod_memory.attr("MemoryUSMHost");
py::object py_default_usm_memory =
mem_kl(1, py::arg("queue") = py_sycl_queue);
default_usm_memory = std::shared_ptr<py::object>(
new py::object{py_default_usm_memory}, Deleter{});

py::module_ mod_usmarray =
py::module_::import("dpctl.tensor._usmarray");
auto tensor_kl = mod_usmarray.attr("usm_ndarray");

py::object py_default_usm_ndarray =
tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
py::arg("buffer") = py_default_usm_memory);

default_usm_ndarray = std::shared_ptr<py::object>(
new py::object{py_default_usm_ndarray}, Deleter{});
}

public:
dpctl_api(dpctl_api const &) = delete;
void operator=(dpctl_api const &) = delete;
~dpctl_api(){};
};

} // namespace detail

namespace memory
{

Expand Down Expand Up @@ -232,7 +355,9 @@ class usm_memory : public py::object
}
// END_TOKEN

usm_memory() : py::object(default_constructed(), stolen_t{})
usm_memory()
: py::object(::dpctl::detail::dpctl_api::get().default_usm_memory_(),
borrowed_t{})
{
if (!m_ptr)
throw py::error_already_set();
Expand Down Expand Up @@ -267,26 +392,12 @@ class usm_memory : public py::object
"cannot create a usm_memory from a nullptr");
return nullptr;
}
py::module_ m = py::module_::import("dpctl.memory");
auto convertor = m.attr("as_usm_memory");

py::object res;
try {
res = convertor(py::handle(o));
} catch (const py::error_already_set &e) {
return nullptr;
}
return res.ptr();
}
auto convertor = ::dpctl::detail::dpctl_api::get().as_usm_memory_();

static PyObject *default_constructed()
{
py::module_ m = py::module_::import("dpctl.memory");
auto kl = m.attr("MemoryUSMDevice");
py::object res;
try {
// allocate 1 byte
res = kl(1);
res = convertor(py::handle(o));
} catch (const py::error_already_set &e) {
return nullptr;
}
Expand All @@ -295,10 +406,7 @@ class usm_memory : public py::object
};

} // end namespace memory
} // end namespace dpctl

namespace dpctl
{
namespace tensor
{
class usm_ndarray : public py::object
Expand Down Expand Up @@ -349,7 +457,9 @@ class usm_ndarray : public py::object
}
// END_TOKEN

usm_ndarray() : py::object(default_constructed(), stolen_t{})
usm_ndarray()
: py::object(::dpctl::detail::dpctl_api::get().default_usm_ndarray_(),
borrowed_t{})
{
if (!m_ptr)
throw py::error_already_set();
Expand Down Expand Up @@ -481,21 +591,6 @@ class usm_ndarray : public py::object

return UsmNDArray_GetElementSize(raw_ar);
}

private:
static PyObject *default_constructed()
{
py::module_ m = py::module_::import("dpctl.tensor");
auto kl = m.attr("usm_ndarray");
py::object res;
try {
// allocate 1 byte
res = kl(py::make_tuple(), py::arg("dtype") = "u1");
} catch (const py::error_already_set &e) {
return nullptr;
}
return res.ptr();
}
};

} // end namespace tensor
Expand Down
10 changes: 8 additions & 2 deletions dpctl/tensor/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class Device:
This is a wrapper around :class:`dpctl.SyclQueue` with custom
formatting. The class does not have public constructor,
but a class method to construct it from device= keyword
in Array-API functions.
but a class method `create_device` to construct it from device= keyword
argument in Array-API functions.
Instance can be queried for ``sycl_queue``, ``sycl_context``,
or ``sycl_device``.
Expand Down Expand Up @@ -111,6 +111,12 @@ def __repr__(self):
# This is a sub-device
return repr(self.sycl_queue)

def wait(self):
"""
Call ``wait`` method of the underlying ``sycl_queue``.
"""
self.sycl_queue_.wait()


def normalize_queue_device(sycl_queue=None, device=None):
"""
Expand Down
Loading

0 comments on commit fbe078e

Please sign in to comment.