Skip to content

Commit

Permalink
dpnp.full() complete and test case added
Browse files Browse the repository at this point in the history
Remove 'like' from dpnp.empty() overload

ty_x --> ty_x1

Bitcasts are done through unions
  • Loading branch information
khaled committed Apr 13, 2023
1 parent bbf978e commit d2e8141
Show file tree
Hide file tree
Showing 11 changed files with 394 additions and 88 deletions.
77 changes: 65 additions & 12 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
#include "_queuestruct.h"
#include "numba/_arraystruct.h"

typedef union
{
float f_;
uint32_t i_;
} float_uint32_t;

typedef union
{
double d_;
uint64_t i_;
} double_uint64_t;

// forward declarations
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim);
Expand All @@ -34,6 +46,12 @@ static NRT_ExternalAllocator *
NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type);
static void *DPEXRTQueue_CreateFromFilterString(const char *device);
static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner);
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
size_t itemsize,
bool dest_is_float,
bool value_is_float,
int64_t value,
const char *device);
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
void *data,
npy_intp nitems,
Expand Down Expand Up @@ -521,8 +539,9 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
*/
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
size_t itemsize,
bool is_float,
uint8_t value,
bool dest_is_float,
bool value_is_float,
int64_t value,
const char *device)
{
DPCTLSyclQueueRef qref = NULL;
Expand Down Expand Up @@ -553,12 +572,17 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
case 3:
{
uint64_t value_assign = (uint64_t)value;
if (is_float) {
double const_val = (double)value;
if (dest_is_float && !value_is_float) {
double_uint64_t du;
// To stop warning: dereferencing type-punned pointer
// will break strict-aliasing rules [-Wstrict-aliasing]
double *p = &const_val;
value_assign = *((uint64_t *)(p));
double cd = (double)value;
du.d_ = *((double *)(&cd));
value_assign = du.i_;
}
else if (!dest_is_float && value_is_float) {
double *p = (double *)&value;
value_assign = *p;
}
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count)))
goto error;
Expand All @@ -567,25 +591,54 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
case 2:
{
uint32_t value_assign = (uint32_t)value;
if (is_float) {
float const_val = (float)value;
if (dest_is_float && value_is_float) {
float_uint32_t fu;
double *p = (double *)(&value);
fu.f_ = *p;
value_assign = fu.i_;
}
else if (dest_is_float && !value_is_float) {
float_uint32_t fu;
// To stop warning: dereferencing type-punned pointer
// will break strict-aliasing rules [-Wstrict-aliasing]
float *p = &const_val;
value_assign = *((uint32_t *)(p));
float cf = (float)value;
fu.f_ = *((float *)(&cf));
value_assign = fu.i_;
}
else if (!dest_is_float && value_is_float) {
double *p = (double *)&value;
value_assign = *p;
}
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count)))
goto error;
break;
}
case 1:
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value, count)))
{
if (dest_is_float)
goto error;
uint16_t value_assign = (uint16_t)value;
if (value_is_float) {
double *p = (double *)&value;
value_assign = *p;
}
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value_assign, count)))
goto error;
break;
}
case 0:
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value, count)))
{
if (dest_is_float)
goto error;
uint8_t value_assign = (uint8_t)value;
if (value_is_float) {
double *p = (double *)&value;
value_assign = *p;
}
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value_assign, count)))
goto error;
break;
}
default:
goto error;
}
Expand Down
35 changes: 30 additions & 5 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,25 @@ def meminfo_alloc(self, builder, size, usm_type, device):
return self.meminfo_alloc_unchecked(builder, size, usm_type, device)

@_check_null_result
def meminfo_fill(self, builder, meminfo, itemsize, is_float, value, device):
def meminfo_fill(
self,
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
):
"""A wrapped caller for meminfo_fill_unchecked() with null check."""
return self.meminfo_fill_unchecked(
builder, meminfo, itemsize, is_float, value, device
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
)

def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
Expand Down Expand Up @@ -71,7 +86,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
return ret

def meminfo_fill_unchecked(
self, builder, meminfo, itemsize, is_float, value, device
self,
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
):
"""Fills an allocated `MemInfo` with the value specified.
Expand All @@ -96,12 +118,15 @@ def meminfo_fill_unchecked(
b = llvmir.IntType(1)
fnty = llvmir.FunctionType(
cgutils.voidptr_t,
[cgutils.voidptr_t, u64, b, cgutils.int8_t, cgutils.voidptr_t],
[cgutils.voidptr_t, u64, b, b, cgutils.intp_t, cgutils.voidptr_t],
)
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_fill")
fn.return_value.add_attribute("noalias")

ret = builder.call(fn, [meminfo, itemsize, is_float, value, device])
ret = builder.call(
fn,
[meminfo, itemsize, dest_is_float, value_is_float, value, device],
)

return ret

Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(

if not dtype:
dummy_tensor = dpctl.tensor.empty(
shape=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
)
# convert dpnp type to numba/numpy type
_dtype = dummy_tensor.dtype
Expand Down
Loading

0 comments on commit d2e8141

Please sign in to comment.