Skip to content

Commit

Permalink
work around IntelPython/numba-dpex#906 by moving bugged instruction p…
Browse files Browse the repository at this point in the history
…atterns to `dpex.func` functions
  • Loading branch information
fcharras committed Feb 27, 2023
1 parent 33ef2dc commit 563d59e
Show file tree
Hide file tree
Showing 9 changed files with 548 additions and 135 deletions.
374 changes: 286 additions & 88 deletions sklearn_numba_dpex/common/kernels.py

Large diffs are not rendered by default.

Empty file.
90 changes: 69 additions & 21 deletions sklearn_numba_dpex/kmeans/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
make_argmin_reduction_1d_kernel,
make_broadcast_division_1d_2d_axis0_kernel,
make_broadcast_ops_1d_2d_axis1_kernel,
make_check_all_equal_kernel,
make_half_l2_norm_2d_axis0_kernel,
make_initialize_to_zeros_kernel,
make_sum_reduction_2d_kernel,
Expand Down Expand Up @@ -153,6 +154,10 @@ def lloyd(
dtype=compute_dtype,
)

check_all_equal_kernel = make_check_all_equal_kernel(
(n_samples,), max_work_group_size
)

# Allocate the necessary memory in the device global memory
new_centroids_t = dpt.empty_like(centroids_t, device=device)
centroids_half_l2_norm = dpt.empty(n_clusters, dtype=compute_dtype, device=device)
Expand All @@ -163,7 +168,18 @@ def lloyd(
sq_dist_to_nearest_centroid = per_sample_inertia = dpt.empty(
n_samples, dtype=compute_dtype, device=device
)
assignments_idx = dpt.empty(n_samples, dtype=np.uint32, device=device)

assignments_idx_new = dpt.empty(n_samples, dtype=np.uint32, device=device)
check_label_convergence = tol == 0
# See the main loop for a more elaborate note about checking "strict convergence"
if check_label_convergence:
assignments_idx = dpt.empty(n_samples, dtype=np.uint32, device=device)
# allocation of one scalar where we store the result of array comparisons
check_equal_result = dpt.empty(1, dtype=np.uint32, device=device)
label_convergence = False
else:
assignments_idx = assignments_idx_new

new_centroids_t_private_copies = dpt.empty(
(n_centroids_private_copies, n_features, n_clusters),
dtype=compute_dtype,
Expand Down Expand Up @@ -203,7 +219,7 @@ def lloyd(
centroids_t,
centroids_half_l2_norm,
# OUT:
assignments_idx,
assignments_idx_new,
new_centroids_t_private_copies,
cluster_sizes_private_copies,
)
Expand All @@ -226,7 +242,7 @@ def lloyd(
X_t,
sample_weight,
new_centroids_t,
assignments_idx,
assignments_idx_new,
# OUT:
per_sample_inertia,
)
Expand All @@ -249,7 +265,7 @@ def lloyd(
centroids_t,
centroids_half_l2_norm,
# OUT:
assignments_idx,
assignments_idx_new,
)

# if verbose is True and if sample_weight is uniform, distances to
Expand All @@ -262,7 +278,7 @@ def lloyd(
X_t,
dpt.ones_like(sample_weight),
centroids_t,
assignments_idx,
assignments_idx_new,
# OUT:
sq_dist_to_nearest_centroid,
)
Expand All @@ -273,7 +289,7 @@ def lloyd(
sample_weight,
new_centroids_t,
cluster_sizes,
assignments_idx,
assignments_idx_new,
empty_clusters_list,
sq_dist_to_nearest_centroid,
per_sample_inertia,
Expand All @@ -283,17 +299,6 @@ def lloyd(
# Change `new_centroids_t` inplace
broadcast_division_kernel(new_centroids_t, cluster_sizes)

compute_centroid_shifts_kernel(
centroids_t,
new_centroids_t,
# OUT:
centroid_shifts,
)

centroid_shifts_sum, *_ = reduce_centroid_shifts_kernel(centroid_shifts)
# Use numpy type to work around https://github.com/IntelPython/dpnp/issues/1238
centroid_shifts_sum = compute_dtype(centroid_shifts_sum)

# ???: unlike sklearn, sklearn_intelex checks that pseudo_inertia decreases
# and keep an additional copy of centroids that is updated only if the
# value of the pseudo_inertia is smaller than all the past values.
Expand All @@ -313,14 +318,57 @@ def lloyd(
# refers. For this reason, this strategy is not compatible with sklearn
# unit tests, that consider new_centroids_t (the array after the update)
# to be the best centroids at each iteration.

centroids_t, new_centroids_t = (new_centroids_t, centroids_t)

# ???: Also, if two successive assignations have been computed equal,
# scikit-learn decides that the algorithm has converged, and call it a strict
# convergence. However, isn't there a risk that we miss out additional
# iterations that could help compute better centroids ? Indeed, even if the
# assignments have converged, the centroids haven't yet.

# Moreover, the cost of checking equality of assignments at each iterations
# might not be negligible.

# Following this reasoning, unlike scikit-learn, we choose to enforce this
# behavior if and only if `tol == 0`, because in this case, it is easy to see
# that lloyd can indeed fail to stop due to numerical errors. Moreover, this is
# enough to pass scikit-learn unit tests. When `tol > 0`, we rely on the user
# setting an appropriate tolerance threshold.

# TODO: open an issue at `scikit-learn` and propose to adopt this behavior
# instead ?

if check_label_convergence:
assignments_idx, assignments_idx_new = (
assignments_idx_new,
assignments_idx,
)
check_all_equal_kernel(
assignments_idx_new, assignments_idx, check_equal_result
)
are_assignments_equal, *_ = check_equal_result
if are_assignments_equal:
label_convergence = True
break

compute_centroid_shifts_kernel(
centroids_t,
new_centroids_t,
# OUT:
centroid_shifts,
)

centroid_shifts_sum, *_ = reduce_centroid_shifts_kernel(centroid_shifts)
# Use numpy type to work around https://github.com/IntelPython/dpnp/issues/1238
centroid_shifts_sum = compute_dtype(centroid_shifts_sum)

n_iteration += 1

if verbose:
converged_at = n_iteration - 1
if centroid_shifts_sum == 0: # NB: possible if tol = 0
if (check_label_convergence and label_convergence) or (
centroid_shifts_sum == 0
): # NB: possible if tol = 0
print(f"Converged at iteration {converged_at}: strict convergence.")

elif centroid_shifts_sum <= tol:
Expand Down Expand Up @@ -478,7 +526,7 @@ def prepare_data_for_lloyd(X_t, init, tol, sample_weight, copy_x):
)

# At the time of writing this code, dpnp does not support functions (like `==`
# operator) that would help computing `sample_weight_is_uniform` in a simpler
# operator) that would help computing `X_mean_is_zeroed` in a simpler
# manner.
# TODO: if dpnp support extends to relevant features, use it instead ?
sum_of_squares_kernel = make_sum_reduction_2d_kernel(
Expand Down Expand Up @@ -539,7 +587,7 @@ def prepare_data_for_lloyd(X_t, init, tol, sample_weight, copy_x):

variance = variance_kernel(dpt.reshape(X_t, -1))
# Use numpy type to work around https://github.com/IntelPython/dpnp/issues/1238
tol = (dpt.asnumpy(variance)[0] / n_features) * tol
tol = (dpt.asnumpy(variance)[0] / (n_features * n_samples)) * tol

# check if sample_weight is uniform
# At the time of writing this code, dpnp does not support functions (like `==`
Expand Down
14 changes: 12 additions & 2 deletions sklearn_numba_dpex/kmeans/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def kmeans_single(self, X, sample_weight, centers_init_t):
# XXX: having a C-contiguous centroid array is expected in sklearn in some
# unit test and by the cython engine.
assignments_idx = dpt.asnumpy(assignments_idx).astype(np.int32)
best_centroids_t = np.asfortranarray(dpt.asnumpy(best_centroids_t))
best_centroids_t = np.asfortranarray(dpt.asnumpy(best_centroids_t)).astype(
self.estimator._output_dtype
)

# ???: rather that returning whatever dtype the driver returns (which might
# depends on device support for float64), shouldn't we cast to a dtype that
Expand Down Expand Up @@ -269,7 +271,9 @@ def get_euclidean_distances(self, X):
)
euclidean_distances = get_euclidean_distances(X.T, cluster_centers)
if self._is_in_testing_mode:
euclidean_distances = dpt.asnumpy(euclidean_distances)
euclidean_distances = dpt.asnumpy(euclidean_distances).astype(
self.estimator._output_dtype
)
return euclidean_distances

def _validate_data(self, X, reset=True):
Expand All @@ -291,6 +295,12 @@ def _validate_data(self, X, reset=True):
else:
accepted_dtypes = [np.float32]

if self._is_in_testing_mode and reset:
if (X_dtype := X.dtype) not in accepted_dtypes:
self.estimator._output_dtype = np.float64
else:
self.estimator._output_dtype = X_dtype

with _validate_with_array_api(device):
try:
X = self.estimator._validate_data(
Expand Down
35 changes: 28 additions & 7 deletions sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,39 @@ def compute_distances(

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

if sample_idx < n_samples:
for i in range(window_n_centroids):
centroid_idx = first_centroid_idx + i
if centroid_idx < n_clusters:
euclidean_distances_t[first_centroid_idx + i, sample_idx] = (
math.sqrt(sq_distances[i])
)
_save_distance(
sample_idx,
first_centroid_idx,
euclidean_distances_t,
# OUT
sq_distances
)

first_centroid_idx += window_n_centroids

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

# HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906
@dpex.func
# fmt: off
def _save_distance(
sample_idx, # PARAM
first_centroid_idx, # PARAM
euclidean_distances_t, # IN
sq_distances # OUT
):
# fmt: on
if sample_idx >= n_samples:
return

for i in range(window_n_centroids):
centroid_idx = first_centroid_idx + i

if centroid_idx < n_clusters:
euclidean_distances_t[centroid_idx, sample_idx] = (
math.sqrt(sq_distances[i])
)

n_windows_for_sample = math.ceil(n_samples / window_n_centroids)

global_size = (
Expand Down
17 changes: 13 additions & 4 deletions sklearn_numba_dpex/kmeans/kernels/compute_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,20 @@ def assignment(

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

# No update step, only store min_idx in the output array
if sample_idx >= n_samples:
return
_setitem_if(
sample_idx < n_samples,
sample_idx,
min_idx,
# OUT
assignments_idx,
)

assignments_idx[sample_idx] = min_idx
# HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906
@dpex.func
def _setitem_if(condition, index, value, array):
if condition:
array[index] = value
return condition

n_windows_for_sample = math.ceil(n_samples / window_n_centroids)

Expand Down
45 changes: 32 additions & 13 deletions sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,24 +216,43 @@ def kmeansplusplus_single_step(

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

if sample_idx < n_samples:
sample_weight_ = sample_weight[sample_idx]
closest_dist_sq_ = closest_dist_sq[sample_idx]
for i in range(window_n_candidates):
candidate_idx = first_candidate_idx + i
if candidate_idx < n_candidates:
sq_distance_i = min(
sq_distances[i] * sample_weight_,
closest_dist_sq_
)
sq_distances_t[first_candidate_idx + i, sample_idx] = (
sq_distance_i
)
_save_sq_distances(
sample_idx,
first_candidate_idx,
sq_distances,
sample_weight,
closest_dist_sq,
# OUT
sq_distances_t
)

first_candidate_idx += window_n_candidates

dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE)

# HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906
@dpex.func
# fmt: off
def _save_sq_distances(
sample_idx, # PARAM
first_candidate_idx, # PARAM
sq_distances, # IN
sample_weight, # IN
closest_dist_sq, # IN
sq_distances_t, # OUT
):
# fmt: on
if sample_idx >= n_samples:
return

sample_weight_ = sample_weight[sample_idx]
closest_dist_sq_ = closest_dist_sq[sample_idx]
for i in range(window_n_candidates):
candidate_idx = first_candidate_idx + i
if candidate_idx < n_candidates:
sq_distance_i = min(sq_distances[i] * sample_weight_, closest_dist_sq_)
sq_distances_t[first_candidate_idx + i, sample_idx] = sq_distance_i

n_windows_for_samples = math.ceil(n_samples / window_n_candidates)

global_size = (
Expand Down
28 changes: 28 additions & 0 deletions sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,33 @@ def fused_lloyd_single_step(
# End of outer loop. By now min_idx and min_sample_pseudo_inertia
# contains the expected values.

_update_result_data(
sample_idx,
min_idx,
sub_group_idx,
X_t,
sample_weight,
# OUT
assignments_idx,
cluster_sizes_private_copies,
new_centroids_t_private_copies,
)

# HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906
@dpex.func
# fmt: off
def _update_result_data(
sample_idx, # PARAM
min_idx, # PARAM
sub_group_idx, # PARAM
X_t, # IN
sample_weight, # IN
assignments_idx, # OUT
cluster_sizes_private_copies, # OUT
new_centroids_t_private_copies, # OUT
):
# fmt: on

# NB: this check can't be moved at the top at the kernel, because if a work item
# exits early with a `return` it will never reach the barriers, thus causing a
# deadlock. Early returns are only possible when there are no barriers within
Expand Down Expand Up @@ -384,6 +411,7 @@ def fused_lloyd_single_step(
(privatization_idx, feature_idx, min_idx),
X_t[feature_idx, sample_idx] * weight,
)

return (
n_centroids_private_copies,
fused_lloyd_single_step[global_size, work_group_shape],
Expand Down
Loading

0 comments on commit 563d59e

Please sign in to comment.