Skip to content

Commit

Permalink
Use current device resource instead of explicit managed memory resource.
Browse files Browse the repository at this point in the history
  • Loading branch information
csadorf committed Feb 4, 2025
1 parent 74a2b6c commit be4586d
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,7 @@ void build_hierarchical(const raft::resources& handle,
IdxT n_mesoclusters = std::min(n_clusters, static_cast<IdxT>(std::sqrt(n_clusters) + 0.5));
RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters);

// TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf.
rmm::mr::managed_memory_resource managed_memory;
rmm::device_async_resource_ref managed_memory = rmm::mr::get_current_device_resource();
rmm::device_async_resource_ref device_memory = resource::get_large_workspace_resource(handle);
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
Expand Down Expand Up @@ -999,8 +998,8 @@ void build_hierarchical(const raft::resources& handle,
CounterT;

// build coarse clusters (mesoclusters)
rmm::device_uvector<LabelT> mesocluster_labels_buf(n_rows, stream, &managed_memory);
rmm::device_uvector<CounterT> mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory);
rmm::device_uvector<LabelT> mesocluster_labels_buf(n_rows, stream, managed_memory);
rmm::device_uvector<CounterT> mesocluster_sizes_buf(n_mesoclusters, stream, managed_memory);
{
rmm::device_uvector<MathT> mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory);
build_clusters(handle,
Expand Down Expand Up @@ -1056,7 +1055,7 @@ void build_hierarchical(const raft::resources& handle,
fine_clusters_nums_max,
cluster_centers,
mapping_op,
&managed_memory,
managed_memory,
device_memory);
RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters.");

Expand Down

0 comments on commit be4586d

Please sign in to comment.