Skip to content

Commit

Permalink
Unit test checking if FlagEnum values are lowered as constants in LLV…
Browse files Browse the repository at this point in the history
…M IR.
  • Loading branch information
Diptorup Deb committed Dec 8, 2023
1 parent 9ce9927 commit ec3e962
Showing 1 changed file with 54 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import re

import dpctl
from numba.core import types

import numba_dpex.experimental as exp_dpex
from numba_dpex import DpctlSyclQueue, DpnpNdArray, int64
from numba_dpex.experimental.flag_enum import FlagEnum


class MockFlags(FlagEnum):
FLAG1 = 1
FLAG2 = 2


def test_compilation_as_literal_constant():
"""Tests if FlagEnum objects are treaded as scalar constants inside
numba-dpex generated code.
The test case compiles the kernel `pass_flags_to_func` that includes a
call to the device_func `bitwise_or_flags`. The `bitwise_or_flags` function
is passed two FlagEnum arguments. The test case evaluates the generated
LLVM IR for `pass_flags_to_func` to see if the call to `bitwise_or_flags`
has the scalar arguments `i64 1` and `i64 2`.
"""

@exp_dpex.device_func
def bitwise_or_flags(flag1, flag2):
return flag1 | flag2

def pass_flags_to_func(a):
f1 = MockFlags.FLAG1
f2 = MockFlags.FLAG2
a[0] = bitwise_or_flags(f1, f2)

queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
kernel_sig = types.void(i64arr_ty)

disp = exp_dpex.kernel(pass_flags_to_func)
disp.compile(kernel_sig)
kcres = disp.overloads[kernel_sig.args]
llvm_ir_mod = kcres.library._final_module.__str__()

pattern = re.compile(
r"call spir_func i32 @\_Z.*bitwise\_or"
r"\_flags.*\(i64\* nonnull %.*, i64 1, i64 2\)"
)

assert re.search(pattern, llvm_ir_mod) is not None

0 comments on commit ec3e962

Please sign in to comment.