Skip to content

Commit

Permalink
dpnp.full() complete and test case added
Browse files Browse the repository at this point in the history
Remove 'like' from dpnp.empty() overload

ty_x --> ty_x1

Bitcasts are done through unions

Addressed all review comments
  • Loading branch information
khaled committed Apr 13, 2023
1 parent bbf978e commit 54ae6ee
Show file tree
Hide file tree
Showing 11 changed files with 620 additions and 149 deletions.
112 changes: 90 additions & 22 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@
#include "_queuestruct.h"
#include "numba/_arraystruct.h"

/**
* @brief A union for bit representations of float.
* This is useful in DPEXRT_MemInfo_fill() function.
*/
typedef union
{
float f_; /**< The float to be represented. */
uint32_t i_; /**< The bit representation. */
} float_uint32_t;

/**
* @brief A union for bit representations of double.
* This is useful in DPEXRT_MemInfo_fill() function.
*/
typedef union
{
double d_; /**< The double to be represented. */
uint64_t i_; /**< The bit representation. */
} double_uint64_t;

// forward declarations
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim);
Expand All @@ -34,6 +54,12 @@ static NRT_ExternalAllocator *
NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type);
static void *DPEXRTQueue_CreateFromFilterString(const char *device);
static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner);
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
size_t itemsize,
bool dest_is_float,
bool value_is_float,
int64_t value,
const char *device);
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
void *data,
npy_intp nitems,
Expand Down Expand Up @@ -510,19 +536,21 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
* This function takes an allocated memory as NRT_MemInfo and fills it with
* the value specified by `value`.
*
* @param mi An NRT_MemInfo object, should be found from memory
* allocation.
* @param itemsize The itemsize, the size of each item in the array.
* @param is_float Flag to specify if the data being float or not.
* @param value The value to be used to fill an array.
* @param device The device on which the memory was allocated.
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
* object could be created.
* @param mi An NRT_MemInfo object, should be found from memory
* allocation.
* @param itemsize The itemsize, the size of each item in the array.
* @param dest_is_float True if the destination array's dtype is float.
* @param value_is_float True if the value to be filled is float.
* @param value The value to be used to fill an array.
* @param device The device on which the memory was allocated.
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
* object could be created.
*/
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
size_t itemsize,
bool is_float,
uint8_t value,
bool dest_is_float,
bool value_is_float,
int64_t value,
const char *device)
{
DPCTLSyclQueueRef qref = NULL;
Expand Down Expand Up @@ -552,40 +580,80 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
switch (exp) {
case 3:
{
uint64_t value_assign = (uint64_t)value;
if (is_float) {
double const_val = (double)value;
int64_t value_assign = (int64_t)value;
if (dest_is_float && value_is_float) {
double_uint64_t du;
double *p = (double *)(&value);
du.d_ = *p;
value_assign = du.i_;
}
else if (dest_is_float && !value_is_float) {
double_uint64_t du;
// To stop warning: dereferencing type-punned pointer
// will break strict-aliasing rules [-Wstrict-aliasing]
double *p = &const_val;
value_assign = *((uint64_t *)(p));
double cd = (double)value;
du.d_ = *((double *)(&cd));
value_assign = du.i_;
}
else if (!dest_is_float && value_is_float) {
double *p = (double *)&value;
value_assign = *p;
}
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count)))
goto error;
break;
}
case 2:
{
uint32_t value_assign = (uint32_t)value;
if (is_float) {
float const_val = (float)value;
int32_t value_assign = (int32_t)value;
if (dest_is_float && value_is_float) {
float_uint32_t fu;
double *p = (double *)(&value);
fu.f_ = *p;
value_assign = fu.i_;
}
else if (dest_is_float && !value_is_float) {
float_uint32_t fu;
// To stop warning: dereferencing type-punned pointer
// will break strict-aliasing rules [-Wstrict-aliasing]
float *p = &const_val;
value_assign = *((uint32_t *)(p));
float cf = (float)value;
fu.f_ = *((float *)(&cf));
value_assign = fu.i_;
}
else if (!dest_is_float && value_is_float) {
double *p = (double *)&value;
value_assign = *p;
}
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count)))
goto error;
break;
}
case 1:
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value, count)))
{
if (dest_is_float)
goto error;
int16_t value_assign = (int16_t)value;
if (value_is_float) {
double *p = (double *)&value;
value_assign = *p;
}
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value_assign, count)))
goto error;
break;
}
case 0:
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value, count)))
{
if (dest_is_float)
goto error;
int8_t value_assign = (int8_t)value;
if (value_is_float) {
double *p = (double *)&value;
value_assign = *p;
}
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value_assign, count)))
goto error;
break;
}
default:
goto error;
}
Expand Down
47 changes: 40 additions & 7 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,37 @@ def wrap(self, builder, *args, **kwargs):

@_check_null_result
def meminfo_alloc(self, builder, size, usm_type, device):
"""A wrapped caller for meminfo_alloc_unchecked() with null check."""
"""
A wrapped caller for :func:`~context.DpexRTContext.meminfo_alloc_unchecked`
with null check. Please refer to that function for the details on how the
null check is done.
"""
return self.meminfo_alloc_unchecked(builder, size, usm_type, device)

@_check_null_result
def meminfo_fill(self, builder, meminfo, itemsize, is_float, value, device):
"""A wrapped caller for meminfo_fill_unchecked() with null check."""
def meminfo_fill(
self,
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
):
"""
A wrapped caller for :func:`~context.DpexRTContext.meminfo_fill_unchecked`
with null check. Please refer to that function for the details on how the
null check is done.
"""
return self.meminfo_fill_unchecked(
builder, meminfo, itemsize, is_float, value, device
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
)

def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
Expand Down Expand Up @@ -71,7 +94,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
return ret

def meminfo_fill_unchecked(
self, builder, meminfo, itemsize, is_float, value, device
self,
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
):
"""Fills an allocated `MemInfo` with the value specified.
Expand All @@ -96,12 +126,15 @@ def meminfo_fill_unchecked(
b = llvmir.IntType(1)
fnty = llvmir.FunctionType(
cgutils.voidptr_t,
[cgutils.voidptr_t, u64, b, cgutils.int8_t, cgutils.voidptr_t],
[cgutils.voidptr_t, u64, b, b, cgutils.intp_t, cgutils.voidptr_t],
)
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_fill")
fn.return_value.add_attribute("noalias")

ret = builder.call(fn, [meminfo, itemsize, is_float, value, device])
ret = builder.call(
fn,
[meminfo, itemsize, dest_is_float, value_is_float, value, device],
)

return ret

Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(

if not dtype:
dummy_tensor = dpctl.tensor.empty(
shape=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
)
# convert dpnp type to numba/numpy type
_dtype = dummy_tensor.dtype
Expand Down
Loading

0 comments on commit 54ae6ee

Please sign in to comment.