Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add support for falcon-11B #5069

Merged
merged 3 commits into from
May 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
Expand Down Expand Up @@ -246,18 +246,26 @@ def __init__(
self.mlp = FalconMLP(config, quant_config)
self.config = config

if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
if (config.num_ln_in_parallel_attn is None
and config.new_decoder_architecture):
config.num_ln_in_parallel_attn = 2

if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
else:
if config.num_ln_in_parallel_attn == 2:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)

self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
Expand All @@ -271,7 +279,7 @@ def forward(
) -> torch.Tensor:
residual = hidden_states

if self.config.new_decoder_architecture:
if self.config.num_ln_in_parallel_attn == 2:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
Expand All @@ -294,6 +302,10 @@ def forward(
residual += attention_output
mlp_layernorm_out = self.post_attention_layernorm(residual)

if (self.config.new_decoder_architecture and self.config.parallel_attn
and self.config.num_ln_in_parallel_attn == 1):
mlp_layernorm_out = attention_layernorm_out

# MLP.
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
if self.reduce_row_parallel_results and mlp_bias is not None:
Expand Down Expand Up @@ -375,7 +387,20 @@ def __init__(
self.config = config
self.quant_config = quant_config
self.transformer = FalconModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
# only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default
self.tie_word_embeddings = (config.tie_word_embeddings
if config.tie_word_embeddings is not None
else True)
if self.tie_word_embeddings:
self.lm_head_weight = self.transformer.word_embeddings.weight
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

Expand Down Expand Up @@ -419,8 +444,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name == "lm_head.weight":
# Falcon uses tied embeddings.
if name == "lm_head.weight" and self.tie_word_embeddings:
# Falcon uses tied embeddings except Falcon-11b.
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
Expand Down
Loading