Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

confusion in ParallelEmebdding layer in model.py #298

Open
baihuaxie opened this issue Jan 17, 2025 · 0 comments
Open

confusion in ParallelEmebdding layer in model.py #298

baihuaxie opened this issue Jan 17, 2025 · 0 comments

Comments

@baihuaxie
Copy link

baihuaxie commented Jan 17, 2025

hi there,

in the ParallelEmebdding layer as defined in inference/model.py, in the following code:

if world_size > 1:
            mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
            x = x - self.vocab_start_idx
            x[mask] = 0

if I understand it correctly, in each rank the local emebdding matrix handles a portion of the token ids specified by the range [self.vocab_start_idx, self.vocab_end_idx), excluding rhs.

the line x = x - self.vocab_start_idx shifts the token ids to be aligned with the starting idx assigned to the current rank, so that it can be processed by the local embedding matrix (which starts from token id = 0). so for instance if the local rank handles token ids [100, 199], the token id 100 in the original input tensor x is now shifted to become 0.

the next line masks out the out-of-range token ids to also be 0.

what's confusing to me is that now how would I distinguish a token id = 0 in the input tensor x to be from 1) an originall out-of-range token id (for the current rank) or 2) an original valid token id which equals self.vocab_start_idx? it seems to me that these two lines of code have now shrinked the valid range of token ids handled by the current rank to become (self.vocab_start_idx, self.vocab_end_idx), both excluding.

I haven't run the code yet so this confusion might be due to some oversight on my part. in the given code world_size=1 is fixed so this code snippet never gets executed anyways. I'd appreciate if there could be some clarification or guidance on tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant