-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving GPT2 language model prediction #207
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to see some additional unit tests added to demonstrate the improvements. For instance:
- Fixed the multiple spaces bug; you could test that the probability for space character returned after a space has been typed is some small value (or at least smaller than the probability before; ex. "THE" vs "THE_").
- All letters in the alphabet (except for the backspace character) now have nonzero probability after interpolating with a unigram language model; call
predict
for a word that previously returned 0 values and assert that all values are greater than 0.
bcipy/language/model/gpt2.py
Outdated
'X': 0.0008, 'Z': 0.0005, 'Q': 0.0002, BACKSPACE_CHAR: 0.0} | ||
|
||
# A uniform language model | ||
self.uniform_lm = dict(zip(self.symbol_set, equally_probable(self.symbol_set, {BACKSPACE_CHAR: 0.0}))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see where you're actually using self.uniform_lm
anywhere. Maybe this is for debugging and can be removed? I also don't see any uses of the is_start_of_word
attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have removed unused attributes.
bcipy/language/model/gpt2.py
Outdated
|
||
def predict(self, evidence: List[str]) -> List[Tuple]: | ||
def predict(self, evidence: List[str], beam_width: int = 20, search_depth: int = 2) -> List[Tuple]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These additional variables violate the API. They should be class attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. I changed them to class attributes.
bcipy/language/model/gpt2.py
Outdated
|
||
def __model_infer(self, text: str) -> List[float]: | ||
def __rescale(self, lm: Dict[str, float], coeff: float): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be a utility function in this module? It seems like it might be more generally useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Changed it to a static method.
bcipy/language/model/gpt2.py
Outdated
@@ -81,84 +182,122 @@ def __get_char_predictions(self, word_prefix: str) -> List[tuple]: | |||
|
|||
return char_prob_tuples | |||
|
|||
def __build_vocab(self) -> Dict[int, str]: | |||
def __interpolate_language_models(self, lm1: Dict[str, float], lm2: Dict[str, float], coeff: float) -> List[Tuple]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it could be a utility function as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also changed it to a static method.
bcipy/language/model/gpt2.py
Outdated
|
||
# sort the new candidates based on likelihood and populate the beam | ||
ordered_candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True) | ||
beam = ordered_candidates[:beam_width] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be useful to factor out the beam search algorithm from the specifics of how we're using it. This should be generalizable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I moved the beam search code to another method, self.__beam_search().
@lawhead, I added two unit tests that address your comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the suggested changes. The predictions seem better than before, but I'm still surprised by how high the SPACE
character is ranked when no letters have been typed (model.predict([])
) and after a SPACE
has just been typed (model.predict(list("THE_"))
). This may be worth looking into further and writing some unit tests against.
|
||
# Hard coding a unigram language model trained on ALS phrase dataset | ||
# for smoothing purpose | ||
self.unigram_lm = {'E': 0.0998, 'T': 0.096, 'O': 0.0946, 'I': 0.0835, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this fix our symbol set? This may be alright given how we want to use it in the short term, but we should validate the same set was passed as you define here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This serves as a quick way to introduce a unigram language model for interpolating purpose. It does use the same symbol set as the default symbol set, alphabet(). I think for now this should be alright. If the symbol set is a parameter that would change, we probably need to train a unigram model on that symbol set on the fly, which may be undesirable. For now I have added a check to make sure that the unigram symbol set is the same as the one passed in.
'D': 0.0358, 'Y': 0.0324, 'W': 0.0288, 'M': 0.0266, | ||
'G': 0.0221, 'C': 0.018, 'K': 0.016, 'P': 0.0145, | ||
'F': 0.0117, 'B': 0.0113, 'V': 0.0091, 'J': 0.0016, | ||
'X': 0.0008, 'Z': 0.0005, 'Q': 0.0002, BACKSPACE_CHAR: 0.0} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure the impact of setting a zero on BACKSPACE_CHAR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the language model sets zero probability on BACKSPACE_CHAR, but the copy phrase code would add a nonzero probability to it (I think it is 0.05). I think we can leave it this way for now.
self.lm_path = lm_path or "gpt2" | ||
|
||
# Hard coding a unigram language model trained on ALS phrase dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is specific to the population or phrases used, it would be better to load these weights or smooth parameters in some other way. This overfits this model to the current experiment phrases. How would this change if we used a new dataset. What is the training procedure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can think of this as the same as loading a pretrained GPT2 language model -- we can certainly put the weights in some other files and load from there, but we still need to train offline and update those files, which is not much different from what we are doing now. For now I think it's okay to keep it as it is.
bcipy/language/model/gpt2.py
Outdated
|
||
# interpolate with unigram language model to smooth the probability distribution returned | ||
# by GPT2 language model | ||
next_char_pred = GPT2LanguageModel.interpolate_language_models(dict(next_char_pred), self.unigram_lm, 0.8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bring hardcoded coeffecients up to init for easier configuration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have done it.
Overview
A new beam-search based method for GPT2 character prediction. The language model can predict multiple worldpieces and marginalize the word-level prediction to make character level prediction.
Ticket
https://www.pivotaltracker.com/story/show/181349709
Contributions
Test