diff --git a/gnes/encoder/numeric/quantizer.py b/gnes/encoder/numeric/quantizer.py index ebfeb395..3be5b8f1 100644 --- a/gnes/encoder/numeric/quantizer.py +++ b/gnes/encoder/numeric/quantizer.py @@ -36,8 +36,7 @@ def __init__(self, dim_per_byte: int, cluster_per_byte: int = 255, self.upper_bound = upper_bound self.lower_bound = lower_bound self.partition_method = partition_method - self.centroids = None - self._get_centroids() + self.centroids = self._get_centroids() def _get_centroids(self): """ @@ -52,7 +51,7 @@ def _get_centroids(self): if self.upper_bound < self.lower_bound: raise ValueError("upper bound is smaller than lower bound") - self.centroids = [] + centroids = [] num_sample_per_dim = np.ceil(pow(self.num_clusters, 1 / self.dim_per_byte)).astype(np.uint8) if self.partition_method == 'average': axis_point = np.linspace(self.lower_bound, self.upper_bound, num=num_sample_per_dim+1, @@ -65,16 +64,13 @@ def _get_centroids(self): raise NotImplementedError for item in product(*coordinates): - self.centroids.append(list(item)) - self.centroids = self.centroids[:self.num_clusters] + centroids.append(list(item)) + return centroids[:self.num_clusters] @batching def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray: + self._check_bound(vecs) num_bytes = self._get_num_bytes(vecs) - max_value, min_value = self._get_max_min_value(vecs) - - self._check_bound(max_value, min_value) - x = np.reshape(vecs, [vecs.shape[0], num_bytes, 1, self.dim_per_byte]) x = np.sum(np.square(x - self.centroids), -1) # start from 1 @@ -93,7 +89,8 @@ def _get_num_bytes(self, vecs: np.ndarray): def _get_max_min_value(vecs): return np.amax(vecs, axis=None), np.amin(vecs, axis=None) - def _check_bound(self, max_value, min_value): + def _check_bound(self, vecs): + max_value, min_value = self._get_max_min_value(vecs) if self.upper_bound < max_value: raise Warning("upper bound (=%.3f) is smaller than max value of input data (=%.3f), you should choose" "a bigger value for upper bound" % (self.upper_bound, max_value))