Skip to content

Commit

Permalink
Merge branch 'main' into granite
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Dec 11, 2024
2 parents 268f6a6 + ece7249 commit 580cbfa
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 116 deletions.
71 changes: 66 additions & 5 deletions python/sglang/srt/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,69 @@
Common utilities for torch model parallelism.
"""

from typing import Optional
from typing import Optional, Sequence

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh

try:
from torch.distributed.tensor import DTensor, Shard
import torch.distributed.tensor as dt
except ImportError:
# torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard
import torch.distributed._tensor as dt

from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)


def _shard_tensor(
full_tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Sequence[dt.Shard],
) -> "dt.DTensor":
"""
Locally shards a full tensor based on indicated sharding arrangement, and
returns a DTensor containing the local shard.
.. warning:: This is a private API that is subject to change. It skips the
communication otherwise required by `distribute_tensor`. It is only
applicable to cases where all ranks have the same `full_tensor`. For
example, in distributed inference all ranks load from the same
checkpoint. This API will not check for data equality between ranks, it
is thus user's responsibility to ensure the `full_tensor` is the same
across ranks.
Args:
full_tensor (torch.Tensor): the full tensor to be sharded.
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
DTensor. Must have same dimension as the number of placements.
placements (Sequence[:class:`Shard`]): the placements that
describes how to place the local tensor on DeviceMesh.
Returns:
A :class:`DTensor` object with the shard as its local tensor.
Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
"""
shape, offset = dt._utils.compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return dt.DTensor.from_local(local_tensor, device_mesh, placements)


class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
Expand All @@ -34,7 +78,7 @@ def _partition_linear_fn(self, name, module, device_mesh):
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)

Expand All @@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""

def _partition_linear_fn(self, name, module, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
module.register_parameter(
"weight",
nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
)
if getattr(module, "bias", None) is not None:
# The Linear module has bias
module.register_parameter(
"bias",
nn.Parameter(
dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
),
)

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super(
Expand Down
162 changes: 110 additions & 52 deletions rust/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,6 @@ pub enum PolicyConfig {
},
}

fn get_text_from_request(body: &Bytes, route: &str) -> String {
// convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();

if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}

return "".to_string();
}

impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
// Wait until all workers are healthy
Expand Down Expand Up @@ -204,20 +182,6 @@ impl Router {
})
}

pub fn get_first(&self) -> Option<String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
None
} else {
Some(worker_urls.read().unwrap()[0].clone())
}
}
}
}

fn wait_for_healthy_workers(
worker_urls: &[String],
timeout_secs: u64,
Expand Down Expand Up @@ -271,14 +235,76 @@ impl Router {
}
}

pub async fn dispatch(
fn select_first_worker(&self) -> Result<String, String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
Err("No workers are available".to_string())
} else {
Ok(worker_urls.read().unwrap()[0].clone())
}
}
}
}

async fn send_request(
&self,
client: &reqwest::Client,
req: HttpRequest,
body: Bytes,
worker_url: &str,
route: &str,
) -> HttpResponse {
let text = get_text_from_request(&body, route);
match client.get(format!("{}{}", worker_url, route)).send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e)),
}
}
Err(e) => HttpResponse::InternalServerError().body(format!(
"Failed to send request to worker {}: {}",
worker_url, e
)),
}
}

pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse {
match self.select_first_worker() {
Ok(worker_url) => self.send_request(client, &worker_url, route).await,
Err(e) => HttpResponse::InternalServerError().body(e),
}
}

fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
// convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();

if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}

return "".to_string();
}

// TODO: return Result<String, String> instead of panicking
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
let text = self.get_text_from_request(&body, route);

let worker_url = match self {
Router::RoundRobin {
Expand Down Expand Up @@ -366,12 +392,23 @@ impl Router {
}
};

worker_url
}

async fn send_generate_request(
&self,
client: &reqwest::Client,
req: &HttpRequest,
body: &Bytes,
route: &str,
worker_url: &str,
) -> HttpResponse {
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false);

let res = match client
.post(format!("{}/{}", worker_url.clone(), route))
.post(format!("{}{}", worker_url, route))
.header(
"Content-Type",
req.headers()
Expand Down Expand Up @@ -403,7 +440,7 @@ impl Router {
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(&worker_url) {
if let Some(count) = queue.get_mut(worker_url) {
*count = count.saturating_sub(1);
}
}
Expand All @@ -412,7 +449,7 @@ impl Router {
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.clone();
let worker_url = worker_url.to_string();

HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
Expand All @@ -431,7 +468,7 @@ impl Router {
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
debug!("streaming is done!!")
debug!("Streaming is done!!")
}
}),
)
Expand All @@ -444,7 +481,19 @@ impl Router {
}
}

pub async fn add_worker(&self, worker_url: String) -> Result<String, String> {
pub async fn route_generate_request(
&self,
client: &reqwest::Client,
req: &HttpRequest,
body: &Bytes,
route: &str,
) -> HttpResponse {
let worker_url = self.select_generate_worker(&body, route);
self.send_generate_request(client, req, body, route, &worker_url)
.await
}

pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
let interval_secs = 10; // check every 10 seconds
let timeout_secs = 300; // 5 minutes

Expand All @@ -468,11 +517,11 @@ impl Router {
| Router::CacheAware { worker_urls, .. } => {
info!("Worker {} health check passed", worker_url);
let mut urls = worker_urls.write().unwrap();
if urls.contains(&worker_url) {
if urls.contains(&worker_url.to_string()) {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
urls.push(worker_url.clone());
urls.push(worker_url.to_string());
}
}

Expand All @@ -485,13 +534,16 @@ impl Router {
} = self
{
// Add worker to running queue with initial count of 0
running_queue.lock().unwrap().insert(worker_url.clone(), 0);
running_queue
.lock()
.unwrap()
.insert(worker_url.to_string(), 0);

// Add worker to processed queue with initial count of 0
processed_queue
.lock()
.unwrap()
.insert(worker_url.clone(), 0);
.insert(worker_url.to_string(), 0);

// Add worker to tree
tree.lock().unwrap().insert(&"".to_string(), &worker_url);
Expand Down Expand Up @@ -532,7 +584,7 @@ impl Router {
}
}

pub fn remove_worker(&self, worker_url: String) {
pub fn remove_worker(&self, worker_url: &str) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
Expand All @@ -553,8 +605,14 @@ impl Router {
} = self
{
tree.lock().unwrap().remove_tenant(&worker_url);
running_queue.lock().unwrap().remove(&worker_url);
processed_queue.lock().unwrap().remove(&worker_url);
running_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
processed_queue
.lock()
.unwrap()
.remove(&worker_url.to_string());
info!(
"Removed worker from tree and cleaned up queues: {}",
worker_url
Expand Down
Loading

0 comments on commit 580cbfa

Please sign in to comment.