Skip to content

Commit

Permalink
In sft, make the number of entries in data for training configurable.…
Browse files Browse the repository at this point in the history
… Use trl DataCollatorForCompletionOnlyLM instead of customized one. Debug: cannot use ConstantLengthDataset or packing when using DataCollatorForCompletionOnly
  • Loading branch information
Diyi Hu committed Jan 29, 2024
1 parent 60bcb60 commit fe1cc98
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 32 deletions.
1 change: 1 addition & 0 deletions example/rlhf/supervised_finetuning_demo_d2l.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
data_collator="DataCollatorForCompletionOnlyLM",
no_evaluation=True,
prepare_text="d2l",
split = "train[:10%]"
)
rlhf_step1_sft = SupervisedFinetuning(config)
rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft")
44 changes: 12 additions & 32 deletions pykoi/rlhf/supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
AutoModelForSequenceClassification, AutoTokenizer,
TrainingArguments)
from trl import SFTTrainer
from trl.trainer.utils import ConstantLengthDataset

from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
QA_CSV_HEADER_VOTE_STATUS)
from pykoi.chat.db.qa_database import QuestionAnswerDatabase
from pykoi.rlhf.config import RLHFConfig
from pykoi.telemetry.events import SFTStartEvent, SFTStopEvent
from pykoi.telemetry.telemetry import Telemetry
from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM
from trl import DataCollatorForCompletionOnlyLM
# from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM


class SupervisedFinetuning:
Expand Down Expand Up @@ -101,8 +100,9 @@ def __init__(
if self._rlhf_config.data_collator == "DataCollatorForCompletionOnlyLM":
# dh: try the customized data collator that only predicts the
# answer part
data_collator = DataCollatorForCompletionOnlyLM(
tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8)
response_template = RESPONSE_KEY
data_collator = DataCollatorForCompletionOnlyLM(response_template,
tokenizer=self.tokenizer)

self.trainer = SFTTrainer(
model=self.model,
Expand All @@ -111,9 +111,10 @@ def __init__(
eval_dataset=self.dataset["eval"],
peft_config=self._rlhf_config.lora_config_rl,
# TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED
packing=True,
packing=False, # required for compatibility with the completiononly data collator
data_collator=data_collator,
dataset_text_field="text",
max_seq_length=self._rlhf_config.max_seq_length,
)

def train(self):
Expand Down Expand Up @@ -270,14 +271,14 @@ def create_datasets(self, tokenizer, args):
elif args.dataset_type == "local_csv":
# this way will load 1660 enetries
# dataset = load_dataset("csv", data_files=args.dataset_name)
# dataset = dataset[args.split] # Convert DatasetDict to Dataset
# dataset = dataset["train"] # Convert DatasetDict to Dataset

# this way will load 166 entries

dataset = load_dataset(
"csv",
data_files=args.dataset_name,
split='train[:10%]')
split=args.split)

elif args.dataset_type == "huggingface":
dataset = load_dataset(
Expand Down Expand Up @@ -305,14 +306,7 @@ def create_datasets(self, tokenizer, args):
f"Size of the train set: {len(dataset)}. "
)

train_dataset = ConstantLengthDataset(
tokenizer,
dataset,
formatting_func=self.prepare_text,
infinite=True,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)
train_dataset = dataset
eval_dataset = None
else:
dataset = dataset.train_test_split(
Expand All @@ -322,22 +316,8 @@ def create_datasets(self, tokenizer, args):
f"Size of the train set: {len(dataset['train'])}. "
f" Size of the validation set: {len(dataset['test'])}")

train_dataset = ConstantLengthDataset(
tokenizer,
dataset["train"],
formatting_func=self.prepare_text,
infinite=True,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)
train_dataset = dataset["train"]

eval_dataset = ConstantLengthDataset(
tokenizer,
dataset["test"],
formatting_func=self.prepare_text,
infinite=False,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)
eval_dataset = dataset["test"]

return {"train": train_dataset, "eval": eval_dataset}

0 comments on commit fe1cc98

Please sign in to comment.