Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga authored and xtchen96 committed Jul 17, 2024
1 parent 31bd45c commit ff575f7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions src/llamafactory/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,23 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
get_template_and_fix_tokenizer(tokenizer, data_args.template)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab

if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
raise ValueError("Cannot merge adapters to a quantized model.")

if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")

if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type
setattr(model.config, "torch_dtype", torch.float16)
else:
if model_args.infer_dtype == "auto":
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
else:
output_dtype = getattr(torch, model_args.infer_dtype)

setattr(model.config, "torch_dtype", output_dtype)
model = model.to(output_dtype)
else:
setattr(model.config, "torch_dtype", torch.float16)
logger.info("Convert model dtype to: {}.".format(output_dtype))

model.save_pretrained(
save_directory=model_args.export_dir,
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/webui/components/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def save_model(
template: str,
visual_inputs: bool,
export_size: int,
export_quantization_bit: int,
export_quantization_bit: str,
export_quantization_dataset: str,
export_device: str,
export_legacy_format: bool,
Expand Down

0 comments on commit ff575f7

Please sign in to comment.