Skip to content

Commit

Permalink
Changes to _set_functions_sync.py for test suite to pass
Browse files Browse the repository at this point in the history
Changes to sync implementation of set functions so that test
suite passes with sync implementation as it does for async
implementation.
  • Loading branch information
oleksandr-pavlyk committed Jan 8, 2024
1 parent 94a2ebf commit de20631
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions dpctl/tensor/_set_functions_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray:
if x.ndim == 1:
fx = x
else:
fx = dpt.reshape(x, (x.size,), order="C", copy=False)
fx = dpt.reshape(x, (x.size,), order="C")
if fx.size == 0:
return fx
s = dpt.sort(fx)
unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q)
dpt.not_equal(s[:-1], s[1:], out=unique_mask[1:])
unique_mask[0] = True
cumsum = dpt.empty(s.shape, dtype=dpt.int64)
cumsum = dpt.empty(s.shape, dtype=dpt.int64, sycl_queue=exec_q)
n_uniques = mask_positions(unique_mask, cumsum, sycl_queue=exec_q)
if n_uniques == fx.size:
return s
Expand Down Expand Up @@ -127,13 +129,15 @@ def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
if x.ndim == 1:
fx = x
else:
fx = dpt.reshape(x, (x.size,), order="C", copy=False)
s = dpt.sort(x)
fx = dpt.reshape(x, (x.size,), order="C")
ind_dt = default_device_index_type(exec_q)
if fx.size == 0:
return UniqueCountsResult(fx, dpt.empty_like(fx, dtype=ind_dt))
s = dpt.sort(fx)
unique_mask = dpt.empty(s.shape, dtype="?", sycl_queue=exec_q)
dpt.not_equal(s[:-1], s[1:], out=unique_mask[1:])
unique_mask[0] = True
ind_dt = default_device_index_type(exec_q)
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64)
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
# synchronizing call
n_uniques = mask_positions(unique_mask, cumsum, sycl_queue=exec_q)
if n_uniques == fx.size:
Expand Down Expand Up @@ -195,18 +199,20 @@ def unique_inverse(x):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
array_api_dev = x.device
exec_q = array_api_dev.sycl_queue
ind_dt = default_device_index_type(exec_q)
if x.ndim == 1:
fx = x
else:
fx = dpt.reshape(x, (x.size,), order="C", copy=False)
ind_dt = default_device_index_type(exec_q)
fx = dpt.reshape(x, (x.size,), order="C")
sorting_ids = dpt.argsort(fx)
unsorting_ids = dpt.argsort(sorting_ids)
if fx.size == 0:
return UniqueInverseResult(fx, dpt.reshape(unsorting_ids, x.shape))
s = fx[sorting_ids]
unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q)
unique_mask[0] = True
dpt.not_equal(s[:-1], s[1:], out=unique_mask[1:])
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64)
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
# synchronizing call
n_uniques = mask_positions(unique_mask, cumsum, sycl_queue=exec_q)
if n_uniques == fx.size:
Expand Down Expand Up @@ -251,7 +257,9 @@ def unique_inverse(x):
ht_ev, _ = _full_usm_ndarray(fill_value=i, dst=_dst, sycl_queue=exec_q)
ht_ev.wait()
pos = pos_next
return UniqueInverseResult(unique_vals, inv[unsorting_ids])
return UniqueInverseResult(
unique_vals, dpt.reshape(inv[unsorting_ids], x.shape)
)


def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
Expand Down Expand Up @@ -289,22 +297,39 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
array_api_dev = x.device
exec_q = array_api_dev.sycl_queue
ind_dt = default_device_index_type(exec_q)
if x.ndim == 1:
fx = x
else:
fx = dpt.reshape(x, (x.size,), order="C", copy=False)
ind_dt = default_device_index_type(exec_q)
fx = dpt.reshape(x, (x.size,), order="C")
sorting_ids = dpt.argsort(fx)
unsorting_ids = dpt.argsort(sorting_ids)
if fx.size == 0:
# original array contains no data
# so it can be safely returned as values
return UniqueAllResult(
fx,
sorting_ids,
dpt.reshape(unsorting_ids, x.shape),
dpt.empty_like(fx, dtype=ind_dt),
)
s = fx[sorting_ids]
unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q)
dpt.not_equal(s[:-1], s[1:], out=unique_mask[1:])
unique_mask[0] = True
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64)
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
# synchronizing call
n_uniques = mask_positions(unique_mask, cumsum, sycl_queue=exec_q)
if n_uniques == fx.size:
return UniqueInverseResult(s, unsorting_ids)
_counts = dpt.ones(
n_uniques, dtype=ind_dt, usm_type=x.usm_type, sycl_queue=exec_q
)
return UniqueAllResult(
s,
sorting_ids,
dpt.reshape(unsorting_ids, x.shape),
_counts,
)
unique_vals = dpt.empty(
n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
)
Expand Down Expand Up @@ -346,6 +371,6 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
return UniqueAllResult(
unique_vals,
sorting_ids[cum_unique_counts[:-1]],
inv[unsorting_ids],
dpt.reshape(inv[unsorting_ids], x.shape),
_counts,
)

0 comments on commit de20631

Please sign in to comment.