You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
in the ParallelEmebdding layer as defined in inference/model.py, in the following code:
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
if I understand it correctly, in each rank the local emebdding matrix handles a portion of the token ids specified by the range [self.vocab_start_idx, self.vocab_end_idx), excluding rhs.
the line x = x - self.vocab_start_idx shifts the token ids to be aligned with the starting idx assigned to the current rank, so that it can be processed by the local embedding matrix (which starts from token id = 0). so for instance if the local rank handles token ids [100, 199], the token id 100 in the original input tensor x is now shifted to become 0.
the next line masks out the out-of-range token ids to also be 0.
what's confusing to me is that now how would I distinguish a token id = 0 in the input tensor x to be from 1) an originall out-of-range token id (for the current rank) or 2) an original valid token id which equals self.vocab_start_idx? it seems to me that these two lines of code have now shrinked the valid range of token ids handled by the current rank to become (self.vocab_start_idx, self.vocab_end_idx), both excluding.
I haven't run the code yet so this confusion might be due to some oversight on my part. in the given code world_size=1 is fixed so this code snippet never gets executed anyways. I'd appreciate if there could be some clarification or guidance on tests.
The text was updated successfully, but these errors were encountered:
hi there,
in the ParallelEmebdding layer as defined in inference/model.py, in the following code:
if I understand it correctly, in each rank the local emebdding matrix handles a portion of the token ids specified by the range
[self.vocab_start_idx, self.vocab_end_idx)
, excluding rhs.the line
x = x - self.vocab_start_idx
shifts the token ids to be aligned with the starting idx assigned to the current rank, so that it can be processed by the local embedding matrix (which starts from token id = 0). so for instance if the local rank handles token ids [100, 199], the token id 100 in the original input tensor x is now shifted to become 0.the next line masks out the out-of-range token ids to also be 0.
what's confusing to me is that now how would I distinguish a token id = 0 in the input tensor x to be from 1) an originall out-of-range token id (for the current rank) or 2) an original valid token id which equals self.vocab_start_idx? it seems to me that these two lines of code have now shrinked the valid range of token ids handled by the current rank to become
(self.vocab_start_idx, self.vocab_end_idx)
, both excluding.I haven't run the code yet so this confusion might be due to some oversight on my part. in the given code world_size=1 is fixed so this code snippet never gets executed anyways. I'd appreciate if there could be some clarification or guidance on tests.
The text was updated successfully, but these errors were encountered: