Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] GRPO #20

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fast_llm/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int
data_sample_warn_time_ms=self._config.data_sample_warn_time_ms,
)
)
for phase, datasets in self._sampled_datasets.items()
for phase, datasets in self._sampled_datasets.items() # check data/dataset.py
}

def get_iterator(
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/models/auto.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig
from fast_llm.models.grpo.config import GRPOModelConfig, GRPOTrainerConfig
from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig
from fast_llm.utils import Registry

model_registry = Registry(
"Model",
{
"gpt": GPTModelConfig,
"gpt_custom": CustomModelConfig,
"grpo": GRPOModelConfig,
},
)

trainer_registry = Registry(
"Model",
{
"gpt": GPTTrainerConfig,
"gpt_custom": CustomTrainerConfig,
"grpo": GRPOTrainerConfig,
},
)
Empty file.
72 changes: 72 additions & 0 deletions fast_llm/models/grpo/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from fast_llm.config import Field, FieldUpdate, config_class
from fast_llm.data.config import DataConfig
from fast_llm.models.gpt.config import (
GPTArchitectureConfig,
GPTBaseModelConfig,
GPTModelConfig,
GPTTrainerConfig,
PretrainedGPTModelConfig,
)


@config_class()
class GRPOConfig:
epsilon: float = Field(default=0.2, desc="PPO clipping parameter")
kl_coef: float = Field(default=0.1, desc="KL divergence coefficient")
ratio_threshold: float = Field(default=1.5, desc="Early stopping ratio threshold")
use_advantages: bool = Field(default=True, desc="Use advantages instead of raw rewards")


@config_class()
class GRPODataConfig(DataConfig):
# TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything.
pass


@config_class()
class GRPOArchitectureConfig(GPTArchitectureConfig):
# TODO: Add custom base model architecture config parameters, if any.
pass


@config_class()
class GRPOBaseModelConfig(GPTBaseModelConfig, GRPOArchitectureConfig):
# TODO: Add custom other base model config parameters, if any.
architecture_cls = GRPOArchitectureConfig
grpo: GRPOConfig = Field(default_factory=GRPOConfig, desc="GRPO specific configuration")


@config_class()
class GRPOModelConfig(GPTModelConfig):
# TODO: Add custom model config parameters, if any (typically none).
base_model: GRPOBaseModelConfig = FieldUpdate(default_factory=GRPOBaseModelConfig)

@classmethod
def get_model_class(cls):
from fast_llm.models.grpo.model import GRPOModel

return GRPOModel

@classmethod
def get_huggingface_model_class(cls):
from fast_llm.models.grpo.huggingface import HuggingfaceGRPOModelForCausalLM

return HuggingfaceGRPOModelForCausalLM


@config_class()
class PretrainedGRPOModelConfig(PretrainedGPTModelConfig):
model: GRPOModelConfig = FieldUpdate(default_factory=GRPOModelConfig)


@config_class()
class GRPOTrainerConfig(PretrainedGRPOModelConfig, GPTTrainerConfig):
# TODO: Add custom trainer config parameters, if any (typically none).

data: GRPODataConfig = FieldUpdate(default_factory=GRPODataConfig)

@classmethod
def get_trainer_class(cls):
from fast_llm.models.grpo.trainer import GRPOTrainer

return GRPOTrainer
90 changes: 90 additions & 0 deletions fast_llm/models/grpo/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import json
import torch
from fast_llm.data.data import Data, DatasetSource
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.models.grpo.config import GRPODataConfig
from fast_llm.data.dataset import BlendedDataset, SampledDataset
from fast_llm.utils import Assert


class GRPODataset(SampledDataset):
"""Dataset wrapper that adds GRPO-specific fields (rewards, advantages, etc)"""
def __init__(self, base_dataset: SampledDataset, data_path: str):
self.base_dataset = base_dataset
self.data_path = data_path

# Load the JSONL data
self.data = []
with open(data_path, 'r') as f:
for line in f:
self.data.append(json.loads(line))

def __len__(self):
return len(self.base_dataset)

def __getitem__(self, idx):
item = self.base_dataset[idx]
data_item = self.data[idx]

# Extract fields from the JSONL data
batch = {
"input_ids": item, # Original input tokens
"rewards": torch.tensor(data_item["reward"]),
"old_logprobs": torch.tensor(data_item["logprobs"]), # These are the logprobs from previous iteration
"ref_logprobs": torch.tensor(data_item["ref_logprobs"]),
}

# Compute advantages if not provided in data
# Here we're using rewards as advantages, but you might want to implement
# proper advantage estimation
batch["advantages"] = batch["rewards"].clone()

return batch


class GRPOData(Data):
def __init__(
self,
config: GRPODataConfig,
distributed_config: DistributedConfig,
vocab_size: int,
max_sequence_length: int,
):
super().__init__(config, distributed_config, vocab_size, max_sequence_length)

def setup(self, distributed, samples_per_phase):
# setup the base data infrastructure
super().setup(distributed, samples_per_phase)

# wrap each dataset with GRPO-specific functionality
for phase in self._blended_datasets:
if isinstance(self._blended_datasets[phase], BlendedDataset):
# if it's a blended dataset, wrap each underlying dataset
for i, dataset in enumerate(self._blended_datasets[phase].datasets):
dataset = GRPODataset(
dataset,
data_path=self._dataset_prefixes[f"dataset_{i}"]
)
else:
# single dataset case
self._blended_datasets[phase] = GRPODataset(
self._blended_datasets[phase],
data_path=next(iter(self._dataset_prefixes.values()))
)

def get_iterator(
self,
batch_config,
phase,
*,
consumed_samples,
num_workers,
prefetch_factor=None,
):
return super().get_iterator(
batch_config,
phase,
consumed_samples=consumed_samples,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
)
77 changes: 77 additions & 0 deletions fast_llm/models/grpo/head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import torch.nn.functional as F
from fast_llm.layers.language_model.head import LanguageModelHead
from fast_llm.layers.language_model.config import LanguageModelLossNames

class GRPOHead(LanguageModelHead):
def masked_mean(self, values: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
"""Calculate mean of values with masks applied"""
return (values * masks).sum() / (masks.sum() + 1e-8)

def compute_grpo_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
rewards: torch.Tensor,
advantages: torch.Tensor,
ref_logprobs: torch.Tensor,
old_logprobs: torch.Tensor,
config: GRPOConfig,
) -> torch.Tensor:
masks = labels != -100
masks = masks[:, 1:]

new_log_probs = torch.gather(
F.log_softmax(logits[:, :-1, :], dim=-1),
dim=2,
index=labels[:, 1:].unsqueeze(2),
).squeeze(2)

# surrogate loss calculation
log_ratio_new_old = new_log_probs - old_logprobs
ratio_new_old = torch.exp(log_ratio_new_old)
weights = advantages if config.use_advantages else rewards

surr1 = ratio_new_old * weights
clamped_ratio = torch.clamp(
ratio_new_old,
1 - config.epsilon,
1 + config.epsilon
)
surr2 = clamped_ratio * weights
surrogate_loss = torch.min(surr1, surr2)

# KL divergence approximation
log_ratio_ref_new = ref_logprobs - new_log_probs
approx_kl = torch.exp(log_ratio_ref_new) - log_ratio_ref_new - 1

# Final loss computation
loss = -self.masked_mean(
surrogate_loss - config.kl_coef * approx_kl,
masks
)

# Early stopping based on ratio threshold
if self.masked_mean(ratio_new_old, masks) > config.ratio_threshold:
loss = loss * 0

return loss

def forward(self, input_: torch.Tensor, kwargs: dict):
# Regular language model forward pass
output = super().forward(input_, kwargs)

# If we have GRPO inputs, compute GRPO loss
if all(k in kwargs for k in ["rewards", "advantages", "ref_logprobs", "old_logprobs"]):
grpo_loss = self.compute_grpo_loss(
logits=kwargs["logits"],
labels=kwargs["labels"],
rewards=kwargs["rewards"],
advantages=kwargs["advantages"],
ref_logprobs=kwargs["ref_logprobs"],
old_logprobs=kwargs["old_logprobs"],
config=kwargs["grpo_config"],
)
kwargs[LanguageModelLossNames.grpo_loss] = grpo_loss

return output
18 changes: 18 additions & 0 deletions fast_llm/models/grpo/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from fast_llm.models.grpo.config import GRPOModelConfig
from fast_llm.models.grpo.model import GRPOModel
from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelConfig, HuggingfaceGPTModelForCausalLM


class HuggingfaceCustomModelConfig(HuggingfaceGPTModelConfig):
model_type = "fast_llm_gpt_custom"
model_config_class = GRPOModelConfig
fast_llm_config: GRPOModelConfig


class HuggingfaceCustomModelForCausalLM(HuggingfaceGPTModelForCausalLM):
# TODO: Implement changes in huggingface interface, if any.
# Ex.: Return predictions instead of logits.
config_class = HuggingfaceCustomModelConfig
config: HuggingfaceCustomModelConfig
model_class = GRPOModel
_fast_llm_model: GRPOModel
77 changes: 77 additions & 0 deletions fast_llm/models/grpo/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.layers.language_model.embedding import LanguageModelEmbedding
from fast_llm.layers.transformer.transformer import TransformerLayer
from fast_llm.models.grpo.head import GRPOHead
from fast_llm.models.grpo.config import GRPOBaseModelConfig, GRPOModelConfig
from fast_llm.models.gpt.model import GPTBaseModel, GPTModel


class GRPOBaseModel(GPTBaseModel):
_config: GRPOBaseModelConfig
config_cls = GRPOBaseModelConfig

def __init__(
self,
config: GRPOModelConfig,
distributed_config: DistributedConfig,
):
super().__init__(config, distributed_config)
assert self._config.transformer.use_rotary_position_embeddings
assert not self._config.use_absolute_position_embeddings

def get_layers(self):
return [
LanguageModelEmbedding(self._config, self._tensor_space),
*[
TransformerLayer(
self._config.transformer,
self._tensor_space,
layer_index=i + 1,
)
for i in range(self._config.transformer.num_layers)
],
GRPOHead(self._config, self._tensor_space), # Use our custom head
]

def preprocess(
self,
batch: dict,
preprocessed_meta=None,
*,
phase: PhaseType,
iteration: int,
metrics=None
):
# Extract GRPO specific inputs
grpo_inputs = {
"rewards": batch.pop("rewards")[:, 1:],
"advantages": batch.pop("advantages")[:, 1:],
"ref_logprobs": batch.pop("ref_logprobs")[:, 1:],
"old_logprobs": batch.pop("old_logprobs")[:, 1:],
"grpo_config": self._config.grpo,
}

# Process the remaining inputs using parent class
preprocessed = super().preprocess(
batch["input_ids"],
preprocessed_meta,
phase=phase,
iteration=iteration,
metrics=metrics
)

# Add GRPO inputs to kwargs
for tokens, kwargs in preprocessed:
kwargs.update(grpo_inputs)

return preprocessed

@property
def loss_defs(self):
# TODO: Adjust or reimplement.
return super().loss_defs


class GRPOModel(GPTModel):
config_class = GRPOModelConfig
base_model_class = GRPOBaseModel
Loading