Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static method SyclQueue._create_from_context_and_device change #579

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ per-file-ignores =
dpctl/program/_program.pyx: E999, E225, E226, E227
dpctl/tensor/_usmarray.pyx: E999, E225, E226, E227
dpctl/tensor/numpy_usm_shared.py: F821
dpctl/tests/_cython_api.pyx: E999, E225, E227, E402
examples/cython/sycl_buffer/_buffer_example.pyx: E999, E225, E402
examples/cython/sycl_direct_linkage/_buffer_example.pyx: E999, E225, E402
examples/cython/usm_memory/blackscholes.pyx: E999, E225, E226, E402
2 changes: 1 addition & 1 deletion .github/workflows/generate-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ jobs:
source /opt/intel/oneapi/setvars.sh
python setup.py develop --coverage=True
python -c "import dpctl; print(dpctl.__version__); dpctl.lsplatform()"
pytest -q -ra --disable-warnings --cov dpctl --cov-report term-missing --pyargs dpctl -vv
pytest -q -ra --disable-warnings --cov-config pyproject.toml --cov dpctl --cov-report term-missing --pyargs dpctl -vv

- name: Install coverall dependencies
shell: bash -l {0}
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ include dpctl/memory/_memory_api.h
include dpctl/tensor/_usmarray.h
include dpctl/tensor/_usmarray_api.h
include dpctl/tests/input_files/*
include dpctl/tests/*.pyx
1 change: 1 addition & 0 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ requirements:

test:
requires:
- cython
- pytest
- pytest-cov

Expand Down
2 changes: 1 addition & 1 deletion dpctl/_sycl_queue.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ cdef public api class SyclQueue (_SyclQueue) [
cdef SyclQueue _create(DPCTLSyclQueueRef qref)
@staticmethod
cdef SyclQueue _create_from_context_and_device(
SyclContext ctx, SyclDevice dev
SyclContext ctx, SyclDevice dev, int props=*
)
cdef cpp_bool equals(self, SyclQueue q)
cpdef SyclContext get_sycl_context(self)
Expand Down
15 changes: 13 additions & 2 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,24 @@ cdef class SyclQueue(_SyclQueue):

@staticmethod
cdef SyclQueue _create_from_context_and_device(
SyclContext ctx, SyclDevice dev
SyclContext ctx, SyclDevice dev, int props=0
):
"""
Static factory method to create :class:`dpctl.SyclQueue` instance
from given :class:`dpctl.SyclContext`, :class:`dpctl.SyclDevice`
and optional integer `props` encoding the queue properties.
"""
cdef _SyclQueue ret = _SyclQueue.__new__(_SyclQueue)
cdef DPCTLSyclContextRef cref = ctx.get_context_ref()
cdef DPCTLSyclDeviceRef dref = dev.get_device_ref()
cdef DPCTLSyclQueueRef qref = DPCTLQueue_Create(cref, dref, NULL, 0)
cdef DPCTLSyclQueueRef qref = NULL

qref = DPCTLQueue_Create(
cref,
dref,
<error_handler_callback *>&default_async_error_handler,
props
)
if qref is NULL:
raise SyclQueueCreationError("Queue creation failed.")
ret._queue_ref = qref
Expand Down
18 changes: 18 additions & 0 deletions dpctl/tests/_cython_api.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# cython: language=c++
# cython: language_level=3

cimport dpctl as c_dpctl

import dpctl


def call_create_from_context_and_devices():
cdef c_dpctl.SyclQueue q
d = dpctl.SyclDevice()
ctx = dpctl.SyclContext(d)
# calling static method
q = c_dpctl.SyclQueue._create_from_context_and_device(
<c_dpctl.SyclContext> ctx,
<c_dpctl.SyclDevice> d
)
return q
12 changes: 12 additions & 0 deletions dpctl/tests/setup_cython_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import setuptools

import dpctl

ext = setuptools.Extension(
"_cython_api",
["_cython_api.pyx"],
include_dirs=[dpctl.get_include()],
language="c++",
)

setuptools.setup(name="_cython_api", version="0.0.0", ext_modules=[ext])
46 changes: 44 additions & 2 deletions dpctl/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,15 +433,15 @@ def test_queue_submit_barrier(valid_filter):


def test_queue__repr__():
q1 = dpctl.SyclQueue()
q1 = dpctl.SyclQueue(property=0)
r1 = q1.__repr__()
q2 = dpctl.SyclQueue(property="in_order")
r2 = q2.__repr__()
q3 = dpctl.SyclQueue(property="enable_profiling")
r3 = q3.__repr__()
q4 = dpctl.SyclQueue(property="default")
r4 = q4.__repr__()
q5 = dpctl.SyclQueue(property=["in_order", "enable_profiling"])
q5 = dpctl.SyclQueue(property=["in_order", "enable_profiling", 0])
r5 = q5.__repr__()
assert type(r1) is str
assert type(r2) is str
Expand Down Expand Up @@ -552,3 +552,45 @@ def test_queue_memops():
q.prefetch(list(), 512)
with pytest.raises(TypeError):
q.mem_advise(list(), 512, 0)


@pytest.fixture(scope="session")
def dpctl_cython_extension(tmp_path_factory):
import os.path
import shutil
import subprocess
import sys
import sysconfig

curr_dir = os.path.dirname(__file__)
dr = tmp_path_factory.mktemp("_cython_api")
for fn in ["_cython_api.pyx", "setup_cython_api.py"]:
shutil.copy(
src=os.path.join(curr_dir, fn),
dst=dr,
follow_symlinks=False,
)
res = subprocess.run(
[sys.executable, "setup_cython_api.py", "build_ext", "--inplace"],
cwd=dr,
)
if res.returncode == 0:
import glob
from importlib.util import module_from_spec, spec_from_file_location

sfx = sysconfig.get_config_vars()["EXT_SUFFIX"]
pth = glob.glob(os.path.join(dr, "_cython_api*" + sfx))
if not pth:
pytest.skip("Cython extension was not built")
spec = spec_from_file_location("_cython_api", pth[0])
builder_module = module_from_spec(spec)
spec.loader.exec_module(builder_module)
return builder_module
else:
pytest.skip("Cython extension could not be built")


def test_cython_api(dpctl_cython_extension):
q = dpctl_cython_extension.call_create_from_context_and_devices()
d = dpctl.SyclDevice()
assert q.sycl_device == d
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ omit = [
"dpctl/_version.py",
]

[tool.coverage.report]
omit = [
"dpctl/tests/*",
"dpctl/_version.py",
]

[tool.pytest.ini.options]
minversion = "6.0"
norecursedirs= [
Expand Down