Skip to content

Commit

Permalink
Remove local arrays in reduction remainder kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Apr 3, 2024
1 parent 5cbd2d4 commit 62fe934
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions numba_dpex/core/parfors/kernel_templates/reduction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,13 @@ def _generate_kernel_stub_as_string(self):
)

for redvar in self._redvars:
rtyp = str(self._typemap[redvar])
legal_redvar = self._redvars_dict[redvar]
gufunc_txt += " "
gufunc_txt += legal_redvar + " = "
gufunc_txt += f"{self._parfor_reddict[redvar].init_val}\n"
gufunc_txt += (
f"dpnp.{rtyp}({self._parfor_reddict[redvar].init_val})\n"
)

gufunc_txt += (
" "
Expand All @@ -290,32 +293,17 @@ def _generate_kernel_stub_as_string(self):
+ f"{self._global_size_var_name[0]} + j\n"
)

for redvar in self._redvars:
rtyp = str(self._typemap[redvar])
redvar = self._redvars_dict[redvar]
gufunc_txt += (
" "
+ f"local_sums_{redvar} = "
+ f"dpex.local.array(1, dpnp.{rtyp})\n"
)

gufunc_txt += " " + self._sentinel_name + " = 0\n"

for i, redvar in enumerate(self._redvars):
legal_redvar = self._redvars_dict[redvar]
gufunc_txt += (
" " + f"local_sums_{legal_redvar}[0] = {legal_redvar}\n"
)

for i, redvar in enumerate(self._redvars):
legal_redvar = self._redvars_dict[redvar]
redop = self._parfor_reddict[redvar].redop
if redop == operator.iadd:
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \
local_sums_{legal_redvar}[0]\n"
{legal_redvar}\n"
elif redop == operator.imul:
gufunc_txt += f" {self._final_sum_var_name[i]}[0] *= \
local_sums_{legal_redvar}[0]\n"
{legal_redvar}\n"
else:
raise NotImplementedError

Expand Down

0 comments on commit 62fe934

Please sign in to comment.