Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] AttributeError: 'MiniCPM3ForCausalLM' object has no attribute 'get_module_name' #1416

Closed
5 tasks
Lixtt opened this issue Sep 13, 2024 · 7 comments
Closed
5 tasks

Comments

@Lixtt
Copy link

Lixtt commented Sep 13, 2024

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Except for the LLaMa model, there is no "get_module_name" method in the other model configurations, so the LoRA configuration cannot be loaded.

python/sglang/srt/lora/lora_manager.py:106

def init_loras(self):
    # get configs and target modules
    self.configs = {}
    self.origin_target_modules = set()
    for path in self.lora_paths:
        self.configs[path] = LoRAConfig(path)
        self.origin_target_modules = set(self.origin_target_modules) | set(
            self.configs[path].target_modules
        )
    self.target_modules = set(
        [
            self.base_model.get_module_name(module)
            for module in self.origin_target_modules
        ]
    )
    self.target_weights = set(
        [get_stacked_name(module) for module in self.origin_target_modules]
    )

Reproduction

--lora-path openbmb/MiniCPM3-RAG-LoRA

Environment

not important

@merrymercy
Copy link
Contributor

Hi @lixiangtiandashen, this should be easy to fix if you copy these functions to minicpm3.py. Can you contribute a fix?

cc @Ying1123

@Lixtt
Copy link
Author

Lixtt commented Sep 15, 2024

Hi @lixiangtiandashen, this should be easy to fix if you copy these functions to minicpm3.py. Can you contribute a fix?

cc @Ying1123

Oh oh, yes, indeed, I'd love to do this

@upskyy
Copy link
Contributor

upskyy commented Oct 22, 2024

@merrymercy
I also get the same error when serving the gemma2 model as multi lora. Has the error been resolved?
I tested on image lmsysorg/sglang:v0.3.2-cu121.

python3 -m sglang.launch_server --model-path /base_model --tokenizer-path /base_model --lora-paths /lora_model0 /lora_model1  --disable-radix --disable-cuda-graph --max-loras-per-batch 2 --mem-fraction-static 0.5 --random-seed 0 --enable-torch-compile

AttributeError: 'Gemma2ForCausalLM' object has no attribute 'get_module_name'

@Lixtt
Copy link
Author

Lixtt commented Oct 22, 2024

@merrymercy I also get the same error when serving the gemma2 model as multi lora. Has the error been resolved? I tested on image lmsysorg/sglang:v0.3.2-cu121.

python3 -m sglang.launch_server --model-path /base_model --tokenizer-path /base_model --lora-paths /lora_model0 /lora_model1  --disable-radix --disable-cuda-graph --max-loras-per-batch 2 --mem-fraction-static 0.5 --random-seed 0 --enable-torch-compile

AttributeError: 'Gemma2ForCausalLM' object has no attribute 'get_module_name'

Because this issue has not been resolved, only the Llama series can support the Lora model.
I thought it was easy to make changes at that time, but after actually operating it, I found that it was still a bit troublesome.
And there are some things inside that I don't understand either, there are also many works behind me. So I haven't done it yet.

@upskyy
Copy link
Contributor

upskyy commented Oct 22, 2024

@lixiangtiandashen
Thank you for sharing your experience.

@upskyy
Copy link
Contributor

upskyy commented Oct 22, 2024

@Ying1123 @merrymercy

When using the lmsysorg/sglang:v0.3.4 version of the image, a warning appears and the model is loaded.
I think that code handles it.
https://github.com/sgl-project/sglang/blob/v0.3.4/python/sglang/srt/lora/lora_manager.py#L148-L152

However, when I do inference, the error below occurs.
Is it true that only the llama model type is officially supported?

[04:31:48 TP0] WARNING: get_hidden_dim() is not defined, which is used to get the hidden dim for different lora modulesUse the default one, but please check if it is correct for your model.
[04:31:48 TP0] WARNING: get_hidden_dim() is not defined, which is used to get the hidden dim for different lora modulesUse the default one, but please check if it is correct for your model.
...
[04:40:36 TP0] Prefill batch. #new-seq: 1, #new-token: 556, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[04:40:36 TP0] Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1017, in run_scheduler_process
    scheduler.event_loop()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 238, in event_loop
    self.run_step()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 417, in run_step
    result = self.run_batch(new_batch)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 623, in run_batch
    logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 113, in forward_batch_generation
    forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/forward_batch_info.py", line 171, in init_new
    model_runner.lora_manager.prepare_lora_batch(ret)
  File "/sgl-workspace/sglang/python/sglang/srt/lora/lora_manager.py", line 290, in prepare_lora_batch
    self.load_lora(uid, index)
  File "/sgl-workspace/sglang/python/sglang/srt/lora/lora_manager.py", line 264, in load_lora
    self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
RuntimeError: The size of tensor a (2304) must match the size of tensor b (2048) at non-singleton dimension 1

@upskyy
Copy link
Contributor

upskyy commented Dec 9, 2024

The gemma2 model was solved with its PR #2330

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants