diff --git a/gnes/encoder/numeric/pca.py b/gnes/encoder/numeric/pca.py index 970730d5..693a1d62 100644 --- a/gnes/encoder/numeric/pca.py +++ b/gnes/encoder/numeric/pca.py @@ -20,6 +20,42 @@ from ...helper import get_perm, batching, get_optimal_sample_size, train_required +class PCAEncoder(BaseNumericEncoder): + batch_size = 2048 + + def __init__(self, output_dim: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self.output_dim = output_dim + self.pca_components = None + self.mean = None + + def post_init(self): + from sklearn.decomposition import IncrementalPCA + self.pca = IncrementalPCA(n_components=self.output_dim) + + @batching + def train(self, vecs: np.ndarray, *args, **kwargs) -> None: + num_samples, num_dim = vecs.shape + if self.output_dim > num_samples: + if self.mean.size: + return + else: + raise ValueError('training PCA requires at least %d points, but %d was given' % (self.output_dim, num_samples)) + + assert self.output_dim < num_dim, 'PCA output dimension should < data dimension, received (%d, %d)' % ( + self.output_dim, num_dim) + + self.pca.partial_fit(vecs) + + self.pca_components = np.transpose(self.pca.components_) + self.mean = self.pca.mean_ + + @train_required + @batching + def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray: + return np.matmul(vecs - self.mean, self.pca_components) + + class PCALocalEncoder(BaseNumericEncoder): batch_size = 2048