Skip to content

Commit

Permalink
[mypy] Forward pass function type hints in lora (#11740)
Browse files Browse the repository at this point in the history
Signed-off-by: lucast2021 <[email protected]>
Co-authored-by: lucast2021 <[email protected]>
  • Loading branch information
lucas-tucker and lucast2021 authored Jan 6, 2025
1 parent 022c5c6 commit 9c74971
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
12 changes: 9 additions & 3 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ def __init__(self, base_layer: ReplicatedLinear) -> None:
self.output_size = self.base_layer.output_size
self.n_slices = 1

def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of ReplicatedLinearWithLoRA
Args:
Expand Down Expand Up @@ -496,7 +498,9 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
bias = bias[start_idx:end_idx]
return bias

def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of ColumnParallelLinear
Args:
Expand Down Expand Up @@ -833,7 +837,9 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias

def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of RowParallelLinear
Args:
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

import safetensors.torch
import torch
Expand Down Expand Up @@ -219,6 +219,7 @@ def from_local_checkpoint(

config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
unexpected_modules: List[Union[list[str], str]]
if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self, x: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
Expand Down

0 comments on commit 9c74971

Please sign in to comment.