-
Notifications
You must be signed in to change notification settings - Fork 464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add NJT/TD support for EC #2596
base: main
Are you sure you want to change the base?
Conversation
This pull request was exported from Phabricator. Differential Revision: D66521351 |
6134492
to
9db8d75
Compare
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
This pull request was exported from Phabricator. Differential Revision: D66521351 |
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
9db8d75
to
79f5c5e
Compare
This pull request was exported from Phabricator. Differential Revision: D66521351 |
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
79f5c5e
to
24b17d8
Compare
This pull request was exported from Phabricator. Differential Revision: D66521351 |
24b17d8
to
1ad89fe
Compare
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
This pull request was exported from Phabricator. Differential Revision: D66521351 |
1ad89fe
to
f22736d
Compare
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
This pull request was exported from Phabricator. Differential Revision: D66521351 |
f22736d
to
5fb3864
Compare
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
This pull request was exported from Phabricator. Differential Revision: D66521351 |
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
5fb3864
to
5d16527
Compare
This pull request was exported from Phabricator. Differential Revision: D66521351 |
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Reviewed By: dstaay-fb Differential Revision: D66521351
5d16527
to
329a7b6
Compare
This pull request was exported from Phabricator. Differential Revision: D66521351 |
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EBC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingBagCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. * ref: D63436011 # Details * `td_to_kjt` implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results) * Currently only support EBC use case WARNING: `TensorDict` does **NOT** support weighted jagged tensor, **Nor** variable batch_size neither. NOTE: All the following comparisons are between the **`KJT.permute`** in the KJT input scenario and the **`TD-KJT conversion`** in the TD input scenario. * Both `KJT.permute` and `TD-KJT conversion` are correctly marked in the `TrainPipelineBase` traces `TD-KJT conversion` has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%. {F1949366822} * For the `Copy-Batch-To-GPU` part, TD has more fragmented `HtoD` comms while KJT has a single contiguous `HtoD` comm Runtime-wise they are similar ~10% {F1949374305} * In the most commonly used `TrainPipelineSparseDist`, where the `Copy-Batch-To-GPU` and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1% {F1949390271} ``` TrainPipelineSparseDist | Runtime (P90): 26.737 s | Memory (P90): 34.801 GB (TD) TrainPipelineSparseDist | Runtime (P90): 26.539 s | Memory (P90): 34.765 GB (KJT) ``` * increased data size, GPU runtime is 4x {F1949386106} # Conclusion 1. [Enablement] With this approach (replacing the `KJT permute` with `TD-KJT conversion`), the EBC can now take `TensorDict` as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch. 2. [Performance] The TD host-to-device data transfer might not necessarily be a concern/blocker for the most commonly used train pipeline (TrainPipelineSparseDist). 2. [Feature Support] In order to become production-ready, the TensorDict needs to (1) integrate the `KJT.weights` data, and (2) to support the variable batch size, which are almost used in all the production models. 3. [Improvement] There are two major operations we can improve: (1) move TensorDict from host to device, and (2) convert TD to KJT. Currently they are both in the vanilla state. Since we are not sure how the real traces would be like with production models, we can't tell if these improvements are needed/helpful. Reviewed By: dstaay-fb Differential Revision: D65103519
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Reviewed By: dstaay-fb Differential Revision: D66521351
329a7b6
to
8991b5b
Compare
This pull request was exported from Phabricator. Differential Revision: D66521351 |
Summary:
Documents
{F1949248817}
Context
Input (Union[KJT, TD]) ==> EC ==> Output (KT)
td_to_kjt
in the forward function to convert TD to KJT.ShardedEmbeddingCollection
, specifically in theinput_dist
, where the input sparse features are prepared (permuted) for theKJTAllToAll
communication.KJTAllToAll
communication.While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following
KJTAllToAll
communication.Verification - input with TensorDict
Differential Revision: D66521351