diff --git a/demo/visualize_result.py b/demo/visualize_result.py index 493b70202..4207a01bc 100644 --- a/demo/visualize_result.py +++ b/demo/visualize_result.py @@ -124,7 +124,7 @@ def get_parser(): # compute cosine distance distmat = 1 - torch.mm(q_feat, g_feat.t()) - distmat = distmat.numpy() + distmat = distmat.cpu().numpy() logger.info("Computing APs for all query images ...") cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)