Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #235 from gnes-ai/fix_vlad
Browse files Browse the repository at this point in the history
fix(encoder): fix vlad encoder and add unittest for it
  • Loading branch information
mergify[bot] authored Sep 9, 2019
2 parents d1eb573 + 0eccfdc commit 8e7a161
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 20 deletions.
59 changes: 39 additions & 20 deletions gnes/encoder/numeric/vlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.


import copy

import numpy as np

from ..base import BaseNumericEncoder
Expand All @@ -25,51 +23,72 @@
class VladEncoder(BaseNumericEncoder):
batch_size = 2048

def __init__(self, num_clusters: int, *args, **kwargs):
def __init__(self, num_clusters: int,
using_faiss_pred: bool=True,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.num_clusters = num_clusters
self.using_faiss_pred = using_faiss_pred
self.centroids = None
self.index_flat = None

def kmeans_train(self, vecs):
import faiss

kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False)
kmeans.train(vecs)
self.centroids = kmeans.centroids
if self.using_faiss_pred:
self.faiss_index()

def faiss_index(self):
import faiss
self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1])
self.index_flat.add(self.centroids)

def kmeans_pred(self, vecs):
vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]])
dist = np.sum(np.square(vecs - self.centroids), -1)
return np.argmax(-dist, axis=-1).astype(np.int64)
if self.using_faiss_pred:
_, pred = self.index_flat.search(vecs.astype(np.float32), 1)
return np.reshape(pred, [-1])
else:
vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]])
dist = np.sum(np.square(vecs - self.centroids), -1)
return np.argmax(-dist, axis=-1).reshape([-1]).astype(np.int32)

@batching
def train(self, vecs: np.ndarray, *args, **kwargs):
vecs = vecs.reshape([-1, vecs.shape[-1]])
assert len(vecs) > self.num_clusters, 'number of data should be larger than number of clusters'
vecs_ = copy.deepcopy(vecs)
vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0)
self.kmeans_train(vecs_)
self.kmeans_train(vecs)

@train_required
@batching
def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
vecs_ = copy.deepcopy(vecs)
vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0)

knn_output = self.kmeans_pred(vecs_)
knn_output = [knn_output[i:i + vecs.shape[1]] for i in range(0, len(knn_output), vecs.shape[1])]
knn_output = [self.kmeans_pred(vecs_) for vecs_ in vecs]

output = []
for chunk_count, chunk in enumerate(vecs):
res = np.zeros((self.centroids.shape[0], self.centroids.shape[1]))
for frame_count, frame in enumerate(chunk):
center_index = knn_output[chunk_count][frame_count][0]
center_index = knn_output[chunk_count][frame_count]
res[center_index] += (frame - self.centroids[center_index])
output.append(res)
res = res.reshape([-1])
output.append(res / np.sum(res**2)**0.5)

output = np.array(list(map(lambda x: x.reshape(1, -1), output)), dtype=np.float32)
output = np.squeeze(output, axis=1)
return output
return np.array(output, dtype=np.float32)

def _copy_from(self, x: 'VladEncoder') -> None:
self.num_clusters = x.num_clusters
self.centroids = x.centroids
self.using_faiss_pred = x.using_faiss_pred
if self.using_faiss_pred:
self.faiss_index()

def __setstate__(self, state):
super().__setstate__(state)
if self.using_faiss_pred:
self.faiss_index()

def __getstate__(self):
state = super().__getstate__()
del state['index_flat']
return state
29 changes: 29 additions & 0 deletions tests/test_vlad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import unittest
import numpy as np
from gnes.encoder.numeric.vlad import VladEncoder


class TestVladEncoder(unittest.TestCase):
def setUp(self):
self.mock_train_data = np.random.random([1, 200, 128]).astype(np.float32)
self.mock_eval_data = np.random.random([2, 2, 128]).astype(np.float32)
self.dump_path = os.path.join(os.path.dirname(__file__), 'vlad.bin')

def tearDown(self):
if os.path.exists(self.dump_path):
os.remove(self.dump_path)

def test_vlad_train(self):
model = VladEncoder(20)
model.train(self.mock_train_data)
self.assertEqual(model.centroids.shape, (20, 128))
v = model.encode(self.mock_eval_data)
self.assertEqual(v.shape, (2, 2560))

def test_vlad_dump_load(self):
model = VladEncoder(20)
model.train(self.mock_train_data)
model.dump(self.dump_path)
model_new = VladEncoder.load(self.dump_path)
self.assertEqual(model_new.centroids.shape, (20, 128))

0 comments on commit 8e7a161

Please sign in to comment.