Skip to content

Commit

Permalink
fix #2961
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Mar 26, 2024
1 parent 7ea1a1f commit 511f675
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 31 deletions.
2 changes: 2 additions & 0 deletions src/llmtuner/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,8 @@ def get_template_and_fix_tokenizer(

_register_template(
name="vanilla",
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)


Expand Down
37 changes: 20 additions & 17 deletions src/llmtuner/eval/template.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,48 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import Dict, List, Sequence, Tuple

from ..data import Role
from ..extras.constants import CHOICES


if TYPE_CHECKING:
from datasets import Dataset


@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
prefix: str

def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]

def format_example(
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
messages = []
for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(support_set[k])
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})

prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(target_data)
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages


eval_templates: Dict[str, "EvalTemplate"] = {}


def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)


Expand All @@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
return eval_template


register_eval_template(
_register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
Expand All @@ -58,10 +61,10 @@ def get_eval_template(name: str) -> "EvalTemplate":
)


register_eval_template(
_register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n",
prefix=" ",
)
7 changes: 7 additions & 0 deletions src/llmtuner/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class RLHFArguments:
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_label_smoothing = field(
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
dpo_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
Expand Down Expand Up @@ -248,6 +252,9 @@ def split_arg(arg):
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")

if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")

if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")

Expand Down
2 changes: 1 addition & 1 deletion src/llmtuner/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _configure_quantization(
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")

init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
Expand Down
13 changes: 5 additions & 8 deletions src/llmtuner/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: bool = True,
**kwargs,
):
Expand All @@ -47,10 +44,10 @@ def __init__(
self._peft_has_been_casted_to_bf16 = False

self.ref_model = ref_model
self.beta = beta
self.label_smoothing = 0
self.loss_type = loss_type
self.ftx_gamma = ftx_gamma
self.beta = finetuning_args.dpo_beta
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list))

Trainer.__init__(self, model=model, **kwargs)
Expand Down
5 changes: 1 addition & 4 deletions src/llmtuner/train/dpo/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,10 @@ def run_dpo(

# Initialize our Trainer
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
finetuning_args=finetuning_args,
model=model,
ref_model=ref_model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
Expand Down
1 change: 0 additions & 1 deletion src/llmtuner/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version

from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
Expand Down

0 comments on commit 511f675

Please sign in to comment.