Skip to content

Commit

Permalink
Merge pull request vllm-project#8 from ri938/organise
Browse files Browse the repository at this point in the history
Organise
  • Loading branch information
ri938 authored Aug 24, 2023
2 parents aaea899 + 878a370 commit 2617c55
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 6 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "quantization/gemm_cuda.h"
#include "gemm_cuda.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
Expand Down
4 changes: 2 additions & 2 deletions vllm/awq_quantization/kernels/setup.py → awq_ext/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
CUDAExtension(
name="awq_inference_engine",
sources=[
"csrc/pybind.cpp",
"csrc/quantization/gemm_cuda_gen.cu",
"awq_kernels/pybind.cpp",
"awq_kernels/gemm_cuda_gen.cu",
],
extra_compile_args=extra_compile_args,
),
Expand Down
Empty file removed vllm/awq_quantization/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn


try:
import awq_inference_engine # with CUDA kernels
except ImportError as ex:
Expand All @@ -21,7 +22,7 @@ def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)


class WQLinear(nn.Module):
class AWQLinear(nn.Module):
def __init__(
self,
w_bit,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers import quant
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.awq_quantization import qmodule

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -141,7 +141,7 @@ def forward(


def get_quantized_layer(in_features, out_features, quant_config):
layer = qmodule.WQLinear(
layer = quant.AWQLinear(
w_bit=quant_config.bits,
group_size=quant_config.group_size,
in_features=in_features,
Expand Down

0 comments on commit 2617c55

Please sign in to comment.