Skip to content

Commit

Permalink
fix mask token id assignment tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Plutone11011 committed Feb 25, 2023
1 parent 91ecd8d commit 07115f8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
6 changes: 2 additions & 4 deletions keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,13 @@ def __init__(self, proto, **kwargs):
self.cls_token_id = self.token_to_id(cls_token)
self.sep_token_id = self.token_to_id(sep_token)
self.pad_token_id = self.token_to_id(pad_token)
self.mask_token_id = self.token_to_id(mask_token)

# If the mask token is not in the vocabulary, add it to the end of the
# vocabulary.
if mask_token in super().get_vocabulary():
self.mask_token_id = self.token_to_id(mask_token)
self.mask_token_id = super().token_to_id(mask_token)
else:
self.mask_token_id = super().vocabulary_size()

def vocabulary_size(self):
sentence_piece_size = super().vocabulary_size()
if sentence_piece_size == self.mask_token_id:
Expand Down
9 changes: 4 additions & 5 deletions keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def setUp(self):
sentencepiece.SentencePieceTrainer.train(
sentence_iterator=vocab_data.as_numpy_iterator(),
model_writer=bytes_io,
vocab_size=12,
vocab_size=10,
model_type="WORD",
pad_id=0,
bos_id=1,
Expand All @@ -44,7 +44,6 @@ def setUp(self):
bos_piece="[CLS]",
eos_piece="[SEP]",
unk_piece="[UNK]",
user_defined_symbols="[MASK]",
)
self.proto = bytes_io.getvalue()

Expand All @@ -53,15 +52,15 @@ def setUp(self):
def test_tokenize(self):
input_data = "the quick brown fox"
output = self.tokenizer(input_data)
self.assertAllEqual(output, [5, 10, 6, 8])
self.assertAllEqual(output, [4, 9, 5, 7])

def test_tokenize_batch(self):
input_data = tf.constant(["the quick brown fox", "the earth is round"])
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[5, 10, 6, 8], [5, 7, 9, 11]])
self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 3]])

def test_detokenize(self):
input_data = tf.constant([[5, 10, 6, 8]])
input_data = tf.constant([[4, 9, 5, 7]])
output = self.tokenizer.detokenize(input_data)
self.assertEqual(output, tf.constant(["the quick brown fox"]))

Expand Down

0 comments on commit 07115f8

Please sign in to comment.