diff --git a/.github/workflows/experiment-runner.yml b/.github/workflows/experiment-runner.yml new file mode 100644 index 00000000000..5ccb8ad28ff --- /dev/null +++ b/.github/workflows/experiment-runner.yml @@ -0,0 +1,30 @@ +name: Experiment Runner + +on: + workflow_dispatch: + inputs: + script: + description: "Experiment Runner Script" + default: "configs/sharegpt_config.yaml" + +concurrency: + group: experiment-runner-${{ github.ref }} + cancel-in-progress: true + +jobs: + experiment-runner-1-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + bash scripts/ci_install_dependency.sh + + - name: Test experiment runner + timeout-minutes: 120 + run: | + cd test/srt + python3 experiment_runner.py --config ${{ inputs.script }} diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 92bd986a077..b9e8c5bcb6b 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -42,7 +42,7 @@ jobs: e2e-rust: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: 1-gpu-runner + runs-on: 2-gpu-runner steps: - name: Checkout code uses: actions/checkout@v3 @@ -57,7 +57,7 @@ jobs: cd rust pip install setuptools-rust wheel build python3 -m build - pip install dist/*.whl + pip install --force-reinstall dist/*.whl - name: Run e2e test run: | cd rust/py_test diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 59f0006e128..49c6ec88327 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -105,6 +105,12 @@ jobs: cd test/srt python3 test_update_weights_from_distributed.py + - name: Evaluate MoE EP accuracy (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_moe_ep.py + performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner diff --git a/.github/workflows/release-pypi-router.yml b/.github/workflows/release-pypi-router.yml index f762aad6ce2..fbbb1f0243b 100644 --- a/.github/workflows/release-pypi-router.yml +++ b/.github/workflows/release-pypi-router.yml @@ -69,9 +69,10 @@ jobs: with: path: sglang-repo - - name: Move rust folder to root and delete sglang-repo + - name: Move rust folder to root, copy the license file, and delete sglang-repo run: | mv sglang-repo/rust/* . + mv sglang-repo/LICENSE . rm -rf sglang-repo ls -alt diff --git a/README.md b/README.md index 45967ee5838..bc8734936cd 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ [**Join Bi-Weekly Development Meeting**](https://docs.google.com/document/d/1xEow4eIM152xNcRxqZz9VEcOiTQo8-CEuuQ5qTmkt-E/edit?usp=sharing) | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | ## News +- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). - [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). - [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). - [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). @@ -47,13 +48,13 @@ The core features include: - [Frontend: Structured Generation Language (SGLang)](https://sgl-project.github.io/frontend/frontend.html) ## Benchmark And Performance -Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/) +Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/) ## Roadmap [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI. +The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI. ## Acknowledgment and Citation We learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). diff --git a/benchmark/kernels/fused_moe_triton/README.md b/benchmark/kernels/fused_moe_triton/README.md index ba29ede5099..2a3e37f6874 100644 --- a/benchmark/kernels/fused_moe_triton/README.md +++ b/benchmark/kernels/fused_moe_triton/README.md @@ -10,7 +10,7 @@ Example usage: ```bash # Tune Qwen2-57B with FP8 and TP=4 python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ - --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --model Qwen/Qwen2-57B-A14B-Instruct \ --tp-size 4 \ --dtype fp8_w8a8 \ --tune @@ -34,7 +34,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri # Compare with FP8 mode for Qwen2-57B python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ - --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --model Qwen/Qwen2-57B-A14B-Instruct \ --use-fp8 # Compare with custom TP size @@ -43,3 +43,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri ``` The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). + +- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel. + +Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`. diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 1f54f9f9f49..1bd6eec1645 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -6,6 +6,7 @@ from transformers import AutoConfig from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton +from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config def get_model_config(model_name: str, tp_size: int): @@ -64,7 +65,7 @@ def fused_topk_native( return topk_weights, topk_ids -@torch.compile +@torch.compile(dynamic=False) def fused_moe_torch( x, w1, @@ -88,7 +89,8 @@ def fused_moe_torch( w13_weights = w1[topk_ids] w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = w2[topk_ids] - x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + x1 = F.silu(x1) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) @@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): print(f"benchmark {provider} with batch_size={batch_size}") torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) + set_torch_compile_config() num_tokens = batch_size num_experts = model_config["num_experts"] diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 8bd05289979..d79dabeb5fc 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -50,11 +50,11 @@ RUN curl -L https://github.com/clangd/clangd/releases/download/18.1.3/clangd-lin && rm -rf clangd_18.1.3 clangd.zip # Install CMake -RUN curl -L https://cmake.org/download/#:~:text=cmake%2D3.31.1%2Dlinux%2Dx86_64.tar.gz -o cmake.tar.gz \ - && tar -xzf cmake.tar.gz \ +RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1-linux-x86_64.tar.gz \ + && tar -xzf cmake-3.31.1-linux-x86_64.tar.gz \ && cp -r cmake-3.31.1-linux-x86_64/bin/* /usr/local/bin/ \ && cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \ - && rm -rf cmake-3.31.1-linux-x86_64 cmake.tar.gz + && rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz # Add yank script COPY --chown=root:root <<-"EOF" /usr/local/bin/yank diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 7df5e5fcf23..2c9af6e7b0d 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,5 +1,5 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.3.6.post3 -t v0.3.6.post3-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.0.post1 -t v0.4.0.post1-rocm620 -f Dockerfile.rocm . # default base image ARG BASE_IMAGE="rocm/vllm-dev:20241022" @@ -33,6 +33,7 @@ RUN python -m pip cache purge # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 +ENV SGLANG_SET_CPU_AFFINITY=1 ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 ENV NCCL_MIN_NCHANNELS=112 diff --git a/docs/Makefile b/docs/Makefile index 50f77a30c09..13d81f4f847 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -19,7 +19,7 @@ compile: echo "Executing $$nb"; \ jupyter nbconvert --to notebook --execute --inplace "$$nb" \ --ExecutePreprocessor.timeout=600 \ - --ExecutePreprocessor.kernel_name=python3; \ + --ExecutePreprocessor.kernel_name=python3 || exit 1; \ fi; \ done diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 7207259ea3c..26758f7f975 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -220,19 +220,19 @@ "metadata": {}, "outputs": [], "source": [ - "# failed update with different parameter size\n", + "# failed update with different parameter size or wrong name\n", "\n", "url = \"http://localhost:30010/update_weights_from_disk\"\n", - "data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n", + "data = {\"model_path\": \"meta-llama/Llama-3.2-1B-wrong\"}\n", "\n", "response = requests.post(url, json=data)\n", "response_json = response.json()\n", "print_highlight(response_json)\n", "assert response_json[\"success\"] is False\n", "assert response_json[\"message\"] == (\n", - " \"Failed to update weights: The size of tensor a (2048) must match \"\n", - " \"the size of tensor b (3072) at non-singleton dimension 1.\\n\"\n", - " \"Rolling back to original weights.\"\n", + " \"Failed to get weights iterator: \"\n", + " \"meta-llama/Llama-3.2-1B-wrong\"\n", + " \" (repository not found).\"\n", ")" ] }, diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index 3c73c0da0af..c82094f6de3 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.3.6.post3-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0.post1-rocm620 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.3.6.post3-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0.post1-rocm620 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/index.rst b/docs/index.rst index 873999d2587..8c6c018c4ce 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,6 +39,13 @@ The core features include: frontend/choices_methods.md +.. toctree:: + :maxdepth: 1 + :caption: SGLang Router + + router/router.md + + .. toctree:: :maxdepth: 1 :caption: References diff --git a/docs/references/contributor_guide.md b/docs/references/contributor_guide.md index a9b25163d12..550f267ab1a 100644 --- a/docs/references/contributor_guide.md +++ b/docs/references/contributor_guide.md @@ -1,5 +1,9 @@ # Contributor Guide +# Build SGLang + +See [Install SGLang, Method 2: From Source section](../start/install.md). + ## Format Your Code Use these commands to format your code and pass CI linting tests. diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index dbf4f71a021..bf1044f8498 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -80,3 +80,30 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla - Remove `Sample`. - Change `forward()` functions, and add `forward_batch`. - Add `EntryClass` at the end. + +### Registering an external model implementation + +In addition to the methods described above, you can also register your new model with the `ModelRegistry` before launching the server. This approach is useful if you want to integrate your model without needing to modify the source code. + +Here is how you can do it: + +```python +from sglang.srt.models.registry import ModelRegistry +from sglang.srt.server import launch_server + +# for a single model, you can add it to the registry +ModelRegistry.models[model_name] = model_class + +# for multiple models, you can imitate the import_model_classes() function in sglang/srt/models/registry.py +from functools import lru_cache + +@lru_cache() +def import_new_model_classes(): + model_arch_name_to_cls = {} + ... + return model_arch_name_to_cls + +ModelRegistry.models.update(import_new_model_classes()) + +launch_server(server_args) +``` diff --git a/docs/router/router.md b/docs/router/router.md new file mode 100644 index 00000000000..11ea8c59065 --- /dev/null +++ b/docs/router/router.md @@ -0,0 +1,110 @@ +# Router for Data Parallelism + +Given multiple GPUs running multiple SGLang Runtimes, SGLang Router distributes the requests to different Runtimes with its unique cache-aware load-balancing algorithm. + +The router is a independent Python package, and it can be used as a drop-in replacement for the OpenAI API. + +## Installation + +```bash +pip install sglang-router +``` + +Detailed usage of the router can be found in [launch_router](https://github.com/sgl-project/sglang/blob/main/rust/py_src/sglang_router/launch_router.py) and [launch_server](https://github.com/sgl-project/sglang/blob/main/rust/py_src/sglang/launch_server.py). Also, you can directly run the following command to see the usage of the router. + +```bash +python -m sglang_router.launch_server --help +python -m sglang_router.launch_routher --help +``` + +The router supports two working modes: + +1. Co-launch Router and Runtimes +2. Launch Runtimes and Router separately + +## Co-launch Router and Runtimes + +This will be a drop-in replacement for the existing `--dp-size` arguement of SGLang Runtime. Under the hood, it uses multi-processes to launch multiple workers, wait for them to be ready, then connect the router to all workers. + +```bash +python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dp-size 1 +``` + +After the server is ready, you can directly send requests to the router as the same way as sending requests to each single worker. + +```python +import requests + +url = "http://localhost:30000/generate" +data = {"text": "What is the capital of France?"} + +response = requests.post(url, json=data) +print(response.json()) +``` + +## Launch Runtimes and Router Separately + +This is useful for multi-node DP. First, launch workers on multiple nodes, then launch a router on the main node, and connect the router to all workers. + +```bash +python -m sglang_router.launch_router --worker-urls http://worker_url_1 http://worker_url_2 +``` + +## Strategies + +### Cache-Aware Load-Balancing Router + +The native router combines two strategies to optimize both cache utilization and request distribution: + +1. Cache-Aware Routing (Approximate Tree) +2. Load-Balancing Routing (Shortest Queue with Balance Thresholds) + +The router dynamically switches between these strategies based on load conditions: + +- Uses load balancing when the system is imbalanced +- Uses cache-aware routing when the system is balanced + +A system is considered imbalanced if both conditions are met: + +1. (max_load - min_load) > balance_abs_threshold +2. max_load > balance_rel_threshold * min_load + +***Cache-Aware Routing (Approximate Tree)*** + +When the workers are considered to be balanced, the router maintains an approximate radix tree for each worker based on request history, eliminating the need for direct cache state queries on each worker. The tree stores raw text characters instead of token IDs to avoid tokenization overhead. + +Process: + +1. For each request, find the worker with the highest prefix match. + + - If match rate > cache_threshold, route the request to the worker with highest match (likely has relevant data cached) + - If match rate ≤ cache_threshold, route the request to the worker with smallest tree size (most available cache capacity) + +2. Background maintenance: Periodically evict least recently used leaf nodes on the approximate tree to prevent memory overflow. + +***Load-Balancing (Shortest Queue)*** + +For unbalanced systems, this strategy tracks pending request counts per worker and routes new requests to the least busy worker. This helps maintain optimal load distribution across workers. + +## Configuration Parameters + +1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5) + - Minimum prefix match ratio to use highest-match routing. + - Below this threshold, the request will be routed to the worker with most available cache space. + +2. `balance_abs_threshold`: (integer, default: 32) + - Absolute difference threshold for load imbalance detection. + - The system is potentially imbalanced if (max_load - min_load) > abs_threshold. + +3. `balance_rel_threshold`: (float, default: 1.0001) + - Relative ratio threshold for load imbalance detection. + - The system is potentially imbalanced if max_load > min_load * rel_threshold. + - Used in conjunction with `balance_abs_threshold` to determine the final imbalance state. + +4. `eviction_interval`: (integer, default: 60) + - Interval in seconds between LRU eviction cycles for the approximate trees. + - Background thread periodically evicts least recently used nodes to maintain tree size. + +5. `max_tree_size`: (integer, default: 16777216) + - Maximum nodes on the approximate tree. + - When exceeded, LRU leaf nodes are evicted during the next eviction cycle. diff --git a/docs/start/install.md b/docs/start/install.md index 3f6a816412a..e9d3abc8e78 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -13,7 +13,7 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.3.6.post3 https://github.com/sgl-project/sglang.git +git clone -b v0.4.0.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -26,7 +26,7 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.3.6.post3 https://github.com/sgl-project/sglang.git +git clone -b v0.4.0.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -51,7 +51,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.3.6.post3 -t v0.3.6.post3-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.0.post1 -t v0.4.0.post1-rocm620 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -60,11 +60,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.3.6.post3-rocm620 \ + v0.4.0.post1-rocm620 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.3.6.post3-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.0.post1-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb b/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb index a54f92bd6aa..3c1b2a6c400 100644 --- a/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb +++ b/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb @@ -177,18 +177,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'The World Health Organization formally declared an end to the COVID-19 global health emergency'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "@trace\n", "def rag_pipeline(question: str) -> str:\n", @@ -307,18 +296,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'The World Health Organization formally declared an end to the COVID-19 global health emergency in May 2023.'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "@trace\n", "def rag_pipeline(question: str) -> str:\n", @@ -355,15 +333,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: nest-asyncio in /Users/joschkabraun/miniconda3/envs/sglang/lib/python3.10/site-packages (1.6.0)\r\n" - ] - } - ], + "outputs": [], "source": [ "!pip install nest-asyncio\n", "import nest_asyncio\n", @@ -382,45 +352,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Run name set to: sneak-weal, since a name was not provided.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 100/100 [00:27<00:00, 3.63it/s]\n", - "Waiting for evaluations to finish: 100%|██████████| 19/19 [00:10<00:00, 1.89it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Experiment RAG Run sneak-weal stats:\n", - "{\n", - " \"latency\": \"2.69\",\n", - " \"input_tokens\": \"61.26\",\n", - " \"output_tokens\": \"75.88\",\n", - " \"total_tokens\": \"137.14\",\n", - " \"cost\": \"0.00\",\n", - " \"answer_context_faithfulness_statement_level\": \"0.26\",\n", - " \"answer_matches_target_llm_grader\": \"0.22\",\n", - " \"context_query_relevancy\": \"0.27\",\n", - " \"percent_target_supported_by_context\": \"0.40\"\n", - "}\n", - "\n", - "\n", - "View experiment & traces at: https://app.parea.ai/experiments/RAG/30f0244a-d56c-44ff-bdfb-8f47626304b6\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "e = p.experiment(\n", " \"RAG\",\n", diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index 7404c7e4e7f..724051eab53 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -1,7 +1,13 @@ +import argparse +import dataclasses + import sglang as sgl +from sglang.srt.server_args import ServerArgs -def main(): +def main( + server_args: ServerArgs, +): # Sample prompts. prompts = [ "Hello, my name is", @@ -13,7 +19,7 @@ def main(): sampling_params = {"temperature": 0.8, "top_p": 0.95} # Create an LLM. - llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + llm = sgl.Engine(**dataclasses.asdict(server_args)) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -25,4 +31,8 @@ def main(): # The __main__ condition is necessary here because we use "spawn" to create subprocesses # Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + main(server_args) diff --git a/python/pyproject.toml b/python/pyproject.toml index 1ecfc4fa50a..8e935528e21 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.3.6.post3" +version = "0.4.0.post1" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" @@ -13,7 +13,7 @@ classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] -dependencies = ["requests", "tqdm", "numpy", "IPython"] +dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] [project.optional-dependencies] runtime_common = ["aiohttp", "decord", "fastapi", @@ -22,8 +22,8 @@ runtime_common = ["aiohttp", "decord", "fastapi", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", - "xgrammar>=0.1.4"] -srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1", "cuda-python", "flashinfer>=0.1.6"] + "xgrammar>=0.1.6"] +srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 40fbf17bc9d..de9134857a6 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -66,7 +66,7 @@ __all__ += ["__version__"] -# SGL Backends +# SGLang Backends from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.utils import LazyImport diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 3eca72de4aa..96e8677bb60 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -163,7 +163,6 @@ async def async_request_openai_completions( "max_tokens": request_func_input.output_len, "stream": not args.disable_stream, "ignore_eos": not args.disable_ignore_eos, - "lora_path": request_func_input.lora_name, **request_func_input.extra_request_body, } headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} @@ -322,6 +321,8 @@ async def async_request_sglang_generate( }, "stream": not args.disable_stream, "lora_path": request_func_input.lora_name, + "return_logprob": args.return_logprob, + "logprob_start_len": -1, **request_func_input.extra_request_body, } headers = {} @@ -912,7 +913,7 @@ async def limited_request_func(request_func_input, pbar): prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, - output_len=test_output_len, + output_len=min(test_output_len, 32), lora_name=lora_name, extra_request_body=extra_request_body, ) @@ -1414,6 +1415,11 @@ def set_ulimit(target_soft_limit=65535): action="store_true", help="Disable ignoring EOS.", ) + parser.add_argument( + "--return-logprob", + action="store_true", + help="Return logprob.", + ) parser.add_argument( "--extra-request-body", metavar='{"key1": "value1", "key2": "value2"}', diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 26c476a0599..4820d473959 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -42,6 +42,7 @@ def __init__( self.guide = guide self.jump_forward_map = jump_forward_map self.state = 0 + self.finished = False def accept_token(self, token: int): self.state = self.guide.get_next_state(self.state, token) @@ -84,6 +85,10 @@ def allocate_vocab_mask( ) -> torch.Tensor: return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: tokens = torch.tensor( self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 1bcc51c6468..ee8e8eb07f4 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -45,6 +45,7 @@ def __init__( self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx + self.finished = False def accept_token(self, token: int): assert self.matcher.accept_token(token) @@ -85,12 +86,11 @@ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(vocab_mask, idx) @staticmethod - def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - if vocab_mask.device.type != logits.device.type: - # vocab_mask must then be on the same device as logits - # when applying the token bitmask, so we check and move if needed - vocab_mask = vocab_mask.to(logits.device) + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask.to(device, non_blocking=True) + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: apply_token_bitmask_inplace(logits, vocab_mask) def copy(self): diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index f5d573f5f7b..a70e9537bfe 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -52,12 +52,13 @@ def forward( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run forward on an attention layer.""" if forward_batch.forward_mode.is_decode(): - return self.forward_decode(q, k, v, layer, forward_batch) + return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache) else: - return self.forward_extend(q, k, v, layer, forward_batch) + return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache) def forward_decode( self, @@ -66,6 +67,7 @@ def forward_decode( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run a forward for decode.""" raise NotImplementedError() @@ -77,6 +79,7 @@ def forward_extend( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run a forward for extend.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 73c32df8f6e..856aa984c38 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -165,7 +165,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 1 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: @@ -181,9 +187,10 @@ def forward_extend( .expand(k.shape[0], -1, -1), ) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v, k_label - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v, k_label + ) ( start_loc, @@ -212,7 +219,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -242,9 +255,10 @@ def forward_decode( .expand(k.shape[0], -1, -1), ) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v, k_label - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v, k_label + ) # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num # and set a minimum value for sparse_decode diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 258659efa2a..536358fbc94 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -221,7 +221,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 0 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): prefill_wrapper_paged = self.prefill_wrappers_paged[ self._get_wrapper_idx(layer) @@ -237,7 +243,8 @@ def forward_extend( if not use_ragged: if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -270,12 +277,19 @@ def forward_extend( o, _ = merge_state(o1, s1, o2, s2) - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] cache_loc = ( @@ -286,7 +300,8 @@ def forward_decode( if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -663,6 +678,7 @@ def call_begin_forward( self.num_qo_heads, self.num_kv_heads, self.head_dim, + q_data_type=self.q_data_type, ) # cached part @@ -676,6 +692,7 @@ def call_begin_forward( self.num_kv_heads, self.head_dim, 1, + q_data_type=self.q_data_type, ) diff --git a/python/sglang/srt/layers/attention/torch_native_backend.py b/python/sglang/srt/layers/attention/torch_native_backend.py index 4ccad2216f7..5e7e0e66e22 100644 --- a/python/sglang/srt/layers/attention/torch_native_backend.py +++ b/python/sglang/srt/layers/attention/torch_native_backend.py @@ -216,16 +216,23 @@ def _run_sdpa_forward_decode( return output def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num @@ -249,7 +256,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -260,9 +273,10 @@ def forward_decode( else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index b9597b3ea41..1a539ebd75c 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -35,10 +35,8 @@ def __init__(self, model_runner: ModelRunner): model_runner.model_config.num_attention_heads // model_runner.tp_size ) - if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): - self.reduce_dtype = torch.float32 - else: - self.reduce_dtype = torch.float16 + self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] self.forward_metadata = None @@ -50,23 +48,23 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" if forward_batch.forward_mode.is_decode(): - start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) - start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - - total_num_tokens = forward_batch.seq_lens_sum attn_logits = torch.empty( - (self.num_head, total_num_tokens), - dtype=self.reduce_dtype, + ( + forward_batch.batch_size, + self.num_head, + self.num_kv_splits, + self.v_head_dim + 1, + ), + dtype=torch.float32, device=self.device, ) - max_seq_len = torch.max(forward_batch.seq_lens).item() max_extend_len = None else: - start_loc = attn_logits = max_seq_len = None + attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len + self.forward_metadata = attn_logits, max_extend_len def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -75,11 +73,8 @@ def init_cuda_graph_state(self, max_bs: int): (max_bs,), dtype=torch.int32, device=self.device ) self.cuda_graph_attn_logits = torch.empty( - ( - self.num_head, - self.cuda_graph_max_total_num_tokens, - ), - dtype=self.reduce_dtype, + (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), + dtype=torch.float32, device="cuda", ) @@ -92,9 +87,7 @@ def init_forward_metadata_capture_cuda_graph( ): # NOTE: encoder_lens expected to be zeros or None self.forward_metadata = ( - self.cuda_graph_start_loc, self.cuda_graph_attn_logits, - self.cuda_graph_max_seq_len, None, ) @@ -114,7 +107,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 1 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: @@ -122,11 +121,12 @@ def forward_extend( else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + _, max_extend_len = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -146,7 +146,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -158,11 +164,12 @@ def forward_decode( else: o = torch.empty_like(q) - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + attn_logits, _ = self.forward_metadata - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), @@ -171,10 +178,9 @@ def forward_decode( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.req_to_token_pool.req_to_token, forward_batch.req_pool_indices, - start_loc, forward_batch.seq_lens, attn_logits, - max_seq_len, + self.num_kv_splits, layer.scaling, layer.logit_cap, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 56d38693f4f..d2e856ca605 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -17,8 +17,11 @@ """ # Adapted from -# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py -# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +import logging + import triton import triton.language as tl @@ -26,6 +29,13 @@ is_hip_ = is_hip() +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. +logger.warn( + "The following error message 'operation scheduled before its operands' can be ignored." +) + @triton.jit def tanh(x): @@ -37,10 +47,10 @@ def tanh(x): def _fwd_kernel_stage1( Q, K_Buffer, + V_Buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, @@ -48,152 +58,136 @@ def _fwd_kernel_stage1( stride_qh, stride_buf_kbs, stride_buf_kh, - att_stride_h, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, - SPLIT_K: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - split_k_id = tl.program_id(2) + split_kv_id = tl.program_id(2) - reduce_dtype = Att_Out.dtype.element_ty cur_kv_head = cur_head // kv_group_num offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - q = tl.load(Q + off_q).to(reduce_dtype) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) - split_k_start = kv_len_per_split * split_k_id - split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_k_start, split_k_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, - mask=offs_n < split_k_end, - other=0, - ) - offs_buf_k = ( - k_loc[:, None] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[None, :] - ) - k = tl.load( - K_Buffer + offs_buf_k, - mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk), - other=0.0, - ).to(reduce_dtype) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - - if logit_cap > 0: - att_value = logit_cap * tanh(att_value / logit_cap) + q = tl.load(Q + off_q, mask=mask_d, other=0.0) - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) - tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end) + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale -@triton.jit -def _fwd_kernel_stage2( - logits, - V_Buffer, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, - stride_buf_vbs, - stride_buf_vh, - stride_obs, - stride_oh, - stride_req_to_token_b, - kv_group_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - Lv: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) - cur_kv_head = cur_head // kv_group_num + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) - offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] - v_ptrs = V_Buffer + offs_buf_v + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (start_n + offs_n), - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0, + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv ) - qk = tl.load( - logits - + cur_head * stride_logic_h - + (cur_batch_start_loc + start_n + offs_n), - mask=start_n + offs_n < cur_batch_seq_len, - other=float("-inf"), + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), ) - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load( - v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv ) - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) - e_max = n_e_max - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_d < Lv)) + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) def _decode_att_m_fwd( q, k_buffer, + v_buffer, att_out, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ): - BLOCK = 32 - SPLIT_K = 8 + BLOCK = 64 + NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, SPLIT_K) + grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] if kv_group_num == 1: @@ -202,14 +196,15 @@ def _decode_att_m_fwd( num_warps = 2 BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) _fwd_kernel_stage1[grid]( q, k_buffer, + v_buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), @@ -217,56 +212,20 @@ def _decode_att_m_fwd( q.stride(1), k_buffer.stride(0), k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), att_out.stride(0), + att_out.stride(1), + att_out.stride(2), kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, - SPLIT_K=SPLIT_K, + NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, num_warps=num_warps, - num_stages=1, + num_stages=2, Lk=Lk, - ) - - -def _decode_softmax_reducev_fwd( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, -): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logits.shape[0] - grid = (batch, head, 1) - kv_group_num = logits.shape[0] // v_buffer.shape[1] - - num_warps = 1 - - Lv = v_buffer.shape[-1] - BLOCK_DMODEL = triton.next_power_of_2(Lv) - - _fwd_kernel_stage2[grid]( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - logits.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), - o.stride(0), - o.stride(1), - req_to_tokens.stride(0), - kv_group_num=kv_group_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, Lv=Lv, ) @@ -275,10 +234,10 @@ def _decode_softmax_reducev_fwd( def _fwd_grouped_kernel_stage1( Q, K_Buffer, + V_Buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, @@ -286,23 +245,27 @@ def _fwd_grouped_kernel_stage1( stride_qh, stride_buf_kbs, stride_buf_kh, - att_stride_h, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, - SPLIT_K: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - split_k_id = tl.program_id(2) - - reduce_dtype = Att_Out.dtype.element_ty + split_kv_id = tl.program_id(2) if BLOCK_H < kv_group_num: VALID_BLOCK_H: tl.constexpr = BLOCK_H @@ -313,171 +276,135 @@ def _fwd_grouped_kernel_stage1( mask_h = mask_h & (cur_head < q_head_num) offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] - q = tl.load( - Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0 - ).to(reduce_dtype) + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk off_qpe = ( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) - qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) - split_k_start = kv_len_per_split * split_k_id - split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_k_start, split_k_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, - mask=offs_n < split_k_end, - other=0, + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 ) - offs_buf_k = ( - k_loc[None, :] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[:, None] - ) - k = tl.load( - K_Buffer + offs_buf_k, - mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk), - other=0.0, - ).to(reduce_dtype) - qk = tl.dot(q, k) - if BLOCK_DPE > 0: - offs_buf_kpe = ( - k_loc[None, :] * stride_buf_kbs + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh - + offs_dpe[:, None] + + offs_d[:, None] ) - kpe = tl.load( - K_Buffer + offs_buf_kpe, - mask=offs_n[None, :] < split_k_end, + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, - ).to(reduce_dtype) - qk += tl.dot(qpe, kpe) - qk *= sm_scale - - if logit_cap > 0: - qk = logit_cap * tanh(qk / logit_cap) - - offs_o = cur_head[:, None] * att_stride_h + ( - cur_batch_in_all_start_index + offs_n[None, :] - ) - - tl.store( - Att_Out + offs_o, - qk, - mask=mask_h[:, None] & (offs_n[None, :] < split_k_end), - ) - - -@triton.jit -def _fwd_grouped_kernel_stage2( - logits, - V_Buffer, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, - stride_buf_vbs, - stride_buf_vh, - stride_obs, - stride_oh, - stride_req_to_token_b, - kv_group_num: tl.constexpr, - q_head_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_H: tl.constexpr, - Lv: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head_id = tl.program_id(1) - cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - - if BLOCK_H < kv_group_num: - VALID_BLOCK_H: tl.constexpr = BLOCK_H - else: - VALID_BLOCK_H: tl.constexpr = kv_group_num - cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) - mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H - mask_h = mask_h & (cur_head < q_head_num) + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) - offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] - v_ptrs = V_Buffer + offs_buf_v + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max - e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") - e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) - acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (start_n + offs_n), - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0, + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] ) - offs_qk = cur_head[:, None] * stride_logic_h + ( - cur_batch_start_loc + start_n + offs_n[None, :] + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), ) - qk = tl.load( - logits + offs_qk, - mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), - other=float("-inf"), + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv ) - n_e_max = tl.maximum(tl.max(qk, 1), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max[:, None]) - e_sum = e_sum * old_scale + tl.sum(p, 1) - v = tl.load( - v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, ) - p = p.to(v.dtype) - acc = acc * old_scale[:, None] + tl.dot(p, v) - e_max = n_e_max - - acc = acc / e_sum[:, None] - off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) def _decode_grouped_att_m_fwd( q, k_buffer, + v_buffer, att_out, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ): - BLOCK = 64 + BLOCK = 32 Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] if Lk == 576: BLOCK_DMODEL = 512 @@ -488,20 +415,19 @@ def _decode_grouped_att_m_fwd( else: BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) batch, head_num = B_req_idx.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] - BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) - SPLIT_K = 8 + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits grid = ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), - SPLIT_K, + NUM_KV_SPLITS, ) - num_warps = 4 - extra_kargs = {} if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html @@ -511,10 +437,10 @@ def _decode_grouped_att_m_fwd( _fwd_grouped_kernel_stage1[grid]( q, k_buffer, + v_buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), @@ -522,41 +448,88 @@ def _decode_grouped_att_m_fwd( q.stride(1), k_buffer.stride(0), k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), att_out.stride(0), + att_out.stride(1), + att_out.stride(2), kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, - SPLIT_K=SPLIT_K, + NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, - num_warps=num_warps, - num_stages=1, + num_warps=4, + num_stages=2, Lk=Lk, + Lv=Lv, **extra_kargs, ) -def _decode_grouped_softmax_reducev_fwd( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + O, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, ): - BLOCK = 128 - batch, head_num = b_seq_len.shape[0], logits.shape[0] - kv_group_num = logits.shape[0] // v_buffer.shape[1] - BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) - grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) - num_warps = 8 + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] - BLOCK_DMODEL = triton.next_power_of_2(Lv) + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits extra_kargs = {} if is_hip_: @@ -564,28 +537,20 @@ def _decode_grouped_softmax_reducev_fwd( # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} - _fwd_grouped_kernel_stage2[grid]( + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( logits, - v_buffer, o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, logits.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), + logits.stride(1), + logits.stride(2), o.stride(0), o.stride(1), - req_to_tokens.stride(0), - kv_group_num=kv_group_num, - q_head_num=head_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, Lv=Lv, - num_warps=num_warps, - num_stages=1, + num_warps=4, + num_stages=2, **extra_kargs, ) @@ -597,34 +562,25 @@ def decode_attention_fwd_normal( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): _decode_att_m_fwd( q, k_buffer, + v_buffer, attn_logits, req_to_token, b_req_idx, - b_start_loc, b_seq_len, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd( - attn_logits, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) def decode_attention_fwd_grouped( @@ -634,34 +590,25 @@ def decode_attention_fwd_grouped( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): _decode_grouped_att_m_fwd( q, k_buffer, + v_buffer, attn_logits, req_to_token, b_req_idx, - b_start_loc, b_seq_len, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) - _decode_grouped_softmax_reducev_fwd( - attn_logits, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) def decode_attention_fwd( @@ -671,13 +618,13 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): + assert num_kv_splits == attn_logits.shape[2] kv_group_num = q.shape[1] // v_buffer.shape[1] if kv_group_num == 1: @@ -689,10 +636,9 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) @@ -705,10 +651,9 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 56cc439c31e..b7afd62e723 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -284,6 +284,9 @@ def extend_attention_fwd( elif Lq == 288: BLOCK_DMODEL = 256 BLOCK_DPE = 32 + elif Lq == 192: + BLOCK_DMODEL = 128 + BLOCK_DPE = 64 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 diff --git a/python/sglang/srt/layers/ep_moe/__init__.py b/python/sglang/srt/layers/ep_moe/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/ep_moe/kernels.py new file mode 100644 index 00000000000..e0486891aa7 --- /dev/null +++ b/python/sglang/srt/layers/ep_moe/kernels.py @@ -0,0 +1,349 @@ +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + + +@triton.jit +def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): + expert = tl.program_id(0) + low = 0 + high = num_toks - 1 + target_location = -1 + while low <= high: + mid = (low + high) // 2 + + if tl.load(reorder_topk_ids + mid) > expert: + high = mid - 1 + else: + low = mid + 1 + target_location = mid + tl.store(seg_indptr + expert + 1, target_location + 1) + + +@triton.jit +def compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + tl.store(src2dst + src_id, dst_id, mask=mask) + + +def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): + reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) + seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) + + compute_seg_indptr_triton_kernel[(num_experts,)]( + reorder_topk_ids, seg_indptr, topk_ids.numel() + ) + + BLOCK_SIZE = 512 + grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) + compute_src2dst_triton_kernel[grid]( + reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE + ) + return reorder_topk_ids, src2dst, seg_indptr + + +@triton.jit +def pre_reorder_triton_kernel( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + OutDtype = gateup_input_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr = input_ptr + src_idx * hidden_size + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + if a1_scales_ptr is not None: + scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) + else: + scale = 1.0 + + dst_idx = tl.load(src2dst_ptr + idx) + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) + out_data = (in_data * scale).to(OutDtype) + tl.store(dst_ptr + offset, out_data, mask=mask) + + +@triton.jit +def silu_and_mul_triton_kernel( + gateup_output, + down_input, + hidden_size, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + BLOCK_SIZE: tl.constexpr, +): + InDtype = gateup_output.dtype.element_ty + OutDtype = down_input.dtype.element_ty + + half_hidden_size = hidden_size // 2 + + pid = tl.program_id(0) + expert_id = tl.load(reorder_topk_ids + pid) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + gateup_output_ptr = gateup_output + pid * hidden_size + gate_output_ptr = gateup_output_ptr + up_output_ptr = gateup_output_ptr + half_hidden_size + down_input_ptr = down_input + pid * half_hidden_size + + if scales is not None: + scale = tl.load(scales + expert_id - start_expert_id) + scale = (1 / scale).to(InDtype) + else: + scale = 1 + + for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < half_hidden_size + + gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) + up_output = tl.load(up_output_ptr + offset, mask=mask) + + # silu & mul & quantize + gate_output = gate_output * tl.sigmoid(gate_output) + gate_output = gate_output.to(InDtype) + + silu_mul_output = gate_output * up_output * scale + silu_mul_output = silu_mul_output.to(OutDtype) + tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) + + +@triton.jit +def post_reorder_triton_kernel( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + InDtype = down_output_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk + + computed = False + store_ptr = output_ptr + src_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + computed = True + dst_idx = tl.load(src2dst_ptr + idx) + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) + + if computed == False: + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + tl.store( + store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask + ) + + +@triton.jit +def compute_m_range( + pid, + batch_size, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + BLOCK_SIZE_M: tl.constexpr, +): + idx = 0 + for bs in range(batch_size): + tiles = tl.load(m_num_tiles_indptr + bs) + if pid >= tiles: + idx = bs + + idx_start = tl.load(m_num_tiles_indptr + idx) + + m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M + m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) + expert_id = tl.load(weight_indices + idx) + return m_range_start, m_range_end, expert_id + + +@triton.jit +def grouped_gemm_triton_kernel( + a, + b, + c, + batch_size, + N, + K, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + use_fp8_w8a8, + scale_a, + scale_b, + a_stride_0: tl.constexpr, + b_stride_0: tl.constexpr, + b_stride_1: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + c_dtype = c.dtype.element_ty + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + total_m_block = tl.load(m_num_tiles_indptr + batch_size) + if pid_m >= total_m_block: + return + + m_range_start, m_range_end, expert_id = compute_m_range( + pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M + ) + if m_range_end - m_range_start == 0: + return + + n_range_start = pid_n * BLOCK_SIZE_N + n_range_end = min(n_range_start + BLOCK_SIZE_N, N) + + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) + offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] + b_ptr = b + ( + (expert_id * b_stride_0) + + (n_range_start + offs_bn[:, None]) * b_stride_1 + + offs_k[None, :] + ) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_tile = tl.load( + a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + b_tile = tl.load( + b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + accumulator = tl.dot(a_tile, b_tile.T, accumulator) + a_ptr += BLOCK_SIZE_K + b_ptr += BLOCK_SIZE_K + + if use_fp8_w8a8: + scale_a_value = tl.load(scale_a + expert_id) + scale_b_value = tl.load(scale_b + expert_id) + accumulator *= scale_a_value * scale_b_value + c_tile = accumulator.to(c_dtype) + + offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) + c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] + c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) + tl.store(c_ptr, c_tile, mask=c_mask) + + +@triton.jit +def compute_m_num_tiles_indptr( + m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr +): + for bs in range(batch_size): + m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) + cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) + pre_num_tiles = tl.load(m_num_tiles_indptr + bs) + tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) + + +def grouped_gemm_triton( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, +): + assert weight_column_major == True # TODO: more + if use_fp8_w8a8: + assert scale_a is not None and scale_b is not None + + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + } + + m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) + compute_m_num_tiles_indptr[(1,)]( + m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] + ) + + grid = lambda META: ( + triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, + triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), + ) + + grouped_gemm_triton_kernel[grid]( + a, + b, + c, + batch_size, + b.size(1), + b.size(2), + seg_indptr, + weight_indices, + m_num_tiles_indptr, + use_fp8_w8a8, + scale_a, + scale_b, + a.stride(0), + b.stride(0), + b.stride(1), + **config, + ) + return c diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py new file mode 100644 index 00000000000..eca119845a7 --- /dev/null +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -0,0 +1,661 @@ +import logging +from typing import Callable, List, Optional, Tuple + +import torch +from torch.nn import Module +from vllm import _custom_ops as ops +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod + +from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.layers.ep_moe.kernels import ( + grouped_gemm_triton, + post_reorder_triton_kernel, + pre_reorder_triton_kernel, + run_moe_ep_preproess, + silu_and_mul_triton_kernel, +) +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk +from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.utils import is_hip, set_weight_attrs + +logger = logging.getLogger(__name__) + + +class GroupedGemmRunner(torch.nn.Module): + flashinfer_gemm_warpper = None + + def __init__(self, device, use_flashinfer: bool = False): + super().__init__() + self.device = device + self.use_flashinfer = use_flashinfer + if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None: + GroupedGemmRunner._init_flashinfer_wrapper(device) + + @classmethod + def _init_flashinfer_wrapper(cls, device): + from flashinfer import SegmentGEMMWrapper + + workspace_buffer = torch.empty( + 128 * 1024 * 1024, dtype=torch.int8, device=device + ) + cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer) + + # c = a * b + def forward( + self, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, + ): + if self.use_flashinfer: + # TODO: flashinfer + assert False + assert GroupedGemmRunner.flashinfer_gemm_warpper is not None + c = GroupedGemmRunner.flashinfer_gemm_warpper.run( + x=a, + weights=b, + batch_size=batch_size, + weight_column_major=weight_column_major, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + ) + else: + assert weight_column_major == True + c = grouped_gemm_triton( + a, + b, + c, + batch_size, + weight_column_major, + seg_indptr, + weight_indices, + use_fp8_w8a8, + scale_a, + scale_b, + ) + return c + + +class EPMoE(torch.nn.Module): + """ + MoE Expert Parallel Impl + + + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + self.tp_rank = get_tensor_model_parallel_rank() + + self.num_experts = num_experts + assert self.num_experts % self.tp_size == 0 + self.num_experts_per_partition = self.num_experts // self.tp_size + self.start_expert_id = self.tp_rank * self.num_experts_per_partition + self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 + + self.top_k = top_k + self.intermediate_size = intermediate_size + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() + self.use_fp8_w8a8 = False + self.activation_scheme = None + else: + self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( + quant_config + ) + self.use_fp8_w8a8 = True + self.fp8_dtype = torch.float8_e4m3fn + self.activation_scheme = quant_config.activation_scheme + + self.quant_method.create_weights( + layer=self, + num_experts_per_partition=self.num_experts_per_partition, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + self.grouped_gemm_runner = None + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + + if self.grouped_gemm_runner is None: + self.grouped_gemm_runner = GroupedGemmRunner( + hidden_states.device, use_flashinfer=False # TODO: use flashinfer + ) + + topk_weights, topk_ids = self.select_experts( + hidden_states, + router_logits, + self.top_k, + self.renormalize, + self.topk_group, + self.num_expert_group, + ) + + reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( + topk_ids, self.num_experts + ) + + gateup_input = torch.empty( + (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), + device=hidden_states.device, + dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + ) + if self.activation_scheme == "dynamic": + max_value = ( + torch.max(hidden_states) + .repeat(self.num_experts_per_partition) + .to(torch.float32) + ) + self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + + # PreReorder + pre_reorder_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + src2dst, + topk_ids, + self.w13_input_scale, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + weight_indices_cur_rank = torch.arange( + 0, + self.num_experts_per_partition, + device=hidden_states.device, + dtype=torch.int64, + ) + # GroupGemm-0 + gateup_output = torch.empty( + gateup_input.shape[0], + self.w13_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + gateup_output = self.grouped_gemm_runner( + a=gateup_input, + b=self.w13_weight, + c=gateup_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=self.use_fp8_w8a8, + scale_a=self.w13_input_scale, + scale_b=self.w13_weight_scale, + ) + + # Act + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + ) + if self.w2_input_scale is None: + self.w2_input_scale = torch.ones( + self.num_experts_per_partition, + dtype=torch.float32, + device=hidden_states.device, + ) + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + self.start_expert_id, + self.end_expert_id, + BLOCK_SIZE=512, + ) + + # GroupGemm-1 + down_output = torch.empty( + down_input.shape[0], + self.w2_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + down_output = self.grouped_gemm_runner( + a=down_input, + b=self.w2_weight, + c=down_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=self.use_fp8_w8a8, + scale_a=self.w2_input_scale, + scale_b=self.w2_weight_scale, + ) + + # PostReorder + output = torch.empty_like(hidden_states) + post_reorder_triton_kernel[(hidden_states.size(0),)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states.size(1), + BLOCK_SIZE=512, + ) + return output + + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + ): + if self.use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids.to(torch.int32) + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "experts.w2_" + ), + f"experts.{expert_id}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + if expert_id < self.start_expert_id or expert_id > self.end_expert_id: + return + expert_id = expert_id - self.start_expert_id + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError( + f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." + ) + + # Special case for fp8 scales. + if "scale" in weight_name: + self._load_fp8_scale( + param.data, loaded_weight, weight_name, shard_id, expert_id + ) + return + + expert_data = param.data[expert_id] + if shard_id == "w2": + param.data[expert_id] = loaded_weight + elif shard_id == "w1": + param.data[expert_id][: self.intermediate_size, :] = loaded_weight + elif shard_id == "w3": + param.data[expert_id][self.intermediate_size :, :] = loaded_weight + else: + raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") + + def _load_fp8_scale( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if ( + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}" + ) + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight + + +@register_custom_op("sglang_unquantized_ep_moe") +class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + def create_weights( + self, + layer: torch.nn.Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) + w13_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class Fp8EPMoEMethod(Fp8MoEMethod): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update({"quant_method": "tensor"}) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If rocm, use float8_e4m3fnuz as dtype + fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts_per_partition, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) + + for expert in range(layer.num_experts_per_partition): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/python/sglang/srt/layers/fused_moe_patch.py b/python/sglang/srt/layers/fused_moe_patch.py index 400ca03c434..baca2581150 100644 --- a/python/sglang/srt/layers/fused_moe_patch.py +++ b/python/sglang/srt/layers/fused_moe_patch.py @@ -105,20 +105,29 @@ def fused_moe_forward_native( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - assert custom_routing_function is None - topk_weights, topk_ids = select_experts_native( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - ) + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + top_k, + renormalize, + num_expert_group, + topk_group, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + x, router_logits, top_k, renormalize + ) + w13_weights = layer.w13_weight[topk_ids] w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = layer.w2_weight[topk_ids] - x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + x1 = F.silu(x1) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 00000000000..283ffd8ff1d --- /dev/null +++ b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 00000000000..8a18afe7d6d --- /dev/null +++ b/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py index 4f92512b2d5..e6ce9cb4d39 100644 --- a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe_triton/fused_moe.py @@ -16,6 +16,7 @@ from sglang.srt.utils import direct_register_custom_op, get_device_name logger = logging.getLogger(__name__) +padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @triton.jit @@ -58,6 +59,7 @@ def fused_moe_kernel( compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + even_Ks: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -143,12 +145,21 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -254,7 +265,9 @@ def invoke_fused_moe_kernel( assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + padded_size = 0 if use_fp8_w8a8: + padded_size = padding_size A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None elif use_int8_w8a16: @@ -268,6 +281,12 @@ def invoke_fused_moe_kernel( * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + K = B.shape[2] - padded_size + if K % config["BLOCK_SIZE_K"] == 0: + even_Ks = True + else: + even_Ks = False + fused_moe_kernel[grid]( A, B, @@ -279,7 +298,7 @@ def invoke_fused_moe_kernel( expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - padded_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -296,6 +315,7 @@ def invoke_fused_moe_kernel( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + even_Ks=even_Ks, **config, ) @@ -351,20 +371,39 @@ def get_default_config( dtype: Optional[str], is_marlin: bool, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8": config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4, } + if M <= E: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -645,8 +684,12 @@ def fused_experts_impl( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ): + padded_size = padding_size + if not use_fp8_w8a8: + padded_size = 0 + # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -668,7 +711,7 @@ def fused_experts_impl( get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - w2.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), topk_ids.shape[1], config_dtype, ) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 274c4c311ec..915cb47d271 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -39,10 +39,12 @@ class LogitsProcessorOutput: # The logprobs of input tokens. shape: [#token, vocab_size] input_token_logprobs: torch.Tensor = None - # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - input_top_logprobs: List = None - # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - output_top_logprobs: List = None + # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] + input_top_logprobs_val: List = None + input_top_logprobs_idx: List = None + # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] + output_top_logprobs_val: List = None + output_top_logprobs_idx: List = None @dataclasses.dataclass @@ -125,12 +127,15 @@ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata indices = ret.indices.tolist() if logits_metadata.forward_mode.is_decode(): - output_top_logprobs = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] for i, k in enumerate(logits_metadata.top_logprobs_nums): - output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) - return None, output_top_logprobs + output_top_logprobs_val.append(values[i][:k]) + output_top_logprobs_idx.append(indices[i][:k]) + return None, None, output_top_logprobs_val, output_top_logprobs_idx else: - input_top_logprobs, output_top_logprobs = [], [] + input_top_logprobs_val, input_top_logprobs_idx = [], [] + output_top_logprobs_val, output_top_logprobs_idx = [], [] pt = 0 for k, pruned_len in zip( @@ -138,27 +143,36 @@ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata logits_metadata.extend_logprob_pruned_lens_cpu, ): if pruned_len <= 0: - input_top_logprobs.append([]) - output_top_logprobs.append([]) + input_top_logprobs_val.append([]) + input_top_logprobs_idx.append([]) + output_top_logprobs_val.append([]) + output_top_logprobs_idx.append([]) continue - input_top_logprobs.append( - [ - list(zip(values[pt + j][:k], indices[pt + j][:k])) - for j in range(pruned_len - 1) - ] + input_top_logprobs_val.append( + [values[pt + j][:k] for j in range(pruned_len - 1)] ) - output_top_logprobs.append( + input_top_logprobs_idx.append( + [indices[pt + j][:k] for j in range(pruned_len - 1)] + ) + output_top_logprobs_val.append( + list( + values[pt + pruned_len - 1][:k], + ) + ) + output_top_logprobs_idx.append( list( - zip( - values[pt + pruned_len - 1][:k], - indices[pt + pruned_len - 1][:k], - ) + indices[pt + pruned_len - 1][:k], ) ) pt += pruned_len - return input_top_logprobs, output_top_logprobs + return ( + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + ) def forward( self, @@ -193,29 +207,22 @@ def forward( if not logits_metadata.return_logprob: return LogitsProcessorOutput( next_token_logits=last_logits, - next_token_logprobs=None, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=None, ) else: last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) if logits_metadata.forward_mode.is_decode(): if logits_metadata.return_top_logprob: - output_top_logprobs = self.get_top_logprobs( - last_logprobs, logits_metadata - )[1] + output_top_logprobs_val, output_top_logprobs_idx = ( + self.get_top_logprobs(last_logprobs, logits_metadata)[2:4] + ) else: - output_top_logprobs = None + output_top_logprobs_val = output_top_logprobs_idx = None return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=output_top_logprobs, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, ) else: # Slice the requested tokens to compute logprob @@ -246,11 +253,16 @@ def forward( # Get the logprob of top-k tokens if logits_metadata.return_top_logprob: - input_top_logprobs, output_top_logprobs = self.get_top_logprobs( - all_logprobs, logits_metadata - ) + ( + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + ) = self.get_top_logprobs(all_logprobs, logits_metadata) else: - input_top_logprobs = output_top_logprobs = None + input_top_logprobs_val = input_top_logprobs_idx = ( + output_top_logprobs_val + ) = output_top_logprobs_idx = None # Compute the normalized logprobs for the requested tokens. # Note that we pad a zero at the end for easy batching. @@ -273,8 +285,10 @@ def forward( next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, input_token_logprobs=input_token_logprobs, - input_top_logprobs=input_top_logprobs, - output_top_logprobs=output_top_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, ) def _get_logits( diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index a1bacdce036..48b733fdb78 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config -from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig @@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8 import Fp8Config QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -53,60 +53,16 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] -def fp8_moe_apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, -) -> torch.Tensor: - """Enhanced apply method for FP8 MoE.""" - from sglang.srt.layers.fused_moe_triton import FusedMoE - from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts - - # Expert selection - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - ) - - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - def fp8_get_quant_method(self, layer, prefix): """Enhanced get_quant_method for FP8 config.""" from vllm.model_executor.layers.linear import LinearBase - from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, ) from sglang.srt.layers.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.linear import UnquantizedLinearMethod + from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): @@ -117,10 +73,43 @@ def fp8_get_quant_method(self, layer, prefix): return None +def gptq_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod, + GPTQMarlinMoEMethod, + ) + + from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, LinearBase): + return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) + return None + + +def awq_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinLinearMethod, + AWQMoEMethod, + ) + + from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, LinearBase): + return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) + return None + + def apply_monkey_patches(): """Apply all monkey patches in one place.""" - setattr(Fp8MoEMethod, "apply", fp8_moe_apply) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) + setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) + setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) # Apply patches when module is imported diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py new file mode 100644 index 00000000000..c5a254b547e --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -0,0 +1,607 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py + +import logging +import os +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn import Module +from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + apply_fp8_linear, + convert_to_channelwise, + cutlass_fp8_supported, + per_tensor_dequantize, + requantize_with_max_scale, +) +from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter + +from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size +from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.utils import ( + get_bool_env_var, + is_hip, + print_warning_once, + set_weight_attrs, +) + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = logging.getLogger(__name__) + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning( + "Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "fp8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + from sglang.srt.layers.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") + # Disable marlin for ROCm + if is_hip(): + self.use_marlin = False + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) + + # INPUT ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + else: + layer.register_parameter("input_scale", None) + + def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + assert weight_scale.numel() == 1 + weight_scale = convert_to_channelwise( + weight_scale.expand(len(layer.logical_widths)), layer.logical_widths + ) + + # Update the layer with the new values. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + # If checkpoint is fp8, handle that there are N scales for N + # shards in a fused module + else: + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False + ) + if self.quant_config.activation_scheme == "static": + layer.input_scale = torch.nn.Parameter( + layer.input_scale.data, requires_grad=False + ) + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + weight = layer.weight + weight_scale = convert_to_channelwise( + layer.weight_scale, layer.logical_widths + ) + + # If using w8a8, torch._scaled_mm needs per tensor, so + # requantize the logical shards as a single weight. + else: + # Dequant -> Quant with max scale so we can run per tensor. + weight = layer.weight + weight_scale = layer.weight_scale + + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip(): + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=weight_scale, + input_scale=layer.input_scale, + ) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=layer.logical_widths, + ) + + # Update layer with new values. + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = Parameter( + layer.input_scale.max(), requires_grad=False + ) + + if self.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=False, + ) + + +class Fp8MoEMethod: + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16 or bfloat16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If ROCm, use float8_e4m3fnuz instead (MI300x HW) + fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, dtype=torch.float32, device=w13_weight.device + ), + requires_grad=False, + ) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + + # If ROCm, apply weight padding (min. Mem channel contention) only if set + if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + from sglang.srt.layers.fused_moe_triton import FusedMoE + from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts + + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + ) + + # Expert fusion with FP8 quantization + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + +class Fp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py new file mode 100644 index 00000000000..3ba381a373f --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -0,0 +1,27 @@ +from typing import Optional, Tuple + +import torch + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 5d8c6470178..1df29ec68a9 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -48,11 +48,13 @@ def __init__( self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention - def forward(self, q, k, v, forward_batch: ForwardBatch): + def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True): if k is not None: # For cross-layer sharing, kv can be None assert v is not None k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim) - return forward_batch.attn_backend.forward(q, k, v, self, forward_batch) + return forward_batch.attn_backend.forward( + q, k, v, self, forward_batch, save_kv_cache + ) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index d7db6036ca9..b0dfda3e882 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch( probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) sampled_index = torch.multinomial(probs_sort, num_samples=1) + # int32 range is enough to represent the token ids + probs_idx = probs_idx.to(torch.int32) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) return batch_next_token_ids diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 9395cdf271b..910309da973 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -2,23 +2,24 @@ Common utilities for torchao. """ -from typing import Dict, Set - import torch -def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): - """Quantize a Tensor with torchao quantization specified by torchao_config +def apply_torchao_config_to_model( + model: torch.nn.Module, torchao_config: str, filter_fn=None +): + """Quantize a modelwith torchao quantization specified by torchao_config Args: - `param`: weight parameter of the linear module - `torchao_config`: type of quantization and their arguments we want to use to - quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size + `model`: a model to be quantized based on torchao_config + `torchao_config` (str): type of quantization and their arguments we want to use to + quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size 128 """ # Lazy import to suppress some warnings from torchao.quantization import ( float8_dynamic_activation_float8_weight, + float8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, @@ -26,12 +27,17 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): ) from torchao.quantization.observer import PerRow, PerTensor - dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) - dummy_linear.weight = param - if "int8wo" in torchao_config: - quantize_(dummy_linear, int8_weight_only()) + if filter_fn is None: + + def filter_fn(module, fqn): + return "proj" in fqn + + if torchao_config == "" or torchao_config is None: + return model + elif "int8wo" in torchao_config: + quantize_(model, int8_weight_only(), filter_fn=filter_fn) elif "int8dq" in torchao_config: - quantize_(dummy_linear, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn) elif "int4wo" in torchao_config: group_size = int(torchao_config.split("-")[-1]) assert group_size in [ @@ -40,13 +46,11 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): 128, 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" - quantize_(dummy_linear, int4_weight_only(group_size=group_size)) + quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) elif "fp8wo" in torchao_config: - from torchao.quantization import float8_weight_only - # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - quantize_(dummy_linear, float8_weight_only()) + quantize_(model, float8_weight_only(), filter_fn=filter_fn) elif "fp8dq" in torchao_config: granularity = torchao_config.split("-")[-1] GRANULARITY_MAP = { @@ -57,39 +61,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): granularity in GRANULARITY_MAP ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}" quantize_( - dummy_linear, + model, float8_dynamic_activation_float8_weight( granularity=GRANULARITY_MAP[granularity] ), + filter_fn=filter_fn, ) else: raise ValueError(f"Unexpected config: {torchao_config}") - return dummy_linear.weight - - -def apply_torchao_config_( - self: torch.nn.Module, - params_dict: Dict[str, torch.Tensor], - param_suffixes: Set[str], -) -> None: - """A util function used for quantizing the weight parameters after they are loaded if - self.torchao_config is specified - - Args: - `self`: the model we want to quantize - `params_dict`: dictionary mapping from param_name to the parameter Tensor - `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes - - Returns: - None, the `params_dict` is modified inplace and the weights of `self` model are quantized - """ - if self.torchao_config: - for param_suffix in param_suffixes: - for name in params_dict: - param = params_dict[name] - if param_suffix in name and param.ndim == 2: - params_dict[name] = torchao_quantize_param_data( - param, self.torchao_config - ) - self.load_state_dict(params_dict, assign=True) + return model diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index e74ba5026c1..bc9e4a53b5c 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -17,9 +17,10 @@ import logging import signal from collections import OrderedDict -from typing import List, Union +from typing import Dict, List, Union import psutil +import setproctitle import zmq from sglang.srt.hf_transformers_utils import get_tokenizer @@ -75,17 +76,25 @@ def __init__( self.decode_status = LimitedCapacityDict() - def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim): - if no_stop_trim: + def trim_matched_stop( + self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool + ): + if no_stop_trim or not finished_reason: + return output + + matched = finished_reason.get("matched", None) + if not matched: return output - # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit - if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str): - pos = output.find(finished_reason.matched) + # TODO(lmzheng): handle the case where multiple stop strs are hit + + # Trim stop str. + if isinstance(matched, str) and isinstance(output, str): + pos = output.find(matched) return output[:pos] if pos != -1 else output - if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance( - output, list - ): + + # Trim stop token. + if isinstance(matched, int) and isinstance(output, list): assert len(output) > 0 return output[:-1] return output @@ -124,9 +133,9 @@ def event_loop(self): s.decode_ids = recv_obj.decode_ids[i] read_ids.append( - self.trim_eos( + self.trim_matched_stop( s.decode_ids[s.surr_offset :], - recv_obj.finished_reason[i], + recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) ) @@ -149,7 +158,7 @@ def event_loop(self): for i in range(bs): s = self.decode_status[recv_obj.rids[i]] new_text = read_texts[i][len(surr_texts[i]) :] - if recv_obj.finished_reason[i] is None: + if recv_obj.finished_reasons[i] is None: # Streaming chunk: update the decode status if len(new_text) > 0 and not new_text.endswith("�"): s.decoded_text = s.decoded_text + new_text @@ -160,9 +169,9 @@ def event_loop(self): new_text = find_printable_text(new_text) output_strs.append( - self.trim_eos( + self.trim_matched_stop( s.decoded_text + new_text, - recv_obj.finished_reason[i], + recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) ) @@ -170,9 +179,20 @@ def event_loop(self): self.send_to_tokenizer.send_pyobj( BatchStrOut( rids=recv_obj.rids, + finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, - meta_info=recv_obj.meta_info, - finished_reason=recv_obj.finished_reason, + prompt_tokens=recv_obj.prompt_tokens, + completion_tokens=recv_obj.completion_tokens, + cached_tokens=recv_obj.cached_tokens, + input_token_logprobs_val=recv_obj.input_token_logprobs_val, + input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, + output_token_logprobs_val=recv_obj.output_token_logprobs_val, + output_token_logprobs_idx=recv_obj.output_token_logprobs_idx, + input_top_logprobs_val=recv_obj.input_top_logprobs_val, + input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, + output_top_logprobs_val=recv_obj.output_top_logprobs_val, + output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + normalized_prompt_logprob=recv_obj.normalized_prompt_logprob, ) ) @@ -194,6 +214,7 @@ def run_detokenizer_process( server_args: ServerArgs, port_args: PortArgs, ): + setproctitle.setproctitle("sglang::detokenizer") configure_logger(server_args) parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 27bf5a4bdb1..c5884b5f0f6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput: class BatchTokenIDOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[BaseFinishReason] + # For incremental decoding # The version id to sync decode status with in detokenizer_manager vids: List[int] decoded_texts: List[str] @@ -315,35 +318,61 @@ class BatchTokenIDOut: read_offsets: List[int] # Only used when `--skip-tokenizer-init` output_ids: Optional[List[int]] + # Detokenization configs skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] - meta_info: List[Dict] - finished_reason: List[BaseFinishReason] no_stop_trim: List[bool] + # Token counts + prompt_tokens: List[int] + completion_tokens: List[int] + cached_tokens: List[int] + # Logprobs + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + input_top_logprobs_val: List[List] + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] + normalized_prompt_logprob: List[float] @dataclass class BatchStrOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[dict] # The output decoded strings output_strs: List[str] - # The meta info - meta_info: List[Dict] - # The finish reason - finished_reason: List[BaseFinishReason] + + # Token counts + prompt_tokens: List[int] + completion_tokens: List[int] + cached_tokens: List[int] + # Logprobs + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + input_top_logprobs_val: List[List] + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] + normalized_prompt_logprob: List[float] @dataclass class BatchEmbeddingOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[BaseFinishReason] # The output embedding embeddings: List[List[float]] - # The meta info - meta_info: List[Dict] - # The finish reason - finished_reason: List[BaseFinishReason] + # Token counts + prompt_tokens: List[int] @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 301ef4bb74b..5068c79361c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -58,6 +58,7 @@ "torchao_config": ServerArgs.torchao_config, "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, + "enable_ep_moe": ServerArgs.enable_ep_moe, } @@ -200,6 +201,9 @@ def __init__( origin_input_text: str, origin_input_ids: Tuple[int], sampling_params: SamplingParams, + return_logprob: bool = False, + top_logprobs_num: int = 0, + stream: bool = False, origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, @@ -217,10 +221,11 @@ def __init__( self.output_ids = [] # Each decode stage's output ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.session_id = session_id + self.input_embeds = input_embeds + # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path - self.input_embeds = input_embeds # Memory pool info self.req_pool_idx = None @@ -228,8 +233,8 @@ def __init__( # Check finish self.tokenizer = None self.finished_reason = None - self.stream = False self.to_abort = False + self.stream = stream # For incremental decoding # ----- | --------- read_ids -------| @@ -241,13 +246,9 @@ def __init__( # 2: read_offset # 3: last token self.vid = 0 # version id to sync decode status with in detokenizer_manager - self.decoded_text = "" self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None - - # The number of decoded tokens for token usage report. Note that - # this does not include the jump forward tokens. - self.completion_tokens_wo_jump_forward = 0 + self.decoded_text = "" # For multimodal inputs self.image_inputs: Optional[ImageInputs] = None @@ -256,22 +257,34 @@ def __init__( self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None + + # Chunked prefill self.is_being_chunked = 0 # For retraction self.is_retracted = False # Logprobs (arguments) - self.return_logprob = False + self.return_logprob = return_logprob self.logprob_start_len = 0 - self.top_logprobs_num = 0 + self.top_logprobs_num = top_logprobs_num # Logprobs (return value) self.normalized_prompt_logprob = None - self.input_token_logprobs = None - self.input_top_logprobs = None - self.output_token_logprobs = [] - self.output_top_logprobs = [] + self.input_token_logprobs_val = None + self.input_token_logprobs_idx = None + self.input_top_logprobs_val = None + self.input_top_logprobs_idx = None + + if return_logprob: + self.output_token_logprobs_val = [] + self.output_token_logprobs_idx = [] + self.output_top_logprobs_val = [] + self.output_top_logprobs_idx = [] + else: + self.output_token_logprobs_val = self.output_token_logprobs_idx = ( + self.output_top_logprobs_val + ) = self.output_top_logprobs_idx = None # Logprobs (internal values) # The tokens is prefilled but need to be considered as decode tokens @@ -295,8 +308,8 @@ def extend_image_inputs(self, image_inputs): else: self.image_inputs.merge(image_inputs) - # whether request reached finished condition def finished(self) -> bool: + # Whether request reached finished condition return self.finished_reason is not None def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): @@ -454,8 +467,10 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): k = k + 1 else: break - self.output_token_logprobs = self.output_token_logprobs[:k] - self.output_top_logprobs = self.output_top_logprobs[:k] + self.output_token_logprobs_val = self.output_token_logprobs_val[:k] + self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k] + self.output_top_logprobs_val = self.output_top_logprobs_val[:k] + self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k] self.logprob_start_len = prompt_tokens + k self.last_update_decode_tokens = len(self.output_ids) - k diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3714f19b633..4ece8786878 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -25,6 +25,7 @@ from typing import List, Optional import psutil +import setproctitle import torch import zmq @@ -114,9 +115,6 @@ def __init__( self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics - # Session info - self.sessions = {} - # Init inter-process communication context = zmq.Context(2) @@ -259,6 +257,10 @@ def __init__( self.num_generated_tokens = 0 self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval + self.current_stream = torch.get_device_module(self.device).current_stream() + + # Session info + self.sessions = {} # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size @@ -356,6 +358,7 @@ def __init__( ) def watchdog_thread(self): + """A watch dog thread that will try to kill the server itself if one batch takes too long.""" self.watchdog_last_forward_ct = 0 self.watchdog_last_time = time.time() @@ -433,61 +436,6 @@ def event_loop_overlap(self): self.last_batch = batch - def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): - # Check if other DP workers have running batches - if local_batch is None: - num_tokens = 0 - elif local_batch.forward_mode.is_decode(): - num_tokens = local_batch.batch_size() - else: - num_tokens = local_batch.extend_num_tokens - - local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) - global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) - torch.distributed.all_gather_into_tensor( - global_num_tokens, - local_num_tokens, - group=self.tp_cpu_group, - ) - - if local_batch is None and global_num_tokens.max().item() > 0: - local_batch = self.get_idle_batch() - - if local_batch is not None: - local_batch.global_num_tokens = global_num_tokens.tolist() - - # Check forward mode for cuda graph - if not self.server_args.disable_cuda_graph: - forward_mode_state = torch.tensor( - ( - 1 - if local_batch.forward_mode.is_decode() - or local_batch.forward_mode.is_idle() - else 0 - ), - dtype=torch.int32, - ) - torch.distributed.all_reduce( - forward_mode_state, - op=torch.distributed.ReduceOp.MIN, - group=self.tp_cpu_group, - ) - local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1 - - return local_batch - - def get_idle_batch(self): - idle_batch = ScheduleBatch.init_new( - [], - self.req_to_token_pool, - self.token_to_kv_pool, - self.tree_cache, - self.model_config, - self.enable_overlap, - ) - idle_batch.prepare_for_idle() - return idle_batch - def recv_requests(self): if self.tp_rank == 0 or self.server_args.enable_dp_attention: recv_reqs = [] @@ -567,6 +515,9 @@ def handle_generate_request( recv_req.input_text, recv_req.input_ids, recv_req.sampling_params, + return_logprob=recv_req.return_logprob, + top_logprobs_num=recv_req.top_logprobs_num, + stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, ) @@ -610,9 +561,6 @@ def handle_generate_request( return # Copy more attributes - req.return_logprob = recv_req.return_logprob - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream req.logprob_start_len = recv_req.logprob_start_len if req.logprob_start_len == -1: @@ -993,10 +941,11 @@ def process_batch_result(self, batch: ScheduleBatch, result): self.process_batch_result_prefill(batch, result) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() def process_batch_result_prefill(self, batch: ScheduleBatch, result): + skip_stream_req = None if self.is_generation: logits_output, next_token_ids, bid = result @@ -1033,7 +982,6 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): continue if req.is_being_chunked <= 0: - req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -1049,13 +997,18 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): if req.grammar is not None: req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() else: # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 + # There is only at most one request being currently chunked. + # Because this request does not finish prefill, + # we don't want to stream the request currently being chunked. + skip_stream_req = req if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model @@ -1081,7 +1034,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 - self.stream_output(batch.reqs) + self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids, bid = result @@ -1111,7 +1064,6 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue - req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -1119,21 +1071,26 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.tree_cache.cache_finished_req(req) if req.return_logprob: - req.output_token_logprobs.append( - (next_token_logprobs[i], next_token_id) - ) + req.output_token_logprobs_val.append(next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_id) if req.top_logprobs_num > 0: - req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + req.output_top_logprobs_val.append( + logits_output.output_top_logprobs_val[i] + ) + req.output_top_logprobs_idx.append( + logits_output.output_top_logprobs_idx[i] + ) if req.grammar is not None: req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() - torch.cuda.current_stream().synchronize() + self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - self.stream_output(batch.reqs) + self.stream_output(batch.reqs, batch.return_logprob) self.token_to_kv_pool.free_group_end() @@ -1153,9 +1110,8 @@ def add_logprob_return_values( output: LogitsProcessorOutput, ): """Attach logprobs to the return values.""" - req.output_token_logprobs.append( - (output.next_token_logprobs[i], next_token_ids[i]) - ) + req.output_token_logprobs_val.append(output.next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_ids[i]) # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len @@ -1163,170 +1119,250 @@ def add_logprob_return_values( if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - if req.input_token_logprobs is None: - input_token_logprobs = output.input_token_logprobs[ + if req.input_token_logprobs_val is None: + input_token_logprobs_val = output.input_token_logprobs[ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens ] - input_token_ids = req.fill_ids[ + + input_token_logprobs_idx = req.fill_ids[ len(req.fill_ids) - num_input_logprobs + 1 : len(req.fill_ids) - req.last_update_decode_tokens ] - # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. - input_token_ids = [ + input_token_logprobs_idx = [ x if x < self.model_config.vocab_size - 1 else 0 - for x in input_token_ids + for x in input_token_logprobs_idx ] - req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids)) - if ( req.logprob_start_len == 0 ): # The first token does not have logprob, pad it. - req.input_token_logprobs = [ - (None, req.fill_ids[0]) - ] + req.input_token_logprobs + input_token_logprobs_val = [None] + input_token_logprobs_val + input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx + + req.input_token_logprobs_val = input_token_logprobs_val + req.input_token_logprobs_idx = input_token_logprobs_idx if req.last_update_decode_tokens != 0: # Some decode tokens are re-computed in an extend batch - req.output_token_logprobs.extend( - list( - zip( - output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 - ], - req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) - ], - ) - ) + req.output_token_logprobs_val.extend( + output.input_token_logprobs[ + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens : pt + + num_input_logprobs + - 1 + ], + ) + req.output_token_logprobs_idx.extend( + req.fill_ids[ + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) + ] ) if req.top_logprobs_num > 0: - if req.input_top_logprobs is None: - req.input_top_logprobs = output.input_top_logprobs[i] + if req.input_top_logprobs_val is None: + req.input_top_logprobs_val = output.input_top_logprobs_val[i] + req.input_top_logprobs_idx = output.input_top_logprobs_idx[i] if req.logprob_start_len == 0: - req.input_top_logprobs = [None] + req.input_top_logprobs + req.input_top_logprobs_val = [None] + req.input_top_logprobs_val + req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx if req.last_update_decode_tokens != 0: - req.output_top_logprobs.extend( - output.input_top_logprobs[i][-req.last_update_decode_tokens :] + req.output_top_logprobs_val.extend( + output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] ) - req.output_top_logprobs.append(output.output_top_logprobs[i]) + req.output_top_logprobs_idx.extend( + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] + ) + req.output_top_logprobs_val.append(output.output_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i]) return num_input_logprobs - def stream_output(self, reqs: List[Req]): + def stream_output( + self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None + ): """Stream the output to detokenizer.""" - output_rids = [] - output_meta_info: List[dict] = [] - output_finished_reason: List[BaseFinishReason] = [] + rids = [] + finished_reasons: List[BaseFinishReason] = [] + if self.is_generation: - output_vids = [] + vids = [] decoded_texts = [] - output_read_ids = [] - output_read_offsets = [] + decode_ids_list = [] + read_offsets = [] output_ids = [] - output_skip_special_tokens = [] - output_spaces_between_special_tokens = [] - output_no_stop_trim = [] - else: # embedding or reward model - output_embeddings = [] - - is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 + skip_special_tokens = [] + spaces_between_special_tokens = [] + no_stop_trim = [] + prompt_tokens = [] + completion_tokens = [] + cached_tokens = [] + + if return_logprob: + input_token_logprobs_val = [] + input_token_logprobs_idx = [] + output_token_logprobs_val = [] + output_token_logprobs_idx = [] + input_top_logprobs_val = [] + input_top_logprobs_idx = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] + normalized_prompt_logprob = [] + else: + input_token_logprobs_val = input_token_logprobs_idx = ( + output_token_logprobs_val + ) = output_token_logprobs_idx = input_top_logprobs_val = ( + input_top_logprobs_idx + ) = output_top_logprobs_val = output_top_logprobs_idx = ( + normalized_prompt_logprob + ) = None + + for req in reqs: + if req is skip_req: + continue - for req in reqs: - # TODO(lianmin): revisit this for overlap + retract + stream - if req.finished() or ( - req.stream and (is_stream_iter or len(req.output_ids) == 1) - ): - output_rids.append(req.rid) - output_finished_reason.append(req.finished_reason) - if self.is_generation: - output_vids.append(req.vid) + # TODO(lianmin): revisit this for overlap + retract + stream + if ( + req.finished() + # If stream, follow the given stream_interval + or (req.stream and len(req.output_ids) % self.stream_interval == 0) + # If not stream, we still want to output some tokens to get the benefit of incremental decoding. + or (not req.stream and len(req.output_ids) % 50 == 0) + ): + rids.append(req.rid) + finished_reasons.append( + req.finished_reason.to_json() if req.finished_reason else None + ) + vids.append(req.vid) decoded_texts.append(req.decoded_text) - read_ids, read_offset = req.init_incremental_detokenize() - output_read_ids.append(read_ids) - output_read_offsets.append(read_offset) + decode_ids, read_offset = req.init_incremental_detokenize() + decode_ids_list.append(decode_ids) + read_offsets.append(read_offset) if self.skip_tokenizer_init: output_ids.append(req.output_ids) - output_skip_special_tokens.append( - req.sampling_params.skip_special_tokens - ) - output_spaces_between_special_tokens.append( + skip_special_tokens.append(req.sampling_params.skip_special_tokens) + spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens ) - output_no_stop_trim.append(req.sampling_params.no_stop_trim) - - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - "completion_tokens": len(req.output_ids), - "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "cached_tokens": req.cached_tokens, - "finish_reason": ( - req.finished_reason.to_json() - if req.finished_reason is not None - else None - ), - } - if req.return_logprob: - ( - meta_info["input_token_logprobs"], - meta_info["output_token_logprobs"], - meta_info["input_top_logprobs"], - meta_info["output_top_logprobs"], - meta_info["normalized_prompt_logprob"], - ) = ( - req.input_token_logprobs, - req.output_token_logprobs, - req.input_top_logprobs, - req.output_top_logprobs, - req.normalized_prompt_logprob, - ) - output_meta_info.append(meta_info) - else: # embedding or reward model - output_embeddings.append(req.embedding) - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - } - output_meta_info.append(meta_info) - - # Send to detokenizer - if output_rids: - if self.is_generation: + no_stop_trim.append(req.sampling_params.no_stop_trim) + + prompt_tokens.append(len(req.origin_input_ids)) + completion_tokens.append(len(req.output_ids)) + cached_tokens.append(req.cached_tokens) + + if return_logprob: + input_token_logprobs_val.append(req.input_token_logprobs_val) + input_token_logprobs_idx.append(req.input_token_logprobs_idx) + output_token_logprobs_val.append(req.output_token_logprobs_val) + output_token_logprobs_idx.append(req.output_token_logprobs_idx) + input_top_logprobs_val.append(req.input_top_logprobs_val) + input_top_logprobs_idx.append(req.input_top_logprobs_idx) + output_top_logprobs_val.append(req.output_top_logprobs_val) + output_top_logprobs_idx.append(req.output_top_logprobs_idx) + normalized_prompt_logprob.append(req.normalized_prompt_logprob) + + # Send to detokenizer + if rids: self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( - output_rids, - output_vids, + rids, + finished_reasons, + vids, decoded_texts, - output_read_ids, - output_read_offsets, + decode_ids_list, + read_offsets, output_ids, - output_skip_special_tokens, - output_spaces_between_special_tokens, - output_meta_info, - output_finished_reason, - output_no_stop_trim, + skip_special_tokens, + spaces_between_special_tokens, + no_stop_trim, + prompt_tokens, + completion_tokens, + cached_tokens, + input_token_logprobs_val, + input_token_logprobs_idx, + output_token_logprobs_val, + output_token_logprobs_idx, + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + normalized_prompt_logprob, ) ) - else: # embedding or reward model - self.send_to_detokenizer.send_pyobj( - BatchEmbeddingOut( - output_rids, - output_embeddings, - output_meta_info, - output_finished_reason, - ) + else: # embedding or reward model + embeddings = [] + prompt_tokens = [] + for req in reqs: + assert req.finished() + rids.append(req.rid) + finished_reasons.append(req.finished_reason.to_json()) + embeddings.append(req.embedding) + prompt_tokens.append(len(req.origin_input_ids)) + self.send_to_detokenizer.send_pyobj( + BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens) + ) + + def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): + # Check if other DP workers have running batches + if local_batch is None: + num_tokens = 0 + elif local_batch.forward_mode.is_decode(): + num_tokens = local_batch.batch_size() + else: + num_tokens = local_batch.extend_num_tokens + + local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) + global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) + torch.distributed.all_gather_into_tensor( + global_num_tokens, + local_num_tokens, + group=self.tp_cpu_group, + ) + + if local_batch is None and global_num_tokens.max().item() > 0: + local_batch = self.get_idle_batch() + + if local_batch is not None: + local_batch.global_num_tokens = global_num_tokens.tolist() + + # Check forward mode for cuda graph + if not self.server_args.disable_cuda_graph: + forward_mode_state = torch.tensor( + ( + 1 + if local_batch.forward_mode.is_decode() + or local_batch.forward_mode.is_idle() + else 0 + ), + dtype=torch.int32, ) + torch.distributed.all_reduce( + forward_mode_state, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_cpu_group, + ) + local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1 + + return local_batch + + def get_idle_batch(self): + idle_batch = ScheduleBatch.init_new( + [], + self.req_to_token_pool, + self.token_to_kv_pool, + self.tree_cache, + self.model_config, + self.enable_overlap, + ) + idle_batch.prepare_for_idle() + return idle_batch def move_ready_grammar_requests(self): """Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" @@ -1469,9 +1505,7 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): - # set cpu affinity to this gpu process - if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): - set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + setproctitle.setproctitle("sglang::scheduler") # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var if dp_rank is None and "SGLANG_DP_RANK" in os.environ: @@ -1482,6 +1516,10 @@ def run_scheduler_process( else: configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") + # set cpu affinity to this gpu process + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + suppress_other_loggers() parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 56e01528add..29b98df2efa 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -22,7 +22,7 @@ import sys import time import uuid -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import fastapi import uvloop @@ -76,6 +76,7 @@ class ReqState: out_list: List finished: bool event: asyncio.Event + obj: Any # For metrics created_time: float @@ -283,7 +284,7 @@ async def _wait_one_response( ): """Wait for the response of one request.""" event = asyncio.Event() - state = ReqState([], False, event, created_time=created_time) + state = ReqState([], False, event, obj, created_time=created_time) self.rid_to_state[obj.rid] = state while True: @@ -295,15 +296,7 @@ async def _wait_one_response( raise ValueError(f"Abort request {obj.rid}") continue - if isinstance(obj, GenerateReqInput): - out = self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, - obj.return_text_in_logprobs, - ) - else: # isinstance(obj, (EmbeddingReqInput,)) - out = state.out_list[-1] + out = state.out_list[-1] state.out_list = [] if state.finished: @@ -315,7 +308,13 @@ async def _wait_one_response( break state.event.clear() - yield out + + if obj.stream: + yield out + else: + if request is not None and await request.is_disconnected(): + self.abort_request(obj.rid) + raise ValueError(f"Abort request {obj.rid}") async def _handle_batch_request( self, @@ -573,7 +572,7 @@ def create_handle_loop(self): async def sigterm_watchdog(self): while not self.gracefully_exit: - await asyncio.sleep(60) + await asyncio.sleep(5) # drain requests while True: @@ -609,29 +608,55 @@ async def handle_loop(self): if state is None: continue - recv_obj.meta_info[i]["id"] = rid + meta_info = { + "id": rid, + "finish_reason": recv_obj.finished_reasons[i], + "prompt_tokens": recv_obj.prompt_tokens[i], + } + + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( + meta_info, + state.obj.top_logprobs_num, + state.obj.return_text_in_logprobs, + recv_obj, + i, + ) + if isinstance(recv_obj, BatchStrOut): out_dict = { "text": recv_obj.output_strs[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": { + **meta_info, + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + }, } elif isinstance(recv_obj, BatchTokenIDOut): out_dict = { "token_ids": recv_obj.output_ids[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": { + **meta_info, + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + }, } else: assert isinstance(recv_obj, BatchEmbeddingOut) out_dict = { "embedding": recv_obj.embeddings[i], - "meta_info": recv_obj.meta_info[i], + "meta_info": meta_info, } state.out_list.append(out_dict) - state.finished = recv_obj.finished_reason[i] is not None + state.finished = recv_obj.finished_reasons[i] is not None state.event.set() if self.enable_metrics: - completion_tokens = recv_obj.meta_info[i]["completion_tokens"] + completion_tokens = ( + recv_obj.completion_tokens[i] + if recv_obj.completion_tokens + else 0 + ) if state.first_token_time is None: state.first_token_time = time.time() @@ -647,7 +672,7 @@ async def handle_loop(self): if state.finished: self.metrics_collector.inc_prompt_tokens( - recv_obj.meta_info[i]["prompt_tokens"] + recv_obj.prompt_tokens[i] ) self.metrics_collector.inc_generation_tokens( completion_tokens @@ -696,57 +721,73 @@ async def handle_loop(self): def convert_logprob_style( self, - ret: dict, - return_logprob: bool, + meta_info: dict, top_logprobs_num: int, return_text_in_logprobs: bool, + recv_obj: BatchStrOut, + recv_obj_index: int, ): - if return_logprob: - ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs + meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.input_token_logprobs_val[recv_obj_index], + recv_obj.input_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.output_token_logprobs_val[recv_obj_index], + recv_obj.output_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[ + recv_obj_index + ] + + if top_logprobs_num > 0: + meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.input_top_logprobs_val[recv_obj_index], + recv_obj.input_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, ) - ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs + meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.output_top_logprobs_val[recv_obj_index], + recv_obj.output_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, ) - if top_logprobs_num > 0: - ret["meta_info"]["input_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["input_top_logprobs"], - return_text_in_logprobs, - ) - ) - ret["meta_info"]["output_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs - ) - ) - return ret - def detokenize_logprob_tokens( - self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, ): - # TODO(lianmin): This should run on DetokenizerManager if not decode_to_text: - return [(logprob, token_id, None) for logprob, token_id in token_logprobs] - - assert self.tokenizer is not None - token_ids = [tid for _, tid in token_logprobs] - token_texts = self.tokenizer.batch_decode(token_ids) - return [ - (logprob, token_id, token_text) - for (logprob, token_id), token_text in zip(token_logprobs, token_texts) - ] + return [ + (logprob, token_id, None) + for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx) + ] + else: + assert self.tokenizer is not None + token_texts = self.tokenizer.batch_decode(token_logprobs_idx) + return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) - def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): + def detokenize_top_logprobs_tokens( + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, + ): # TODO: The current implementation only batches the detokenization for top-k tokens per single position. # We should batch all top-k tokens in all positions. - for i, token_top_logprobs in enumerate(top_logprobs): - if token_top_logprobs: - top_logprobs[i] = self.detokenize_logprob_tokens( - token_top_logprobs, decode_to_text + ret = [] + for i in range(len(token_logprobs_val)): + if token_logprobs_val[i]: + ret.append( + self.detokenize_logprob_tokens( + token_logprobs_val[i], token_logprobs_idx[i], decode_to_text + ) ) - return top_logprobs + else: + ret.append(None) + return ret class SignalHandler: diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index e4e20ad8f64..a9db1878391 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -32,12 +32,13 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_compiler_backend from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def resolve_future_token_ids(input_ids, future_token_ids_map): input_ids[:] = torch.where( input_ids < 0, @@ -73,12 +74,13 @@ def __init__( # Launch threads self.input_queue = Queue() self.output_queue = Queue() - self.forward_stream = torch.cuda.Stream() + self.forward_stream = torch.get_device_module(self.device).Stream() self.forward_thread = threading.Thread( target=self.forward_thread_func, ) self.forward_thread.start() self.parent_process = psutil.Process().parent() + self.scheduler_stream = torch.get_device_module(self.device).current_stream() def get_worker_info(self): return self.worker.get_worker_info() @@ -97,7 +99,7 @@ def get_memory_pool(self): def forward_thread_func(self): try: - with torch.cuda.stream(self.forward_stream): + with torch.get_device_module(self.device).stream(self.forward_stream): self.forward_thread_func_() except Exception: traceback = get_exception_traceback() @@ -122,7 +124,7 @@ def forward_thread_func_(self): # Create event self.launch_done = threading.Event() - copy_done = torch.cuda.Event() + copy_done = torch.get_device_module(self.device).Event() # Resolve future tokens in the input input_ids = model_worker_batch.input_ids @@ -190,7 +192,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): ) # A cuda stream sync here to avoid the cuda illegal memory access error. - torch.cuda.current_stream().synchronize() + self.scheduler_stream.synchronize() # Push a new batch to the queue self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b028309c745..646e71749d8 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -27,6 +27,7 @@ import torch from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.utils import get_compiler_backend logger = logging.getLogger(__name__) @@ -129,6 +130,9 @@ def alloc(self, need_size: int): return select_index.to(self.device, non_blocking=True) def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + if self.is_not_in_free_group: self.free_slots = torch.concat((self.free_slots, free_index.cpu())) else: @@ -234,7 +238,7 @@ def set_kv_buffer( # This compiled version is slower in the unit test # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): dst_1[loc] = src_1.to(dtype).view(store_dtype) dst_2[loc] = src_2.to(dtype).view(store_dtype) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 84f6825c3fd..77efba89212 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -36,7 +36,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner -def _to_torch(model: torch.nn.Module, reverse: bool = False): +def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: @@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): else: # NOTE: Temporarily workaround MoE if "FusedMoE" in sub.__class__.__name__: - sub._forward_method = fused_moe_forward_native + if batch_size == 1: + # The performance of torch.compile on this layer is not always good when bs > 1, + # so we decide to only use torch.compile when bs =1 + sub._forward_method = fused_moe_forward_native else: sub._forward_method = sub.forward_native setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): - _to_torch(sub, reverse) + _to_torch(sub, reverse, batch_size) @contextmanager def patch_model( - model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" + model: torch.nn.Module, + enable_compile: bool, + batch_size: int, + tp_group: "GroupCoordinator", ): """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None try: if enable_compile: - _to_torch(model) + _to_torch(model, reverse=False, batch_size=batch_size) monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. @@ -70,13 +76,15 @@ def patch_model( # even with ENABLE_INTRA_NODE_COMM=1. # tp_group.ca_comm = None yield torch.compile( - torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs" + torch.no_grad()(model.forward), + mode="max-autotune-no-cudagraphs", + dynamic=False, ) else: yield model.forward finally: if enable_compile: - _to_torch(model, reverse=True) + _to_torch(model, reverse=True, batch_size=batch_size) monkey_patch_vllm_all_gather(reverse=True) tp_group.ca_comm = backup_ca_comm @@ -122,6 +130,20 @@ def __init__(self, model_runner: "ModelRunner"): self.capture_bs = list(range(1, 32)) + [64, 128] else: self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + + if max(self.capture_bs) > model_runner.req_to_token_pool.size: + # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests + # is very samll. We add more values here to make sure we capture the maximum bs. + self.capture_bs = list( + sorted( + set( + self.capture_bs + + [model_runner.req_to_token_pool.size - 1] + + [model_runner.req_to_token_pool.size] + ) + ) + ) + self.capture_bs = [ bs for bs in self.capture_bs @@ -237,6 +259,7 @@ def capture(self): with patch_model( self.model_runner.model, bs in self.compile_bs, + bs, self.model_runner.tp_group, ) as forward: ( @@ -377,9 +400,14 @@ def replay(self, forward_batch: ForwardBatch): forward_mode=ForwardMode.DECODE, top_logprobs_nums=forward_batch.top_logprobs_nums, ) - logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( + ( + logits_output.output_top_logprobs_val, + logits_output.output_top_logprobs_idx, + ) = LogitsProcessor.get_top_logprobs( next_token_logprobs, logits_metadata - )[1] + )[ + 2:4 + ] else: logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 657e0c2ca5d..688cdd9ec4b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -16,6 +16,7 @@ import gc import json import logging +import time from typing import Optional import torch @@ -26,7 +27,6 @@ initialize_model_parallel, set_custom_all_reduce, ) -from vllm.distributed.parallel_state import in_the_same_node_as from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig @@ -37,6 +37,7 @@ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -111,11 +112,17 @@ def __init__( if self.is_multimodal: self.mem_fraction_static *= 0.95 + logger.info( + f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} " + f"because this is a multimodal model." + ) + if self.model_config.hf_config.architectures == [ "MllamaForConditionalGeneration" ]: logger.info("Automatically turn off --chunked-prefill-size for mllama.") server_args.chunked_prefill_size = -1 + # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" @@ -129,7 +136,7 @@ def __init__( # Global vars if server_args.show_time_cost: enable_show_time_cost() - if server_args.disable_disk_cache: + if server_args.disable_outlines_disk_cache: from outlines.caching import disable_cache disable_cache() @@ -143,6 +150,7 @@ def __init__( "torchao_config": server_args.torchao_config, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, + "enable_ep_moe": server_args.enable_ep_moe, } ) @@ -163,6 +171,10 @@ def __init__( else: self.torch_tp_applied = False + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) + # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager() @@ -623,8 +635,10 @@ def init_cuda_graphs(self): if self.server_args.disable_cuda_graph: return + tic = time.time() logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 6817bce0235..778347b8ef3 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -54,11 +54,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me )._prepare_output_fn( output_layouts, use_local_output, mod, outputs, device_mesh ) - # wait for the output to be ready - if isinstance(outputs, AsyncCollectiveTensor): - return outputs.wait() - else: - return outputs + return torch.distributed._functional_collectives.wait_tensor(outputs) def tensor_parallel( diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index a758e4f56b1..83ac3d8671b 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -62,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import get_compiler_backend, set_weight_attrs -@torch.compile +@torch.compile(backend=get_compiler_backend()) def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 424f86aec28..63cea92c289 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -21,6 +21,7 @@ import torch from torch import nn from transformers import PretrainedConfig +from vllm import _custom_ops as ops from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -30,6 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.ep_moe.layer import EPMoE from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -112,12 +114,12 @@ def __init__( "Only silu is supported for now." ) - self.experts = FusedMoE( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + self.experts = MoEImpl( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -453,7 +455,7 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.attn = RadixAttention( + self.attn_mqa = RadixAttention( self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim, self.scaling, @@ -462,6 +464,15 @@ def __init__( v_head_dim=self.kv_lora_rank, ) + self.attn_mha = RadixAttention( + self.num_local_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + layer_id=layer_id, + v_head_dim=self.v_head_dim, + ) + self.w_kc = None self.w_vc = None self.w_scale = None @@ -471,6 +482,63 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + ) -> torch.Tensor: + # Use normal computation for prefill and use weight absorption for extend/decode + if ( + forward_batch.forward_mode.is_extend() + and forward_batch.extend_prefix_lens.sum() == 0 + ): + return self.forward_normal(positions, hidden_states, forward_batch) + else: + return self.forward_absorb(positions, hidden_states, forward_batch) + + def forward_normal( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope = kv[..., : self.qk_nope_head_dim] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q[..., self.qk_nope_head_dim :] = q_pe + k = torch.empty_like(q) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + + latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) + latent_cache[:, :, self.kv_lora_rank :] = k_pe + + # Save latent cache + forward_batch.token_to_kv_pool.set_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, latent_cache, None + ) + attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) + attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + def forward_absorb( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( @@ -508,7 +576,7 @@ def forward( q_input[..., self.kv_lora_rank :] = q_pe k_input[..., self.kv_lora_rank :] = k_pe - attn_output = self.attn(q_input, k_input, v_input, forward_batch) + attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) if self.w_vc.dtype == torch.float8_e4m3fn: @@ -767,7 +835,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -828,14 +897,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if not global_server_args_dict["disable_mla"]: for layer_id in range(self.config.num_hidden_layers): self_attn = self.model.layers[layer_id].self_attn - w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_vc = w_vc.contiguous().transpose(1, 2) if hasattr(self_attn.kv_b_proj, "weight_scale"): self_attn.w_scale = self_attn.kv_b_proj.weight_scale - del self_attn.kv_b_proj EntryClass = DeepseekV2ForCausalLM diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index dbca7268803..0c0e6155d35 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -355,6 +355,40 @@ def forward( input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "qkv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_attention_heads, + ) + elif module_name in ["o_proj"]: + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name in ["kv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_key_value_heads, + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 956f73b1482..2b52e2b1bcc 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -35,12 +35,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -290,7 +288,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -374,8 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - class Grok1ModelForCausalLM(Grok1ForCausalLM): """An alias for backward-compatbility.""" diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 61409a9eaeb..e3e44ea6ffc 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -36,12 +36,10 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -304,7 +302,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) # Llama 3.2 1B Insturct set tie_word_embeddings to True # Llama 3.1 8B Insturct set tie_word_embeddings to False @@ -424,8 +421,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - def get_weights_by_name( self, name: str, truncate_size: int = 100, tp_size: int = 1 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index b222387a776..f3fad226091 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -21,9 +21,13 @@ import torch from torch import nn from transformers import MixtralConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.layers.ep_moe.layer import EPMoE from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -34,7 +38,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -65,6 +68,7 @@ def __init__( prefix: str = "", ): super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. @@ -76,14 +80,13 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) - - self.experts = FusedMoE( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + self.experts = MoEImpl( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - reduce_results=True, renormalize=True, quant_config=quant_config, tp_size=tp_size, @@ -97,6 +100,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) @@ -295,7 +300,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -322,7 +326,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", @@ -339,7 +344,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] @@ -353,6 +360,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -365,7 +376,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip loading kv_scale from ckpts towards new design. if name.endswith(".kv_scale") and name not in params_dict: @@ -379,7 +392,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 6340330774d..1e70c7d7874 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -17,13 +17,11 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -348,7 +346,6 @@ def __init__( quant_config=quant_config, prefix="model", ) - self.torchao_config = global_server_args_dict["torchao_config"] self.vocab_size = config.vocab_size self.mup_width_multiplier = config.mup_width_multiplier self.lm_head = ParallelLMHead( @@ -441,7 +438,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = Phi3SmallForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 0094cb8c3e2..62cd3281d03 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -40,12 +40,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -352,7 +350,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = Qwen2MoeModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config @@ -445,7 +442,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 25e555484a7..7a55d50457a 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -58,12 +58,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -392,7 +390,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) if self.config.tie_word_embeddings: @@ -503,8 +500,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 1624fd255f9..a64a84a62dc 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -158,22 +158,23 @@ def update_regex_vocab_mask(self): return # find a grammar from the list - grammar = next(grammar for grammar in self.grammars if grammar) + first_grammar = next(grammar for grammar in self.grammars if grammar) # maybe we can reuse the existing mask? - self.vocab_mask = grammar.allocate_vocab_mask( + self.vocab_mask = first_grammar.allocate_vocab_mask( vocab_size=self.vocab_size, batch_size=len(self.temperatures), device=self.device, ) - self.apply_mask = type(grammar).apply_vocab_mask # force to use static method + self.apply_mask = first_grammar.apply_vocab_mask # force to use static method + # Apply the mask for i, grammar in enumerate(self.grammars): - if grammar is not None: - try: - grammar.fill_vocab_mask(self.vocab_mask, i) - except RuntimeError: - continue + if grammar and not grammar.finished: + grammar.fill_vocab_mask(self.vocab_mask, i) + + # Move the mask to the device if needed + self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index fc8ac150b3f..29bc44eb524 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -329,7 +329,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ) -@app.api_route("/encode", methods=["POST", "PUT"]) +@app.api_route("/classify", methods=["POST", "PUT"]) @time_func_latency async def classify_request(obj: EmbeddingReqInput, request: Request): """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" @@ -462,8 +462,8 @@ def launch_engine( if server_args.node_rank >= 1: # For other nodes, they do not need to run tokenizer or detokenizer, # so they can just wait here. - while True: - pass + for proc in scheduler_procs: + proc.join() else: # Launch the data parallel controller reader, writer = mp.Pipe(duplex=False) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 37ad6cfc530..fe12d961d3a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,9 +20,12 @@ import tempfile from typing import List, Optional +import torch + from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.utils import ( get_amdgpu_memory_capacity, + get_hpu_memory_capacity, get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, @@ -91,6 +94,8 @@ class ServerArgs: # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + # Expert parallelism + ep_size: int = 1 # Multi-node distributed serving dist_init_addr: Optional[str] = None @@ -122,12 +127,13 @@ class ServerArgs: disable_jump_forward: bool = False disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False - disable_disk_cache: bool = False + disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False disable_mla: bool = False disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False + enable_ep_moe: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -135,6 +141,7 @@ class ServerArgs: enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False + triton_attention_num_kv_splits: int = 8 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False @@ -151,15 +158,20 @@ def __post_init__(self): if is_hip(): gpu_mem = get_amdgpu_memory_capacity() - else: + elif torch.cuda.is_available(): gpu_mem = get_nvgpu_memory_capacity() + elif self.device == "hpu": + gpu_mem = get_hpu_memory_capacity() + else: + # GPU memory is not known yet or no GPU is available. + gpu_mem = None # Set mem fraction static, which depends on the tensor parallelism size if self.mem_fraction_static is None: if self.tp_size >= 16: self.mem_fraction_static = 0.79 elif self.tp_size >= 8: - self.mem_fraction_static = 0.82 + self.mem_fraction_static = 0.81 elif self.tp_size >= 4: self.mem_fraction_static = 0.85 elif self.tp_size >= 2: @@ -169,19 +181,27 @@ def __post_init__(self): # Set chunked prefill size, which depends on the gpu memory capacity if self.chunked_prefill_size is None: - if gpu_mem < 25_000: + if gpu_mem is not None and gpu_mem < 25_000: self.chunked_prefill_size = 2048 else: self.chunked_prefill_size = 8192 # Set cuda graph max batch size if self.cuda_graph_max_bs is None: - if gpu_mem < 25_000: - self.cuda_graph_max_bs = 8 + # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. + if gpu_mem is not None and gpu_mem < 25_000: + if self.tp_size < 4: + self.cuda_graph_max_bs = 8 + else: + self.cuda_graph_max_bs = 80 else: self.cuda_graph_max_bs = 160 # Choose kernel backends + if self.device == "hpu": + self.attention_backend = "torch_native" + self.sampling_backend = "pytorch" + if self.attention_backend is None: self.attention_backend = ( "flashinfer" if is_flashinfer_available() else "triton" @@ -192,7 +212,7 @@ def __post_init__(self): ) if self.attention_backend == "torch_native": - logger.info( + logger.warning( "Cuda graph is disabled because of using torch native attention backend" ) self.disable_cuda_graph = True @@ -204,12 +224,18 @@ def __post_init__(self): self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.disable_overlap_schedule = True - logger.info( + logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " "Data parallel size is adjusted to be the same as tensor parallel size. " - "Overlap schedule is disabled." + "Overlap scheduler is disabled." + ) + # Expert parallelism + if self.enable_ep_moe: + self.ep_size = self.tp_size + logger.info( + f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) # GGUF @@ -521,6 +547,14 @@ def add_cli_args(parser: argparse.ArgumentParser): "shortest_queue", ], ) + # Expert parallelism + parser.add_argument( + "--expert-parallel-size", + "--ep-size", + type=int, + default=ServerArgs.ep_size, + help="The expert parallelism size.", + ) # Multi-node distributed serving parser.add_argument( @@ -642,9 +676,9 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.", ) parser.add_argument( - "--disable-disk-cache", + "--disable-outlines-disk-cache", action="store_true", - help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", + help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.", ) parser.add_argument( "--disable-custom-all-reduce", @@ -676,6 +710,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) + parser.add_argument( + "--enable-ep-moe", + action="store_true", + help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", @@ -715,6 +754,12 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", ) + parser.add_argument( + "--triton-attention-num-kv-splits", + type=int, + default=ServerArgs.triton_attention_num_kv_splits, + help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.", + ) parser.add_argument( "--num-continuous-decode-steps", type=int, @@ -745,11 +790,17 @@ def add_cli_args(parser: argparse.ArgumentParser): action=DeprecatedAction, help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.", ) + parser.add_argument( + "--disable-disk-cache", + action=DeprecatedAction, + help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size + args.ep_size = args.expert_parallel_size attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c19d521a066..5c310136a21 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory free_gpu_memory = total_gpu_memory - used_memory + elif device == "hpu": + num_gpus = torch.hpu.device_count() + assert gpu_id < num_gpus + + if torch.hpu.current_device() != gpu_id: + print( + f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ", + "which may cause useless memory allocation for torch HPU context.", + ) + + free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info() + if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( torch.device(device, gpu_id) @@ -879,7 +891,7 @@ def get_amdgpu_memory_capacity(): # Run rocm-smi and capture the output result = subprocess.run( [ - "rocminfo | grep 'gfx94' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'" + "rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'" ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity(): ) +def get_hpu_memory_capacity(): + try: + # Run hl-smi and capture the output + result = subprocess.run( + ["hl-smi --query | grep 'Total'"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"hl-smi error: {result.stderr.strip()}") + + # Parse the output to extract memory values in MiB + memory_values = [ + float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n") + ] + + if not memory_values: + raise ValueError("No GPU memory values found.") + + # Return the minimum memory value + return min(memory_values) + + except FileNotFoundError: + raise RuntimeError( + "hl-smi not found. Ensure Habana drivers are installed and accessible." + ) + + # Copy from pytorch and OpenRLHF to allow creating multiple main groups. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py # https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py @@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: return major, minor +def get_compiler_backend() -> str: + if hasattr(torch, "hpu") and torch.hpu.is_available(): + return "hpu_backend" + + return "inductor" + + sglang_lib = Library("sglang", "FRAGMENT") # noqa diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index f97fc12355a..32c6e08b69f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -568,6 +568,7 @@ def run_bench_serving( disable_tqdm=False, disable_stream=disable_stream, disable_ignore_eos=False, + return_logprob=False, lora_name=None, extra_request_body=None, profile=None, @@ -719,13 +720,13 @@ def run_and_check_memory_leak( # Clean up everything kill_process_tree(process.pid) - kill_process_tree(process.pid) stdout.close() stderr.close() if os.path.exists(STDOUT_FILENAME): os.remove(STDOUT_FILENAME) if os.path.exists(STDERR_FILENAME): os.remove(STDERR_FILENAME) + kill_process_tree(process.pid) t.join() # Assert success @@ -733,7 +734,7 @@ def run_and_check_memory_leak( has_leak = False has_abort = False for line in output_lines: - if "The server is fired" in line: + if "Uvicorn running" in line: has_new_server = True if "leak" in line: has_leak = True diff --git a/python/sglang/utils.py b/python/sglang/utils.py index c1bf62ef983..5689b097d23 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,4 +1,4 @@ -"""Common utilities.""" +"""Common utilities""" import base64 import gc diff --git a/python/sglang/version.py b/python/sglang/version.py index 4f6e29b6637..a21caf9d324 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.3.6.post3" +__version__ = "0.4.0.post1" diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 37c2733fdc0..dc9c46a7146 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "actix-codec" @@ -851,6 +851,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1986,6 +1987,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2 0.4.6", @@ -2219,6 +2221,7 @@ dependencies = [ "serde", "serde_json", "tokenizers", + "tokio", ] [[package]] @@ -2475,9 +2478,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.0" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" dependencies = [ "backtrace", "bytes", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 5ac77665bcc..d20a381ee7b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -19,7 +19,7 @@ serde = { version = "1.0", features = ["derive"] } clap = { version = "4.4", features = ["derive"] } bytes = "1.8.0" rand = "0.8.5" -reqwest = { version = "0.12.8", features = ["stream"] } +reqwest = { version = "0.12.8", features = ["stream", "blocking"] } futures-util = "0.3" serde_json = "1.0" pyo3 = { version = "0.22.5", features = ["extension-module"] } @@ -29,6 +29,7 @@ http = "1.1.0" env_logger = "0.11.5" log = "0.4.22" chrono = "0.4.38" +tokio = "1.42.0" [profile.release] lto = "thin" diff --git a/rust/py_src/sglang_router/launch_server.py b/rust/py_src/sglang_router/launch_server.py index ec86e8b2adb..6ee19241542 100644 --- a/rust/py_src/sglang_router/launch_server.py +++ b/rust/py_src/sglang_router/launch_server.py @@ -10,12 +10,12 @@ from typing import List import requests +from setproctitle import setproctitle from sglang_router.launch_router import RouterArgs, launch_router from sglang.srt.server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available -from sglang.utils import get_exception_traceback def setup_logger(): @@ -34,10 +34,41 @@ def setup_logger(): return logger +logger = setup_logger() + + # Create new process group def run_server(server_args, dp_rank): - os.setpgrp() # Create new process group - + """ + Note: + + 1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously. + This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes. + + Terminal (PGID=100) + └── Main Python Process (PGID=100) + └── Server Process 1 (PGID=100) + └── Scheduler 1 + └── Detokenizer 1 + └── Server Process 2 (PGID=100) + └── Scheduler 2 + └── Detokenizer 2 + + 2. With os.setpgrp(), the main Python process and its children are in a separate group. Now: + + Terminal (PGID=100) + └── Main Python Process (PGID=200) + └── Server Process 1 (PGID=300) + └── Scheduler 1 + └── Detokenizer 1 + └── Server Process 2 (PGID=400) + └── Scheduler 2 + └── Detokenizer 2 + """ + # create new process group + os.setpgrp() + + setproctitle(f"sglang::server") # Set SGLANG_DP_RANK environment variable os.environ["SGLANG_DP_RANK"] = str(dp_rank) @@ -58,36 +89,6 @@ def launch_server_process( return proc -def cleanup_processes(processes: List[mp.Process]): - logger = logging.getLogger("router") - logger.info("Cleaning up processes...") - for proc in processes: - if proc.is_alive(): - try: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - proc.join(timeout=3) - if proc.is_alive(): - logger.warning( - f"Process {proc.pid} did not terminate gracefully, force killing..." - ) - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - except ProcessLookupError: - pass - - -def setup_signal_handlers(cleanup_func): - """Setup handlers for various termination signals.""" - - def signal_handler(signum, frame): - cleanup_func() - sys.exit(1) - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - if hasattr(signal, "SIGQUIT"): - signal.signal(signal.SIGQUIT, signal_handler) - - def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool: """Wait for server to be healthy by checking /health endpoint.""" start_time = time.time() @@ -117,9 +118,14 @@ def find_available_ports(base_port: int, count: int) -> List[int]: return available_ports -def main(): - logger = setup_logger() +def cleanup_processes(processes: List[mp.Process]): + for process in processes: + logger.info(f"Terminating process {process.pid}") + process.terminate() + logger.info("All processes terminated") + +def main(): # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes mp.set_start_method("spawn") @@ -148,52 +154,26 @@ def main(): # Start server processes server_processes = [] - try: - for i, worker_port in enumerate(worker_ports): - logger.info(f"Launching DP server process {i} on port {worker_port}") - proc = launch_server_process(server_args, worker_port, i) - server_processes.append(proc) - - # Setup cleanup handler - setup_signal_handlers(lambda: cleanup_processes(server_processes)) - - # Wait for all servers to be healthy - all_healthy = True - - for port in worker_ports: - if not wait_for_server_health(server_args.host, port): - logger.error(f"Server on port {port} failed to become healthy") - all_healthy = False - break - - if not all_healthy: - logger.error("Not all servers are healthy. Shutting down...") - cleanup_processes(server_processes) - sys.exit(1) - - logger.info("All servers are healthy. Starting router...") - - # Update router args with worker URLs - router_args.worker_urls = [ - f"http://{server_args.host}:{port}" for port in worker_ports - ] - - # Start the router - router = launch_router(router_args) - - if router is None: - logger.error("Failed to start router. Shutting down...") - cleanup_processes(server_processes) - sys.exit(1) - - except KeyboardInterrupt: - logger.info("Received shutdown signal...") - except Exception as e: - logger.error(f"Error occurred: {e}") - logger.error(get_exception_traceback()) - finally: - logger.info("Cleaning up processes...") - cleanup_processes(server_processes) + for i, worker_port in enumerate(worker_ports): + logger.info(f"Launching DP server process {i} on port {worker_port}") + proc = launch_server_process(server_args, worker_port, i) + server_processes.append(proc) + + signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes)) + signal.signal( + signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes) + ) + signal.signal( + signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes) + ) + + # Update router args with worker URLs + router_args.worker_urls = [ + f"http://{server_args.host}:{port}" for port in worker_ports + ] + + # Start the router + router = launch_router(router_args) if __name__ == "__main__": diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index f39b341df2b..2591abb5cdf 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -1,3 +1,4 @@ +import socket import subprocess import time import unittest @@ -5,7 +6,6 @@ import requests -from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -19,6 +19,7 @@ def popen_launch_router( base_url: str, dp_size: int, timeout: float, + policy: str = "cache_aware", ): """ Launch the router server process. @@ -28,6 +29,7 @@ def popen_launch_router( base_url: Server base URL dp_size: Data parallel size timeout: Server launch timeout + policy: Router policy, one of "cache_aware", "round_robin", "random" """ _, host, port = base_url.split(":") host = host[2:] @@ -44,12 +46,13 @@ def popen_launch_router( port, "--dp", str(dp_size), # Convert dp_size to string + "--router-eviction-interval", + "5", # frequent eviction for testing + "--router-policy", + policy, ] - # Use current environment - env = None - - process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() with requests.Session() as session: @@ -57,31 +60,143 @@ def popen_launch_router( try: response = session.get(f"{base_url}/health") if response.status_code == 200: + print(f"Router {base_url} is healthy") return process except requests.RequestException: pass time.sleep(10) - raise TimeoutError("Server failed to start within the timeout period.") + raise TimeoutError("Router failed to start within the timeout period.") + + +def find_available_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def popen_launch_server( + model: str, + base_url: str, + timeout: float, +): + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + "--base-gpu-id", + "1", + ] + + process = subprocess.Popen(command, stdout=None, stderr=None) + + # intentionally don't wait and defer the job to the router health check + return process -class TestEvalAccuracyMini(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_router( - cls.model, - cls.base_url, +def terminate_and_wait(process, timeout=300): + """Terminate a process and wait until it is terminated. + + Args: + process: subprocess.Popen object + timeout: maximum time to wait in seconds + + Raises: + TimeoutError: if process does not terminate within timeout + """ + if process is None: + return + + process.terminate() + start_time = time.time() + + while process.poll() is None: + print(f"Terminating process {process.pid}") + if time.time() - start_time > timeout: + raise TimeoutError( + f"Process {process.pid} failed to terminate within {timeout}s" + ) + time.sleep(1) + + print(f"Process {process.pid} is successfully terminated") + + +class TestLaunchServer(unittest.TestCase): + def setUp(self): + self.model = DEFAULT_MODEL_NAME_FOR_TEST + self.base_url = DEFAULT_URL_FOR_TEST + self.process = None + self.other_process = [] + + def tearDown(self): + print("Running tearDown...") + if self.process: + terminate_and_wait(self.process) + for process in self.other_process: + terminate_and_wait(process) + print("tearDown done") + + def test_1_mmlu(self): + print("Running test_1_mmlu...") + # DP size = 2 + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=2, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="cache_aware", + ) + + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + def test_2_add_and_remove_worker(self): + print("Running test_2_add_and_remove_worker...") + # DP size = 1 + self.process = popen_launch_router( + self.model, + self.base_url, dp_size=1, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", # use round robin to make sure every worker processes requests + ) + # 1. start a worker, and wait until it is healthy + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ) + self.other_process.append(worker_process) - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) + # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) - def test_mmlu(self): + # 3. run mmlu args = SimpleNamespace( base_url=self.base_url, model=self.model, @@ -90,9 +205,26 @@ def test_mmlu(self): num_threads=32, temperature=0.1, ) - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.65) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + # 4. use /remove_worker api to remove it from the router + with requests.Session() as session: + response = session.post(f"{self.base_url}/remove_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # 5. run mmlu again + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) if __name__ == "__main__": diff --git a/rust/pyproject.toml b/rust/pyproject.toml index 35b1dc7d433..d1327d9203e 100644 --- a/rust/pyproject.toml +++ b/rust/pyproject.toml @@ -4,11 +4,12 @@ build-backend = "setuptools.build_meta" [project] name = "sglang-router" -version = "0.0.10" +version = "0.0.11" description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances." authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] requires-python = ">=3.8" readme = "README.md" +license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Rust", diff --git a/rust/src/router.rs b/rust/src/router.rs index e17cba874c9..615a9550ef3 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -3,22 +3,23 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use bytes::Bytes; use futures_util::{StreamExt, TryStreamExt}; -use log::{debug, info}; +use log::{debug, info, warn}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::thread; use std::time::Duration; +use tokio; #[derive(Debug)] pub enum Router { RoundRobin { - worker_urls: Vec, + worker_urls: Arc>>, current_index: AtomicUsize, }, Random { - worker_urls: Vec, + worker_urls: Arc>>, }, CacheAware { /* @@ -81,7 +82,7 @@ pub enum Router { Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted during the next eviction cycle. */ - worker_urls: Vec, + worker_urls: Arc>>, tree: Arc>, running_queue: Arc>>, processed_queue: Arc>>, @@ -92,7 +93,7 @@ pub enum Router { }, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum PolicyConfig { RandomConfig, RoundRobinConfig, @@ -126,12 +127,19 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String { return "".to_string(); } + impl Router { - pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { - match policy_config { - PolicyConfig::RandomConfig => Router::Random { worker_urls }, + pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { + // Wait until all workers are healthy + Self::wait_for_healthy_workers(&worker_urls, 300, 10)?; + + // Create router based on policy... + Ok(match policy_config { + PolicyConfig::RandomConfig => Router::Random { + worker_urls: Arc::new(RwLock::new(worker_urls)), + }, PolicyConfig::RoundRobinConfig => Router::RoundRobin { - worker_urls, + worker_urls: Arc::new(RwLock::new(worker_urls)), current_index: std::sync::atomic::AtomicUsize::new(0), }, PolicyConfig::CacheAwareConfig { @@ -166,7 +174,7 @@ impl Router { let locked_tree_clone = tree_clone.lock().unwrap(); // Run eviction - locked_tree_clone.evict_tenant_data(max_tree_size); + locked_tree_clone.evict_tenant_by_size(max_tree_size); // Print the process queue let locked_processed_queue = processed_queue_clone.lock().unwrap(); @@ -183,7 +191,7 @@ impl Router { } Router::CacheAware { - worker_urls, + worker_urls: Arc::new(RwLock::new(worker_urls)), tree, running_queue, processed_queue, @@ -193,7 +201,7 @@ impl Router { _eviction_thread: Some(eviction_thread), } } - } + }) } pub fn get_first(&self) -> Option { @@ -201,11 +209,64 @@ impl Router { Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } | Router::CacheAware { worker_urls, .. } => { - if worker_urls.is_empty() { + if worker_urls.read().unwrap().is_empty() { None } else { - Some(worker_urls[0].clone()) + Some(worker_urls.read().unwrap()[0].clone()) + } + } + } + } + + fn wait_for_healthy_workers( + worker_urls: &[String], + timeout_secs: u64, + interval_secs: u64, + ) -> Result<(), String> { + let start_time = std::time::Instant::now(); + let sync_client = reqwest::blocking::Client::new(); + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + return Err(format!( + "Timeout {}s waiting for workers to become healthy", + timeout_secs + )); + } + + let mut all_healthy = true; + let mut unhealthy_workers = Vec::new(); + + for url in worker_urls { + match sync_client.get(&format!("{}/health", url)).send() { + Ok(res) => { + if !res.status().is_success() { + info!( + "Worker {} health check is pending with status: {}.", + url, + res.status() + ); + all_healthy = false; + unhealthy_workers.push((url, format!("Status: {}", res.status()))); + } + } + Err(e) => { + info!("Worker {} health check is pending with error: {}", url, e); + all_healthy = false; + unhealthy_workers.push((url, format!("Error: {}", e))); + } + } + } + + if all_healthy { + info!("All workers are healthy"); + return Ok(()); + } else { + info!("Unhealthy workers:"); + for (url, reason) in &unhealthy_workers { + info!(" {} - {}", url, reason); } + thread::sleep(Duration::from_secs(interval_secs)); } } } @@ -228,15 +289,15 @@ impl Router { .fetch_update( std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst, - |x| Some((x + 1) % worker_urls.len()), + |x| Some((x + 1) % worker_urls.read().unwrap().len()), ) .unwrap(); - worker_urls[idx].clone() + worker_urls.read().unwrap()[idx].clone() } - Router::Random { worker_urls } => { - worker_urls[rand::random::() % worker_urls.len()].clone() - } + Router::Random { worker_urls } => worker_urls.read().unwrap() + [rand::random::() % worker_urls.read().unwrap().len()] + .clone(), Router::CacheAware { worker_urls, @@ -277,7 +338,7 @@ impl Router { .iter() .min_by_key(|(_url, &count)| count) .map(|(url, _)| url.clone()) - .unwrap_or_else(|| worker_urls[0].clone()) + .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) } else { // Use cache-aware routing when load is balanced let (matched_text, matched_worker) = tree.prefix_match(&text); @@ -333,7 +394,10 @@ impl Router { // For non-streaming requests, get response first let response = match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(_) => HttpResponse::InternalServerError().finish(), + Err(e) => { + let error_msg = format!("Failed to get response body: {}", e); + HttpResponse::InternalServerError().body(error_msg) + } }; // Then decrement running queue counter if using CacheAware @@ -379,4 +443,122 @@ impl Router { })) } } + + pub async fn add_worker(&self, worker_url: String) -> Result { + let interval_secs = 10; // check every 10 seconds + let timeout_secs = 300; // 5 minutes + + let start_time = std::time::Instant::now(); + let client = reqwest::Client::new(); + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + return Err(format!( + "Timeout {}s waiting for worker {} to become healthy", + timeout_secs, worker_url + )); + } + + match client.get(&format!("{}/health", worker_url)).send().await { + Ok(res) => { + if res.status().is_success() { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + info!("Worker {} health check passed", worker_url); + let mut urls = worker_urls.write().unwrap(); + if urls.contains(&worker_url) { + return Err(format!("Worker {} already exists", worker_url)); + } + info!("Added worker: {}", worker_url); + urls.push(worker_url.clone()); + } + } + + // If cache aware, initialize the queues for the new worker + if let Router::CacheAware { + running_queue, + processed_queue, + tree, + .. + } = self + { + // Add worker to running queue with initial count of 0 + running_queue.lock().unwrap().insert(worker_url.clone(), 0); + + // Add worker to processed queue with initial count of 0 + processed_queue + .lock() + .unwrap() + .insert(worker_url.clone(), 0); + + // Add worker to tree + tree.lock().unwrap().insert(&"".to_string(), &worker_url); + } + + return Ok(format!("Successfully added worker: {}", worker_url)); + } else { + info!( + "Worker {} health check is pending with status: {}.", + worker_url, + res.status() + ); + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") + { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + continue; + } + } + Err(e) => { + info!( + "Worker {} health check is pending with error: {}", + worker_url, e + ); + + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + continue; + } + } + } + } + + pub fn remove_worker(&self, worker_url: String) { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + let mut urls = worker_urls.write().unwrap(); + let index = urls.iter().position(|url| url == &worker_url).unwrap(); + urls.remove(index); + info!("Removed worker: {}", worker_url); + } + } + + // if cache aware, remove the worker from the tree + if let Router::CacheAware { + tree, + running_queue, + processed_queue, + .. + } = self + { + tree.lock().unwrap().remove_tenant(&worker_url); + running_queue.lock().unwrap().remove(&worker_url); + processed_queue.lock().unwrap().remove(&worker_url); + info!( + "Removed worker from tree and cleaned up queues: {}", + worker_url + ); + } + } } diff --git a/rust/src/server.rs b/rust/src/server.rs index 3fbe5c3e895..8a0eb1547d6 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -4,6 +4,7 @@ use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Resp use bytes::Bytes; use env_logger::Builder; use log::{info, LevelFilter}; +use std::collections::HashMap; use std::io::Write; #[derive(Debug)] @@ -19,7 +20,10 @@ impl AppState { policy_config: PolicyConfig, ) -> Self { // Create router based on policy - let router = Router::new(worker_urls, policy_config); + let router = match Router::new(worker_urls, policy_config) { + Ok(router) => router, + Err(error) => panic!("Failed to create router: {}", error), + }; Self { router, client } } @@ -128,6 +132,38 @@ async fn v1_completions( .await } +#[post("/add_worker")] +async fn add_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => { + return HttpResponse::BadRequest() + .body("Worker URL required. Provide 'url' query parameter") + } + }; + + match data.router.add_worker(worker_url).await { + Ok(message) => HttpResponse::Ok().body(message), + Err(error) => HttpResponse::BadRequest().body(error), + } +} + +#[post("/remove_worker")] +async fn remove_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => return HttpResponse::BadRequest().finish(), + }; + data.router.remove_worker(worker_url); + HttpResponse::Ok().finish() +} + pub struct ServerConfig { pub host: String, pub port: u16, @@ -158,20 +194,20 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ) .init(); - info!("Starting server on {}:{}", config.host, config.port); - info!("Worker URLs: {:?}", config.worker_urls); - info!("Policy Config: {:?}", config.policy_config); - let client = reqwest::Client::builder() .build() .expect("Failed to create HTTP client"); let app_state = web::Data::new(AppState::new( - config.worker_urls, + config.worker_urls.clone(), client, - config.policy_config, + config.policy_config.clone(), )); + info!("✅ Starting router on {}:{}", config.host, config.port); + info!("✅ Serving Worker URLs: {:?}", config.worker_urls); + info!("✅ Policy Config: {:?}", config.policy_config); + HttpServer::new(move || { App::new() .app_data(app_state.clone()) @@ -183,6 +219,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(health) .service(health_generate) .service(get_server_info) + .service(add_worker) + .service(remove_worker) }) .bind((config.host, config.port))? .run() diff --git a/rust/src/tree.rs b/rust/src/tree.rs index 3c403676f01..e8dc8b7a0da 100644 --- a/rust/src/tree.rs +++ b/rust/src/tree.rs @@ -5,6 +5,7 @@ use log::info; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::collections::HashMap; +use std::collections::VecDeque; use std::sync::Arc; use std::sync::RwLock; @@ -24,7 +25,6 @@ struct Node { #[derive(Debug)] pub struct Tree { root: NodeRef, - // TODO: Char Count per tenant pub tenant_char_count: DashMap, } @@ -405,20 +405,12 @@ impl Tree { .collect() } - pub fn evict_tenant_data(&self, max_size: usize) { + pub fn evict_tenant_by_size(&self, max_size: usize) { // Calculate used size and collect leaves let mut stack = vec![Arc::clone(&self.root)]; - let mut used_size_per_tenant: HashMap = HashMap::new(); let mut pq = BinaryHeap::new(); while let Some(curr) = stack.pop() { - for tenant in curr.tenant_last_access_time.iter() { - let size = used_size_per_tenant - .entry(tenant.key().clone()) - .or_insert(0); - *size += curr.text.read().unwrap().chars().count(); - } - for child in curr.children.iter() { stack.push(Arc::clone(child.value())); } @@ -436,65 +428,100 @@ impl Tree { } info!("Before eviction - Used size per tenant:"); - for (tenant, size) in &used_size_per_tenant { - info!("Tenant: {}, Size: {}", tenant, size); + for entry in self.tenant_char_count.iter() { + info!("Tenant: {}, Size: {}", entry.key(), entry.value()); } // Process eviction while let Some(Reverse(entry)) = pq.pop() { let EvictionEntry { tenant, node, .. } = entry; - if let Some(&used_size) = used_size_per_tenant.get(&tenant) { - if used_size <= max_size { + if let Some(used_size) = self.tenant_char_count.get(&tenant) { + if *used_size <= max_size { continue; } + } + + // Decrement when removing tenant from node + if node.tenant_last_access_time.contains_key(&tenant) { + self.tenant_char_count + .entry(tenant.clone()) + .and_modify(|count| { + if *count > 0 { + *count -= node.text.read().unwrap().chars().count(); + } + }); + } - // Update used size - if let Some(size) = used_size_per_tenant.get_mut(&tenant) { - *size -= node.text.read().unwrap().chars().count(); + // Remove tenant from node + node.tenant_last_access_time.remove(&tenant); + + // Remove empty nodes + if node.children.is_empty() && node.tenant_last_access_time.is_empty() { + if let Some(parent) = node.parent.write().unwrap().as_ref() { + let first_char = node.text.read().unwrap().chars().next().unwrap(); + parent.children.remove(&first_char); } + } - // Decrement when removing tenant from node - if node.tenant_last_access_time.contains_key(&tenant) { - self.tenant_char_count - .entry(tenant.clone()) - .and_modify(|count| { - if *count > 0 { - *count -= node.text.read().unwrap().chars().count(); - } - }); + // Add parent to queue if it becomes a leaf + if let Some(parent) = node.parent.read().unwrap().as_ref() { + if Tree::leaf_of(parent).contains(&tenant) { + if let Some(timestamp) = parent.tenant_last_access_time.get(&tenant) { + pq.push(Reverse(EvictionEntry { + timestamp: *timestamp, + tenant: tenant.clone(), + node: Arc::clone(parent), + })); + } } + }; + } + + info!("After eviction - Used size per tenant:"); + for entry in self.tenant_char_count.iter() { + info!("Tenant: {}, Size: {}", entry.key(), entry.value()); + } + } - // Remove tenant from node - node.tenant_last_access_time.remove(&tenant); + pub fn remove_tenant(&self, tenant: &str) { + // 1. Find all the leaves for the tenant + let mut stack = vec![Arc::clone(&self.root)]; + let mut queue = VecDeque::new(); - // Remove empty nodes - if node.children.is_empty() && node.tenant_last_access_time.is_empty() { - if let Some(parent) = node.parent.write().unwrap().as_ref() { - let first_char = node.text.read().unwrap().chars().next().unwrap(); - parent.children.remove(&first_char); - } + while let Some(curr) = stack.pop() { + for child in curr.children.iter() { + stack.push(Arc::clone(child.value())); + } + + if Tree::leaf_of(&curr).contains(&tenant.to_string()) { + queue.push_back(Arc::clone(&curr)); + } + } + + // 2. Start from the leaves and traverse up to the root, removing the tenant from each node + while let Some(curr) = queue.pop_front() { + // remove tenant from node + curr.tenant_last_access_time.remove(&tenant.to_string()); + + // remove empty nodes + if curr.children.is_empty() && curr.tenant_last_access_time.is_empty() { + if let Some(parent) = curr.parent.read().unwrap().as_ref() { + let first_char = curr.text.read().unwrap().chars().next().unwrap(); + parent.children.remove(&first_char); } + } - // Add parent to queue if it becomes a leaf - if let Some(parent) = node.parent.read().unwrap().as_ref() { - if Tree::leaf_of(parent).contains(&tenant) { - if let Some(timestamp) = parent.tenant_last_access_time.get(&tenant) { - pq.push(Reverse(EvictionEntry { - timestamp: *timestamp, - tenant: tenant.clone(), - node: Arc::clone(parent), - })); - } - } + // add parent to queue if it becomes a leaf + if let Some(parent) = curr.parent.read().unwrap().as_ref() { + if Tree::leaf_of(parent).contains(&tenant.to_string()) { + queue.push_back(Arc::clone(&parent)); } } } - info!("After eviction - Used size per tenant:"); - for (tenant, size) in &used_size_per_tenant { - info!("Tenant: {}, Size: {}", tenant, size); - } + // 3. Remove the tenant from the tenant_char_count map + self.tenant_char_count.remove(&tenant.to_string()); } pub fn get_tenant_char_count(&self) -> HashMap { @@ -687,7 +714,7 @@ mod tests { ); // Test eviction - tree.evict_tenant_data(3); // This should evict tenants with more than 3 chars + tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars let post_eviction_smallest = tree.get_smallest_tenant(); println!("Smallest tenant after eviction: {}", post_eviction_smallest); @@ -768,7 +795,7 @@ mod tests { ); // Phase 4: Eviction test - tree.evict_tenant_data(10); + tree.evict_tenant_by_size(10); let computed_sizes = tree.get_used_size_per_tenant(); let maintained_counts: HashMap = tree @@ -1146,7 +1173,7 @@ mod tests { assert_eq!(sizes_before.get("tenant2").unwrap(), &10); // "hello" + "world" = 10 // Evict - should remove "hello" from tenant2 as it's the oldest - tree.evict_tenant_data(max_size); + tree.evict_tenant_by_size(max_size); tree.pretty_print(); @@ -1182,7 +1209,7 @@ mod tests { } // Perform eviction - tree.evict_tenant_data(max_size); + tree.evict_tenant_by_size(max_size); // Check sizes after eviction let sizes_after = tree.get_used_size_per_tenant(); @@ -1214,7 +1241,7 @@ mod tests { let handle = thread::spawn(move || { while start_time.elapsed() < test_duration { // Run eviction - tree.evict_tenant_data(max_size); + tree.evict_tenant_by_size(max_size); // Sleep for 5 seconds thread::sleep(Duration::from_secs(5)); @@ -1259,7 +1286,7 @@ mod tests { } // final eviction - tree.evict_tenant_data(max_size); + tree.evict_tenant_by_size(max_size); // Final size check let final_sizes = tree.get_used_size_per_tenant(); @@ -1366,4 +1393,91 @@ mod tests { assert_eq!(tree.prefix_match_tenant("hello", "tenant3"), ""); // Non-existent tenant assert_eq!(tree.prefix_match_tenant("help", "tenant3"), ""); // Non-existent tenant } + + #[test] + fn test_simple_tenant_eviction() { + let tree = Tree::new(); + + // Insert data for multiple tenants + tree.insert("hello", "tenant1"); + tree.insert("world", "tenant1"); + tree.insert("hello", "tenant2"); + tree.insert("help", "tenant2"); + + // Verify initial state + let initial_sizes = tree.get_used_size_per_tenant(); + assert_eq!(initial_sizes.get("tenant1").unwrap(), &10); // "hello" + "world" + assert_eq!(initial_sizes.get("tenant2").unwrap(), &6); // "hello" + "p" + + // Evict tenant1 + tree.remove_tenant("tenant1"); + + // Verify after eviction + let final_sizes = tree.get_used_size_per_tenant(); + assert!( + !final_sizes.contains_key("tenant1"), + "tenant1 should be completely removed" + ); + assert_eq!( + final_sizes.get("tenant2").unwrap(), + &6, + "tenant2 should be unaffected" + ); + + // Verify tenant1's data is inaccessible + assert_eq!(tree.prefix_match_tenant("hello", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("world", "tenant1"), ""); + + // Verify tenant2's data is still accessible + assert_eq!(tree.prefix_match_tenant("hello", "tenant2"), "hello"); + assert_eq!(tree.prefix_match_tenant("help", "tenant2"), "help"); + } + + #[test] + fn test_complex_tenant_eviction() { + let tree = Tree::new(); + + // Create a more complex tree structure with shared prefixes + tree.insert("apple", "tenant1"); + tree.insert("application", "tenant1"); + tree.insert("apple", "tenant2"); + tree.insert("appetite", "tenant2"); + tree.insert("banana", "tenant1"); + tree.insert("banana", "tenant2"); + tree.insert("ball", "tenant2"); + + // Verify initial state + let initial_sizes = tree.get_used_size_per_tenant(); + println!("Initial sizes: {:?}", initial_sizes); + tree.pretty_print(); + + // Evict tenant1 + tree.remove_tenant("tenant1"); + + // Verify final state + let final_sizes = tree.get_used_size_per_tenant(); + println!("Final sizes: {:?}", final_sizes); + tree.pretty_print(); + + // Verify tenant1 is completely removed + assert!( + !final_sizes.contains_key("tenant1"), + "tenant1 should be completely removed" + ); + + // Verify all tenant1's data is inaccessible + assert_eq!(tree.prefix_match_tenant("apple", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("application", "tenant1"), ""); + assert_eq!(tree.prefix_match_tenant("banana", "tenant1"), ""); + + // Verify tenant2's data is intact + assert_eq!(tree.prefix_match_tenant("apple", "tenant2"), "apple"); + assert_eq!(tree.prefix_match_tenant("appetite", "tenant2"), "appetite"); + assert_eq!(tree.prefix_match_tenant("banana", "tenant2"), "banana"); + assert_eq!(tree.prefix_match_tenant("ball", "tenant2"), "ball"); + + // Verify the tree structure is still valid for tenant2 + let tenant2_size = final_sizes.get("tenant2").unwrap(); + assert_eq!(tenant2_size, &(5 + 5 + 6 + 2)); // "apple" + "etite" + "banana" + "ll" + } } diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index fcad493c59c..3696a1c35f4 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,5 +1,14 @@ -# Kill all SGLang processes and free the GPU memory. +#!/bin/bash -kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') -kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') -kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') +# Show current GPU status +nvidia-smi + +# Clean SGLang processes +kill -9 $(ps aux | grep 'sglang::' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') 2>/dev/null + +# Clean all GPU processes if any argument is provided +if [ $# -gt 0 ]; then + kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null +fi diff --git a/test/srt/configs/random_config.yaml b/test/srt/configs/random_config.yaml new file mode 100644 index 00000000000..eae8c27f41c --- /dev/null +++ b/test/srt/configs/random_config.yaml @@ -0,0 +1,25 @@ +tasks: + - name: sglang-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: vllm-128-4 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: vllm-2000-100 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: vllm-4000-200 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 + - name: vllm-32000-100 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 diff --git a/test/srt/configs/random_flashinfer_vs_triton_config.yaml b/test/srt/configs/random_flashinfer_vs_triton_config.yaml new file mode 100644 index 00000000000..7f4a386ddcf --- /dev/null +++ b/test/srt/configs/random_flashinfer_vs_triton_config.yaml @@ -0,0 +1,25 @@ +tasks: + - name: sglang-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-triton-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-triton-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-triton-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 + - name: sglang-triton-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 diff --git a/test/srt/configs/sharegpt_config.yaml b/test/srt/configs/sharegpt_config.yaml new file mode 100644 index 00000000000..a80b96c8eae --- /dev/null +++ b/test/srt/configs/sharegpt_config.yaml @@ -0,0 +1,7 @@ +tasks: + - name: sglang-benchmark + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate 16 + - name: vllm-benchmark + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate 16 diff --git a/test/srt/experiment_runner.py b/test/srt/experiment_runner.py new file mode 100644 index 00000000000..c4966dc77ba --- /dev/null +++ b/test/srt/experiment_runner.py @@ -0,0 +1,359 @@ +import argparse +import logging +import os +import queue +import re +import subprocess +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import psutil +import requests +import yaml + + +@dataclass +class ServerConfig: + command: str + process_names: List[str] + default_port: int + + +@dataclass +class TaskConfig: + server_cmd: str + client_cmd: str + name: Optional[str] = None + server_type: Optional[str] = None + + +@dataclass +class TaskResult: + name: str + success: bool + output: str + runtime: float + timestamp: str + + +SERVER_DEFAULTS = { + "sglang": ServerConfig( + command="sglang.launch_server", + process_names=["sglang.launch_server"], + default_port=30000, + ), + "vllm": ServerConfig( + command="vllm.entrypoints.openai.api_server", + process_names=["vllm.entrypoints.openai.api_server"], + default_port=8000, + ), +} + + +def parse_key_info(output: str) -> str: + """Extract and format key information from the output""" + key_info = [] + + # Extract Args namespace + args_match = re.search(r"Namespace\(.*?\)", output, re.DOTALL) + if args_match: + key_info.append(args_match.group(0)) + + # Extract input/output token counts + token_matches = re.findall(r"#(Input|Output) tokens: \d+", output) + key_info.extend(token_matches) + + # Extract benchmark result section + result_match = re.search( + r"============ Serving Benchmark Result ============.*?={50,}", + output, + re.DOTALL, + ) + if result_match: + key_info.append(result_match.group(0)) + + return "\n\n".join(key_info) + + +def extract_port_from_command(cmd: str, server_type: str) -> int: + port_match = re.search(r"--port[= ](\d+)", cmd) + if port_match: + return int(port_match.group(1)) + return SERVER_DEFAULTS.get(server_type, ServerConfig("", [], 8000)).default_port + + +def detect_server_type(cmd: str) -> str: + for server_type, config in SERVER_DEFAULTS.items(): + if config.command in cmd: + return server_type + return "unknown" + + +def stream_output( + process: subprocess.Popen, prefix: str, logger: logging.Logger +) -> queue.Queue: + output_queue = queue.Queue() + + def stream_pipe(pipe, prefix): + for line in iter(pipe.readline, ""): + if prefix == "CLIENT": + output_queue.put(line.rstrip()) + logger.debug(f"{prefix} | {line.rstrip()}") + + stdout_thread = threading.Thread( + target=stream_pipe, args=(process.stdout, prefix), daemon=True + ) + stderr_thread = threading.Thread( + target=stream_pipe, args=(process.stderr, prefix), daemon=True + ) + + stdout_thread.start() + stderr_thread.start() + return output_queue, (stdout_thread, stderr_thread) + + +class ProcessManager: + def __init__(self): + self.server_process: Optional[subprocess.Popen] = None + self.client_process: Optional[subprocess.Popen] = None + self.logger = logging.getLogger(__name__) + + def start_process( + self, command: str, prefix: str + ) -> Tuple[subprocess.Popen, queue.Queue]: + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + output_queue, threads = stream_output(process, prefix, self.logger) + return process, output_queue, threads + + def kill_process_tree(self, process: subprocess.Popen): + try: + parent = psutil.Process(process.pid) + children = parent.children(recursive=True) + + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + parent.kill() + gone, alive = psutil.wait_procs(children + [parent], timeout=3) + + for p in alive: + try: + p.kill() + except psutil.NoSuchProcess: + pass + + except psutil.NoSuchProcess: + pass + + def cleanup(self, process_names: List[str]): + if self.client_process: + self.kill_process_tree(self.client_process) + self.client_process = None + + if self.server_process: + self.kill_process_tree(self.server_process) + self.server_process = None + + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + try: + cmdline = " ".join(proc.cmdline()) + if any(name in cmdline for name in process_names): + proc.kill() + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + +class ExperimentRunner: + def __init__(self): + self.process_manager = ProcessManager() + self.logger = logging.getLogger(__name__) + + def wait_for_server(self, port: int, timeout: int = 300) -> bool: + start_time = time.time() + + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/health") + if response.status_code == 200: + self.logger.debug(f"Server ready on port {port}") + return True + except requests.RequestException: + time.sleep(2) + return False + + def run_task(self, config: TaskConfig) -> TaskResult: + start_time = time.time() + client_output = [] + + try: + if not config.server_type: + config.server_type = detect_server_type(config.server_cmd) + + server_config = SERVER_DEFAULTS.get(config.server_type) + if not server_config: + raise ValueError(f"Unknown server type: {config.server_type}") + + port = extract_port_from_command(config.server_cmd, config.server_type) + + self.process_manager.cleanup(server_config.process_names) + + self.logger.debug(f"Starting server: {config.name}") + self.process_manager.server_process, _, server_threads = ( + self.process_manager.start_process(config.server_cmd, "SERVER") + ) + + if not self.wait_for_server(port): + raise TimeoutError("Server startup timeout") + + time.sleep(10) + + self.logger.debug("Starting client") + self.process_manager.client_process, output_queue, client_threads = ( + self.process_manager.start_process(config.client_cmd, "CLIENT") + ) + + returncode = self.process_manager.client_process.wait() + + while True: + try: + line = output_queue.get_nowait() + client_output.append(line) + except queue.Empty: + break + + if returncode != 0: + raise RuntimeError(f"Client failed with code {returncode}") + + # Parse and format the output + full_output = "\n".join(client_output) + formatted_output = parse_key_info(full_output) + + return TaskResult( + name=config.name, + success=True, + output=formatted_output, + runtime=time.time() - start_time, + timestamp=datetime.now().isoformat(), + ) + + except Exception as e: + return TaskResult( + name=config.name, + success=False, + output=str(e), + runtime=time.time() - start_time, + timestamp=datetime.now().isoformat(), + ) + + finally: + if config.server_type in SERVER_DEFAULTS: + self.process_manager.cleanup( + SERVER_DEFAULTS[config.server_type].process_names + ) + time.sleep(10) + + +def load_config(config_path: str) -> List[TaskConfig]: + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) + + configs = [] + for idx, entry in enumerate(config_data.get("tasks", [])): + if not isinstance(entry, dict): + raise ValueError(f"Invalid entry at index {idx}") + + config = TaskConfig( + server_cmd=entry.get("server_cmd"), + client_cmd=entry.get("client_cmd"), + name=entry.get("name", f"task-{idx+1}"), + server_type=entry.get("server_type"), + ) + + if not config.server_cmd or not config.client_cmd: + raise ValueError(f"Missing commands in {config.name}") + + configs.append(config) + + return configs + + +def setup_logging(debug: bool = False): + level = logging.DEBUG if debug else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(), logging.FileHandler("experiment.log")], + ) + + +def format_results(results: List[TaskResult]) -> str: + """Format experiment results in Markdown for GitHub step summary.""" + output = ["# Experiment Results\n"] + + for result in results: + output.append(f"## {result.name}") + output.append(f"**Status**: {'✅ Success' if result.success else '❌ Failed'}") + output.append(f"**Runtime**: {result.runtime:.2f} seconds") + output.append(f"**Timestamp**: {result.timestamp}") + output.append("\n**Output**:\n```") + output.append(result.output) + output.append("```\n") + + return "\n".join(output) + + +def write_in_github_step_summary(results: List[TaskResult]): + """Write formatted results to GitHub step summary.""" + if not os.environ.get("GITHUB_STEP_SUMMARY"): + logging.warning("GITHUB_STEP_SUMMARY environment variable not set") + return + + formatted_content = format_results(results) + with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: + f.write(formatted_content) + + +def main(): + parser = argparse.ArgumentParser(description="Experiment Runner") + parser.add_argument( + "--config", type=str, required=True, help="Path to YAML config file" + ) + parser.add_argument("--debug", action="store_true", help="Enable debug output") + args = parser.parse_args() + + setup_logging(args.debug) + logger = logging.getLogger(__name__) + results = [] + + try: + configs = load_config(args.config) + runner = ExperimentRunner() + + for config in configs: + logger.info(f"Running {config.name}") + result = runner.run_task(config) + results.append(result) + + write_in_github_step_summary(results) + except Exception as e: + logger.error(f"Error: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 89fdacfa231..7737951f824 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -15,6 +15,7 @@ "test_double_sparsity.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", + "test_fused_moe.py", "test_get_weights_by_name.py", "test_gguf.py", "test_input_embeddings.py", diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 34a7b6c9670..b882f12f9df 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -125,7 +125,7 @@ def test_online_latency_default(self): if is_in_ci(): write_github_step_summary( f"### test_online_latency_default\n" - f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} token/s\n' + f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' ) self.assertLess(res["median_e2e_latency_ms"], 12000) self.assertLess(res["median_ttft_ms"], 86) diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py new file mode 100644 index 00000000000..7b50c551a82 --- /dev/null +++ b/test/srt/test_fused_moe.py @@ -0,0 +1,126 @@ +import unittest + +import torch +from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe + + +class TestFusedMOE(unittest.TestCase): + NUM_EXPERTS = [8, 64] + TOP_KS = [2, 6] + + def torch_naive_moe(self, a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[ + i + ].transpose(0, 1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): + if use_fp8_w8a8: + # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 + capability = torch.cuda.get_device_capability() + if not (capability[0] >= 9 or capability == (8, 9)): + return + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + score = torch.randn((m, e), device="cuda", dtype=dtype) + + w1_scale = torch.randn(e, dtype=torch.float32, device="cuda") + w2_scale = torch.randn(e, dtype=torch.float32, device="cuda") + a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") + a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") + + sglang_output = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + vllm_output = fused_moe_vllm( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) + + else: + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + torch_output = self.torch_naive_moe(a, w1, w2, score, topk) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + def test_various_configurations(self): + m_values = [1, 33, 64, 222, 1024 * 128] + n_values = [128, 1024, 2048] + k_values = [128, 511, 1024] + dtypes = [torch.float16, torch.bfloat16] + fp8_modes = [False, True] + + for m in m_values: + for n in n_values: + for k in k_values: + for e in self.NUM_EXPERTS: + for topk in self.TOP_KS: + for dtype in dtypes: + for use_fp8_w8a8 in fp8_modes: + with self.subTest( + m=m, + n=n, + k=k, + e=e, + topk=topk, + dtype=dtype, + fp8=use_fp8_w8a8, + ): + self._test_case( + m, + n, + k, + e, + topk, + dtype, + use_fp8_w8a8=use_fp8_w8a8, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 28acdabd9d0..adb5c18fbe2 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -1,5 +1,6 @@ """ -python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate +python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate +python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate """ import json @@ -11,38 +12,50 @@ from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) +def setup_class(cls, backend: str, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + backend, + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + class TestJSONConstrainedOutlinesBackend(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string", "pattern": "^[\\w]+$"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=300, - other_args=[ - "--max-running-requests", - "10", - "--grammar-backend", - "outlines", - ], - ) + setup_class(cls, backend="outlines", disable_overlap=False) + cls.check_jump_forward = False @classmethod def tearDownClass(cls): @@ -82,13 +95,6 @@ def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1) self.assertIsInstance(js_obj["name"], str) self.assertIsInstance(js_obj["population"], int) - # Make sure jump forward is triggered - # NOTE: This is skipped because overlap scheduler does not support jump forward - # self.assertGreater( - # ret["meta_info"]["completion_tokens"], - # ret["meta_info"]["completion_tokens_wo_jump_forward"], - # ) - def test_json_generate(self): self.run_decode(json_schema=self.json_schema) @@ -126,32 +132,18 @@ def test_mix_json_and_other(self): list(executor.map(self.run_decode, json_schemas)) +class TestJumpForwardOutlinesBackend(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, backend="outlines", disable_overlap=True) + cls.check_jump_forward = True + + class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } - ) - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=300, - other_args=[ - "--max-running-requests", - "10", - "--grammar-backend", - "xgrammar", - ], - ) + setup_class(cls, backend="xgrammar", disable_overlap=False) + cls.check_jump_forward = False if __name__ == "__main__": diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py new file mode 100644 index 00000000000..4d9fd435edb --- /dev/null +++ b/test/srt/test_moe_ep.py @@ -0,0 +1,113 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEpMoE(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--ep-size", + "2", + "--enable-ep-moe", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +class TestEpMoEFP8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--ep-size", + "2", + "--enable-ep-moe", + "--quantization", + "fp8", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index a985c8dda9e..7479b646837 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -188,7 +188,7 @@ def test_8_engine_offline_throughput(self): ) bench_args = BenchArgs(num_prompts=10) result = throughput_test(server_args=server_args, bench_args=bench_args) - self.assertGreater(result["total_throughput"], 3500) + self.assertGreater(result["total_throughput"], 3000) if __name__ == "__main__": diff --git a/test/srt/test_torch_compile_moe.py b/test/srt/test_torch_compile_moe.py index 89d4ed6bdf9..fb78dd7f4b8 100644 --- a/test/srt/test_torch_compile_moe.py +++ b/test/srt/test_torch_compile_moe.py @@ -14,7 +14,7 @@ ) -class TestTorchCompile(unittest.TestCase): +class TestTorchCompileMoe(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST @@ -23,7 +23,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile", "--torch-compile-max-bs", "1"], + other_args=["--enable-torch-compile", "--torch-compile-max-bs", "8"], ) @classmethod diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 44abfd61bd7..048d27c5658 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -182,6 +182,7 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): seq_len = 10 # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) + num_kv_splits = 8 # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -195,12 +196,11 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") - b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( - (H_Q, total_tokens), - dtype=dtype, + (B, H_Q, num_kv_splits, D + 1), + dtype=torch.float32, device="cuda", ) @@ -211,10 +211,9 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - seq_len, + num_kv_splits, sm_scale, ) @@ -235,9 +234,10 @@ def test_decode_attention(self): def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): dtype = torch.bfloat16 - seq_len = 10 # This represents the number of tokens already in the sequence + seq_len = 128 # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) + num_kv_splits = 8 # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -247,17 +247,16 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") # o will have the same shape as q - o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") - o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") - b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( - (H_Q, total_tokens), - dtype=dtype, + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, device="cuda", ) @@ -268,13 +267,18 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - seq_len, + num_kv_splits, sm_scale, ) + attn_logits1 = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + device="cuda", + ) + decode_attention_fwd_grouped( q, k_buffer, @@ -282,21 +286,22 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): o_grouped, req_to_token, b_req_idx, - b_start_loc, b_seq_len, - attn_logits, - seq_len, + attn_logits1, + num_kv_splits, sm_scale, ) cos_sim = torch.nn.functional.cosine_similarity( o.flatten(), o_grouped.flatten(), dim=0 ) + print(cos_sim.item()) self.assertTrue(cos_sim.item() > 0.99) self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) def test_grouped_decode_attention(self): configs = [ + (2, 16, 16, 64, 64), (2, 16, 1, 64, 64), (2, 64, 1, 13, 13), (2, 128, 1, 80, 80),