Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga authored and zhangzh committed Jul 1, 2024
1 parent 7ba0428 commit bdeb03d
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead

Expand Down Expand Up @@ -175,6 +176,10 @@ def load_model(

if not is_trainable:
model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)

model.eval()
else:
model.train()
Expand Down

0 comments on commit bdeb03d

Please sign in to comment.