Skip to content

Commit

Permalink
Fix unit test segfault on OCL CPU device, rename test file.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Jan 27, 2024
1 parent 5a3a087 commit f23f358
Showing 1 changed file with 17 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import dpnp
import pytest
from numba.core.errors import TypingError
Expand All @@ -11,37 +15,26 @@
no_bool=True, no_float16=True, no_none=True, no_complex=True
)

list_of_cmp_exchg_funcs = [
"compare_exchange_weak",
"compare_exchange_strong",
]


@pytest.fixture(params=list_of_cmp_exchg_funcs)
def cmp_exchg_fn(request):
@pytest.fixture(params=["store", "exchange"])
def store_exchange_fn(request):
return request.param


@pytest.fixture(params=list_of_supported_dtypes)
def input_arrays(request):
# The size of input and out arrays to be used
N = 10
a = dpnp.zeros(2 * N, dtype=request.param)
b = dpnp.arange(N, dtype=request.param)
return a, b


def test_load_store_fn(input_arrays):
def test_load_store_fn():
"""A test for load/store atomic functions."""

@dpex_exp.kernel
def _kernel(a, b):
i = dpex.get_global_id(0)
a_ref = AtomicRef(a, index=i)
b_ref = AtomicRef(b, index=i)
a_ref.store(b_ref.load())
val = b_ref.load()
a_ref.store(val)

a, b = input_arrays
N = 10
a = dpnp.zeros(2 * N, dtype=dpnp.float32)
b = dpnp.arange(N, dtype=dpnp.float32)

dpex_exp.call_kernel(_kernel, dpex.Range(b.size), a, b)
# Verify that `b[i]` loaded and stored into a[i] by kernel
Expand All @@ -55,7 +48,7 @@ def _kernel(a, b):
assert a[i] == a[i + b.size]


def test_exchange_fn(input_arrays):
def test_exchange_fn():
"""A test for exchange atomic function."""

@dpex_exp.kernel
Expand All @@ -64,7 +57,10 @@ def _kernel(a, b):
v = AtomicRef(a, index=i)
b[i] = v.exchange(b[i])

a_orig, b_orig = input_arrays
N = 10
a_orig = dpnp.zeros(2 * N, dtype=dpnp.float32)
b_orig = dpnp.arange(N, dtype=dpnp.float32)

a_copy = dpnp.copy(a_orig)
b_copy = dpnp.copy(b_orig)

Expand All @@ -79,11 +75,6 @@ def _kernel(a, b):
assert b_copy[i] == a_orig[i]


@pytest.fixture(params=["store", "exchange"])
def store_exchange_fn(request):
return request.param


def test_store_exchange_diff_types(store_exchange_fn):
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value are of different types.
Expand Down

0 comments on commit f23f358

Please sign in to comment.