Skip to content

Commit

Permalink
fix import of fast tokenizers. add new supported fast tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
tholor committed Oct 23, 2020
1 parent 22d98e6 commit ed0243b
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions farm/modeling/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@

import numpy as np
from transformers.tokenization_albert import AlbertTokenizer
from transformers.tokenization_albert_fast import AlbertTokenizerFast
from transformers.tokenization_bert import BertTokenizer, BertTokenizerFast, load_vocab
from transformers.tokenization_bert_fast import BertTokenizerFast
from transformers.tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from transformers.tokenization_distilbert_fast import DistilBertTokenizerFast
from transformers.tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from transformers.tokenization_electra_fast import ElectraTokenizerFast
from transformers.tokenization_roberta import RobertaTokenizer
from transformers.tokenization_roberta_fast import RobertaTokenizerFast
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer
from transformers.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
from transformers.tokenization_xlnet import XLNetTokenizer
from transformers.tokenization_xlnet_fast import XLNetTokenizerFast
from transformers.tokenization_camembert import CamembertTokenizer
from transformers.tokenization_camembert_fast import CamembertTokenizerFast
from transformers import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast

Expand Down Expand Up @@ -111,20 +119,17 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
ret = None
if tokenizer_class == "AlbertTokenizer":
if use_fast:
logger.error('AlbertTokenizerFast is not supported! Using AlbertTokenizer instead.')
ret = AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, keep_accents=True, **kwargs)
ret = AlbertTokenizerFast.from_pretrained(pretrained_model_name_or_path, keep_accents=True, **kwargs)
else:
ret = AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, keep_accents=True, **kwargs)
elif tokenizer_class == "XLMRobertaTokenizer":
if use_fast:
logger.error('XLMRobertaTokenizerFast is not supported! Using XLMRobertaTokenizer instead.')
ret = XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
ret = XLMRobertaTokenizerFast.from_pretrained(pretrained_model_name_or_path, **kwargs)
else:
ret = XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "RobertaTokenizer" in tokenizer_class: # because it also might be fast tokekenizer we use "in"
if use_fast:
logger.error('RobertaTokenizerFast is not supported! Using RobertaTokenizer instead.')
ret = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
ret = RobertaTokenizerFast.from_pretrained(pretrained_model_name_or_path, **kwargs)
else:
ret = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "DistilBertTokenizer" in tokenizer_class: # because it also might be fast tokekenizer we use "in"
Expand All @@ -139,8 +144,7 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
ret = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif tokenizer_class == "XLNetTokenizer":
if use_fast:
logger.error('XLNetTokenizerFast is not supported! Using XLNetTokenizer instead.')
ret = XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, keep_accents=True, **kwargs)
ret = XLNetTokenizerFast.from_pretrained(pretrained_model_name_or_path, keep_accents=True, **kwargs)
else:
ret = XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, keep_accents=True, **kwargs)
elif "ElectraTokenizer" in tokenizer_class: # because it also might be fast tokekenizer we use "in"
Expand All @@ -156,8 +160,7 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
ret = EmbeddingTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif tokenizer_class == "CamembertTokenizer":
if use_fast:
logger.error('CamembertTokenizerFast is not supported! Using CamembertTokenizer instead.')
ret = CamembertTokenizer._from_pretrained(pretrained_model_name_or_path, **kwargs)
ret = CamembertTokenizerFast._from_pretrained(pretrained_model_name_or_path, **kwargs)
else:
ret = CamembertTokenizer._from_pretrained(pretrained_model_name_or_path, **kwargs)
elif tokenizer_class == "DPRQuestionEncoderTokenizer" or tokenizer_class == "DPRQuestionEncoderTokenizerFast":
Expand Down

0 comments on commit ed0243b

Please sign in to comment.