Skip to content

Commit

Permalink
Organize array datamodel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jan 23, 2024
1 parent 6c233ba commit e298fb3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 81 deletions.
54 changes: 0 additions & 54 deletions numba_dpex/tests/core/types/DpnpNdArray/test_models.py

This file was deleted.

27 changes: 0 additions & 27 deletions numba_dpex/tests/core/types/USMNdArray/test_models.py

This file was deleted.

62 changes: 62 additions & 0 deletions numba_dpex/tests/core/types/test_array_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from numba import types
from numba.core.datamodel import default_manager, models
from numba.core.registry import cpu_target

from numba_dpex.core.datamodel.models import (
USMArrayDeviceModel,
USMArrayHostModel,
dpex_data_model_manager,
)
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray, USMNdArray


@pytest.fixture(
params=[
array(ndim=ndim, dtype=types.float32, layout="C")
for array in [DpnpNdArray, USMNdArray]
for ndim in range(4)
]
)
def nd_array(request):
return request.param


def test_model_for_array(nd_array):
"""Test the datamodel for DpnpNdArray that is registered with numba's
default datamodel manager and numba_dpex's kernel data model manager.
"""
device_model = dpex_data_model_manager.lookup(nd_array)
assert isinstance(device_model, USMArrayDeviceModel)
host_model = default_manager.lookup(nd_array)
assert isinstance(host_model, USMArrayHostModel)


def test_dpnp_usm_ndarray_model():
"""Test for dpctl.tensor.usm_ndarray model.
It is a subclass of models.StructModel and models.ArrayModel.
"""

assert issubclass(USMArrayHostModel, models.StructModel)
assert issubclass(USMArrayDeviceModel, models.StructModel)


def test_flattened_member_count(nd_array):
"""Test that the number of flattened member count matches the number of
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
dpex_dmm = cputargetctx.data_model_manager

argty_tuple = tuple([nd_array])
datamodel = dpex_dmm.lookup(nd_array)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)

0 comments on commit e298fb3

Please sign in to comment.