Skip to content

Commit

Permalink
refactor vllm/inputs/data.py to use newly defined functions
Browse files Browse the repository at this point in the history
Signed-off-by: Tobias Pitters <[email protected]>
  • Loading branch information
CloseChoice committed Jan 2, 2025
1 parent d6add6a commit 5986992
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class SingletonInputsAdapter:
def prompt(self) -> Optional[str]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
return inputs.get("prompt")

assert_never(inputs) # type: ignore[arg-type]
Expand All @@ -277,7 +277,7 @@ def prompt(self) -> Optional[str]:
def prompt_token_ids(self) -> List[int]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
return inputs.get("prompt_token_ids", [])

assert_never(inputs) # type: ignore[arg-type]
Expand All @@ -286,7 +286,7 @@ def prompt_token_ids(self) -> List[int]:
def token_type_ids(self) -> List[int]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
return inputs.get("token_type_ids", [])

assert_never(inputs) # type: ignore[arg-type]
Expand All @@ -295,7 +295,7 @@ def token_type_ids(self) -> List[int]:
def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
if is_token_inputs(inputs) or is_multimodal_inputs(inputs):
return None

assert_never(inputs) # type: ignore[arg-type]
Expand All @@ -304,10 +304,9 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
def multi_modal_data(self) -> "MultiModalDataDict":
inputs = self.inputs

if inputs["type"] == "token":
if is_token_inputs(inputs):
return inputs.get("multi_modal_data", {})

if inputs["type"] == "multimodal":
elif is_multimodal_inputs(inputs):
return inputs.get("mm_kwargs", {})

assert_never(inputs) # type: ignore[arg-type]
Expand All @@ -316,10 +315,9 @@ def multi_modal_data(self) -> "MultiModalDataDict":
def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
inputs = self.inputs

if inputs["type"] == "token":
if is_token_inputs(inputs):
return inputs.get("multi_modal_inputs", {})

if inputs["type"] == "multimodal":
elif is_multimodal_inputs(inputs):
return inputs.get("mm_kwargs", {})

assert_never(inputs) # type: ignore[arg-type]
Expand All @@ -331,7 +329,6 @@ def multi_modal_hashes(self) -> List[str]:
if is_token_inputs(inputs):
return inputs.get("multi_modal_hashes", [])
elif is_multimodal_inputs(inputs):
# only the case when we use MultiModalInputsV2
return inputs.get("mm_hashes", [])

assert_never(inputs) # type: ignore[arg-type]
Expand All @@ -340,10 +337,9 @@ def multi_modal_hashes(self) -> List[str]:
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs

if inputs["type"] == "token":
if is_token_inputs(inputs):
return inputs.get("multi_modal_placeholders", {})

if inputs["type"] == "multimodal":
elif is_multimodal_inputs(inputs):
return inputs.get("mm_placeholders", {})

assert_never(inputs) # type: ignore[arg-type]
Expand All @@ -352,10 +348,9 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
def mm_processor_kwargs(self) -> Dict[str, Any]:
inputs = self.inputs

if inputs["type"] == "token":
if is_token_inputs(inputs):
return inputs.get("mm_processor_kwargs", {})

if inputs["type"] == "multimodal":
elif is_multimodal_inputs(inputs):
return {}

assert_never(inputs) # type: ignore[arg-type]
Expand Down

0 comments on commit 5986992

Please sign in to comment.