Skip to content

Commit

Permalink
Update TSDAE examples with SentenceTransformerTrainer (#3137)
Browse files Browse the repository at this point in the history
* raises value error when num_label > 1 when using Crossencoder.rank()

* updated error message

* update tsdae example with SentenceTransformerTrainer

* lint

* Update the other TSDAE examples as well

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
JINO-ROHIT and tomaarsen authored Jan 17, 2025
1 parent cccab83 commit 0d51f4f
Show file tree
Hide file tree
Showing 4 changed files with 436 additions and 225 deletions.
45 changes: 22 additions & 23 deletions examples/unsupervised_learning/TSDAE/eval_askubuntu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@
import os
import sys

from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation, util
from datasets import Dataset

#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout
from sentence_transformers import SentenceTransformer, util
from sentence_transformers.evaluation import RerankingEvaluator

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

model = SentenceTransformer(sys.argv[1])


################# Download AskUbuntu and extract training corpus #################
askubuntu_folder = "askubuntu"
askubuntu_folder = "data/askubuntu"
training_corpus = os.path.join(askubuntu_folder, "train.unsupervised.txt")


Expand All @@ -37,15 +37,17 @@
dev_test_ids = set()
with gzip.open(os.path.join(askubuntu_folder, "text_tokenized.txt.gz"), "rt", encoding="utf8") as fIn:
for line in fIn:
splits = line.strip().split("\t")
id = splits[0]
title = splits[1]
id, title, *_ = line.strip().split("\t")
corpus[id] = title


# Read dev & test dataset
def read_eval_dataset(filepath):
dataset = []
def read_eval_dataset(filepath) -> Dataset:
data = {
"query": [],
"positive": [],
"negative": [],
}
with open(filepath) as fIn:
for line in fIn:
query_id, relevant_id, candidate_ids, bm25_scores = line.strip().split("\t")
Expand All @@ -55,15 +57,12 @@ def read_eval_dataset(filepath):
relevant_id = relevant_id.split(" ")
candidate_ids = candidate_ids.split(" ")
negative_ids = set(candidate_ids) - set(relevant_id)
dataset.append(
{
"query": corpus[query_id],
"positive": [corpus[pid] for pid in relevant_id],
"negative": [corpus[pid] for pid in negative_ids],
}
)
data["query"].append(corpus[query_id])
data["positive"].append([corpus[pid] for pid in relevant_id])
data["negative"].append([corpus[pid] for pid in negative_ids])
dev_test_ids.add(query_id)
dev_test_ids.update(candidate_ids)
dataset = Dataset.from_dict(data)
return dataset


Expand All @@ -72,11 +71,11 @@ def read_eval_dataset(filepath):


# Create a dev evaluator
dev_evaluator = evaluation.RerankingEvaluator(dev_dataset, name="AskUbuntu dev")
dev_evaluator = RerankingEvaluator(dev_dataset, name="AskUbuntu dev")

logging.info("Dev performance before training")
logging.info("Dev performance")
dev_evaluator(model)

test_evaluator = evaluation.RerankingEvaluator(test_dataset, name="AskUbuntu test")
logging.info("Test performance before training")
test_evaluator = RerankingEvaluator(test_dataset, name="AskUbuntu test")
logging.info("Test performance")
test_evaluator(model)
209 changes: 148 additions & 61 deletions examples/unsupervised_learning/TSDAE/train_askubuntu_tsdae.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,43 @@
import gzip
import logging
import os
import sys
import random
import traceback
from datetime import datetime

from torch.utils.data import DataLoader
from datasets import Dataset

from sentence_transformers import LoggingHandler, SentenceTransformer, datasets, evaluation, losses, models, util
from sentence_transformers import SentenceTransformer, models, util
from sentence_transformers.evaluation import RerankingEvaluator
from sentence_transformers.losses import DenoisingAutoEncoderLoss
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments

#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# Training parameters
model_name = "bert-base-uncased"
train_batch_size = 8
num_epochs = 1
max_seq_length = 75

output_dir = f"output/training_stsb_tsdae-{model_name.replace('/', '-')}-{train_batch_size}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

################# Download AskUbuntu and extract training corpus #################
# 1. Defining our sentence transformer model
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls")
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
# or to load a pre-trained SentenceTransformer model OR use mean pooling
# model = SentenceTransformer(model_name)
# model.max_seq_length = max_seq_length


# 2. Download the AskUbuntu dataset from https://github.com/taolei87/askubuntu
askubuntu_folder = "data/askubuntu"
result_folder = "output/askubuntu-tsdae-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
batch_size = 8

## Download the AskUbuntu dataset from https://github.com/taolei87/askubuntu
for filename in ["text_tokenized.txt.gz", "dev.txt", "test.txt", "train_random.txt"]:
filepath = os.path.join(askubuntu_folder, filename)
if not os.path.exists(filepath):
Expand All @@ -30,15 +48,17 @@
dev_test_ids = set()
with gzip.open(os.path.join(askubuntu_folder, "text_tokenized.txt.gz"), "rt", encoding="utf8") as fIn:
for line in fIn:
splits = line.strip().split("\t")
id = splits[0]
title = splits[1]
id, title, *_ = line.strip().split("\t")
corpus[id] = title


# Read dev & test dataset
def read_eval_dataset(filepath):
dataset = []
def read_eval_dataset(filepath) -> Dataset:
data = {
"query": [],
"positive": [],
"negative": [],
}
with open(filepath) as fIn:
for line in fIn:
query_id, relevant_id, candidate_ids, bm25_scores = line.strip().split("\t")
Expand All @@ -48,62 +68,129 @@ def read_eval_dataset(filepath):
relevant_id = relevant_id.split(" ")
candidate_ids = candidate_ids.split(" ")
negative_ids = set(candidate_ids) - set(relevant_id)
dataset.append(
{
"query": corpus[query_id],
"positive": [corpus[pid] for pid in relevant_id],
"negative": [corpus[pid] for pid in negative_ids],
}
)
data["query"].append(corpus[query_id])
data["positive"].append([corpus[pid] for pid in relevant_id])
data["negative"].append([corpus[pid] for pid in negative_ids])
dev_test_ids.add(query_id)
dev_test_ids.update(candidate_ids)
dataset = Dataset.from_dict(data)
return dataset


dev_dataset = read_eval_dataset(os.path.join(askubuntu_folder, "dev.txt"))
eval_dataset = read_eval_dataset(os.path.join(askubuntu_folder, "dev.txt"))
test_dataset = read_eval_dataset(os.path.join(askubuntu_folder, "test.txt"))


## Now we need a list of train sentences.
## In this example we simply use all sentences that don't appear in the train/dev set
train_sentences = []
for id, sentence in corpus.items():
if id not in dev_test_ids:
train_sentences.append(sentence)


logging.info(f"{len(train_sentences)} train sentences")

################# Initialize an SBERT model #################
model_name = sys.argv[1] if len(sys.argv) >= 2 else "bert-base-uncased"
word_embedding_model = models.Transformer(model_name)
# Apply **cls** pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls")
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

################# Train and evaluate the model (it needs about 1 hour for one epoch of AskUbuntu) #################
# We wrap our training sentences in the DenoisingAutoEncoderDataset to add deletion noise on the fly
train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True)

# Create a dev evaluator
dev_evaluator = evaluation.RerankingEvaluator(dev_dataset, name="AskUbuntu dev")

logging.info("Dev performance before training")
train_sentences = [sentence for id, sentence in corpus.items() if id not in dev_test_ids]
train_dataset = Dataset.from_dict({"text": train_sentences})


def noise_fn(text, del_ratio=0.6):
from nltk import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer

words = word_tokenize(text)
n = len(words)
if n == 0:
return text

kept_words = [word for word in words if random.random() < del_ratio]
# Guarantee that at least one word remains
if len(kept_words) == 0:
return {"noisy": random.choice(words)}

noisy_text = TreebankWordDetokenizer().detokenize(kept_words)
return {"noisy": noisy_text}


# TSDAE requires a dataset with 2 columns: a text column and a noisified text column
# Here we are using a function to delete some words, but you can use any other method to noisify your text
train_dataset = train_dataset.map(noise_fn, input_columns="text")
print(train_dataset)
print(train_dataset[0])
"""
Dataset({
features: ['text', 'noisy'],
num_rows: 160436
})
{
'text': "how to get the `` your battery is broken '' message to go away ?",
'noisy': 'how to get "battery is broken go?',
}
"""
print(eval_dataset)
print(test_dataset)
"""
Dataset({
features: ['query', 'positive', 'negative'],
num_rows: 189
})
Dataset({
features: ['query', 'positive', 'negative'],
num_rows: 186
})
"""

# 3. Define our training loss: https://sbert.net/docs/package_reference/sentence_transformer/losses.html#denoisingautoencoderLoss
# Note that this will likely result in warnings as we're loading 'model_name' as a decoder, but it likely won't
# have weights for that yet. This is fine, as we'll be training it from scratch.
train_loss = DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True)

# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss.
logging.info("Evaluation before training:")
dev_evaluator = RerankingEvaluator(eval_dataset, name="AskUbuntu-dev")
dev_evaluator(model)

total_steps = 20000
logging.info("Start training")
model.fit(
train_objectives=[(train_dataloader, train_loss)],
# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=output_dir,
# Optional training parameters:
learning_rate=3e-5,
num_train_epochs=1,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=5000,
save_strategy="steps",
save_steps=5000,
save_total_limit=2,
logging_steps=1000,
run_name="tsdae-askubuntu", # Will be used in W&B if `wandb` is installed
)

# 6. Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=train_loss,
evaluator=dev_evaluator,
evaluation_steps=1000,
epochs=1,
steps_per_epoch=total_steps,
weight_decay=0,
scheduler="constantlr",
optimizer_params={"lr": 3e-5},
output_path=result_folder,
show_progress_bar=True,
)
trainer.train()

# 7. Evaluate the model performance on the test set after training
logging.info("Evaluation after training:")
test_evaluator = RerankingEvaluator(test_dataset, name="AskUbuntu-test")
test_evaluator(model)

# 8. Save the trained & evaluated model locally
final_output_dir = f"{output_dir}/final"
model.save(final_output_dir)

# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
try:
model.push_to_hub(f"{model_name}-tsdae-askubuntu")
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{model_name}-tsdae')`."
)
Loading

0 comments on commit 0d51f4f

Please sign in to comment.