Skip to content

Commit

Permalink
[Model] Support InternLM2 Reward models (vllm-project#11571)
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: xcnick <[email protected]>
  • Loading branch information
2 people authored and xcnick committed Dec 31, 2024
1 parent 8d3d9b9 commit 278934c
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
- Example HF Models
- :ref:`LoRA <lora-adapter>`
- :ref:`PP <distributed-serving>`
* - :code:`InternLM2ForRewardModel`
- InternLM2-based
- :code:`internlm/internlm2-1_8b-reward`, :code:`internlm/internlm2-7b-reward`, etc.
- ✅︎
- ✅︎
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class _HfExamplesInfo:
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
Expand Down
60 changes: 59 additions & 1 deletion vllm/model_executor/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
Expand Down Expand Up @@ -433,3 +435,59 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class InternLM2ForRewardModel(InternLM2ForCausalLM):

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model,
):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
model_type=model_type)

for attr in ("output", "logits_processor", "sampler"):
delattr(self, attr)

config = vllm_config.model_config.hf_config
self.v_head = RowParallelLinear(
config.hidden_size,
1,
bias=False,
input_is_parallel=False,
prefix=maybe_prefix(prefix, "v_head"),
)

pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False,
)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
logits, _ = self.v_head(hidden_states)
return logits

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
Expand Down

0 comments on commit 278934c

Please sign in to comment.