Skip to content

Commit

Permalink
support batch_eval_metrics, fix #4826
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 16, 2024
1 parent bda302f commit d774b94
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 36 deletions.
15 changes: 14 additions & 1 deletion src/llamafactory/extras/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import gc
import os
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Tuple, Union

import torch
import transformers.dynamic_module_utils
Expand All @@ -43,6 +43,8 @@


if TYPE_CHECKING:
from numpy.typing import NDArray

from ..hparams import ModelArguments


Expand Down Expand Up @@ -178,6 +180,17 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available()


def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
inputs = inputs.to(torch.float32)

inputs = inputs.numpy()

return inputs


def skip_check_imports() -> None:
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports
Expand Down
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
)
compute_accuracy: bool = field(
default=False,
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
Expand Down
3 changes: 3 additions & 0 deletions src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.predict_with_generate and data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")

if training_args.predict_with_generate and finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")

if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")

Expand Down
22 changes: 19 additions & 3 deletions src/llamafactory/train/rm/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional

import numpy as np

from ...extras.misc import numpify


if TYPE_CHECKING:
from transformers import EvalPrediction


def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
return {"accuracy": np.mean(eval_preds.predictions[0] > eval_preds.predictions[1])}
@dataclass
class ComputeAccuracy:
def __post_init__(self):
self.score_dict = {"accuracy": []}

def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
else:
for i in range(len(chosen_scores)):
self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])

if compute_result:
return {"accuracy": float(np.mean(self.score_dict["accuracy"]))}
4 changes: 2 additions & 2 deletions src/llamafactory/train/rm/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint
from ..trainer_utils import create_modelcard_and_push
from .metric import compute_accuracy
from .metric import ComputeAccuracy
from .trainer import PairwiseTrainer


Expand Down Expand Up @@ -55,7 +55,7 @@ def run_rm(
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
compute_metrics=ComputeAccuracy(),
**dataset_module,
**tokenizer_module,
)
Expand Down
52 changes: 29 additions & 23 deletions src/llamafactory/train/sft/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
# limitations under the License.

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Optional

import numpy as np
import torch
from transformers.utils import is_jieba_available, is_nltk_available

from ...extras.constants import IGNORE_INDEX
from ...extras.misc import numpify
from ...extras.packages import is_rouge_available


Expand All @@ -43,17 +44,6 @@
from rouge_chinese import Rouge


def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
preds, labels = eval_preds.predictions, eval_preds.label_ids
accuracies = []
for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:]
label_mask = label != IGNORE_INDEX
accuracies.append(np.mean(pred[label_mask] == label[label_mask]))

return {"accuracy": float(np.mean(accuracies))}


def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
Expand All @@ -68,19 +58,34 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor


@dataclass
class ComputeMetrics:
class ComputeAccuracy:
def __post_init__(self):
self.score_dict = {"accuracy": []}

def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:]
label_mask = label != IGNORE_INDEX
self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask]))

if compute_result:
return {"accuracy": float(np.mean(self.score_dict["accuracy"]))}


@dataclass
class ComputeSimilarity:
r"""
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""

tokenizer: "PreTrainedTokenizer"

def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds.predictions, eval_preds.label_ids
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
def __post_init__(self):
self.score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}

def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)

preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
Expand All @@ -100,9 +105,10 @@ def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
result = scores[0]

for k, v in result.items():
score_dict[k].append(round(v["f"] * 100, 4))
self.score_dict[k].append(round(v["f"] * 100, 4))

bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
self.score_dict["bleu-4"].append(round(bleu_score * 100, 4))

return {k: float(np.mean(v)) for k, v in score_dict.items()}
if compute_result:
return {k: float(np.mean(v)) for k, v in self.score_dict.items()}
21 changes: 14 additions & 7 deletions src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer


Expand All @@ -46,15 +46,12 @@ def run_sft(
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)

if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation

if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction

data_collator = SFTDataCollatorWith4DAttentionMask(
tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
Expand All @@ -66,17 +63,24 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns

# Metric utils
metric_module = {}
if training_args.predict_with_generate:
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
elif finetuning_args.compute_accuracy:
metric_module["compute_metrics"] = ComputeAccuracy()
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor

# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
**dataset_module,
**tokenizer_module,
**metric_module,
)

# Keyword arguments for `model.generate`
Expand All @@ -95,6 +99,9 @@ def run_sft(
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])

if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation

# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
Expand Down

0 comments on commit d774b94

Please sign in to comment.