-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from for-ai/translate_preference_pairs
Translate preference pairs
- Loading branch information
Showing
6 changed files
with
369 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/bin/bash | ||
|
||
# Define the languages and their FLORES-200 codes | ||
declare -A languages=( | ||
["arb_Arab"]="Arabic" | ||
["zho_Hans"]="Chinese_Simplified" | ||
["zho_Hant"]="Chinese_Traditional" | ||
["ces_Latn"]="Czech" | ||
["nld_Latn"]="Dutch" | ||
["fra_Latn"]="French" | ||
["deu_Latn"]="German" | ||
["ell_Grek"]="Greek" | ||
["heb_Hebr"]="Hebrew" | ||
["hin_Deva"]="Hindi" | ||
["ind_Latn"]="Indonesian" | ||
["ita_Latn"]="Italian" | ||
["jpn_Jpan"]="Japanese" | ||
["kor_Hang"]="Korean" | ||
["pes_Arab"]="Persian" | ||
["pol_Latn"]="Polish" | ||
["por_Latn"]="Portuguese" | ||
["ron_Latn"]="Romanian" | ||
["rus_Cyrl"]="Russian" | ||
["spa_Latn"]="Spanish" | ||
["tur_Latn"]="Turkish" | ||
["ukr_Cyrl"]="Ukrainian" | ||
["vie_Latn"]="Vietnamese" | ||
) | ||
|
||
# Define the dataset name and columns to translate | ||
DATASET_NAME="allenai/reward-bench" | ||
MODEL_NAME="facebook/nllb-200-distilled-600M" | ||
COLUMNS_TO_TRANSLATE=("prompt" "chosen" "rejected") | ||
MAX_LENGTH=512 | ||
SUBSET_SIZE=1000000 | ||
OUTPUT_DIR="translations" | ||
|
||
# Get the list of GPU IDs | ||
GPUS=($(nvidia-smi --query-gpu=index --format=csv,noheader | tr '\n' ' ')) | ||
NUM_GPUS=${#GPUS[@]} | ||
|
||
# Function to run translation on a specific GPU | ||
run_translation() { | ||
local lang_code=$1 | ||
local gpu_id=$2 | ||
local language=${languages[$lang_code]} | ||
echo "Translating to $language ($lang_code) on GPU $gpu_id" | ||
|
||
CUDA_VISIBLE_DEVICES=$gpu_id python -m scripts.translate_preference_pairs_nllb.py \ | ||
--dataset_name "$DATASET_NAME" \ | ||
--model_name "$MODEL_NAME" \ | ||
--target_language "$lang_code" \ | ||
--columns_to_translate "${COLUMNS_TO_TRANSLATE[@]}" \ | ||
--max_length "$MAX_LENGTH" \ | ||
--subset_size "$SUBSET_SIZE" \ | ||
--output_dir "$OUTPUT_DIR" & | ||
} | ||
|
||
# Loop through each language in groups of 4 and assign them to GPUs | ||
lang_codes=(${!languages[@]}) | ||
total_langs=${#lang_codes[@]} | ||
|
||
for ((i=0; i<total_langs; i+=NUM_GPUS)); do | ||
for ((j=0; j<NUM_GPUS && i+j<total_langs; j++)); do | ||
lang_code=${lang_codes[i+j]} | ||
run_translation $lang_code ${GPUS[j]} | ||
done | ||
wait # Wait for all background processes to finish before starting the next group | ||
done | ||
|
||
echo "All translations completed." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
"""Convert multilingual ultrafeedback into a format acceptable for RewardBench | ||
We need to follow the load_preference_dataset setup in RewardBench as | ||
shown here: https://github.com/allenai/reward-bench/blob/main/rewardbench/utils.py#L136 | ||
So we need three columns: | ||
- prompt (str) | ||
- chosen (list[dict[str, str]]), and | ||
- rejected (list[dict[str, str]]) | ||
** Translation: 2000/2000 [2:36:00<00:00, 4.68s/ examples] | ||
""" | ||
|
||
import argparse | ||
import logging | ||
import unicodedata | ||
from pathlib import Path | ||
|
||
import ctranslate2 | ||
import transformers | ||
from datasets import load_dataset | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser(description="Translation a HuggingFace dataset into the RewardBench format.") | ||
|
||
parser.add_argument( | ||
"--dataset", type=str, default="nthakur/multilingual-ultrafeedback-dpo-v0.1", help="Dataset to convert." | ||
) | ||
# parser.add_argument("--output_path", type=Path, default="data/multilingual-ultrafeedback-dpo-v0.1-test-ben_Beng.json", help="Path to save converted dataset as JSON file.") | ||
# fmt: on | ||
parser.add_argument("--target", type=str, help="Target-lang") | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
|
||
# model_id = "facebook/nllb-moe-54b" | ||
model_id = "facebook/nllb-200-3.3B" | ||
src_lang = "eng_Latn" | ||
|
||
# tgt_lang = "fra_Latn" | ||
# tgt_lang = "spa_Latn" | ||
# tgt_lang = "ben_Beng" | ||
tgt_lang = args.target | ||
|
||
output_path = Path(f"data/multilingual-ultrafeedback-dpo-v0.1-test-{tgt_lang}.json") | ||
|
||
# ct2-transformers-converter \ | ||
# --model facebook/nllb-200-3.3B --output_dir facebook-nllb-200-3.3B | ||
|
||
model_dir = f"./nllb/{model_id.replace('/', '-')}/" | ||
|
||
translator = ctranslate2.Translator(model_dir, device="cuda") | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, src_lang=src_lang) | ||
|
||
target_prefix = [tgt_lang] | ||
|
||
def translate(source, unicode_norm="NFKC"): | ||
# batched_input = [source] | ||
batched_input = source.split("\n") | ||
tokenized_input = tokenizer(batched_input, return_attention_mask=False).input_ids | ||
source = [tokenizer.convert_ids_to_tokens(x) for x in tokenized_input] | ||
results = translator.translate_batch(source, target_prefix=[target_prefix] * len(batched_input)) | ||
target = [result.hypotheses[0][1:] for result in results] | ||
target = [tokenizer.convert_tokens_to_ids(x) for x in target] | ||
translated = tokenizer.batch_decode(target) | ||
|
||
translated = [x.replace("\n", "") for x in translated] | ||
translated = "\n".join(translated) | ||
translated = unicodedata.normalize(unicode_norm, translated) | ||
# translated = " ".join(translated.splitlines()) | ||
# import ipdb; ipdb.set_trace() | ||
return translated | ||
|
||
if output_path: | ||
output_path.parents[0].mkdir(parents=True, exist_ok=True) | ||
|
||
dataset = load_dataset(args.dataset, split="test") | ||
|
||
# dataset = dataset.train_test_split(test_size=2.0/2000)['test'] | ||
|
||
def _convert_to_turn_based(example): | ||
input = translate(example["en_input"]) | ||
print(f"{src_lang}: {example['en_input']}\n{tgt_lang}: {input}") | ||
chosen = translate(example["en_chosen"]) | ||
rejected = translate(example["en_rejected"]) | ||
# import ipdb; ipdb.set_trace() | ||
|
||
example["language"] = tgt_lang | ||
example["prompt"] = input | ||
|
||
example["chosen"] = [ | ||
{"content": example["prompt"], "role": "user"}, | ||
{"content": chosen, "role": "assistant"}, | ||
] | ||
example["rejected"] = [ | ||
{"content": example["prompt"], "role": "user"}, | ||
{"content": rejected, "role": "assistant"}, | ||
] | ||
return example | ||
|
||
# cols = ["id", "source", "language", "input", "chosen", "rejected"] | ||
rename_map = {"input": "prompt", "chosen": "chosen_raw", "rejected": "rejected_raw"} | ||
cols = [ | ||
"id", | ||
"source", | ||
"language", | ||
"input", | ||
"chosen", | ||
"rejected", | ||
"en_input", | ||
"en_chosen", | ||
"en_rejected", | ||
] | ||
remove_cols = ["chosen_raw", "rejected_raw", "en_input", "en_chosen", "en_rejected"] | ||
|
||
dataset = ( | ||
dataset.select_columns(cols).rename_columns(rename_map).map(_convert_to_turn_based).remove_columns(remove_cols) | ||
) | ||
dataset.to_json(output_path) | ||
logging.info(f"Saved file to {output_path}.") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import argparse | ||
import json | ||
import os | ||
import random | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from sentence_splitter import SentenceSplitter | ||
from tqdm import tqdm | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||
|
||
|
||
def translate_text(text, model, tokenizer, target_lang_code, device, max_length): | ||
inputs = tokenizer(text, return_tensors="pt").to(device) | ||
translated_tokens = model.generate( | ||
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[target_lang_code], max_length=max_length | ||
) | ||
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | ||
|
||
|
||
def validate_columns(dataset, columns): | ||
for subset in dataset.keys(): | ||
for column in columns: | ||
if column not in dataset[subset].column_names: | ||
raise ValueError(f"Column '{column}' not found in subset '{subset}' of the dataset") | ||
|
||
|
||
def validate_language_code(tokenizer, target_language): | ||
if target_language not in tokenizer.lang_code_to_id: | ||
raise ValueError(f"Target language code '{target_language}' is not valid for the given tokenizer") | ||
|
||
|
||
def translate_dataset( | ||
dataset, model_name, columns_to_translate, target_language, max_length, subset_size=None, output_dir="translations" | ||
): | ||
# Check if GPU is available and set the device | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Load the model and tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) | ||
|
||
# Validate language code | ||
validate_language_code(tokenizer, target_language) | ||
|
||
# Validate columns | ||
validate_columns(dataset, columns_to_translate) | ||
|
||
# Initialize the sentence splitter | ||
splitter = SentenceSplitter(language="en") | ||
|
||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) | ||
|
||
for subset in dataset.keys(): | ||
translated_data = [] | ||
data_length = len(dataset[subset]) | ||
|
||
# Randomly select a subset of the data if subset_size is specified | ||
if subset_size: | ||
indices = random.sample(range(data_length), min(subset_size, data_length)) | ||
dataset[subset] = dataset[subset].select(indices) | ||
|
||
for example in tqdm(dataset[subset], desc=f"Translating {subset} subset"): | ||
translated_example = {} | ||
for col in columns_to_translate: | ||
# Split text into sentences | ||
sentences = splitter.split(text=example[col]) | ||
# Translate each sentence individually | ||
translated_sentences = [ | ||
translate_text(sentence, model, tokenizer, target_language, device, max_length) | ||
for sentence in sentences | ||
] | ||
# Join translated sentences back together | ||
translated_text = " ".join(translated_sentences) | ||
translated_example[col] = translated_text | ||
|
||
translated_example["target_language"] = target_language | ||
# Add other columns as-is | ||
for key in example.keys(): | ||
if key not in translated_example: | ||
translated_example[key] = example[key] | ||
translated_data.append(translated_example) | ||
|
||
# Save translated data to JSON file | ||
dataset_name = args.dataset_name.replace("/", "_") | ||
output_file = os.path.join(output_dir, f"{dataset_name}_{subset}_{args.target_language}_translated.json") | ||
with open(output_file, "w", encoding="utf-8") as f: | ||
json.dump(translated_data, f, ensure_ascii=False, indent=4) | ||
|
||
print(f"Translated data for subset '{subset}' saved to {output_file}") | ||
|
||
|
||
if __name__ == "__main__": | ||
# fmt: off | ||
parser = argparse.ArgumentParser(description="Translate dataset columns using a specified translation model.") | ||
parser.add_argument("--dataset_name", type=str, required=True, help="Hugging Face dataset name.") | ||
parser.add_argument("--target_language", type=str, required=True, help="Target language code (e.g., fra_Latn).") | ||
parser.add_argument("--model_name", type=str, default="facebook/nllb-200-distilled-600M", required=False, help="Hugging Face model name.") | ||
parser.add_argument("--columns_to_translate", type=str, nargs="+", required=True, help="Columns to translate.") | ||
parser.add_argument("--max_length", type=int, default=30, help="Maximum length for translation.") | ||
parser.add_argument("--subset_size", type=int, help="Size of the random subset to translate.") | ||
parser.add_argument("--output_dir", type=str, default="translations", help="Output directory to save translations.") | ||
# fmt: on | ||
|
||
args = parser.parse_args() | ||
|
||
# Load dataset | ||
dataset = load_dataset(args.dataset_name) | ||
|
||
# Translate dataset | ||
translate_dataset( | ||
dataset, | ||
args.model_name, | ||
args.columns_to_translate, | ||
args.target_language, | ||
args.max_length, | ||
args.subset_size, | ||
args.output_dir, | ||
) | ||
|
||
# Reference: Language and FLORES-200 codes | ||
# ```markdown | ||
# | Language | FLORES-200 code | | ||
# |------------------------------|-----------------| | ||
# | Arabic | arb_Arab | | ||
# | Chinese (Simplified) | zho_Hans | | ||
# | Chinese (Traditional) | zho_Hant | | ||
# | Czech | ces_Latn | | ||
# | Dutch | nld_Latn | | ||
# | English | eng_Latn | | ||
# | French | fra_Latn | | ||
# | German | deu_Latn | | ||
# | Greek | ell_Grek | | ||
# | Hebrew | heb_Hebr | | ||
# | Hindi | hin_Deva | | ||
# | Indonesian | ind_Latn | | ||
# | Italian | ita_Latn | | ||
# | Japanese | jpn_Jpan | | ||
# | Korean | kor_Hang | | ||
# | Persian | pes_Arab | | ||
# | Polish | pol_Latn | | ||
# | Portuguese | por_Latn | | ||
# | Romanian | ron_Latn | | ||
# | Russian | rus_Cyrl | | ||
# | Spanish | spa_Latn | | ||
# | Turkish | tur_Latn | | ||
# | Ukrainian | ukr_Cyrl | | ||
# | Vietnamese | vie_Latn | | ||
# ``` | ||
|
||
# Example command to run the script: | ||
# python translate_dataset.py --dataset_name my_dataset --target_language fra_Latn --columns_to_translate prompt chosen rejected --max_length 30 --subset_size 100 --output_dir translations |