-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip entire header for llama3 decode #1656
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1656
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit e7bbac6 with merge base 57ab583 (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments and questions
elif token_ids[idx] == self.end_header_id: | ||
# Mask out end header id and "\n\n" after it, then reset mask | ||
not_header[idx] = False | ||
mask = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So after we see "end_header_id", mask is set to True, and will never be False again, is that right? If so, can we just exit the while loop? Or do we need to keep checking because of sample packing, for example?
not_header = [True] * len(token_ids) | ||
mask = True | ||
idx = 0 | ||
while idx < len(token_ids): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like it would be more efficient with numpy arrays. What sucks is going list -> array -> list
But you could do something like:
indices_start = get_index_start_header(token_ids, start_header_id)
indices_end = get_index_start_header(token_ids, start_header_id)
for idx_start, index_end zip(indices_start, indices_end):
mask[idx_start:index_end+1] = False
that way the python for-loop is over number of header ids, and not over indices
ps: not saying it needs to be done. Just thinking outloud. I usually try to avoid python loops
truncate_at_eos=truncate_at_eos, | ||
skip_special_tokens=skip_special_tokens, | ||
) | ||
if skip_special_tokens: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this test pass?
text = decode(encode(text), skip_special_tokens=True)
Probably not well formulated, but i want to see if after every encode/decode we are adding \n\n
idx += 1 | ||
|
||
return self.tt_model.decode( | ||
[token_ids[i] for i, m in enumerate(not_header) if m], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggested defining token_ids inside of the if condition. Then, you can have a single return and remove the if/else and duplicated code
eg:
instead of
if something:
return self.tt_model.decode(token_ids_a)
else:
return self.tt_model.decode(token_ids_b)
do:
if something:
token_ids = token_ids_a
return self.tt_model.decode(token_ids)
truncate_at_eos=truncate_at_eos, | ||
skip_special_tokens=skip_special_tokens, | ||
) | ||
if skip_special_tokens: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that a string here is necessary to explain why you need this extra logic for "skip_special_tokens" if decode already has this flag. In other words: why do we skip special tokens tokens in two different places? Is there a more elegant way to solve this, like adding the special token to the tt_model directly?
@KaiserLeave has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@KaiserLeave has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
skip_special_tokens=False, | ||
) | ||
if skip_special_tokens: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I understand that the goal here is just to fix skip_special_tokens
for Llama3 tokenizer decode, but it seems to me like we are doing something very unexpected here with skip_special_tokens
. (a) we have it defined on the base tokenizer and it's basically a no-op now, and (b) we are now inconsistent on whether this needs to be defined on the ModelTokenizer
and BaseTokenizer
. If it is a function of tokenize_messages
on the ModelTokenizer
moreso than the BaseTokenizer
, maybe we should update the Protocol along with other callsites?
# We will remove special tokens manually via regex on the decoded string. | ||
# This is because removing all special tokens does not remove the role and | ||
# whitespace added from the special tokens, i.e., the "user" and "\n\n" in | ||
# "<|start_header_id|>user<|end_header_id|>\n\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would maybe move this comment up to where you define self._special_token_regex
and self._special_token_header_regex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is where it actually happens, so makes more sense to keep it here? no strong opinions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I feel the same, fine to keep it here then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the PR is out of sync with the internal diff? Otherwise it looks good to me, stamping so you're unblocked
@KaiserLeave has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Context
What is the purpose of this PR? Is it to
With
skip_special_tokens=True
, the llama3 tokenizer still included the role name because it is part of the header but is not a special token. This PR filters out all tokens between start_header_id and end_header_id. Bonus: also exposeskip_special_tokens
in Phi3MiniTokenizer for consistency.Before:
After:
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example