Skip to content

Commit

Permalink
Functional dpnp.arange for int types
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Oct 26, 2023
1 parent 3e2faf1 commit 8fd6dbd
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 37 deletions.
19 changes: 18 additions & 1 deletion numba_dpex/dpnp_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,25 @@ def alloc_empty_arrayobj(context, builder, sig, queue_ref, args, is_like=False):
Returns: The LLVM IR value that stores the empty array
"""

print("alloc_empty_arrayobj: sig =", sig)
print("alloc_empty_arrayobj: args =", args)

arrtype, shape = (
_parse_empty_like_args(context, builder, sig, args)
if is_like
else _parse_empty_args(context, builder, sig, args)
)
print(
"alloc_empty_arrayobj(): arrtype =",
arrtype,
"type(arrtype) =",
type(arrtype),
)
print(
"alloc_empty_arrayobj(): shape =", shape, ", type(shape) =", type(shape)
)
ary = _empty_nd_impl(context, builder, arrtype, shape, queue_ref)

print("alloc_empty_arrayobj(): ary =", ary, ", type(ary) =", type(ary))
return ary


Expand Down Expand Up @@ -473,6 +485,8 @@ def impl_dpnp_empty(
ty_retty_ref,
)

print("--- impl_dpnp_empty()")

sycl_queue_arg_pos = -2

def codegen(context, builder, sig, args):
Expand All @@ -486,6 +500,9 @@ def codegen(context, builder, sig, args):
sycl_queue_arg=sycl_queue_arg,
)

print("impl_dpnp_empty(): sig =", sig, type(sig))
print("impl_dpnp_empty(): args =", args, type(args))

ary = alloc_empty_arrayobj(
context, builder, sig, qref_payload.queue_ref, args
)
Expand Down
214 changes: 178 additions & 36 deletions numba_dpex/dpnp_iface/array_sequence_ops.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
import math
from collections import namedtuple

import dpctl.tensor as dpt
import dpnp
import numba
import numpy as np
from dpctl.tensor._ctors import _coerce_and_infer_dt
from llvmlite import ir as llvmir
from numba import errors, types
from numba.core import cgutils
from numba.core.types.scalars import Complex, Float, Integer
from numba.core.types.misc import UnicodeType
from numba.core.types.scalars import Complex, Float, Integer, IntegerLiteral
from numba.core.typing.templates import Signature
from numba.extending import intrinsic, overload

import numba_dpex.utils as utils
from numba_dpex.core.runtime import context as dpexrt
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.dpnp_iface._intrinsic import _get_queue_ref
from numba_dpex.dpnp_iface._intrinsic import (
_ArgTyAndValue,
_empty_nd_impl,
_get_queue_ref,
alloc_empty_arrayobj,
)
from numba_dpex.dpnp_iface.arrayobj import (
_parse_device_filter_string,
_parse_dim,
_parse_dtype,
_parse_usm_type,
)

Expand All @@ -24,21 +35,153 @@
)


def _parse_dtype(a):
if isinstance(a.dtype, Complex):
v_type = a.dtype
w_type = dpnp.float64 if a.dtype.bitwidth == 128 else dpnp.float32
elif isinstance(a.dtype, Float):
v_type = w_type = a.dtype
elif isinstance(a.dtype, Integer):
v_type = w_type = (
dpnp.float32 if a.dtype.bitwidth == 32 else dpnp.float64
)
# elif a.queue.sycl_device.has_aspect_fp64:
# v_type = w_type = dpnp.float64
def _parse_dtype_from_range(start, stop, step):
max_bw = max(start.bitwidth, stop.bitwidth, step.bitwidth)
if (
isinstance(start, Complex)
or isinstance(stop, Complex)
or isinstance(step, Complex)
):
if max_bw == 128:
return numba.from_dtype(dpnp.complex128)
else:
return numba.from_dtype(dpnp.complex64)
elif (
isinstance(start, Float)
or isinstance(stop, Float)
or isinstance(step, Float)
):
if max_bw == 64:
return numba.from_dtype(dpnp.float64)
elif max_bw == 32:
return numba.from_dtype(dpnp.float32)
elif max_bw == 16:
return numba.from_dtype(dpnp.float16)
else:
return numba.from_dtype(dpnp.float)
elif (
isinstance(start, Integer)
or isinstance(stop, Integer)
or isinstance(step, Integer)
):
if max_bw == 64:
return numba.from_dtype(dpnp.int64)
elif max_bw == 32:
return numba.from_dtype(dpnp.int32)
else:
return numba.from_dtype(dpnp.int)
else:
v_type = w_type = dpnp.float64
return (v_type, w_type)
msg = "Type couldn't be inferred from (start, stop, step)."
raise errors.NumbaValueError(msg)


@intrinsic
def impl_dpnp_arange(
ty_context,
ty_start,
ty_stop,
ty_step,
ty_dtype,
ty_device,
ty_usm_type,
ty_sycl_queue,
ty_ret_ty,
):
ty_retty_ = ty_ret_ty.instance_type
signature = ty_retty_(
ty_start,
ty_stop,
ty_step,
ty_dtype,
ty_device,
ty_usm_type,
ty_sycl_queue,
ty_ret_ty,
)

sycl_queue_arg_pos = -2

def codegen(context, builder, sig, args):
start_ir, stop_ir, step_ir, queue_ir = (
args[0],
args[1],
args[2],
args[sycl_queue_arg_pos],
)
queue_arg_type = sig.args[sycl_queue_arg_pos]

u64 = llvmir.IntType(64)
b = llvmir.IntType(1)
# f64 = llvmir.DoubleType() # noqa: E800
mod = builder.module

sycl_queue_arg = _ArgTyAndValue(queue_arg_type, queue_ir)
qref_payload: _QueueRefPayload = _get_queue_ref(
context=context,
builder=builder,
returned_sycl_queue_ty=sig.return_type.queue,
sycl_queue_arg=sycl_queue_arg,
)

from numba.core.cpu import CPUContext
from numba.np.arrayobj import make_array

from numba_dpex.core.datamodel.models import DpnpNdArrayModel

# dt = builder.bitcast(builder.sdiv(t, builder.bitcast(step_ir, u64)), f64) # noqa: E800
# dt = builder.sdiv(t, builder.bitcast(step_ir, u64)) # noqa: E800

with builder.goto_entry_block():
start_ptr = cgutils.alloca_once(builder, start_ir.type)
step_ptr = cgutils.alloca_once(builder, step_ir.type)
# dt_ptr = cgutils.alloca_once(builder, dt.type) # noqa: E800

builder.store(start_ir, start_ptr)
builder.store(step_ir, step_ptr)
# builder.store(dt, dt_ptr) # noqa: E800

start_vptr = builder.bitcast(start_ptr, cgutils.voidptr_t)
step_vptr = builder.bitcast(step_ptr, cgutils.voidptr_t)
# dt_vptr = builder.bitcast(dt_ptr, cgutils.voidptr_t) # noqa: E800

t = builder.sub(stop_ir, start_ir)
ary = _empty_nd_impl(
context, builder, sig.return_type, [t], qref_payload.queue_ref
)
arrystruct_vptr = builder.bitcast(ary._getpointer(), cgutils.voidptr_t)

ndim = context.get_constant(types.intp, 1)
is_c_contguous = context.get_constant(types.boolean, 1)

fnty = llvmir.FunctionType(
utils.LLVMTypes.int64_ptr_t,
[
cgutils.voidptr_t,
cgutils.voidptr_t,
cgutils.voidptr_t,
u64,
b,
cgutils.voidptr_t,
],
)
fn = cgutils.get_or_insert_function(
mod, fnty, "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence"
)
builder.call(
fn,
[
start_vptr,
step_vptr,
arrystruct_vptr,
ndim,
is_c_contguous,
qref_payload.queue_ref,
],
)

return ary._getvalue()

return signature, codegen


@overload(dpnp.arange, prefer_literal=True)
Expand All @@ -65,7 +208,15 @@ def ol_dpnp_arange(
start = 0
if step is None:
step = 1
_dtype = _parse_dtype(dtype) if dtype is not None else type(start)
print("start =", start, ", type(start) =", type(start))
print("stop =", stop, ", type(stop) =", type(stop))
print("step =", step, ", type(step) =", type(step))
print("-*-")
_dtype = (
_parse_dtype(dtype)
if dtype is not None
else _parse_dtype_from_range(start, stop, step)
)
_device = _parse_device_filter_string(device) if device else None
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"

Expand Down Expand Up @@ -98,25 +249,16 @@ def impl(
usm_type="device",
sycl_queue=None,
):
print("start =", start, ", type(start) =", type(start))
print("stop =", stop, ", type(stop) =", type(stop))
print("step =", step, ", type(step) =", type(step))
print(
"dtype =", dtype
) # , ", type(dtype) =", type(dtype) if dtype is not None else "Null")
print(
"device =", device
) # , ", type(device) =", type(device) if device is not None else "Null")
print(
"usm_type =", usm_type
) # , ", type(usm_type) =", type(usm_type) if usm_type is not None else "Null")
print(
"sycl_queue =", sycl_queue
) # , ", type(sycl_queue) =", type(sycl_queue) if sycl_queue is not None else "Null")
print("###")

v = dpnp.empty(10)
return v
return impl_dpnp_arange(
start,
stop,
step,
_dtype,
_device,
_usm_type,
sycl_queue,
ret_ty,
)

return impl
else:
Expand Down

0 comments on commit 8fd6dbd

Please sign in to comment.