Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypy] Enable following imports for some directories #6681

Merged
merged 13 commits into from
Jul 31, 2024
Merged
31 changes: 13 additions & 18 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,17 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy tests --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/inputs --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/platforms --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/spec_decode --follow-imports skip
mypy vllm/worker --follow-imports skip
mypy

30 changes: 13 additions & 17 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,19 @@ echo 'vLLM yapf: Done'

# Run mypy
echo 'vLLM mypy:'
mypy tests --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy tests --follow-imports skip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new formatter and github action doesn't pass --config-file, is this intended?

Copy link
Member Author

@DarkLight1337 DarkLight1337 Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The TOML file is already being used by default so there is no need to specify it explicitly.

mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/spec_decode --follow-imports skip
mypy vllm/worker --follow-imports skip
mypy


# If git diff returns a file that is in the skip list, the file may be checked anyway:
Expand Down
18 changes: 16 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,23 @@ python_version = "3.8"

ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "skip"
follow_imports = "silent"

files = "vllm"
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from format.sh and mypy.yaml
files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
"vllm/platforms",
"vllm/server",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
Expand Down
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
Expand Down
62 changes: 49 additions & 13 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,33 @@ def _reshape_activation_tensor(
x2 = x2.reshape(num, d)
return x1, x2

@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out)

@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none")

@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")

@staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))

@staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))

# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:

@staticmethod
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
Expand Down Expand Up @@ -78,12 +84,21 @@ def paged_attention_v1(
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v1(out, query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes)
torch.xpu.paged_attention_v1( # type: ignore
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)

@staticmethod
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
Expand Down Expand Up @@ -119,13 +134,24 @@ def paged_attention_v2(
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, block_tables,
context_lens, scale, block_size,
max_context_len, alibi_slopes)
torch.xpu.paged_attention_v2( # type: ignore
out,
exp_sum,
max_logits,
tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
)

@staticmethod
def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
Expand Down Expand Up @@ -158,6 +184,7 @@ def rotary_embedding(
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)

@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
Expand Down Expand Up @@ -189,17 +216,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)

@staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp)

@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True)
input.copy_(tmp)

@staticmethod
def varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand All @@ -222,6 +252,7 @@ def varlen_attention(
softmax_scale, zero_tensors,
is_causal, return_softmax, gen_)

@staticmethod
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -240,8 +271,13 @@ def reshape_and_cache(
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)

@staticmethod
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping)
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
30 changes: 15 additions & 15 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
super().__init__(capacity)
self.deactivate_fn = deactivate_fn

def _on_remove(self, key: Hashable, value: T):
def _on_remove(self, key: Hashable, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
Expand Down Expand Up @@ -59,46 +59,46 @@ def __len__(self) -> int:

@property
@abstractmethod
def adapter_slots(self):
...
def adapter_slots(self) -> int:
raise NotImplementedError

@property
@abstractmethod
def capacity(self):
...
def capacity(self) -> int:
raise NotImplementedError

@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...
raise NotImplementedError

@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...
raise NotImplementedError

@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

@abstractmethod
def remove_all_adapters(self):
...
def remove_all_adapters(self) -> None:
raise NotImplementedError

@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...
raise NotImplementedError

@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...
raise NotImplementedError

@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
10 changes: 5 additions & 5 deletions vllm/adapter_commons/request.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass


@dataclass
class AdapterRequest:
class AdapterRequest(ABC):
"""
Base class for adapter requests.
"""

@property
@abstractmethod
def adapter_id(self):
...
def adapter_id(self) -> int:
raise NotImplementedError

def __post_init__(self):
def __post_init__(self) -> None:
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")

Expand Down
14 changes: 7 additions & 7 deletions vllm/adapter_commons/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ def __init__(self, device: torch.device):
@property
@abstractmethod
def is_enabled(self) -> bool:
...
raise NotImplementedError

@abstractmethod
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
...
raise NotImplementedError

@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
...
raise NotImplementedError

@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

@abstractmethod
def remove_all_adapters(self):
...
def remove_all_adapters(self) -> None:
raise NotImplementedError

@abstractmethod
def list_adapters(self) -> Set[int]:
...
raise NotImplementedError
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def __init__(
backend)

self._verify_args()
self.rank = 0
self.rank: int = 0

@property
def use_ray(self) -> bool:
Expand Down Expand Up @@ -849,6 +849,7 @@ def _verify_args(self) -> None:


class DeviceConfig:
device: Optional[torch.device]

def __init__(self, device: str = "auto") -> None:
if device == "auto":
Expand Down
Loading
Loading