diff --git a/implicit/evaluation.pyx b/implicit/evaluation.pyx index 08aa23ed..51eedb22 100644 --- a/implicit/evaluation.pyx +++ b/implicit/evaluation.pyx @@ -225,7 +225,7 @@ def mean_average_precision_at_k(model, train_user_items, test_user_items, int K= @cython.wraparound(False) @cython.nonecheck(False) def ALS_recommend_all( - model, users_items, int k=10, int threads=1, show_progress=True, recalculate_user=False): + model, users_items, int k=10, int threads=1, show_progress=True, recalculate_user=False, filter_already_liked_items=False): if not isinstance(users_items, csr_matrix): users_items = users_items.tocsr() @@ -251,8 +251,10 @@ def ALS_recommend_all( model._user_factor(u, users_items, recalculate_user) for u in range(u_low, u_high, 1) - ]) - A[:u_len] = users_factors.dot(factors_items) + ]).astype(np.float32) + users_factors.dot(factors_items, out=A[:u_len]) + if filter_already_liked_items: + A[users_items[u_low:u_high].nonzero()] = 0 for u in prange(u_len, nogil=True, num_threads=threads, schedule='dynamic'): fargsort_c(A_mv_p, u, batch * u_b + u, items_c, k, B_mv_p) progress.update(u_len)