Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch TP composable with torchao #2436

Merged
merged 1 commit into from
Dec 11, 2024

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Dec 10, 2024

Motivation

The TP styles in PyTorch trunk performs a scatter from rank 0 to distribute a full tensor into DTensors.
While this is safer for training, it is unnecessary for inference where all ranks' weight already come from the same checkpoint. Therefore, we choose to do a local sharding instead of scatter.

Modifications

This PR customizes the TP styles impls to use _shard_tensor instead of distribute_tensor.

Worked configs:

  • TP + int8wo
  • TP + fp8wo
$ export ENABLE_INTRA_NODE_COMM=1
$ python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --enable-torch-compile --torchao-config int8wo --tp 4

Output:

Benchmark ...
Prefill. latency: 0.05945 s, throughput:   2153.25 token/s
Decode.  latency: 0.00321 s, throughput:    311.98 token/s
Decode.  latency: 0.00299 s, throughput:    334.26 token/s
Decode.  latency: 0.00298 s, throughput:    335.68 token/s
Decode.  latency: 0.00297 s, throughput:    337.05 token/s
Decode.  latency: 0.00296 s, throughput:    337.33 token/s
Decode.  median latency: 0.00297 s, median throughput:    337.05 token/s
Total. latency:  0.080 s, throughput:   1690.04 token/s

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

Customize parallel style to perform local sharding instead of scatter

worked configs:
TP + int8wo, TP + fp8wo
@merrymercy merrymercy merged commit ece7249 into sgl-project:main Dec 11, 2024
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants