Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to dlpack arrays as kernel and dpjit arguments #1088

Open
ZzEeKkAa opened this issue Jul 7, 2023 · 0 comments
Open

Add support to dlpack arrays as kernel and dpjit arguments #1088

ZzEeKkAa opened this issue Jul 7, 2023 · 0 comments
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@ZzEeKkAa
Copy link
Contributor

ZzEeKkAa commented Jul 7, 2023

The following example does not work.

Command to run:

LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<python environment folder>/lib/python3.10/site-packages/jaxlib \
PJRT_NAMES_AND_LIBRARY_PATHS='xpu:<path to libitex_xla_extension with jax support>/libitex_xla_extension.so' \
TF_CPP_MIN_LOG_LEVEL=0 \
ONEAPI_DEVICE_SELECTOR=ext_oneapi_level_zero:gpu \
python example.py
import jax.numpy as jnp
from numba import prange
import numba as nb
from numba_dpex import dpjit

@dpjit
def _sum_nomask(nops, w):
    tot = nb.float32(1.0)

    for i in prange(nops):
        if w[i] > 0:
            tot += w[i]
    
    return tot

if __name__ == "__main__":
    arr = jnp.arange(10, dtype=jnp.float32)
    res = jnp.zeros(2, dtype=jnp.float32)

    res[1] = _sum_nomask(10, arr)
    print("jax:", res)

The output looks like this:

/Projects/users.yevhenii/examples/jax/libitex_xla_extension.so' TF_CPP_MIN_LOG_LEVEL=0 ONEAPI_DEVICE_SELECTOR=ext_oneapi_level_zero:gpu python numba_dpex_jax.py
dpex: [45.  0.]
2023-07-06 22:08:29.180337: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x55904aa17ba0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-07-06 22:08:29.180371: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Interpreter, <undefined>
2023-07-06 22:08:29.186010: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:215] TfrtCpuClient created.
2023-07-06 22:08:29.186791: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:266] Libtpu path is: libtpu.so
2023-07-06 22:08:29.186938: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.cc:73] No TPU platform found.
2023-07-06 22:08:29.241398: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:85] GetPjrtApi was found for xpu at /home/yevhenii/Projects/users.yevhenii/examples/jax/libitex_xla_extension.so
2023-07-06 22:08:29.241443: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:58] PJRT_Api is set for device type xpu
2023-07-06 22:08:29.242084: I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero
2023-07-06 22:08:29.242429: I itex/core/devices/gpu/itex_gpu_runtime.cc:154] number of sub-devices is zero, expose root device.
2023-07-06 22:08:29.248729: I itex/core/compiler/xla/service/service.cc:176] XLA service 0x55904c85d4c0 initialized for platform sycl (this does not guarantee that XLA will be used). Devices:
2023-07-06 22:08:29.248766: I itex/core/compiler/xla/service/service.cc:184]   StreamExecutor device (0): <undefined>, <undefined>
2023-07-06 22:08:29.250558: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc:83] PjRtCApiClient created.
Traceback (most recent call last):
  File "/home/yevhenii/Projects/users.yevhenii/examples/numba-dpex-jax/numba_dpex_jax.py", line 49, in <module>
    res[1] = _sum_nomask(10, arr)
  File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in dpex_dpjit_nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
During: typing of argument at /home/yevhenii/Projects/users.yevhenii/examples/numba-dpex-jax/numba_dpex_jax.py (15)

File "numba_dpex_jax.py", line 15:

@dpjit
^

This error may have been caused by the following argument(s):
- argument 1: Cannot determine Numba type of <class 'jaxlib.xla_extension.Array'>

2023-07-06 22:08:30.365912: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient destroyed.

We need to allow arrays as arguments if they have __dlpack__() method. It then can be converted to dpnp array (dpnp.from_dlpack(arr)) that can be passed to dpjit functions.

@ZzEeKkAa ZzEeKkAa self-assigned this Jul 7, 2023
@diptorupd diptorupd added this to the 0.22 milestone Dec 19, 2023
@ZzEeKkAa ZzEeKkAa changed the title Dlpack arrays are not supported as dpjit function arguments Add support to dlpack arrays as kernel and dpjit arguments Dec 20, 2023
@diptorupd diptorupd added enhancement New feature or request and removed feature labels Jan 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants