Skip to content

Commit

Permalink
Merge pull request #229 from benfred/knn_pickle
Browse files Browse the repository at this point in the history
 Add pickle support for nearest neighbours models
  • Loading branch information
benfred authored Jul 13, 2019
2 parents 32cc638 + c352d08 commit f14c90f
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
10 changes: 7 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ before_install:
fi
fi
install:
- travis_wait travis_retry $PIP uninstall numpy -y
- travis_wait travis_retry $PIP install -r requirements.txt --ignore-installed flake8 isort cpplint nmslib faiss annoy
- travis_retry $PIP install -e .
- |
travis_wait travis_retry $PIP uninstall numpy -y
travis_wait travis_retry $PIP install -r requirements.txt --ignore-installed flake8 isort cpplint faiss annoy
if [ "${PYTHON:0:1}" = "3" ]; then
travis_wait travis_retry $PIP install nmslib
fi
travis_retry $PIP install -e .
script:
- flake8
Expand Down
13 changes: 13 additions & 0 deletions implicit/nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ def similar_items(self, itemid, N=10):

return sorted(list(nonzeros(self.similarity, itemid)), key=lambda x: -x[1])[:N]

def __getstate__(self):
state = self.__dict__.copy()
# scorer isn't picklable
del state['scorer']
return state

def __setstate__(self, state):
self.__dict__.update(state)
if self.similarity is not None:
self.scorer = NearestNeighboursScorer(self.similarity)
else:
self.scorer = None

def save(self, filename):
m = self.similarity
numpy.savez(filename, data=m.data, indptr=m.indptr, indices=m.indices, shape=m.shape,
Expand Down
18 changes: 18 additions & 0 deletions tests/approximate_als_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
class AnnoyALSTest(unittest.TestCase, TestRecommenderBaseMixin):
def _get_model(self):
return AnnoyAlternatingLeastSquares(factors=2, regularization=0, use_gpu=False)

def test_pickle(self):
# pickle isn't supported on annoy indices
pass

except ImportError:
pass

Expand All @@ -25,6 +30,11 @@ class NMSLibALSTest(unittest.TestCase, TestRecommenderBaseMixin):
def _get_model(self):
return NMSLibAlternatingLeastSquares(factors=2, regularization=0,
index_params={'post': 2}, use_gpu=False)

def test_pickle(self):
# pickle isn't supported on nmslib indices
pass

except ImportError:
pass

Expand All @@ -36,6 +46,10 @@ def _get_model(self):
return FaissAlternatingLeastSquares(nlist=1, nprobe=1, factors=2, regularization=0,
use_gpu=False)

def test_pickle(self):
# pickle isn't supported on faiss indices
pass

if HAS_CUDA:
class FaissALSGPUTest(unittest.TestCase, TestRecommenderBaseMixin):
__regularization = 0
Expand Down Expand Up @@ -69,6 +83,10 @@ def test_large_recommend(self):
recs = model.recommend(0, plays.T.tocsr(), N=1050)
self.assertEqual(recs[0][0], 0)

def test_pickle(self):
# pickle isn't supported on faiss indices
pass

except ImportError:
pass

Expand Down
10 changes: 10 additions & 0 deletions tests/recommender_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import print_function

import pickle

import numpy as np
from scipy.sparse import csr_matrix

Expand Down Expand Up @@ -157,6 +159,14 @@ def test_rank_items(self):
wrong_item_list = selected_items + wrong_pos_items
model.rank_items(userid, user_items, wrong_item_list)

def test_pickle(self):
item_users = self.get_checker_board(50)
model = self._get_model()
model.fit(item_users, show_progress=False)

pickled = pickle.dumps(model)
pickle.loads(pickled)

def get_checker_board(self, X):
""" Returns a 'checkerboard' matrix: where every even userid has liked
every even itemid and every odd userid has liked every odd itemid.
Expand Down

0 comments on commit f14c90f

Please sign in to comment.