Skip to content

Commit

Permalink
working through updates for cuvs
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Jan 30, 2025
1 parent cbae315 commit 485d042
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions cpp/include/raft/spectral/detail/matrix_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ struct laplacian_matrix_t : sparse_matrix_t<index_type, value_type, nnz_type> {
laplacian_matrix_t(resources const& raft_handle,
sparse_matrix_t<index_type, value_type, nnz_type> const& csr_m)
: sparse_matrix_t<index_type, value_type, nnz_type>(raft_handle,
csr_m.row_offsets_,
csr_m.col_indices_,
csr_m.values_,
csr_m.nrows_,
csr_m.nnz_),
csr_m.row_offsets_,
csr_m.col_indices_,
csr_m.values_,
csr_m.nrows_,
csr_m.nnz_),
diagonal_(raft_handle, csr_m.nrows_)
{
vector_t<value_type> ones{raft_handle, (size_t)csr_m.nrows_};
Expand Down Expand Up @@ -382,7 +382,8 @@ struct laplacian_matrix_t : sparse_matrix_t<index_type, value_type, nnz_type> {

// Apply adjacency matrix
//
sparse_matrix_t<index_type, value_type, nnz_type>::mv(-alpha, x, 1, y, alg, transpose, symmetric);
sparse_matrix_t<index_type, value_type, nnz_type>::mv(
-alpha, x, 1, y, alg, transpose, symmetric);
}

vector_t<value_type> diagonal_;
Expand Down Expand Up @@ -427,36 +428,37 @@ struct modularity_matrix_t : laplacian_matrix_t<index_type, value_type, nnz_type

// y = A*x
//
sparse_matrix_t<index_type, value_type, nnz_type>::mv(alpha, x, 0, y, alg, transpose, symmetric);
sparse_matrix_t<index_type, value_type, nnz_type>::mv(
alpha, x, 0, y, alg, transpose, symmetric);
value_type dot_res;

// gamma = d'*x
//
// Cublas::dot(this->n, D.raw(), 1, x, 1, &dot_res);
// TODO: Call from public API when ready
RAFT_CUBLAS_TRY(
raft::linalg::detail::cublasdot(cublas_h,
n,
laplacian_matrix_t<index_type, value_type, nnz_type>::diagonal_.raw(),
1,
x,
1,
&dot_res,
stream));
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasdot(
cublas_h,
n,
laplacian_matrix_t<index_type, value_type, nnz_type>::diagonal_.raw(),
1,
x,
1,
&dot_res,
stream));

// y = y -(gamma/edge_sum)*d
//
value_type gamma_ = -dot_res / edge_sum_;
// TODO: Call from public API when ready
RAFT_CUBLAS_TRY(
raft::linalg::detail::cublasaxpy(cublas_h,
n,
&gamma_,
laplacian_matrix_t<index_type, value_type, nnz_type>::diagonal_.raw(),
1,
y,
1,
stream));
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasaxpy(
cublas_h,
n,
&gamma_,
laplacian_matrix_t<index_type, value_type, nnz_type>::diagonal_.raw(),
1,
y,
1,
stream));
}

value_type edge_sum_;
Expand Down

0 comments on commit 485d042

Please sign in to comment.