From 0d51f4f873872211b217a684722ded8577ab53d0 Mon Sep 17 00:00:00 2001 From: JINO ROHIT Date: Fri, 17 Jan 2025 19:17:54 +0530 Subject: [PATCH] Update TSDAE examples with SentenceTransformerTrainer (#3137) * raises value error when num_label > 1 when using Crossencoder.rank() * updated error message * update tsdae example with SentenceTransformerTrainer * lint * Update the other TSDAE examples as well --------- Co-authored-by: Tom Aarsen --- .../TSDAE/eval_askubuntu.py | 45 ++-- .../TSDAE/train_askubuntu_tsdae.py | 209 ++++++++++++----- .../TSDAE/train_stsb_tsdae.py | 217 +++++++++++------- .../TSDAE/train_tsdae_from_file.py | 190 ++++++++++----- 4 files changed, 436 insertions(+), 225 deletions(-) diff --git a/examples/unsupervised_learning/TSDAE/eval_askubuntu.py b/examples/unsupervised_learning/TSDAE/eval_askubuntu.py index 4b5a3cdfc..fcf2b8d1b 100644 --- a/examples/unsupervised_learning/TSDAE/eval_askubuntu.py +++ b/examples/unsupervised_learning/TSDAE/eval_askubuntu.py @@ -10,19 +10,19 @@ import os import sys -from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation, util +from datasets import Dataset -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout +from sentence_transformers import SentenceTransformer, util +from sentence_transformers.evaluation import RerankingEvaluator + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) model = SentenceTransformer(sys.argv[1]) ################# Download AskUbuntu and extract training corpus ################# -askubuntu_folder = "askubuntu" +askubuntu_folder = "data/askubuntu" training_corpus = os.path.join(askubuntu_folder, "train.unsupervised.txt") @@ -37,15 +37,17 @@ dev_test_ids = set() with gzip.open(os.path.join(askubuntu_folder, "text_tokenized.txt.gz"), "rt", encoding="utf8") as fIn: for line in fIn: - splits = line.strip().split("\t") - id = splits[0] - title = splits[1] + id, title, *_ = line.strip().split("\t") corpus[id] = title # Read dev & test dataset -def read_eval_dataset(filepath): - dataset = [] +def read_eval_dataset(filepath) -> Dataset: + data = { + "query": [], + "positive": [], + "negative": [], + } with open(filepath) as fIn: for line in fIn: query_id, relevant_id, candidate_ids, bm25_scores = line.strip().split("\t") @@ -55,15 +57,12 @@ def read_eval_dataset(filepath): relevant_id = relevant_id.split(" ") candidate_ids = candidate_ids.split(" ") negative_ids = set(candidate_ids) - set(relevant_id) - dataset.append( - { - "query": corpus[query_id], - "positive": [corpus[pid] for pid in relevant_id], - "negative": [corpus[pid] for pid in negative_ids], - } - ) + data["query"].append(corpus[query_id]) + data["positive"].append([corpus[pid] for pid in relevant_id]) + data["negative"].append([corpus[pid] for pid in negative_ids]) dev_test_ids.add(query_id) dev_test_ids.update(candidate_ids) + dataset = Dataset.from_dict(data) return dataset @@ -72,11 +71,11 @@ def read_eval_dataset(filepath): # Create a dev evaluator -dev_evaluator = evaluation.RerankingEvaluator(dev_dataset, name="AskUbuntu dev") +dev_evaluator = RerankingEvaluator(dev_dataset, name="AskUbuntu dev") -logging.info("Dev performance before training") +logging.info("Dev performance") dev_evaluator(model) -test_evaluator = evaluation.RerankingEvaluator(test_dataset, name="AskUbuntu test") -logging.info("Test performance before training") +test_evaluator = RerankingEvaluator(test_dataset, name="AskUbuntu test") +logging.info("Test performance") test_evaluator(model) diff --git a/examples/unsupervised_learning/TSDAE/train_askubuntu_tsdae.py b/examples/unsupervised_learning/TSDAE/train_askubuntu_tsdae.py index dfa72cc1e..e9425a5cf 100644 --- a/examples/unsupervised_learning/TSDAE/train_askubuntu_tsdae.py +++ b/examples/unsupervised_learning/TSDAE/train_askubuntu_tsdae.py @@ -1,25 +1,43 @@ import gzip import logging import os -import sys +import random +import traceback from datetime import datetime -from torch.utils.data import DataLoader +from datasets import Dataset -from sentence_transformers import LoggingHandler, SentenceTransformer, datasets, evaluation, losses, models, util +from sentence_transformers import SentenceTransformer, models, util +from sentence_transformers.evaluation import RerankingEvaluator +from sentence_transformers.losses import DenoisingAutoEncoderLoss +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) + +# Training parameters +model_name = "bert-base-uncased" +train_batch_size = 8 +num_epochs = 1 +max_seq_length = 75 + +output_dir = f"output/training_stsb_tsdae-{model_name.replace('/', '-')}-{train_batch_size}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" -################# Download AskUbuntu and extract training corpus ################# +# 1. Defining our sentence transformer model +word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) +pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls") +model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) +# or to load a pre-trained SentenceTransformer model OR use mean pooling +# model = SentenceTransformer(model_name) +# model.max_seq_length = max_seq_length + + +# 2. Download the AskUbuntu dataset from https://github.com/taolei87/askubuntu askubuntu_folder = "data/askubuntu" result_folder = "output/askubuntu-tsdae-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") batch_size = 8 -## Download the AskUbuntu dataset from https://github.com/taolei87/askubuntu for filename in ["text_tokenized.txt.gz", "dev.txt", "test.txt", "train_random.txt"]: filepath = os.path.join(askubuntu_folder, filename) if not os.path.exists(filepath): @@ -30,15 +48,17 @@ dev_test_ids = set() with gzip.open(os.path.join(askubuntu_folder, "text_tokenized.txt.gz"), "rt", encoding="utf8") as fIn: for line in fIn: - splits = line.strip().split("\t") - id = splits[0] - title = splits[1] + id, title, *_ = line.strip().split("\t") corpus[id] = title # Read dev & test dataset -def read_eval_dataset(filepath): - dataset = [] +def read_eval_dataset(filepath) -> Dataset: + data = { + "query": [], + "positive": [], + "negative": [], + } with open(filepath) as fIn: for line in fIn: query_id, relevant_id, candidate_ids, bm25_scores = line.strip().split("\t") @@ -48,62 +68,129 @@ def read_eval_dataset(filepath): relevant_id = relevant_id.split(" ") candidate_ids = candidate_ids.split(" ") negative_ids = set(candidate_ids) - set(relevant_id) - dataset.append( - { - "query": corpus[query_id], - "positive": [corpus[pid] for pid in relevant_id], - "negative": [corpus[pid] for pid in negative_ids], - } - ) + data["query"].append(corpus[query_id]) + data["positive"].append([corpus[pid] for pid in relevant_id]) + data["negative"].append([corpus[pid] for pid in negative_ids]) dev_test_ids.add(query_id) dev_test_ids.update(candidate_ids) + dataset = Dataset.from_dict(data) return dataset -dev_dataset = read_eval_dataset(os.path.join(askubuntu_folder, "dev.txt")) +eval_dataset = read_eval_dataset(os.path.join(askubuntu_folder, "dev.txt")) test_dataset = read_eval_dataset(os.path.join(askubuntu_folder, "test.txt")) - ## Now we need a list of train sentences. ## In this example we simply use all sentences that don't appear in the train/dev set -train_sentences = [] -for id, sentence in corpus.items(): - if id not in dev_test_ids: - train_sentences.append(sentence) - - -logging.info(f"{len(train_sentences)} train sentences") - -################# Initialize an SBERT model ################# -model_name = sys.argv[1] if len(sys.argv) >= 2 else "bert-base-uncased" -word_embedding_model = models.Transformer(model_name) -# Apply **cls** pooling to get one fixed sized sentence vector -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls") -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -################# Train and evaluate the model (it needs about 1 hour for one epoch of AskUbuntu) ################# -# We wrap our training sentences in the DenoisingAutoEncoderDataset to add deletion noise on the fly -train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences) -train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) -train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True) - -# Create a dev evaluator -dev_evaluator = evaluation.RerankingEvaluator(dev_dataset, name="AskUbuntu dev") - -logging.info("Dev performance before training") +train_sentences = [sentence for id, sentence in corpus.items() if id not in dev_test_ids] +train_dataset = Dataset.from_dict({"text": train_sentences}) + + +def noise_fn(text, del_ratio=0.6): + from nltk import word_tokenize + from nltk.tokenize.treebank import TreebankWordDetokenizer + + words = word_tokenize(text) + n = len(words) + if n == 0: + return text + + kept_words = [word for word in words if random.random() < del_ratio] + # Guarantee that at least one word remains + if len(kept_words) == 0: + return {"noisy": random.choice(words)} + + noisy_text = TreebankWordDetokenizer().detokenize(kept_words) + return {"noisy": noisy_text} + + +# TSDAE requires a dataset with 2 columns: a text column and a noisified text column +# Here we are using a function to delete some words, but you can use any other method to noisify your text +train_dataset = train_dataset.map(noise_fn, input_columns="text") +print(train_dataset) +print(train_dataset[0]) +""" +Dataset({ + features: ['text', 'noisy'], + num_rows: 160436 +}) +{ + 'text': "how to get the `` your battery is broken '' message to go away ?", + 'noisy': 'how to get "battery is broken go?', +} +""" +print(eval_dataset) +print(test_dataset) +""" +Dataset({ + features: ['query', 'positive', 'negative'], + num_rows: 189 +}) +Dataset({ + features: ['query', 'positive', 'negative'], + num_rows: 186 +}) +""" + +# 3. Define our training loss: https://sbert.net/docs/package_reference/sentence_transformer/losses.html#denoisingautoencoderLoss +# Note that this will likely result in warnings as we're loading 'model_name' as a decoder, but it likely won't +# have weights for that yet. This is fine, as we'll be training it from scratch. +train_loss = DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +logging.info("Evaluation before training:") +dev_evaluator = RerankingEvaluator(eval_dataset, name="AskUbuntu-dev") dev_evaluator(model) -total_steps = 20000 -logging.info("Start training") -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + learning_rate=3e-5, + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=5000, + save_strategy="steps", + save_steps=5000, + save_total_limit=2, + logging_steps=1000, + run_name="tsdae-askubuntu", # Will be used in W&B if `wandb` is installed +) + +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + loss=train_loss, evaluator=dev_evaluator, - evaluation_steps=1000, - epochs=1, - steps_per_epoch=total_steps, - weight_decay=0, - scheduler="constantlr", - optimizer_params={"lr": 3e-5}, - output_path=result_folder, - show_progress_bar=True, ) +trainer.train() + +# 7. Evaluate the model performance on the test set after training +logging.info("Evaluation after training:") +test_evaluator = RerankingEvaluator(test_dataset, name="AskUbuntu-test") +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-tsdae-askubuntu") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-tsdae')`." + ) diff --git a/examples/unsupervised_learning/TSDAE/train_stsb_tsdae.py b/examples/unsupervised_learning/TSDAE/train_stsb_tsdae.py index d7afc3976..ddfec4eba 100644 --- a/examples/unsupervised_learning/TSDAE/train_stsb_tsdae.py +++ b/examples/unsupervised_learning/TSDAE/train_stsb_tsdae.py @@ -1,19 +1,19 @@ -import csv -import gzip import logging -import os +import random +import traceback from datetime import datetime -from torch.utils.data import DataLoader +from datasets import load_dataset -from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer, datasets, losses, models, util +from sentence_transformers import SentenceTransformer, models from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator +from sentence_transformers.losses import DenoisingAutoEncoderLoss +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # Training parameters model_name = "bert-base-uncased" @@ -21,90 +21,131 @@ num_epochs = 1 max_seq_length = 75 -# Save path to store our model -model_save_path = "output/training_stsb_tsdae-{}-{}-{}".format( - model_name, train_batch_size, datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -) - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "data/stsbenchmark.tsv.gz" +output_dir = f"output/training_stsb_tsdae-{model_name.replace('/', '-')}-{train_batch_size}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - -# Defining our sentence transformer model +# 1. Defining our sentence transformer model word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls") model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - - -# We use 1 Million sentences from Wikipedia to train our model -wikipedia_dataset_path = "data/wiki1m_for_simcse.txt" -if not os.path.exists(wikipedia_dataset_path): - util.http_get( - "https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt", - wikipedia_dataset_path, - ) - -# train_samples is a list of InputExample objects where we pass the same sentence twice to texts, i.e. texts=[sent, sent] -train_sentences = [] -with open(wikipedia_dataset_path, encoding="utf8") as fIn: - for line in fIn: - line = line.strip() - if len(line) >= 10: - train_sentences.append(line) - -# Read STSbenchmark dataset and use it as development set -logging.info("Read STSbenchmark dev dataset") -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - - if row["split"] == "dev": - dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - elif row["split"] == "test": - test_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples( - dev_samples, batch_size=train_batch_size, name="sts-dev" +# or to load a pre-trained SentenceTransformer model OR use mean pooling +# model = SentenceTransformer(model_name) +# model.max_seq_length = max_seq_length + +# 2. We use 1 Million sentences from Wikipedia to train our model: +# https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse +dataset = load_dataset("princeton-nlp/datasets-for-simcse", split="train") + + +def noise_fn(text, del_ratio=0.6): + from nltk import word_tokenize + from nltk.tokenize.treebank import TreebankWordDetokenizer + + words = word_tokenize(text) + n = len(words) + if n == 0: + return text + + kept_words = [word for word in words if random.random() < del_ratio] + # Guarantee that at least one word remains + if len(kept_words) == 0: + return {"noisy": random.choice(words)} + + noisy_text = TreebankWordDetokenizer().detokenize(kept_words) + return {"noisy": noisy_text} + + +# TSDAE requires a dataset with 2 columns: a text column and a noisified text column +# Here we are using a function to delete some words, but you can use any other method to noisify your text +dataset = dataset.map(noise_fn, input_columns="text") +dataset = dataset.train_test_split(test_size=10000) +train_dataset = dataset["train"] +eval_dataset = dataset["test"] +print(train_dataset) +print(train_dataset[0]) +""" +Dataset({ + features: ['text', 'noisy'], + num_rows: 990000 +}) +{ + 'text': 'Oseltamivir is considered to be the primary antiviral drug used to combat avian influenza, commonly known as the bird flu.', + 'noisy': 'to be the primary antiviral drug used combat influenza commonly as the bird flu.', +} +""" + +# 3. Define our training loss: https://sbert.net/docs/package_reference/sentence_transformer/losses.html#denoisingautoencoderLoss +# Note that this will likely result in warnings as we're loading 'model_name' as a decoder, but it likely won't +# have weights for that yet. This is fine, as we'll be training it from scratch. +train_loss = DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples( - test_samples, batch_size=train_batch_size, name="sts-test" -) - -# We train our model using the MultipleNegativesRankingLoss -train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences) -train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, drop_last=True) -train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True) - - -evaluation_steps = 1000 -logging.info(f"Training sentences: {len(train_sentences)}") -logging.info("Performance before training") +logging.info("Evaluation before training:") dev_evaluator(model) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=evaluation_steps, - output_path=model_save_path, - weight_decay=0, - warmup_steps=100, - optimizer_params={"lr": 3e-5}, - use_amp=True, # Set to True, if your GPU supports FP16 cores +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + learning_rate=3e-5, + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=10000, + save_strategy="steps", + save_steps=10000, + save_total_limit=2, + logging_steps=1000, + run_name="tsdae", # Will be used in W&B if `wandb` is installed ) -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - - -model = SentenceTransformer(model_save_path) -test_evaluator(model, output_path=model_save_path) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-tsdae") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-tsdae')`." + ) diff --git a/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py b/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py index 13cc5eaef..cebe29356 100644 --- a/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py +++ b/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py @@ -1,83 +1,167 @@ -""" -This file loads sentences from a provided text file. It is expected, that the there is one sentence per line in that text file. - -TSDAE will be training using these sentences. Checkpoints are stored every 500 steps to the output folder. - -Usage: -python train_tsdae_from_file.py path/to/sentences.txt - -""" - import gzip import logging +import random import sys +import traceback from datetime import datetime import tqdm -from torch.utils.data import DataLoader +from datasets import Dataset, load_dataset -from sentence_transformers import LoggingHandler, SentenceTransformer, datasets, losses, models +from sentence_transformers import SentenceTransformer, models +from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator +from sentence_transformers.losses import DenoisingAutoEncoderLoss +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) -# Train Parameters +# Training parameters model_name = "bert-base-uncased" -batch_size = 8 +train_batch_size = 8 +num_epochs = 1 +max_seq_length = 75 +output_dir = f"output/training_tsdae-{model_name.replace('/', '-')}-{train_batch_size}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + +# 1. Defining our sentence transformer model +word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) +pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls") +model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) +# or to load a pre-trained SentenceTransformer model OR use mean pooling +# model = SentenceTransformer(model_name) +# model.max_seq_length = max_seq_length + +# 2. We use a file provided by the user to train our model # Input file path (a text file, each line a sentence) if len(sys.argv) < 2: print(f"Run this script with: python {sys.argv[0]} path/to/sentences.txt") exit() filepath = sys.argv[1] - -# Save path to store our model -output_name = "" -if len(sys.argv) >= 3: - output_name = "-" + sys.argv[2].replace(" ", "_").replace("/", "_").replace("\\", "_") - -model_output_path = "output/train_tsdae{}-{}".format(output_name, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) - - -################# Read the train corpus ################# train_sentences = [] with ( gzip.open(filepath, "rt", encoding="utf8") if filepath.endswith(".gz") else open(filepath, encoding="utf8") as fIn ): - for line in tqdm.tqdm(fIn, desc="Read file"): + for line in tqdm.tqdm(fIn, desc="Reading file"): line = line.strip() if len(line) >= 10: train_sentences.append(line) +dataset = Dataset.from_dict({"text": train_sentences}) -logging.info(f"{len(train_sentences)} train sentences") +def noise_fn(text, del_ratio=0.6): + from nltk import word_tokenize + from nltk.tokenize.treebank import TreebankWordDetokenizer -################# Initialize an SBERT model ################# + words = word_tokenize(text) + n = len(words) + if n == 0: + return text -word_embedding_model = models.Transformer(model_name) -# Apply **cls** pooling to get one fixed sized sentence vector -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls") -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) + kept_words = [word for word in words if random.random() < del_ratio] + # Guarantee that at least one word remains + if len(kept_words) == 0: + return {"noisy": random.choice(words)} -################# Train and evaluate the model (it needs about 1 hour for one epoch of AskUbuntu) ################# -# We wrap our training sentences in the DenoisingAutoEncoderDataset to add deletion noise on the fly -train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences) -train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) -train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True) - - -logging.info("Start training") -model.fit( - train_objectives=[(train_dataloader, train_loss)], - epochs=1, - weight_decay=0, - scheduler="constantlr", - optimizer_params={"lr": 3e-5}, - show_progress_bar=True, - checkpoint_path=model_output_path, - use_amp=False, # Set to True, if your GPU supports FP16 cores + noisy_text = TreebankWordDetokenizer().detokenize(kept_words) + return {"noisy": noisy_text} + + +# TSDAE requires a dataset with 2 columns: a text column and a noisified text column +# Here we are using a function to delete some words, but you can use any other method to noisify your text +dataset = dataset.map(noise_fn, input_columns="text") +dataset = dataset.train_test_split(test_size=10000) +train_dataset = dataset["train"] +eval_dataset = dataset["test"] +print(train_dataset) +print(train_dataset[0]) +""" +Dataset({ + features: ['text', 'noisy'], + num_rows: 990000 +}) +{ + 'text': 'Oseltamivir is considered to be the primary antiviral drug used to combat avian influenza, commonly known as the bird flu.', + 'noisy': 'to be the primary antiviral drug used combat influenza commonly as the bird flu.', +} +""" + +# 3. Define our training loss: https://sbert.net/docs/package_reference/sentence_transformer/losses.html#denoisingautoencoderLoss +# Note that this will likely result in warnings as we're loading 'model_name' as a decoder, but it likely won't +# have weights for that yet. This is fine, as we'll be training it from scratch. +train_loss = DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) +logging.info("Evaluation before training:") +dev_evaluator(model) + +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + learning_rate=3e-5, + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=1000, + save_strategy="steps", + save_steps=1000, + save_total_limit=2, + logging_steps=100, + run_name="tsdae", # Will be used in W&B if `wandb` is installed +) + +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", ) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-tsdae") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-tsdae')`." + )