Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(encoder): update the init func for flair
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 29, 2019
1 parent 1b85375 commit e588c94
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions gnes/encoder/text/flair.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.


from typing import List
from typing import List, Tuple

import numpy as np

Expand All @@ -25,16 +25,22 @@
class FlairEncoder(BaseTextEncoder):
is_trained = True

def __init__(self, pooling_strategy: str = 'mean', *args, **kwargs):
def __init__(self,
word_embedding: str = 'glove',
flair_embeddings: Tuple[str] = ('news-forward', 'news-backward'),
pooling_strategy: str = 'mean', *args, **kwargs):
super().__init__(*args, **kwargs)

self.word_embedding = word_embedding
self.flair_embeddings = flair_embeddings
self.pooling_strategy = pooling_strategy

def post_init(self):
from flair.embeddings import DocumentPoolEmbeddings, WordEmbeddings, FlairEmbeddings
self._flair = DocumentPoolEmbeddings(
[WordEmbeddings('glove'),
FlairEmbeddings('news-forward'),
FlairEmbeddings('news-backward')],
[WordEmbeddings(self.word_embedding),
FlairEmbeddings(self.flair_embeddings[0]),
FlairEmbeddings(self.flair_embeddings[1])],
pooling=self.pooling_strategy)

@batching
Expand Down

0 comments on commit e588c94

Please sign in to comment.