Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix failed test_kjt_bucketize_before_all2all_cpu #2689

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

TroyGarden
Copy link
Contributor

Summary:

context

  • found a test failure from OSS test run: P1714445461
  • the issue is a recent change (D65912888) incorrectly calling the _fx_wrap_tensor_to_device_dtype function
        block_bucketize_pos=(
            _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
            if block_bucketize_row_pos is not None
            else None
        ),

where block_bucketize_row_pos: List[torch.tensor], but the function only accepts torch.Tensor

torch.fx.wrap
def _fx_wrap_tensor_to_device_dtype(
    t: torch.Tensor, tensor_device_dtype: torch.Tensor
) -> torch.Tensor:
    return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)
  • the fix is supposed to be straightforward to apply a list-comprehension over the function
        block_bucketize_pos=(
            [
                _fx_wrap_tensor_to_device_dtype(pos, kjt.lengths())  # <---- pay attention here, kjt.lengths()
                for pos in block_bucketize_row_pos
            ]
  • according to the previous comments, the block_bucketize_pos's dtype should be the same as kjt._length, however, it triggers the following error
    {F1974430883}
  • according to the operator implementation (codepointer), the block_bucketize_pos should have the same dtype as kjt._values.
    length has a type name of offset_t, values has a type name of index_t, the same as block_bucketize_pos.

Reviewed By: dstaay-fb

Differential Revision: D68358894

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 18, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68358894

Summary:

# context
* found a test failure from OSS [test run](https://github.com/pytorch/torchrec/actions/runs/12816026713/job/35736016089): P1714445461
* the issue is a recent change (D65912888) incorrectly calling the `_fx_wrap_tensor_to_device_dtype` function
```
        block_bucketize_pos=(
            _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
            if block_bucketize_row_pos is not None
            else None
        ),
```
where `block_bucketize_row_pos: List[torch.tensor]`, but the function only accepts torch.Tensor
```
torch.fx.wrap
def _fx_wrap_tensor_to_device_dtype(
    t: torch.Tensor, tensor_device_dtype: torch.Tensor
) -> torch.Tensor:
    return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)
```
* the fix is supposed to be straightforward to apply a list-comprehension over the function
```
        block_bucketize_pos=(
            [
                _fx_wrap_tensor_to_device_dtype(pos, kjt.lengths())  # <---- pay attention here, kjt.lengths()
                for pos in block_bucketize_row_pos
            ]
```
* according to the previous comments, the `block_bucketize_pos`'s `dtype` should be the same as `kjt._length`, however, it triggers the following error
 {F1974430883} 
* according to the operator implementation ([codepointer](https://fburl.com/code/9gyyl8h4)), the `block_bucketize_pos` should have the same dtype as `kjt._values`.
length has a type name of `offset_t`, values has a type name of `index_t`, the same as `block_bucketize_pos`.

Reviewed By: dstaay-fb

Differential Revision: D68358894
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68358894

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants