diff --git a/gnes/encoder/text/transformer.py b/gnes/encoder/text/transformer.py index 512fc7d2..5d69354d 100644 --- a/gnes/encoder/text/transformer.py +++ b/gnes/encoder/text/transformer.py @@ -50,7 +50,7 @@ def post_init(self): (RobertaModel, RobertaTokenizer, 'roberta-base')]}[self.model_name] def load_model_tokenizer(x): - return model_class.from_pretrained(x), tokenizer_class.from_pretrained(x) + return model_class.from_pretrained(x).eval(), tokenizer_class.from_pretrained(x) try: self.model, self.tokenizer = load_model_tokenizer(self.work_dir)