Skip to content

Commit

Permalink
changes to test_atomic_op to make it CFD compliant
Browse files Browse the repository at this point in the history
  • Loading branch information
adarshyoga authored and Diptorup Deb committed Apr 21, 2023
1 parent 0915170 commit aba4385
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions numba_dpex/tests/kernel_tests/test_atomic_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import dpctl
import numpy as np
import dpnp as np
import pytest

import numba_dpex as dpex
Expand Down Expand Up @@ -38,8 +38,11 @@ def fdtype(request):

@pytest.fixture(params=list_of_i_dtypes + list_of_f_dtypes)
def input_arrays(request):
a = np.array([0], request.param)
return a, request.param
def _inpute_arrays(filter_str):
a = np.array([0], request.param, device=filter_str)
return a, request.param

return _inpute_arrays


list_of_op = [
Expand Down Expand Up @@ -72,11 +75,9 @@ def f(a):
@pytest.mark.parametrize("filter_str", filter_strings)
@skip_no_atomic_support
def test_kernel_atomic_simple(filter_str, input_arrays, kernel_result_pair):
a, dtype = input_arrays
a, dtype = input_arrays(filter_str)
kernel, expected = kernel_result_pair
device = dpctl.SyclDevice(filter_str)
with dpctl.device_context(device):
kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a)
kernel[dpex.Range(global_size)](a)
assert a[0] == expected


Expand Down Expand Up @@ -114,15 +115,11 @@ def f(a):
@pytest.mark.parametrize("filter_str", filter_strings)
@skip_no_atomic_support
def test_kernel_atomic_local(filter_str, input_arrays, return_list_of_op):
a, dtype = input_arrays
a, dtype = input_arrays(filter_str)
op_type, expected = return_list_of_op
f = get_func_local(op_type, dtype)
kernel = dpex.kernel(f)
device = dpctl.SyclDevice(filter_str)
with dpctl.device_context(device):
gs = (N,)
ls = (N,)
kernel[gs, ls](a)
kernel[dpex.Range(N), dpex.Range(N)](a)
assert a[0] == expected


Expand Down Expand Up @@ -161,10 +158,8 @@ def test_kernel_atomic_multi_dim(
op_type, expected = return_list_of_op
dim = return_list_of_dim
kernel = get_kernel_multi_dim(op_type, len(dim))
a = np.zeros(dim, return_dtype)
device = dpctl.SyclDevice(filter_str)
with dpctl.device_context(device):
kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a)
a = np.zeros(dim, dtype=return_dtype, device=filter_str)
kernel[dpex.Range(global_size)](a)
assert a[0] == expected


Expand Down

0 comments on commit aba4385

Please sign in to comment.