Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pybind11 type caster for sycl::half #1655

Merged
merged 3 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions dpctl/apis/include/dpctl4pybind11.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,53 @@ struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
_("dpctl.program.SyclProgram"));
};

/* This type caster associates
* ``sycl::half`` C++ class with Python :class:`float` for the purposes
* of generation of Python bindings by pybind11.
*/
template <> struct type_caster<sycl::half>
{
public:
bool load(handle src, bool convert)
{
double py_value;

if (!src) {
return false;
}

PyObject *source = src.ptr();

if (convert || PyFloat_Check(source)) {
py_value = PyFloat_AsDouble(source);
}
else {
return false;
}

bool py_err = (py_value == double(-1)) && PyErr_Occurred();

if (py_err) {
PyErr_Clear();
if (convert && (PyNumber_Check(source) != 0)) {
auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
return load(tmp, false);
}
return false;
}
value = static_cast<sycl::half>(py_value);
return true;
}

static handle cast(sycl::half src, return_value_policy, handle)
{
return PyFloat_FromDouble(static_cast<double>(src));
}

PYBIND11_TYPE_CASTER(sycl::half, _("float"));
};

} // namespace detail
} // namespace pybind11

Expand Down
10 changes: 1 addition & 9 deletions dpctl/tensor/libtensor/source/full_ctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "utils/type_utils.hpp"

#include "full_ctor.hpp"
#include "unboxing_helper.hpp"

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
Expand Down Expand Up @@ -79,14 +78,7 @@ sycl::event full_contig_impl(sycl::queue &exec_q,
char *dst_p,
const std::vector<sycl::event> &depends)
{
dstTy fill_v;

PythonObjectUnboxer<dstTy> unboxer{};
try {
fill_v = unboxer(py_value);
} catch (const py::error_already_set &e) {
throw;
}
dstTy fill_v = py::cast<dstTy>(py_value);

using dpctl::tensor::kernels::constructors::full_contig_impl;

Expand Down
23 changes: 4 additions & 19 deletions dpctl/tensor/libtensor/source/linear_sequences.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "utils/type_utils.hpp"

#include "linear_sequences.hpp"
#include "unboxing_helper.hpp"

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
Expand Down Expand Up @@ -86,16 +85,8 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q,
char *array_data,
const std::vector<sycl::event> &depends)
{
Ty start_v;
Ty step_v;

const auto &unboxer = PythonObjectUnboxer<Ty>{};
try {
start_v = unboxer(start);
step_v = unboxer(step);
} catch (const py::error_already_set &e) {
throw;
}
Ty start_v = py::cast<Ty>(start);
Ty step_v = py::cast<Ty>(step);

using dpctl::tensor::kernels::constructors::lin_space_step_impl;

Expand Down Expand Up @@ -143,14 +134,8 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q,
char *array_data,
const std::vector<sycl::event> &depends)
{
Ty start_v, end_v;
const auto &unboxer = PythonObjectUnboxer<Ty>{};
try {
start_v = unboxer(start);
end_v = unboxer(end);
} catch (const py::error_already_set &e) {
throw;
}
Ty start_v = py::cast<Ty>(start);
Ty end_v = py::cast<Ty>(end);

using dpctl::tensor::kernels::constructors::lin_space_affine_impl;

Expand Down
53 changes: 0 additions & 53 deletions dpctl/tensor/libtensor/source/unboxing_helper.hpp

This file was deleted.