Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
style: minor fix on the styling
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Xiao authored Sep 5, 2019
1 parent 57cc95f commit 2fd8dab
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions gnes/encoder/numeric/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def __init__(self, dim_per_byte: int, cluster_per_byte: int = 255,
self.upper_bound = upper_bound
self.lower_bound = lower_bound
self.partition_method = partition_method
self.centroids = None
self._get_centroids()
self.centroids = self._get_centroids()

def _get_centroids(self):
"""
Expand All @@ -52,7 +51,7 @@ def _get_centroids(self):
if self.upper_bound < self.lower_bound:
raise ValueError("upper bound is smaller than lower bound")

self.centroids = []
centroids = []
num_sample_per_dim = np.ceil(pow(self.num_clusters, 1 / self.dim_per_byte)).astype(np.uint8)
if self.partition_method == 'average':
axis_point = np.linspace(self.lower_bound, self.upper_bound, num=num_sample_per_dim+1,
Expand All @@ -65,16 +64,13 @@ def _get_centroids(self):
raise NotImplementedError

for item in product(*coordinates):
self.centroids.append(list(item))
self.centroids = self.centroids[:self.num_clusters]
centroids.append(list(item))
return centroids[:self.num_clusters]

@batching
def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
self._check_bound(vecs)
num_bytes = self._get_num_bytes(vecs)
max_value, min_value = self._get_max_min_value(vecs)

self._check_bound(max_value, min_value)

x = np.reshape(vecs, [vecs.shape[0], num_bytes, 1, self.dim_per_byte])
x = np.sum(np.square(x - self.centroids), -1)
# start from 1
Expand All @@ -93,7 +89,8 @@ def _get_num_bytes(self, vecs: np.ndarray):
def _get_max_min_value(vecs):
return np.amax(vecs, axis=None), np.amin(vecs, axis=None)

def _check_bound(self, max_value, min_value):
def _check_bound(self, vecs):
max_value, min_value = self._get_max_min_value(vecs)
if self.upper_bound < max_value:
raise Warning("upper bound (=%.3f) is smaller than max value of input data (=%.3f), you should choose"
"a bigger value for upper bound" % (self.upper_bound, max_value))
Expand Down

0 comments on commit 2fd8dab

Please sign in to comment.