Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fittable #140

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
4078a3b
Fix tokenizer issue
stephantul Dec 22, 2024
09f888d
fix issue with warning
stephantul Dec 22, 2024
2167a4e
regenerate lock file
stephantul Dec 22, 2024
c95dca5
fix lock file
stephantul Dec 22, 2024
b5d8bb7
Try to not select 2.5.1
stephantul Dec 22, 2024
3e68669
fix: issue with dividers in utils
stephantul Dec 22, 2024
1ae4d61
Try to not select 2.5.0
stephantul Dec 22, 2024
1349b0c
fix: do not up version
stephantul Dec 22, 2024
4b83d59
Attempt special fix
stephantul Dec 22, 2024
9515b83
merge
stephantul Dec 23, 2024
dfd865b
feat: add training
stephantul Dec 23, 2024
c4ba272
merge with old
stephantul Dec 23, 2024
4713bfa
fix: no grad
stephantul Dec 24, 2024
e8058bb
use numpy
stephantul Dec 24, 2024
a59127e
Add train_test_split
stephantul Dec 24, 2024
310fbb5
fix: issue with fit not resetting
stephantul Dec 24, 2024
b1899d1
feat: add lightning
stephantul Dec 28, 2024
e27f9dc
merge
stephantul Dec 28, 2024
8df3aaf
Fix bugs
stephantul Jan 3, 2025
839d88a
fix: reviewer comments
stephantul Jan 5, 2025
8457357
fix train issue
stephantul Jan 5, 2025
a750709
fix issue with trainer
stephantul Jan 7, 2025
e83c54e
fix: truncate during training
stephantul Jan 7, 2025
803565d
feat: tokenize maximum length truncation
stephantul Jan 7, 2025
9052806
fixes
stephantul Jan 8, 2025
2f9fbf4
typo
stephantul Jan 8, 2025
f1e08c3
Add progressbar
stephantul Jan 8, 2025
bb54a76
small code changes, add docs
stephantul Jan 8, 2025
69ee4ee
fix training comments
stephantul Jan 8, 2025
9962be7
Merge branch 'main' into add-fittable
stephantul Jan 16, 2025
ffec235
Add pipeline saving
stephantul Jan 16, 2025
0af84fc
fix bug
stephantul Jan 16, 2025
c829745
fix issue with normalize test
stephantul Jan 16, 2025
9ce65a1
change default batch size
stephantul Jan 17, 2025
e1169fb
feat: add sklearn skops pipeline
stephantul Jan 20, 2025
f096824
Device handling and automatic batch size
stephantul Jan 20, 2025
ff3ebdf
Add docstrings, defaults
stephantul Jan 20, 2025
b4e966a
docs
stephantul Jan 20, 2025
8f65bfd
fix: rename
stephantul Jan 21, 2025
8cdb668
fix: rename
stephantul Jan 21, 2025
e96a72a
fix installation
stephantul Jan 21, 2025
3e76083
rename
stephantul Jan 21, 2025
9f1cb5a
Add training tutorial
stephantul Jan 23, 2025
e2d92b9
Add tutorial link
stephantul Jan 23, 2025
657cef0
Merge branch 'main' into add-fittable
stephantul Jan 24, 2025
773009f
test: add tests
stephantul Jan 24, 2025
7015341
fix tests
stephantul Jan 24, 2025
8ab8456
tests: fix tests
stephantul Jan 24, 2025
e21e61f
Address comments
stephantul Jan 26, 2025
ff75af9
Add inference reqs to train reqs
stephantul Jan 26, 2025
87de7c4
fix normalize
stephantul Jan 26, 2025
1fb33f1
update lock file
stephantul Jan 26, 2025
59f0076
Merge branch 'main' into add-fittable
stephantul Jan 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions model2vec/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from model2vec.utils import get_package_extras, importable

_REQUIRED_EXTRA = "train"

for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
importable(extra_dependency, _REQUIRED_EXTRA)
168 changes: 168 additions & 0 deletions model2vec/train/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from __future__ import annotations

from typing import Any, TypeVar

import torch
from tokenizers import Encoding, Tokenizer
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

from model2vec import StaticModel


class FinetunableStaticModel(nn.Module):
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int, pad_id: int = 0) -> None:
"""
Initialize a trainable StaticModel from a StaticModel.

:param vectors: The embeddings of the staticmodel.
:param tokenizer: The tokenizer.
:param out_dim: The output dimension of the head.
:param pad_id: The padding id. This is set to 0 in almost all model2vec models
"""
super().__init__()
self.pad_id = pad_id
self.out_dim = out_dim
self.embed_dim = vectors.shape[1]
self.vectors = vectors

self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
self.head = self.construct_head()

weights = torch.zeros(len(vectors))
weights[pad_id] = -10_000
self.w = nn.Parameter(weights)
self.tokenizer = tokenizer

def construct_head(self) -> nn.Module:
"""Method should be overridden for various other classes."""
return nn.Linear(self.embed_dim, self.out_dim)

@classmethod
def from_pretrained(
cls: type[ModelType], out_dim: int, model_name: str = "minishlab/potion-base-8m", **kwargs: Any
) -> ModelType:
"""Load the model from a pretrained model2vec model."""
model = StaticModel.from_pretrained(model_name)
return cls.from_static_model(model, out_dim, **kwargs)

@classmethod
def from_static_model(cls: type[ModelType], model: StaticModel, out_dim: int, **kwargs: Any) -> ModelType:
"""Load the model from a static model."""
embeddings_converted = torch.from_numpy(model.embedding)
return cls(
vectors=embeddings_converted,
pad_id=model.tokenizer.token_to_id("[PAD]"),
out_dim=out_dim,
tokenizer=model.tokenizer,
**kwargs,
)

def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
A forward pass and mean pooling.

This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients
to pass through.

:param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
:return: The mean over the input ids, weighted by token weights.
"""
w = self.w[input_ids]
w = torch.sigmoid(w)
zeros = (input_ids != self.pad_id).float()
w = w * zeros
# Add a small epsilon to avoid division by zero
length = zeros.sum(1) + 1e-16
embedded = self.embeddings(input_ids)
# Simulate actual mean
# Zero out the padding
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
# embedded = embedded.sum(1)
embedded = embedded / length[:, None]

return nn.functional.normalize(embedded)

def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the mean, and a classifier layer after."""
encoded = self._encode(input_ids)
return self.head(encoded), encoded
stephantul marked this conversation as resolved.
Show resolved Hide resolved

def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor:
"""
Tokenize a bunch of strings into a single padded 2D tensor.

Note that this is not used during training.

:param texts: The texts to tokenize.
:param max_length: If this is None, the sequence lengths are truncated to 512.
:return: A 2D padded tensor
"""
encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
return pad_sequence(encoded_ids, batch_first=True)

@property
def device(self) -> str:
"""Get the device of the model."""
return self.embeddings.weight.device

def to_static_model(self, config: dict[str, Any] | None = None) -> StaticModel:
"""
Convert the model to a static model.

This is useful if you want to discard your head, and consolidate the information learned by
the model to use it in a downstream task.

:param config: The config used in the StaticModel. If this is set to None, it will have no config.
:return: A static model.
"""
# Perform the forward pass on the selected device.
with torch.no_grad():
all_indices = torch.arange(len(self.embeddings.weight))[:, None].to(self.device)
vectors = self._encode(all_indices).cpu().numpy()

new_model = StaticModel(vectors=vectors, tokenizer=self.tokenizer, config=config)

return new_model


class TextDataset(Dataset):
def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None:
"""
A dataset of texts.

:param tokenized_texts: The tokenized texts. Each text is a list of token ids.
:param targets: The targets.
:raises ValueError: If the number of labels does not match the number of texts.
"""
if len(targets) != len(tokenized_texts):
raise ValueError("Number of labels does not match number of texts.")
self.tokenized_texts = tokenized_texts
self.targets = targets

def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.tokenized_texts)

def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
"""Gets an item."""
return self.tokenized_texts[index], self.targets[index]

@staticmethod
def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
"""Collate function."""
texts, targets = zip(*batch)

tensors = [torch.LongTensor(x) for x in texts]
padded = pad_sequence(tensors, batch_first=True, padding_value=0)

return padded, torch.stack(targets)

def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
"""Convert the dataset to a DataLoader."""
return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)


ModelType = TypeVar("ModelType", bound=FinetunableStaticModel)
204 changes: 204 additions & 0 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from __future__ import annotations

import logging
from collections import Counter
from typing import Any

import lightning as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import Callback, EarlyStopping
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from sklearn.model_selection import train_test_split
from tokenizers import Tokenizer
from torch import nn

from model2vec.train.base import FinetunableStaticModel, TextDataset

logger = logging.getLogger(__name__)


class ClassificationStaticModel(FinetunableStaticModel):
def __init__(
self,
*,
vectors: torch.Tensor,
tokenizer: Tokenizer,
n_layers: int,
hidden_dim: int,
out_dim: int,
pad_id: int = 0,
) -> None:
stephantul marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize a standard classifier model."""
self.n_layers = n_layers
self.hidden_dim = hidden_dim
# Alias: Follows scikit-learn. Set to dummy classes
self.classes_: list[str] = [str(x) for x in range(out_dim)]
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)

@property
def classes(self) -> list[str]:
"""Return all clasess in the correct order."""
return self.classes_

def construct_head(self) -> nn.Module:
"""Constructs a simple classifier head."""
if self.n_layers == 0:
return nn.Linear(self.embed_dim, self.out_dim)
modules = [
nn.Linear(self.embed_dim, self.hidden_dim),
nn.ReLU(),
]
for _ in range(self.n_layers - 1):
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
stephantul marked this conversation as resolved.
Show resolved Hide resolved
modules.extend([nn.Linear(self.hidden_dim, self.out_dim)])

for module in modules:
if isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight)
nn.init.zeros_(module.bias)

return nn.Sequential(*modules)

def predict(self, X: list[str]) -> list[str]:
"""Predict a class for a set of texts."""
pred: list[str] = []
for batch in range(0, len(X), 1024):
logits = self._predict(X[batch : batch + 1024])
pred.extend([self.classes[idx] for idx in logits.argmax(1)])

return pred

@torch.no_grad()
def _predict(self, X: list[str]) -> torch.Tensor:
input_ids = self.tokenize(X)
vectors, _ = self.forward(input_ids)
return vectors

def predict_proba(self, X: list[str]) -> np.ndarray:
"""Predict the probability of each class."""
pred: list[np.ndarray] = []
for batch in range(0, len(X), 1024):
logits = self._predict(X[batch : batch + 1024])
pred.append(torch.softmax(logits, dim=1).numpy())

return np.concatenate(pred)

def fit(
stephantul marked this conversation as resolved.
Show resolved Hide resolved
self,
X: list[str],
y: list[str],
**kwargs: Any,
) -> ClassificationStaticModel:
"""Fit a model."""
pl.seed_everything(42)
stephantul marked this conversation as resolved.
Show resolved Hide resolved
classes = sorted(set(y))
self.classes_ = classes

if len(self.classes) != self.out_dim:
self.out_dim = len(self.classes)

self.head = self.construct_head()
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)

label_mapping = {label: idx for idx, label in enumerate(self.classes)}
label_counts = Counter(y)
if min(label_counts.values()) < 2:
logger.info("Some classes have less than 2 samples. Stratification is disabled.")
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
X, y, test_size=0.1, random_state=42, shuffle=True
stephantul marked this conversation as resolved.
Show resolved Hide resolved
)
else:
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
X, y, test_size=0.1, random_state=42, shuffle=True, stratify=y
)

# Turn labels into a LongTensor
train_tokenized: list[list[int]] = [
encoding.ids for encoding in self.tokenizer.encode_batch_fast(train_texts, add_special_tokens=False)
]
train_labels_tensor = torch.Tensor([label_mapping[label] for label in train_labels]).long()
train_dataset = TextDataset(train_tokenized, train_labels_tensor)

val_tokenized: list[list[int]] = [
encoding.ids for encoding in self.tokenizer.encode_batch_fast(validation_texts, add_special_tokens=False)
]
val_labels_tensor = torch.Tensor([label_mapping[label] for label in validation_labels]).long()
val_dataset = TextDataset(val_tokenized, val_labels_tensor)

c = ClassifierLightningModule(self)

batch_size = 32
stephantul marked this conversation as resolved.
Show resolved Hide resolved
n_train_batches = len(train_dataset) // batch_size
callbacks: list[Callback] = [EarlyStopping(monitor="val_accuracy", mode="max", patience=5)]
if n_train_batches < 250:
trainer = pl.Trainer(max_epochs=500, callbacks=callbacks, check_val_every_n_epoch=1)
else:
val_check_interval = max(250, 2 * len(val_dataset) // batch_size)
trainer = pl.Trainer(
max_epochs=500, callbacks=callbacks, val_check_interval=val_check_interval, check_val_every_n_epoch=None
)

trainer.fit(
c,
train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size),
val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size),
)
best_model_path = trainer.checkpoint_callback.best_model_path # type: ignore

state_dict = {
k.removeprefix("model."): v for k, v in torch.load(best_model_path, weights_only=True)["state_dict"].items()
}
self.load_state_dict(state_dict)

self.eval()

return self


class ClassifierLightningModule(pl.LightningModule):
def __init__(self, model: ClassificationStaticModel) -> None:
"""Initialize the lightningmodule."""
super().__init__()
self.model = model

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Simple forward pass."""
return self.model(x)

def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Simple training step using cross entropy loss."""
x, y = batch
head_out, _ = self.model(x)
loss = nn.functional.cross_entropy(head_out, y).mean()

self.log("train_loss", loss)
return loss

def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Simple validation step using cross entropy loss and accuracy."""
x, y = batch
head_out, _ = self.model(x)
loss = nn.functional.cross_entropy(head_out, y).mean()
accuracy = (head_out.argmax(1) == y).float().mean()

self.log("val_loss", loss)
self.log("val_accuracy", accuracy, prog_bar=True)

return loss

def configure_optimizers(self) -> OptimizerLRScheduler:
"""Simple Adam optimizer."""
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.5,
patience=3,
verbose=True,
min_lr=1e-6,
threshold=0.03,
threshold_mode="rel",
)

return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ dev = [
"pytest-cov",
"ruff",
]

distill = ["torch", "transformers", "scikit-learn"]
onnx = ["onnx", "torch"]
train = ["torch", "lightning"]
stephantul marked this conversation as resolved.
Show resolved Hide resolved

[project.urls]
"Homepage" = "https://github.com/MinishLab"
Expand Down
Loading
Loading