diff --git a/gnes/encoder/base.py b/gnes/encoder/base.py index cff40538..cbd0edfa 100644 --- a/gnes/encoder/base.py +++ b/gnes/encoder/base.py @@ -49,6 +49,7 @@ def encode(self, text: List[str], *args, **kwargs) -> Union[Tuple, np.ndarray]: class BaseNumericEncoder(BaseEncoder): + """Note that all NumericEncoder can not be used as the first encoder of the pipeline""" def encode(self, data: np.ndarray, *args, **kwargs) -> np.ndarray: pass diff --git a/tests/test_flair_encoder.py b/tests/test_flair_encoder.py index 6a4aef86..798c2eb6 100644 --- a/tests/test_flair_encoder.py +++ b/tests/test_flair_encoder.py @@ -17,15 +17,14 @@ def setUp(self): if line: self.test_str.append(line) - self.flair_encoder = FlairEncoder( - model_name=os.environ.get('FLAIR_CI_MODEL'), - pooling_strategy="REDUCE_MEAN") + self.flair_encoder = FlairEncoder(model_name=os.environ.get('FLAIR_CI_MODEL')) @unittest.SkipTest def test_encoding(self): - vec = self.flair_encoder.encode(self.test_str) - self.assertEqual(vec.shape[0], len(self.test_str)) - self.assertEqual(vec.shape[1], 512) + vec = self.flair_encoder.encode(self.test_str[:2]) + print(vec.shape) + self.assertEqual(vec.shape[0], 2) + self.assertEqual(vec.shape[1], 4196) @unittest.SkipTest def test_dump_load(self):