Skip to content

Commit

Permalink
Add numba integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jan 4, 2024
1 parent 2a66424 commit d95df65
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions numba_dpex/tests/test_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
This module contains tests to ensure that numba.njit works with numpy after
importing numba_dpex. Aka lazy testing if we break numba's default behavior.
"""

import numba as nb
import numpy as np

import numba_dpex


@nb.njit
def add_1(a):
return a + 1


def add_py(a, b):
return np.add(a, b)


add_jit = nb.njit(add_py)


def test_add1():
a = np.asarray([1j])
assert np.array_equal(a, np.asarray([1 + 1j]))


def test_add_py():
a = np.ones((10,), dtype=np.complex128)
assert np.array_equal(add_py(a, 1.5), np.full((10,), 2.5, dtype=a.dtype))


def test_add_jit():
a = np.ones((10,), dtype=np.complex128)
assert np.array_equal(add_jit(a, 1.5), np.full((10,), 2.5, dtype=a.dtype))

0 comments on commit d95df65

Please sign in to comment.