Skip to content

Commit

Permalink
Add more support for intel Gaudi accelerators (#2357)
Browse files Browse the repository at this point in the history
  • Loading branch information
YangQun1 authored Dec 6, 2024
1 parent 34b364e commit 37ee906
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 14 deletions.
16 changes: 13 additions & 3 deletions examples/runtime/engine/offline_batch_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import argparse
import dataclasses

import sglang as sgl
from sglang.srt.server_args import ServerArgs


def main():
def main(
server_args: ServerArgs,
):
# Sample prompts.
prompts = [
"Hello, my name is",
Expand All @@ -13,7 +19,7 @@ def main():
sampling_params = {"temperature": 0.8, "top_p": 0.95}

# Create an LLM.
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
llm = sgl.Engine(**dataclasses.asdict(server_args))

outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
Expand All @@ -25,4 +31,8 @@ def main():
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
main(server_args)
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
sampled_index = torch.multinomial(probs_sort, num_samples=1)
# int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def process_batch_result(self, batch: ScheduleBatch, result):
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):

if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

else: # embedding or reward model
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):

if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

self.stream_output(batch.reqs)
Expand Down
11 changes: 6 additions & 5 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)


@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
# Launch threads
self.input_queue = Queue()
self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
Expand All @@ -97,7 +98,7 @@ def get_memory_pool(self):

def forward_thread_func(self):
try:
with torch.cuda.stream(self.forward_stream):
with torch.get_device_module(self.device).stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
Expand All @@ -122,7 +123,7 @@ def forward_thread_func_(self):

# Create event
self.launch_done = threading.Event()
copy_done = torch.cuda.Event()
copy_done = torch.get_device_module(self.device).Event()

# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
Expand Down Expand Up @@ -190,7 +191,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
)

# A cuda stream sync here to avoid the cuda illegal memory access error.
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()

# Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch

from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_compiler_backend

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,6 +130,9 @@ def alloc(self, need_size: int):
return select_index.to(self.device, non_blocking=True)

def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return

if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
else:
Expand Down Expand Up @@ -234,7 +238,7 @@ def set_kv_buffer(

# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_1[loc] = src_1.to(dtype).view(store_dtype)
dst_2[loc] = src_2.to(dtype).view(store_dtype)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import get_compiler_backend, set_weight_attrs


@torch.compile
@torch.compile(backend=get_compiler_backend())
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
get_nvgpu_memory_capacity,
is_flashinfer_available,
is_hip,
Expand Down Expand Up @@ -158,6 +159,8 @@ def __post_init__(self):
gpu_mem = get_amdgpu_memory_capacity()
elif torch.cuda.is_available():
gpu_mem = get_nvgpu_memory_capacity()
elif self.device == "hpu":
gpu_mem = get_hpu_memory_capacity()
else:
# GPU memory is not known yet or no GPU is available.
gpu_mem = None
Expand Down Expand Up @@ -194,6 +197,10 @@ def __post_init__(self):
self.cuda_graph_max_bs = 160

# Choose kernel backends
if self.device == "hpu":
self.attention_backend = "torch_native"
self.sampling_backend = "pytorch"

if self.attention_backend is None:
self.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
Expand Down
50 changes: 50 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory

elif device == "hpu":
num_gpus = torch.hpu.device_count()
assert gpu_id < num_gpus

if torch.hpu.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
"which may cause useless memory allocation for torch HPU context.",
)

free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()

if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device(device, gpu_id)
Expand Down Expand Up @@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
)


def get_hpu_memory_capacity():
try:
# Run hl-smi and capture the output
result = subprocess.run(
["hl-smi --query | grep 'Total'"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
text=True,
)

if result.returncode != 0:
raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")

# Parse the output to extract memory values in MiB
memory_values = [
float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
]

if not memory_values:
raise ValueError("No GPU memory values found.")

# Return the minimum memory value
return min(memory_values)

except FileNotFoundError:
raise RuntimeError(
"hl-smi not found. Ensure Habana drivers are installed and accessible."
)


# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
Expand Down Expand Up @@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return major, minor


def get_compiler_backend() -> str:
if hasattr(torch, "hpu") and torch.hpu.is_available():
return "hpu_backend"

return "inductor"


sglang_lib = Library("sglang", "FRAGMENT") # noqa


Expand Down

0 comments on commit 37ee906

Please sign in to comment.