-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
62 additions
and
81 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
from numba import types | ||
from numba.core.datamodel import default_manager, models | ||
from numba.core.registry import cpu_target | ||
|
||
from numba_dpex.core.datamodel.models import ( | ||
USMArrayDeviceModel, | ||
USMArrayHostModel, | ||
dpex_data_model_manager, | ||
) | ||
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray, USMNdArray | ||
|
||
|
||
@pytest.fixture( | ||
params=[ | ||
array(ndim=ndim, dtype=types.float32, layout="C") | ||
for array in [DpnpNdArray, USMNdArray] | ||
for ndim in range(4) | ||
] | ||
) | ||
def nd_array(request): | ||
return request.param | ||
|
||
|
||
def test_model_for_array(nd_array): | ||
"""Test the datamodel for DpnpNdArray that is registered with numba's | ||
default datamodel manager and numba_dpex's kernel data model manager. | ||
""" | ||
device_model = dpex_data_model_manager.lookup(nd_array) | ||
assert isinstance(device_model, USMArrayDeviceModel) | ||
host_model = default_manager.lookup(nd_array) | ||
assert isinstance(host_model, USMArrayHostModel) | ||
|
||
|
||
def test_dpnp_usm_ndarray_model(): | ||
"""Test for dpctl.tensor.usm_ndarray model. | ||
It is a subclass of models.StructModel and models.ArrayModel. | ||
""" | ||
|
||
assert issubclass(USMArrayHostModel, models.StructModel) | ||
assert issubclass(USMArrayDeviceModel, models.StructModel) | ||
|
||
|
||
def test_flattened_member_count(nd_array): | ||
"""Test that the number of flattened member count matches the number of | ||
flattened args generated by the CpuTarget's ArgPacker. | ||
""" | ||
|
||
cputargetctx = cpu_target.target_context | ||
dpex_dmm = cputargetctx.data_model_manager | ||
|
||
argty_tuple = tuple([nd_array]) | ||
datamodel = dpex_dmm.lookup(nd_array) | ||
num_flattened_args = datamodel.flattened_field_count | ||
ap = cputargetctx.get_arg_packer(argty_tuple) | ||
|
||
assert num_flattened_args == len(ap._be_args) |