Skip to content

Commit

Permalink
Overload implementation for dpnp.full()
Browse files Browse the repository at this point in the history
    - Adds an overload implementation for dpnp.full
    - Removes the  `like` kwarg from dpnp.empty() overload
    - Unit test cases
  • Loading branch information
khaled authored and Diptorup Deb committed Apr 14, 2023
1 parent bbf978e commit 751da72
Show file tree
Hide file tree
Showing 11 changed files with 629 additions and 150 deletions.
122 changes: 98 additions & 24 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ static NRT_ExternalAllocator *
NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type);
static void *DPEXRTQueue_CreateFromFilterString(const char *device);
static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner);
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
size_t itemsize,
bool dest_is_float,
bool value_is_float,
int64_t value,
const char *device);
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
void *data,
npy_intp nitems,
Expand Down Expand Up @@ -510,25 +516,47 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
* This function takes an allocated memory as NRT_MemInfo and fills it with
* the value specified by `value`.
*
* @param mi An NRT_MemInfo object, should be found from memory
* allocation.
* @param itemsize The itemsize, the size of each item in the array.
* @param is_float Flag to specify if the data being float or not.
* @param value The value to be used to fill an array.
* @param device The device on which the memory was allocated.
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
* object could be created.
* @param mi An NRT_MemInfo object, should be found from memory
* allocation.
* @param itemsize The itemsize, the size of each item in the array.
* @param dest_is_float True if the destination array's dtype is float.
* @param value_is_float True if the value to be filled is float.
* @param value The value to be used to fill an array.
* @param device The device on which the memory was allocated.
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
* object could be created.
*/
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
size_t itemsize,
bool is_float,
uint8_t value,
bool dest_is_float,
bool value_is_float,
int64_t value,
const char *device)
{
DPCTLSyclQueueRef qref = NULL;
DPCTLSyclEventRef eref = NULL;
size_t count = 0, size = 0, exp = 0;

/**
* @brief A union for bit conversion from the input int64_t value
* to a uintX_t bit-pattern with appropriate type conversion when the
* input value represents a float.
*/
typedef union
{
float f_; /**< The float to be represented. */
double d_;
int8_t i8_;
int16_t i16_;
int32_t i32_;
int64_t i64_;
uint8_t ui8_;
uint16_t ui16_;
uint32_t ui32_; /**< The bit representation. */
uint64_t ui64_; /**< The bit representation. */
} bitcaster_t;

bitcaster_t bc;
size = mi->size;
while (itemsize >>= 1)
exp++;
Expand All @@ -552,40 +580,86 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
switch (exp) {
case 3:
{
uint64_t value_assign = (uint64_t)value;
if (is_float) {
double const_val = (double)value;
if (dest_is_float && value_is_float) {
double *p = (double *)(&value);
bc.d_ = *p;
}
else if (dest_is_float && !value_is_float) {
// To stop warning: dereferencing type-punned pointer
// will break strict-aliasing rules [-Wstrict-aliasing]
double *p = &const_val;
value_assign = *((uint64_t *)(p));
double cd = (double)value;
bc.d_ = *((double *)(&cd));
}
else if (!dest_is_float && value_is_float) {
double *p = (double *)&value;
bc.i64_ = *p;
}
else {
bc.i64_ = value;
}
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count)))

if (!(eref = DPCTLQueue_Fill64(qref, mi->data, bc.ui64_, count)))
goto error;
break;
}
case 2:
{
uint32_t value_assign = (uint32_t)value;
if (is_float) {
float const_val = (float)value;
if (dest_is_float && value_is_float) {
double *p = (double *)(&value);
bc.f_ = *p;
}
else if (dest_is_float && !value_is_float) {
// To stop warning: dereferencing type-punned pointer
// will break strict-aliasing rules [-Wstrict-aliasing]
float *p = &const_val;
value_assign = *((uint32_t *)(p));
float cf = (float)value;
bc.f_ = *((float *)(&cf));
}
else if (!dest_is_float && value_is_float) {
double *p = (double *)&value;
bc.i32_ = *p;
}
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count)))
else {
bc.i32_ = (int32_t)value;
}

if (!(eref = DPCTLQueue_Fill32(qref, mi->data, bc.ui32_, count)))
goto error;
break;
}
case 1:
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value, count)))
{
if (dest_is_float)
goto error;

if (value_is_float) {
double *p = (double *)&value;
bc.i16_ = *p;
}
else {
bc.i16_ = (int16_t)value;
}

if (!(eref = DPCTLQueue_Fill16(qref, mi->data, bc.ui16_, count)))
goto error;
break;
}
case 0:
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value, count)))
{
if (dest_is_float)
goto error;

if (value_is_float) {
double *p = (double *)&value;
bc.i8_ = *p;
}
else {
bc.i8_ = (int8_t)value;
}

if (!(eref = DPCTLQueue_Fill8(qref, mi->data, bc.ui8_, count)))
goto error;
break;
}
default:
goto error;
}
Expand Down
45 changes: 38 additions & 7 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,35 @@ def wrap(self, builder, *args, **kwargs):

@_check_null_result
def meminfo_alloc(self, builder, size, usm_type, device):
"""A wrapped caller for meminfo_alloc_unchecked() with null check."""
"""
Wrapper to call :func:`~context.DpexRTContext.meminfo_alloc_unchecked`
with null checking of the returned value.
"""
return self.meminfo_alloc_unchecked(builder, size, usm_type, device)

@_check_null_result
def meminfo_fill(self, builder, meminfo, itemsize, is_float, value, device):
"""A wrapped caller for meminfo_fill_unchecked() with null check."""
def meminfo_fill(
self,
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
):
"""
Wrapper to call :func:`~context.DpexRTContext.meminfo_fill_unchecked`
with null checking of the returned value.
"""
return self.meminfo_fill_unchecked(
builder, meminfo, itemsize, is_float, value, device
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
)

def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
Expand Down Expand Up @@ -71,7 +92,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
return ret

def meminfo_fill_unchecked(
self, builder, meminfo, itemsize, is_float, value, device
self,
builder,
meminfo,
itemsize,
dest_is_float,
value_is_float,
value,
device,
):
"""Fills an allocated `MemInfo` with the value specified.
Expand All @@ -96,12 +124,15 @@ def meminfo_fill_unchecked(
b = llvmir.IntType(1)
fnty = llvmir.FunctionType(
cgutils.voidptr_t,
[cgutils.voidptr_t, u64, b, cgutils.int8_t, cgutils.voidptr_t],
[cgutils.voidptr_t, u64, b, b, u64, cgutils.voidptr_t],
)
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_fill")
fn.return_value.add_attribute("noalias")

ret = builder.call(fn, [meminfo, itemsize, is_float, value, device])
ret = builder.call(
fn,
[meminfo, itemsize, dest_is_float, value_is_float, value, device],
)

return ret

Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(

if not dtype:
dummy_tensor = dpctl.tensor.empty(
shape=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
)
# convert dpnp type to numba/numpy type
_dtype = dummy_tensor.dtype
Expand Down
Loading

0 comments on commit 751da72

Please sign in to comment.