Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the llama model #769

Merged
merged 6 commits into from
Sep 3, 2024
Merged

Fixed the llama model #769

merged 6 commits into from
Sep 3, 2024

Conversation

yiliu30
Copy link
Contributor

@yiliu30 yiliu30 commented Aug 28, 2024

In the training mode (model.setup_caches(..., training=True)) with input_pos is None, the freq_cis is overridden by L208.

if input_pos is None:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]
elif not self.linear_causal_mask:
mask = self.causal_mask[None, None, input_pos]
elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask
mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0)
else: # decode_one_token for linear causal mask
self.causal_mask[0,0,0,input_pos] = 1
mask = self.causal_mask
freqs_cis = self.freqs_cis[input_pos]

This PR attempts to fix this issue, and added some tests for the Llama model.

Test

 pytest  -sv ./test/test_ao_models.py 

Copy link

pytorch-bot bot commented Aug 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/769

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 64480e7 with merge base 05224a9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 28, 2024
@msaroufim msaroufim requested a review from HDCharles August 28, 2024 03:47
Signed-off-by: yiliu30 <[email protected]>
@yiliu30 yiliu30 mentioned this pull request Aug 28, 2024
5 tasks
@HDCharles
Copy link
Contributor

This is relevant only to the training path that @gau-nernst added recently. He would be the one to ask, i'll add him as a reviewer.

@HDCharles HDCharles requested a review from gau-nernst August 30, 2024 03:18
@gau-nernst
Copy link
Collaborator

gau-nernst commented Sep 2, 2024

Sorry for the delay! From what I see, this PR fixes the situation when model is used for inference and input_pos is None i.e.

model = Transformer.from_name(...)
model.setup_caches(training=False)
model(input_ids)

However, it seems like the model is not used this way anywhere in torchao? Or do you have something else planned in your mind?

Otherwise, the change looks fine to me.

Note: from what I know, here are the 2 places where this Llama model is used

  • For inference in torchao/_models/llama/generate.py, where input_pos is always passed explicitly
  • For training in benchmarks/quantized_training/pretrain_llama2.py (KV cache is not initialized), where input_pos is None -> no need input_pos since we don't have KV cache.

@yiliu30
Copy link
Contributor Author

yiliu30 commented Sep 3, 2024

Hi @gau-nernst, thanks for providing such detailed background. I'm using this Llama model for auto-round. The scenario is similar to the training case you mentioned above(not using KV cache and no input_pos). More details can be found here.
https://github.com/yiliu30/torchao-fork/blob/21686f1c87b2961ee0245740e2dcaa6e7fbc4f3a/torchao/_models/llama/generate.py#L245-L267

If the input_pos is None, I believe we should select self.freqs_cis[:idx.shape[1]] for freqs_cis instead of self.freqs_cis. This is what this PR fixed. I also added some UTs, though we may need more in the future to ensure that modifications to this model don't break the generation benchmark.

@gau-nernst
Copy link
Collaborator

self.freqs_cis[:idx.shape[1]] is already done in the latest main I think. Also evident in this PR diff.

if input_pos is None:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]

Given your use case, which also does not use KV cache, I don't think having input_pos = torch.arange(0, idx.shape[1], device=idx.device) is necessary?

I agree that having test is good to prevent regression. Since you are already working on this PR, can you also add a short docstring/comment in Llama's forward about how to use this model? i.e. for inference (w/ KV-cache + input_pos) and training (no KV-cache + no input_pos).

Comment on lines 23 to 24
random_model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len)
out = random_model(input_ids, input_pos)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will test the case "KV cache + input_pos". If this is not required, perhaps we can test for inference and training separately? i.e.

# inference case
random_model.setup_caches(training=False)
random_model(input_ids, input_pos)  # input_pos is not None

# training case
random_model.setup_caches(training=True)
random_model(input_ids)  # no input_pos

@yiliu30
Copy link
Contributor Author

yiliu30 commented Sep 3, 2024

self.freqs_cis[:idx.shape[1]] is already done in the latest main I think. Also evident in this PR diff.

if input_pos is None:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]

Yeah, but L208 update the freqs_cis again.

if input_pos is None:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]
elif not self.linear_causal_mask:
mask = self.causal_mask[None, None, input_pos]
elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask
mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0)
else: # decode_one_token for linear causal mask
self.causal_mask[0,0,0,input_pos] = 1
mask = self.causal_mask
freqs_cis = self.freqs_cis[input_pos]

@@ -197,15 +197,17 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:

if input_pos is None:
mask = None
input_pos = torch.arange(0, idx.shape[1], device=idx.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your reply. It makes sense to separate the logic for creating mask, input_pos, freq_cis like you did here. However, creating input_pos here (for training case) is now not needed?

Copy link
Contributor Author

@yiliu30 yiliu30 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also tested the inference case where input_pos is None. Is input_pos a required argument for model.forward in inference mode?

Copy link
Collaborator

@gau-nernst gau-nernst Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inference case where input_pos is None

This is what I mentioned earlier. I think right now there is no code that uses inference w/o input_pos, and your auto-round PR doesn't seem to need it also. I think it's fine to support this case (though we will have a tiny inefficiency -> create an unneeded input_pos during training. probably insignificant). Perhaps others can have other comments.

Can you update the PR description to describe the problem this PR fixes clearer? i.e. when input_pos is None (during training), freq_cis is overridden by line xxx. I will approve the PR once you added tests for training mode (model.setup_caches(training=True)) and add a short docstring/comment in model.forward().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the clarification! I updated the PR description, UTs and docstring. Please review it again.

@jerryzh168 jerryzh168 merged commit 8c18489 into pytorch:main Sep 3, 2024
17 checks passed
jerryzh168 pushed a commit to jerryzh168/ao that referenced this pull request Sep 4, 2024
* fixed input_pos is None

Signed-off-by: yiliu30 <[email protected]>

* add test

Signed-off-by: yiliu30 <[email protected]>

* update the test

Signed-off-by: yiliu30 <[email protected]>

* update the docstring

Signed-off-by: yiliu30 <[email protected]>

* update the docstring

Signed-off-by: yiliu30 <[email protected]>

---------

Signed-off-by: yiliu30 <[email protected]>
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants