Skip to content

Commit

Permalink
Add baize dataset (#25)
Browse files Browse the repository at this point in the history
* add baize dataset

* add baize dataset
  • Loading branch information
GT9505 authored Jun 3, 2023
1 parent a8a06b7 commit 9c73e47
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
wandb/

checkpoints/
tests/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ conda env create -f environment.yml
You can also customize the data path in the [configs/dataset_config.py](configs/dataset_config.py).
8. [Baize](https://github.com/project-baize/baize-chatbot)
Download it from [this link](https://github.com/project-baize/baize-chatbot/blob/main/data/quora_chat_data.json) and place it in `data/baize/quora_chat_data.json`.
## Start training
Expand Down
4 changes: 4 additions & 0 deletions configs/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,8 @@
type="alpaca_gpt4",
ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
),
dict(
type="baize",
ann_path="data/baize/quora_chat_data.json",
),
]
86 changes: 86 additions & 0 deletions mmgpt/datasets/baize_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json

from mmgpt.datasets.dolly_dataset import DollyDataset


TEMPLATE = {
"description": "Template used by Alpaca-LoRA.",
"prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
"prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Response:\n",
"prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n",
"response_split": "### Response:",
}

class LangDialPrompter:
def __call__(self, question, options=None):
if options:
options = ", ".join(options)
res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
else:
res = TEMPLATE["prompt_dial"].format(question=question)
return res

def get_response(self, output: str) -> str:
return output.split(TEMPLATE["response_split"])[-1].strip()

class BaiZeDataset(DollyDataset):
"""
```json
[
{
"instruction": "Identify the odd one out.",
"input": "Twitter, Instagram, Telegram",
"output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service."
},
]
"""
def __init__(self, *args, **kwargs):
super(BaiZeDataset, self).__init__(*args, **kwargs)
self.prompter = LangDialPrompter()

def load_annotation(self, ann_path):
self.annotation = json.load(open(ann_path, "r"))

def process_text(self, anns):
# TODO remove this
begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
convs = anns['input'].split("[|Human|] ")
conv_list = []
for conv_id, one_conv in enumerate(convs[1:-1]):
question, answer = one_conv.split("[|AI|] ")
question = question.replace("\n", "")
answer = answer.replace("\n", "")
instruction = self.prompter(question)
if conv_id == 0:
single_conv = dict(instruction=begin_string + instruction, answer=answer)
else:
single_conv = dict(instruction=instruction, answer=answer)
conv_list.append(single_conv)
return conv_list

def __getitem__(self, index):
ann = self.annotation[index]
text_list = self.process_text(ann)
res_list = []
for text in text_list:
single_res = self.tokenize(text)
single_res["instruction"] = text["instruction"]
single_res["answer"] = text["answer"]
res_list.append(single_res)

input_ids = []
attention_mask = []
labels = []
instruction = []
answer = []
for res in res_list:
input_ids.extend(res["input_ids"])
attention_mask.extend(res["attention_mask"])
labels.extend(res["labels"])
instruction.append(res["instruction"])
answer.append(res["answer"])

res = dict(
input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer
)
return res
6 changes: 6 additions & 0 deletions mmgpt/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .snli_ve_datasets import SNLIVEDataset # noqa: F401
from .text_ocr_dataset import TextOCRDataset # noqa: F401
from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401
from .baize_dataset import BaiZeDataset # noqa: F401


def build_dataset(dataset_config, **kwargs):
Expand Down Expand Up @@ -108,6 +109,11 @@ def build_dataset(dataset_config, **kwargs):
**dataset_config,
**kwargs,
)
elif dataset_type == "baize":
dataset = BaiZeDataset(
**dataset_config,
**kwargs,
)
else:
raise NotImplementedError

Expand Down

0 comments on commit 9c73e47

Please sign in to comment.