Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving GPT2 language model prediction #207

Merged
merged 11 commits into from
Mar 28, 2022
Merged

Improving GPT2 language model prediction #207

merged 11 commits into from
Mar 28, 2022

Conversation

sliu126
Copy link
Contributor

@sliu126 sliu126 commented Mar 9, 2022

Overview

A new beam-search based method for GPT2 character prediction. The language model can predict multiple worldpieces and marginalize the word-level prediction to make character level prediction.

Ticket

https://www.pivotaltracker.com/story/show/181349709

Contributions

  • Fixed the multiple spaces bug
  • Improved the overall (ranked-based) performance on the ALS phrase dataset
  • All letters in the alphabet (except for the backspace character) now have nonzero probability after interpolating with a unigram language model
  • Smoothed the language model distribution with an exponential rescaling factor

Test

  • Run all unit tests existed.

@sliu126 sliu126 requested review from lawhead and tab-cmd March 9, 2022 17:53
Copy link
Collaborator

@lawhead lawhead left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to see some additional unit tests added to demonstrate the improvements. For instance:

  • Fixed the multiple spaces bug; you could test that the probability for space character returned after a space has been typed is some small value (or at least smaller than the probability before; ex. "THE" vs "THE_").
  • All letters in the alphabet (except for the backspace character) now have nonzero probability after interpolating with a unigram language model; call predict for a word that previously returned 0 values and assert that all values are greater than 0.

'X': 0.0008, 'Z': 0.0005, 'Q': 0.0002, BACKSPACE_CHAR: 0.0}

# A uniform language model
self.uniform_lm = dict(zip(self.symbol_set, equally_probable(self.symbol_set, {BACKSPACE_CHAR: 0.0})))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where you're actually using self.uniform_lm anywhere. Maybe this is for debugging and can be removed? I also don't see any uses of the is_start_of_word attribute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have removed unused attributes.


def predict(self, evidence: List[str]) -> List[Tuple]:
def predict(self, evidence: List[str], beam_width: int = 20, search_depth: int = 2) -> List[Tuple]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These additional variables violate the API. They should be class attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I changed them to class attributes.


def __model_infer(self, text: str) -> List[float]:
def __rescale(self, lm: Dict[str, float], coeff: float):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a utility function in this module? It seems like it might be more generally useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Changed it to a static method.

@@ -81,84 +182,122 @@ def __get_char_predictions(self, word_prefix: str) -> List[tuple]:

return char_prob_tuples

def __build_vocab(self) -> Dict[int, str]:
def __interpolate_language_models(self, lm1: Dict[str, float], lm2: Dict[str, float], coeff: float) -> List[Tuple]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it could be a utility function as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also changed it to a static method.


# sort the new candidates based on likelihood and populate the beam
ordered_candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
beam = ordered_candidates[:beam_width]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to factor out the beam search algorithm from the specifics of how we're using it. This should be generalizable.

Copy link
Contributor Author

@sliu126 sliu126 Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I moved the beam search code to another method, self.__beam_search().

@sliu126
Copy link
Contributor Author

sliu126 commented Mar 16, 2022

@lawhead, I added two unit tests that address your comments.

Copy link
Collaborator

@lawhead lawhead left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the suggested changes. The predictions seem better than before, but I'm still surprised by how high the SPACE character is ranked when no letters have been typed (model.predict([])) and after a SPACE has just been typed (model.predict(list("THE_"))). This may be worth looking into further and writing some unit tests against.


# Hard coding a unigram language model trained on ALS phrase dataset
# for smoothing purpose
self.unigram_lm = {'E': 0.0998, 'T': 0.096, 'O': 0.0946, 'I': 0.0835,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this fix our symbol set? This may be alright given how we want to use it in the short term, but we should validate the same set was passed as you define here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This serves as a quick way to introduce a unigram language model for interpolating purpose. It does use the same symbol set as the default symbol set, alphabet(). I think for now this should be alright. If the symbol set is a parameter that would change, we probably need to train a unigram model on that symbol set on the fly, which may be undesirable. For now I have added a check to make sure that the unigram symbol set is the same as the one passed in.

'D': 0.0358, 'Y': 0.0324, 'W': 0.0288, 'M': 0.0266,
'G': 0.0221, 'C': 0.018, 'K': 0.016, 'P': 0.0145,
'F': 0.0117, 'B': 0.0113, 'V': 0.0091, 'J': 0.0016,
'X': 0.0008, 'Z': 0.0005, 'Q': 0.0002, BACKSPACE_CHAR: 0.0}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the impact of setting a zero on BACKSPACE_CHAR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the language model sets zero probability on BACKSPACE_CHAR, but the copy phrase code would add a nonzero probability to it (I think it is 0.05). I think we can leave it this way for now.

self.lm_path = lm_path or "gpt2"

# Hard coding a unigram language model trained on ALS phrase dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is specific to the population or phrases used, it would be better to load these weights or smooth parameters in some other way. This overfits this model to the current experiment phrases. How would this change if we used a new dataset. What is the training procedure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can think of this as the same as loading a pretrained GPT2 language model -- we can certainly put the weights in some other files and load from there, but we still need to train offline and update those files, which is not much different from what we are doing now. For now I think it's okay to keep it as it is.


# interpolate with unigram language model to smooth the probability distribution returned
# by GPT2 language model
next_char_pred = GPT2LanguageModel.interpolate_language_models(dict(next_char_pred), self.unigram_lm, 0.8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bring hardcoded coeffecients up to init for easier configuration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have done it.

@tab-cmd tab-cmd self-requested a review March 28, 2022 19:53
@sliu126 sliu126 merged commit 2f99c0d into 1.5.1 Mar 28, 2022
@sliu126 sliu126 deleted the gpt2_beam_search branch March 28, 2022 22:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants