Skip to content

Commit

Permalink
release v0.8.0
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jun 7, 2024
1 parent 12d79f8 commit 5aa4ce4
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 13 deletions.
11 changes: 1 addition & 10 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,17 +700,8 @@ def get_template_and_fix_tokenizer(
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
)


Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/extras/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .packages import is_vllm_available


VERSION = "0.7.2.dev0"
VERSION = "0.8.0"


def print_env() -> None:
Expand Down
44 changes: 44 additions & 0 deletions tests/data/test_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os

import pytest
from datasets import load_dataset

from llamafactory.data import get_dataset
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer


TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")

TRAINING_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
"dataset": "llamafactory/tiny_dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}


@pytest.mark.parametrize("test_num", [5])
def test_supervised(test_num: int):
model_args, data_args, training_args, _, _ = get_train_args(TRAINING_ARGS)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
tokenized_data = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)

original_data = load_dataset(TRAINING_ARGS["dataset"], split="train")
for test_idx in range(test_num):
decode_result = tokenizer.decode(tokenized_data["input_ids"][test_idx])
messages = [
{"role": "user", "content": original_data[test_idx]["instruction"]},
{"role": "assistant", "content": original_data[test_idx]["output"]},
]
templated_result = tokenizer.apply_chat_template(messages, tokenize=False)
assert decode_result == templated_result
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_attention():
"flash_attn": requested_attention,
}
)
tokenizer = load_tokenizer(model_args)
model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args)
for module in model.modules():
if "Attention" in module.__class__.__name__:
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
61 changes: 61 additions & 0 deletions tests/model/test_freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os

import torch

from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer


TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")

TRAINING_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "freeze",
"dataset": "llamafactory/tiny_dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}


def test_freeze_all_modules():
model_args, _, _, finetuning_args, _ = get_train_args(
{
"freeze_trainable_layers": 1,
**TRAINING_ARGS,
}
)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
for name, param in model.named_parameters():
if name.startswith("model.layers.1."):
assert param.requires_grad is True
assert param.dtype == torch.float32
else:
assert param.requires_grad is False
assert param.dtype == torch.float16


def test_freeze_extra_modules():
model_args, _, _, finetuning_args, _ = get_train_args(
{
"freeze_trainable_layers": 1,
"freeze_extra_modules": "embed_tokens,lm_head",
**TRAINING_ARGS,
}
)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
for name, param in model.named_parameters():
if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]):
assert param.requires_grad is True
assert param.dtype == torch.float32
else:
assert param.requires_grad is False
assert param.dtype == torch.float16
33 changes: 33 additions & 0 deletions tests/model/test_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

import torch

from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer


TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")

TRAINING_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
"dataset": "llamafactory/tiny_dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}


def test_full():
model_args, _, _, finetuning_args, _ = get_train_args(TRAINING_ARGS)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
for param in model.parameters():
assert param.requires_grad is True
assert param.dtype == torch.float32

0 comments on commit 5aa4ce4

Please sign in to comment.