Skip to content

Commit

Permalink
Deepseek V3 support (#364)
Browse files Browse the repository at this point in the history
* Changing the hard coded datatype to see if it's enough for the model to work

* Picking the upstrteam moe kernel version

* make upstream fix for v3 also works for rocm v2

* Conditional fnuz dtype

* Requantizing from fn to fnuz

* Requantizing moe as well

* Actually requantizing moe weights

* Conditional requantization and assert on padding in block quant

* Format

---------

Co-authored-by: charlifu <[email protected]>
  • Loading branch information
gshtras and charlifu authored Jan 16, 2025
1 parent 8bd76fb commit c5a9406
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 36 deletions.
104 changes: 71 additions & 33 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
__syncthreads();

// For each expert we accumulate the token counts from the different threads.
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
tokens_cnts[index(num_experts, 0, eid)] = 0;
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, eid)] +=
tokens_cnts[index(num_experts, i - 1, eid)];
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}

Expand All @@ -83,9 +83,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
expert_ids[i / block_size] = eid;
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}

Expand Down Expand Up @@ -140,11 +141,11 @@ __global__ void moe_align_block_size_global_mem_kernel(
__syncthreads();

// For each expert we accumulate the token counts from the different threads.
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
tokens_cnts[index(num_experts, 0, eid)] = 0;
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, eid)] +=
tokens_cnts[index(num_experts, i - 1, eid)];
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}

Expand All @@ -168,9 +169,10 @@ __global__ void moe_align_block_size_global_mem_kernel(
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
expert_ids[i / block_size] = eid;
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}

Expand Down Expand Up @@ -221,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = WARP_SIZE;
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});

// If we have very large number of experts, we can no longer use shared
// memory.
// TODO(simon): the right solution should be calculating the exact right
// amount of shared memory and use that. The num_experts >= 256 is just a
// temporary solution to unblock Deepseek V3.
if (num_experts >= 96) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

const int32_t mem_tokens_cnts =
((num_experts + 1) * num_experts) * sizeof(int32_t);
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
// allocate global memory
int32_t* tokens_cnts;
int32_t* cumsum;
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
cudaMalloc(&cumsum, mem_cumsum);

auto kernel =
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
kernel<<<1, num_thread, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), tokens_cnts, cumsum);
cudaFree(tokens_cnts);
cudaFree(cumsum);
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});
}
}

void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
Expand Down
38 changes: 37 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def create_weights(
weight_loader = extra_weight_attrs.get("weight_loader")

if self.block_quant:
assert not envs.VLLM_FP8_PADDING, (
"FP8 weight padding is not supported in block quantization.")
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
block_n, block_k = (
Expand Down Expand Up @@ -196,8 +198,9 @@ def create_weights(
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype

fp8_dtype = torch.float8_e4m3fn
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
weight_dtype = (fp8_dtype
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)

Expand Down Expand Up @@ -252,6 +255,15 @@ def create_weights(
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm() and not is_navi():
weight, weight_scale, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale)
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale,
requires_grad=False)
return
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
Expand Down Expand Up @@ -512,6 +524,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm() and not is_navi():
w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv,
layer.w13_input_scale)
w2_weight, w2_weight_scale_inv, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale_inv, requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import triton
import triton.language as tl

from vllm.platforms import current_platform
from vllm.utils import is_navi


def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
Expand Down Expand Up @@ -34,10 +37,13 @@ def apply_w8a8_block_fp8_linear(

def input_to_float8(
x: torch.Tensor,
dtype: torch.dtype = torch.float8_e4m3fn
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
if dtype is None:
dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm()
and not is_navi() else torch.float8_e4m3fn)
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
Expand Down Expand Up @@ -125,7 +131,7 @@ def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.float8_e4m3fn,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
Expand All @@ -140,6 +146,9 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
if dtype is None:
dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm()
and not is_navi() else torch.float8_e4m3fn)
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
Expand Down

0 comments on commit c5a9406

Please sign in to comment.