Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add NJT/TD support for EC #2596

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

TroyGarden
Copy link
Contributor

Summary:

Documents

Context

  • Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
  • As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
  • Basically we can support TensorDict in both eager mode and distributed (sharded) mode: Input (Union[KJT, TD]) ==> EC ==> Output (KT)
  • In eager mode, we directly call td_to_kjt in the forward function to convert TD to KJT.
  • In distributed mode, we do the conversion inside the ShardedEmbeddingCollection, specifically in the input_dist, where the input sparse features are prepared (permuted) for the KJTAllToAll communication.
  • In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the KJTAllToAll communication.
    While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following KJTAllToAll communication.

Verification - input with TensorDict

(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
  • TD input
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
  • unsharded model
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
  • TD input
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))

Differential Revision: D66521351

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 26, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Nov 26, 2024
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Nov 27, 2024
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 4, 2024
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 15, 2024
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 30, 2024
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 31, 2024
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jan 5, 2025
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jan 7, 2025
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Reviewed By: dstaay-fb

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817}

# Context
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EBC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingBagCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication.
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
* ref: D63436011

# Details
* `td_to_kjt` implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results)
* Currently only support EBC use case
WARNING: `TensorDict` does **NOT** support weighted jagged tensor, **Nor** variable batch_size neither.
NOTE: All the following comparisons are between the **`KJT.permute`** in the KJT input scenario and the **`TD-KJT conversion`** in the TD input scenario.
* Both `KJT.permute` and `TD-KJT conversion` are correctly marked in the `TrainPipelineBase` traces
`TD-KJT conversion` has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%.
 {F1949366822}
* For the `Copy-Batch-To-GPU` part, TD has more fragmented `HtoD` comms while KJT has a single contiguous `HtoD` comm
Runtime-wise they are similar ~10%
 {F1949374305}
* In the most commonly used `TrainPipelineSparseDist`, where the `Copy-Batch-To-GPU` and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1%
{F1949390271}
```
  TrainPipelineSparseDist             | Runtime (P90): 26.737 s | Memory (P90): 34.801 GB (TD)
  TrainPipelineSparseDist             | Runtime (P90): 26.539 s | Memory (P90): 34.765 GB (KJT)
```
* increased data size, GPU runtime is 4x
{F1949386106}

# Conclusion
1. [Enablement] With this approach (replacing the `KJT permute` with `TD-KJT conversion`), the EBC can now take `TensorDict` as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch.
2. [Performance] The TD host-to-device data transfer might not necessarily be a concern/blocker for the most commonly used train pipeline (TrainPipelineSparseDist).
2. [Feature Support] In order to become production-ready, the TensorDict needs to (1) integrate the `KJT.weights` data, and (2) to support the variable batch size, which are almost used in all the production models.
3. [Improvement] There are two major operations we can improve: (1) move TensorDict from host to device, and (2) convert TD to KJT. Currently they are both in the vanilla state. Since we are not sure how the real traces would be like with production models, we can't tell if these improvements are needed/helpful.

Reviewed By: dstaay-fb

Differential Revision: D65103519
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Reviewed By: dstaay-fb

Differential Revision: D66521351
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66521351

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants