Major and breaking changes
General reward model support for Online DPO
Online DPO intially only supported a reward model that had the same tokenizer and chat template as the trained model. Now, you can use any reward model.
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import OnlineDPOConfig, OnlineDPOTrainer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, padding_side="left")
reward_model = AutoModelForSequenceClassification.from_pretrained(training_args.reward_model_path, num_labels=1)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name, truncation=True, truncation_side="left")
dataset = load_dataset(script_args.dataset_name)
training_args = OnlineDPOConfig(output_dir="...")
trainer = OnlineDPOTrainer(
model=model,
reward_model=reward_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_processing_class=reward_tokenizer,
)
trainer.train()
by @qgallouedec in #2276
Migration PPOv2
-> PPO
The PPOv2
trainer has been renamed to PPO
. The old PPO
trainer has been removed. PPOv2
is now deprecated and will be removed in the next release.
- trainer = PPOv2Trainer(...)
+ trainer = PPOTrainer(...)
by @qgallouedec in #2174
Refactor ScriptArguments
We had ScriptArguments
, SFTScriptArguments
, DPOScriptArguments
and RewardScriptArguments
. Since they all share mostly the same fields, we've merged them into a single ScriptArguments
class.
SFTScriptArguments
, DPOScriptArguments
and RewardScriptArguments
still exist but are deprecated and will be removed in the next release.
- script_args = DPOScriptArguments(...)
+ script_args = ScriptArguments(...)
by @qgallouedec in #2145
Soft judges for PairRM
The PairRMJudge
now when called via the judge
method has a flag return_scores
that returns the probability scores of the first completion of the pair (instead of the rank of the preferred completion). The logits for the probability score can be scaled by an optional temperature
parameter.
from trl import PairRMJudge
pairrm_judge = PairRMJudge()
prompts = ["Translate 'hello' to French", "What's the capital of Japan?"]
completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
results = pairrm_judge.judge(prompts, completions, return_scores=True)
print(results) # [0.7492601275444031, 0.0005497377132996917]
Use pairwise judges for online methods
The OnlineDPOTrainer
and any trainers that inherit from it (NashMDTrainer
and XPOTrainer
) can now accept an initialized PairwiseJudge
instead of a reward model.
from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
Rename trainer arg tokenizer
to processing_class
The tokenizer
argument in the trainers has been renamed to processing_class
to better reflect the fact that it can be not only a tokenizer but also a processor.
- trainer = DPOTrainer(model, args=training_args, train_dataset=dataset, tokenizer=tokenizer)
+ trainer = DPOTrainer(model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
tokenizer
is still supported for SFTTrainer
and DPOTrainer
but deprecated and will be removed in the next release.
by @qgallouedec in #2162
Adding weighted preference optimization (WPO) to DPO
The WPO paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the use_weighting
flag to True
in the [DPOConfig
].
DPOConfig(..., use_weighting=True)
by @gaetanlop in #2141
🃏 Model card for TRL
Using trainer.push_to_hub()
now automatically creates a model card that includes:
- A link to the base model used
- A link to the dataset used for training
- A link to the TRL repository
- Sample demo code
- A link to the associated Weights & Biases run
- A link to the paper detailing the training procedure
- Versions of dependencies
- BibTeX citations for both the training procedure and TRL
All links are properly formatted to allow cross-referencing, enabling traceability back to sources (e.g., the model appears linked on the paper’s page).
IOm_SdRMRwAvjfbB.mp4
by @qgallouedec in #2123
Minor
Conversational dataset support
You can now use conversational datasets directly, without needing to apply a chat template beforehand, for the following trainers:
BCOTrainer
(by @qgallouedec in PR #2107)CPOTrainer
(by @qgallouedec in PR #2144)DPOTrainer
(by @qgallouedec in PR #2131)KTOTrainer
(by @qgallouedec in PR #2248)ORPOTrainer
(by @qgallouedec in PR #2184)
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset(dataset_name, split="train")
# Not needed anymore:
#
# def process(row):
# prompt = tokenizer.apply_chat_template(example["prompt"], tokenize=False, add_generation_prompt=True)
# prompt_chosen = tokenizer.apply_chat_template(example["prompt"] + example["chosen"], tokenize=False)
# chosen = prompt_chosen[len(prompt) :]
# prompt_rejected = tokenizer.apply_chat_template(example["prompt"] + example["rejected"], tokenize=False)
# rejected = prompt_rejected[len(prompt) :]
# return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
#
# dataset = dataset.map(process)
training_args = DPOConfig(output_dir="...")
trainer = DPOTrainer(model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()
Refactor DPO data processing
For more information, see PR #2209.
trl env
for printing system info
You can now use trl env
to print system information, including the platform, Python version, PyTorch version, CUDA device(s), and versions of various libraries.
$ trl env
Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.0
- CUDA device(s): NVIDIA H100 80GB HBM3
- Transformers version: 4.47.0.dev0
- Accelerate version: 0.19.0
- Accelerate config: not found
- Datasets version: 3.0.2
- HF Hub version: 0.26.1
- TRL version: 0.12.0+14ef1ab
- bitsandbytes version: 0.44.1
- DeepSpeed version: 0.15.3
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.13.2
by @qgallouedec in #2104
Sequence-Level KD
From GKD paper:
Sequence-Level KD (Kim & Rush, 2016). SeqKD maximizes the likelihood of high probability sequences generated by the teacher, and can be viewed as supervised FT on teacher-generated outputs.
SeqKD is taken as a baseline in the paper. It is now possible to use Sequence-Level KD in the GKDTrainer
by setting seq_kd=True
in the GKDConfig
.
training_args = GKDConfig(..., seq_kd=True)
Default dataset_text_field
to "text"
Since many users use "text"
as the column name for textual data in datasets, we've made it the default (previously a required argument) in SFTConfig
. Now, specifying dataset_text_field="text"
is no longer necessary.
SFTConfig(
...,
- dataset_text_field="text",
)
by @qgallouedec in #2078
What's Changed
- [SFT] fix neftune_noise_alpha in SFTTrainer by @kashif in #1841
- Standardize
training_args
by @qgallouedec in #2082 - Fix typo in ORPO example. by @skandermoalla in #2092
- Fix Inconsistency with IsShardedQLoRA Setting by @fabianlim in #2089
- Fixes #2087 - _process_tokens for empty prompts in KTOTrainer by @gabikadlecova in #2093
- KTO: fix logits metric, add logits metric to BCOTrainer by @claralp in #2094
- Clean up README and remove openrlbenchmark dependency by @lewtun in #2085
- Fix PPO/RLOO examples by @lewtun in #2100
- [CLI]
trl env
for printing system info by @qgallouedec in #2104 - [RewardTrainer] Tokenize inputs within trainer by @lewtun in #2102
- Fix documentation links by @qgallouedec in #2105
- fix formatting by @kashif in #2109
- [online-dpo] allow parse-args as list of floats by @kashif in #2108
- Fix pack test by @qgallouedec in #2111
BCOTrainer
conversational dataset support by @qgallouedec in #2107- Generalizes VSFT script to support REDACTED by @edbeeching in #2120
- Update example_overview.md by @kashif in #2125
- Remove
max_length
fromRewardDataCollatorWithPadding
by @qgallouedec in #2119 - Standardize pushing to Hub in examples by @qgallouedec in #2126
- Eos token encouragement Clarification by @August-murr in #2128
- Tokenize row during in
training_step
by @qgallouedec in #2117 - ♻️ Standardize
script_args
by @qgallouedec in #2130 - Add table for WinRateCallback by @lewtun in #2116
- 🧹 Style by @qgallouedec in #2132
- arXiv to HF Papers by @qgallouedec in #2133
- Add correct label for
WinRateCallback
table by @lewtun in #2134 - 🃏 Model card for TRL by @qgallouedec in #2123
- Rename
dpo_visual.py
example todpo_vlm.py
by @qgallouedec in #2139 - [GKD] Set custom EOS tokens in generation config by @lewtun in #2142
- Fix attention mask warning chat cli by @qgallouedec in #2147
- [CI] Don't use
eval_strategy="steps"
when no eval dataset by @qgallouedec in #2152 - Conversational dataset support for
DPOTrainer
by @qgallouedec in #2131 - 🩹 [Hotfix] Add setter for tokenizer by @qgallouedec in #2163
- ↩️ Revert tokenizer hotfix #2163 by @qgallouedec in #2165
- chore: update test_cli.py by @eltociear in #2168
- 🏷️ Model badges in trainer documentation by @qgallouedec in #2160
- Default
dataset_text_field
to"text"
by @qgallouedec in #2078 - Update trl version in CITATION.cff by @qgallouedec in #2171
- 🗑️ Set deprecation version for DPO and SFT arguments to version 0.13 by @qgallouedec in #2170
- Conversational dataset support for
CPOTrainer
by @qgallouedec in #2144 - Capybara replaced with ultrafeedback_binarized by @August-murr in #2183
- minor KTO setting changes + KL batch size by @kawine in #2153
- 🏷️ Model badges: select only TRL models by @qgallouedec in #2178
- Rename trainer arg
tokenizer
toprocessing_class
by @qgallouedec in #2162 - Update documentation CLI Chat by @qgallouedec in #2191
- 🃏 Model card:
"unsloth"
tag by @qgallouedec in #2173 - [CI] fix dpo gpu ci tests by @kashif in #2189
- Update CONTRIBUTING.md by @kushal34712 in #2181
- Fix RLOO checkpointing by @bartoszzuk in #2114
- Update README.md by @PRIYANKjakharia in #2186
skip_prompt=True
inTextIteratorStreamer
by @qgallouedec in #2193- [CI] Use transformers from source in "tests_no_optional_dep" by @qgallouedec in #2198
- Fix the bug of DPOTrainer where the coefficient of aux_loss is always 0 during training by @muupan in #2200
- Fix the bug of aux_loss coefficient being 0 in BCOTrainer, CPOTrainer, KTOTrainer, and ORPOTrainer by @muupan in #2201
- [DPO] Adding weighted preference optimization (WPO) by @gaetanlop in #2141
- [GKD] interpolate in prob. space by @kashif in #2204
- Drop
decoder_input_ids
inDPOTrainer
by @qgallouedec in #2208 - Update incorrect data processing in DataCollatorForChatML by @ruijunfeng in #2172
- Update log_example_reports.py by @DhruvKadam-git in #2182
- Report to
"none"
in GKD test by @qgallouedec in #2214 - [Judges] Soft judges for PairRM by @kashif in #2221
- Update README.md by @kushal34712 in #2180
- Updated README.md with CLI examples and additional usage instructions by @Singhal1808 in #2199
trl env
report all cuda devices by @qgallouedec in #2216- Conversational dataset support for
ORPOTrainer
by @qgallouedec in #2184 - 🕊️ Migration
PPOv2
->PPO
by @qgallouedec in #2174 - Add Sequence-Level KD by @mst272 in #2220
- Update dataset_formats.mdx by @August-murr in #2222
- 📒 Fix type/format confusions by @qgallouedec in #2223
- Update commands for code linting in contributing guidelines by @Ben-Schneider-code in #2225
- Refactor
ScriptArguments
by @qgallouedec in #2145 - Updated
ScriptArguments
warning messages by @sergiopaniego in #2230 - DPO support
remove_unused_columns
by @qgallouedec in #2233 - Setting capture output to False by @August-murr in #2239
- Update SFT examples by @lewtun in #2244
- Enhancements to Log Report Script: Improved Error Handling and Logging by @DhruvKadam-git in #2232
- 🔀 Rename
get_batch_sample
and addnum_items_in_batch
tocompute_loss
by @qgallouedec in #2246 - Refactor DPO data processing by @qgallouedec in #2209
- Update dataset_formats.mdx by @cameronphchen in #2259
- Use
processing_class
instead oftokenizer
inLogCompletionsCallback
by @qgallouedec in #2261 - Adjust padding in batch generation by @gaetanlop in #2251
- setup_chat_format: throw error if there is already a template in base model by @ngxson in #2252
- Bump the minimum transformers version to v4.46 by @qgallouedec in #2245
- Conversational dataset support for
KTOTrainer
by @qgallouedec in #2248 - [Judges] use the pair-judges in online-preference trainers by @kashif in #2243
- Update reward_modeling.py by @cameronphchen in #2266
- ♾️ Fix test generation
max_new_tokens
by @qgallouedec in #2272 - Refactor
log_reports.py
for Improved Logging, File Processing, and Slack Payload Handling by @Mefisto04 in #2249 - Replace log(sigmoid(log_odds) with logsigmoid(log_odds) for ORPO by @zhanwenchen in #2274
- [KTO/BCO Trainer] add bos_token_id only if it exists by @seanexp in #2279
- Fix the computation of KL divergence loss in Nash MD by @d-tiapkin in #2277
- Don't pass
eval_dataset
in to trainers when no eval strategy by @qgallouedec in #2270 - Update callbacks.py for fix small python type error by @anch0vy in #2285
- Use any reward model for online methods by @qgallouedec in #2276
- Clean dependencies by @qgallouedec in #2298
- Fix
_save_checkpoint
for online methods by @qgallouedec in #2288 - Refactor unit tests to use standard unittest assertion methods by @ccs96307 in #2283
- Remove stale bot by @qgallouedec in #2300
- Fix no optional dependencies by @qgallouedec in #2301
- Add
optimizer_cls_and_kwargs
attribute to PPO and RLOO by @qgallouedec in #2302 - Specify min versions by @qgallouedec in #2303
New Contributors
- @skandermoalla made their first contribution in #2092
- @fabianlim made their first contribution in #2089
- @gabikadlecova made their first contribution in #2093
- @August-murr made their first contribution in #2128
- @kushal34712 made their first contribution in #2181
- @PRIYANKjakharia made their first contribution in #2186
- @ruijunfeng made their first contribution in #2172
- @DhruvKadam-git made their first contribution in #2182
- @Singhal1808 made their first contribution in #2199
- @mst272 made their first contribution in #2220
- @Ben-Schneider-code made their first contribution in #2225
- @sergiopaniego made their first contribution in #2230
- @cameronphchen made their first contribution in #2259
- @ngxson made their first contribution in #2252
- @Mefisto04 made their first contribution in #2249
- @zhanwenchen made their first contribution in #2274
- @d-tiapkin made their first contribution in #2277
- @anch0vy made their first contribution in #2285
- @ccs96307 made their first contribution in #2283
Full Changelog: v0.11.0...v0.12.0