Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(encoder): add quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
jemmyshin committed Sep 5, 2019
1 parent bbf4283 commit 57cc95f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
8 changes: 4 additions & 4 deletions gnes/encoder/numeric/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_centroids(self):
endpoint=False, retstep=False, dtype=None)[1:]
coordinates = np.tile(axis_point, (self.dim_per_byte, 1))
elif self.partition_method == 'random':
coordinates = np.random.randint(self.lower_bound, self.upper_bound,
coordinates = np.random.uniform(self.lower_bound, self.upper_bound,
size=[self.dim_per_byte, num_sample_per_dim])
else:
raise NotImplementedError
Expand Down Expand Up @@ -95,12 +95,12 @@ def _get_max_min_value(vecs):

def _check_bound(self, max_value, min_value):
if self.upper_bound < max_value:
self.logger.warning("upper bound (=%.3f) is smaller than max value of input data (=%.3f), you should choose"
raise Warning("upper bound (=%.3f) is smaller than max value of input data (=%.3f), you should choose"
"a bigger value for upper bound" % (self.upper_bound, max_value))
if self.lower_bound > min_value:
self.logger.warning("lower bound (=%.3f) is bigger than min value of input data (=%.3f), you should choose"
raise Warning("lower bound (=%.3f) is bigger than min value of input data (=%.3f), you should choose"
"a smaller value for lower bound" % (self.lower_bound, min_value))
if (self.upper_bound-self.lower_bound) >= 10*(max_value - min_value):
self.logger.warning("(upper bound - lower_bound) (=%.3f) is 10 times larger than (max value - min value) "
raise Warning("(upper bound - lower_bound) (=%.3f) is 10 times larger than (max value - min value) "
"(=%.3f) of data, maybe you should choose a suitable bound" %
((self.upper_bound-self.lower_bound), (max_value - min_value)))
15 changes: 12 additions & 3 deletions tests/test_quantizer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,22 @@

class TestQuantizerEncoder(unittest.TestCase):
def setUp(self):
self.vecs = np.random.randint(-150, 150, size=[1000, 160]).astype('float32')
dirname = os.path.dirname(__file__)
self.vanilla_quantizer_yaml = os.path.join(dirname, 'yaml', 'quantizer_encoder.yml')

def test_vanilla_quantizer(self):
encoder = BaseNumericEncoder.load_yaml(self.vanilla_quantizer_yaml)
encoder.train()
out = encoder.encode(self.vecs)
print(out.shape)

vecs_1 = np.random.uniform(-150, 150, size=[1000, 160]).astype('float32')
out = encoder.encode(vecs_1)
self.assertEqual(len(out.shape), 2)
self.assertEqual(out.shape[0], 1000)
self.assertEqual(out.shape[1], 16)

vecs_2 = np.random.uniform(-1, 1, size=[1000, 160]).astype('float32')
self.assertRaises(Warning, encoder.encode, vecs_2)

vecs_3 = np.random.uniform(-1, 1000, size=[1000, 160]).astype('float32')
self.assertRaises(Warning, encoder.encode, vecs_3)

4 changes: 2 additions & 2 deletions tests/yaml/quantizer_encoder.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
!QuantizerEncoder
parameters:
upper_bound: 1000000
lower_bound: -100
upper_bound: 500
lower_bound: -200
partition_method: 'random'
cluster_per_byte: 255
dim_per_byte: 10

0 comments on commit 57cc95f

Please sign in to comment.