Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation #20

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,6 @@ data/
*.pkl
*.json
__pycache__/
datasets
checkpoints
.venv
43 changes: 43 additions & 0 deletions docs/data.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Data

This script defines a flexible dataset loading and processing framework designed for machine learning models, particularly those dealing with natural language processing (NLP) and potentially vision tasks. The framework is built to work with the JAX library for high-performance machine learning and supports parallel processing and distributed training. Here's an overview of the main components:

## DatasetFactory Class

A factory class for creating dataset instances based on configuration parameters. It supports loading datasets from Hugging Face's datasets library (huggingface type), as well as custom JSON-formatted datasets (json and json_vision types).
It provides a method to get the default configuration for a dataset, which can be customized with specific parameters.

## TextProcessor Class

Processes text data by encoding strings into token IDs using a provided tokenizer. It supports adding special tokens (like BOS and EOS) and can process multiple fields from the data, concatenating them with a specified separator.
The default configuration and the processing behavior can be customized.

## VisionTextProcessor Class

Designed for processing datasets that include both vision (image or video frames) and text data. It handles encoding of textual data and can integrate special tokens indicating the start and end of vision-related tokens.
Supports custom configurations for handling vision data, including specifying the number of tokens per frame and the maximum number of frames.

## HuggingfaceDataset Class

Loads and processes datasets from the Hugging Face datasets library. It can stream data, making it efficient for large datasets.
The data is processed in chunks, with each chunk transformed into model input and target arrays, along with a loss mask to indicate which tokens should contribute to the loss calculation.

## JsonDataset Class

Loads and processes datasets from newline-delimited JSON files, where each line contains a JSON object representing a data example.
Supports parallel processing to tokenize and encode the data efficiently across multiple CPU cores.
Data examples are batched and padded as necessary to create fixed-size arrays suitable for model training.

## JsonVisionDataset Class

Similar to JsonDataset but specifically designed for datasets that include both vision and text data.
It can handle special tokens for vision data and supports different modes for padding or not padding the batches.

## General Workflow

*Configuration*: The user specifies the dataset type and configuration parameters, including paths to data files, batch sizes, sequence lengths, and any special tokenization or processing requirements.
Dataset Loading: Based on the configuration, the appropriate dataset class is instantiated, which loads the data and prepares it for processing.
Data Processing: Text and/or vision data is tokenized and encoded according to the specified processing rules. The data is then batched, with options for padding batches to a fixed size.

*Iteration*: The dataset objects are iterable, providing batches of data ready for input into a machine learning model. Each batch includes input tokens, target tokens (for supervised learning), and a mask indicating which tokens should be considered for loss calculation.
This framework is highly modular and customizable, making it suitable for a wide range of machine learning tasks and models. It leverages JAX's capabilities for efficient computation and is designed with distributed and parallel processing in mind.
41 changes: 41 additions & 0 deletions docs/llama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# LLama

This script is structured into multiple sections, each defining classes and functions related to the LLaMA model, its configuration, tokenization, and various utilities for handling model layers and attention mechanisms. Here's a detailed overview:

## LLaMAConfig Class

Defines the configuration for a LLaMA model, including parameters like vocabulary size, hidden layer size, and the number of attention heads. It supports loading configurations for different sizes of LLaMA models (e.g., 200m, 1b, etc.).

## FlaxLLaMAAttention Class

Implements the attention mechanism for LLaMA, including query, key, and value projections, as well as the attention calculation itself. It supports causal attention for autoregressive tasks and incorporates options for efficient attention mechanisms like Flash Attention.

## FlaxLLaMAMLP Class

Defines the feed-forward network (MLP) used within each Transformer block, including two linear layers and a GELU activation function.

## FlaxLLaMABlock Class

Represents a single Transformer block, combining the attention and MLP components, along with layer normalization.

## FlaxLLaMAPreTrainedModel and FlaxLLaMAModule Classes

Provide the base implementation for a LLaMA model in Flax, including methods for weight initialization and handling pretrained models.

## FlaxLLaMABlockCollection Class

Manages a collection of Transformer blocks, allowing for sequential processing of inputs through multiple blocks.

## FlaxLLaMAModel and FlaxLLaMAForCausalLM Classes

Define specific model variants, such as a basic LLaMA model and a causal language model variant for tasks like text generation.

## LLaMATokenizer Class

Implements tokenization for LLaMA using SentencePiece, including methods for encoding text into tokens and decoding tokens back into text.

## Utility Functions and Classes

Include various helper functions and classes such as RMSNorm for RMS normalization, apply_rotary_emb for applying rotary embeddings to queries and keys, and methods for managing model parameters and configurations.

Each class and function is designed to be modular and interoperable, allowing for flexible configuration and usage of the LLaMA model components. The use of Flax and JAX libraries facilitates efficient training and inference on hardware accelerators.
53 changes: 53 additions & 0 deletions docs/ring_attention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Ring Attention

This module implements the forward and backward passes of the ring attention mechanism, which is designed for efficient computation on TPUs, especially when handling large sequences. It supports blockwise computation to reduce memory cost and incorporates fused attention for TPU compatibility. The module is structured to accommodate both standard and ring-flash attention mechanisms, with an emphasis on blockwise processing to optimize performance and memory usage.

## Ring Attention Forward Pass

`_ring_attention_fwd`

This function computes the forward pass of the ring attention mechanism, dividing the computation into blocks for efficiency. It uses a scan operation to iterate over key-value (KV) blocks, applying blockwise attention and rotating KV pairs across TPU cores to implement the ring structure.

## Ring Attention Backward Pass

`_ring_attention_bwd`

This function handles the backward pass, computing gradients with respect to the inputs. It mirrors the forward pass but in reverse, iterating over the blocks and applying the backward computations for blockwise attention.

## Standard Attention Forward Pass

`_ring_attention_standard_fwd`

A variant of the ring attention forward pass that does not use blockwise computation. It's more straightforward but less memory efficient compared to the blockwise version.

## Blockwise Attention Functions

`_blockwise_attention_fwd` and `_blockwise_attention_bwd`

These functions are core to the blockwise computation, handling the forward and backward computations within each block. They are designed to be efficient and compatible with TPU architecture.

## Ring Flash Attention TPU-Compatible Functions

`_ring_flash_attention_fwd_tpu` and `_ring_flash_attention_bwd_tpu`

These functions are specialized versions of the ring attention mechanism, optimized for TPU execution. They leverage TPU-specific operations and structures to achieve high performance.

## Utility Functions

The module includes several utility functions, such as `_chunk_attention_bias` for computing attention bias within chunks and `_flash_attention` for a fused attention mechanism that is efficient on TPUs.

## Data Structures

The module defines several data structures, like SegmentIds and BlockSizes, to organize and manage the dimensions and indices involved in blockwise and ring attention computations.

## Blockwise Computation

This approach divides the input into smaller blocks, allowing for more efficient processing by reducing memory requirements and leveraging parallelism.

## Ring Structure

In the context of TPU computation, the ring structure refers to a method where data (e.g., KV pairs) is passed in a ring-like fashion across TPU cores, enabling efficient parallel computation.
## Fused Attention

This technique combines multiple attention-related operations into a single, more efficient operation, particularly beneficial on TPUs where memory bandwidth can be a limiting factor.
This module is a comprehensive implementation of advanced attention mechanisms tailored for high-performance computing environments, particularly TPUs, with a focus on efficiency and scalability.
34 changes: 34 additions & 0 deletions docs/train.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Training
A training script for LWM model designed for use with the JAX library using a LLaMA (Large Language Model) and its variations, including ones for video and text processing. The script is structured to support distributed training across multiple devices or nodes, using functionalities from JAX for parallel execution and the flax library for model state management. Here's a high-level overview of how it works:

## Configuration and Initialization:

Default configurations and flags, include data modality (text or vision+text), dataset loading, model configuration, optimizer setup, logging, and checkpointing.
The main function initializes distributed training using `JaxDistributedConfig` and sets up logging with `tux.WandBLogger` using the integration with the Weights & Biases platform for experiment tracking.

## Model and Dataset Setup:

Depending on the specified modality (text or vision+text), you can select appropriate model configuration and class (`LLaMAConfig` and `FlaxLLaMAForCausalLMModule` for text, `VideoLLaMAConfig` and `FlaxVideoLLaMAForCausalLMModule` for vision+text).
The dataset is loaded using a `DatasetFactory`, which provides a way to load and preprocess data suitable for training the model. There's support for resuming training from a checkpoint or loading a specific dataset state.

## Model Initialization:

The model is initialized with the specified configuration, and the script prepares for distributed training by setting up a computational mesh using `pjit` (parallel JIT compilation in JAX). This involves defining how the model's parameters and operations should be partitioned across the available hardware.
Training Loop:

The main training loop iterates over the total number of training steps. For each step, it processes a batch of data, performs a forward pass and backward pass (computing gradients), and updates the model parameters using the defined optimizer.
The script supports different data modalities by branching the logic within the training step function (train_step), handling text and vision+text differently in terms of how the model is applied and how losses are computed.

## Evaluation and Logging:

Optionally, the script can perform evaluation steps at a specified frequency, computing metrics on a separate evaluation dataset.
Metrics from both training and evaluation are logged using the configured logger, allowing for monitoring of the training process through the Weights & Biases platform.

## Checkpointing:

The script includes functionality for saving model checkpoints at specified intervals, supporting both regular checkpoints and milestone checkpoints. This allows for resuming training from a specific point and provides a way to save model states for later use or analysis.

## Finalization:

After completing the training loop, a final checkpoint may be saved, capturing the final state of the model.
The script is designed with modularity and flexibility in mind, allowing for various configurations and supporting complex distributed training setups. It leverages advanced features of JAX and Flax for efficient, scalable training of potentially large models on specialized hardware.
43 changes: 43 additions & 0 deletions docs/vision_chat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

# Vision Chat

The implementation for sampling from a VideoLLaMA model using a VQGAN model for video processing and tokenization. Here's a high-level overview of the script's functionality and key components:

## Flags Definition

The script starts by defining various flags for configuring the sampling process, including the prompt, input file, VQGAN checkpoint, temperature for generation, maximum number of frames to consider, and various configurations related to the model and tokenization.

## Sampler Class

The core of the script is the Sampler class, which encapsulates the logic for sampling from the VideoLLaMA model. Key functionalities include:

- Loading and setting up the VQGAN model for video processing.
- Initializing tokenizers for processing text and vision inputs.
- Defining the forward generation function using pjit for parallel execution across the specified JAX mesh.
- Constructing model inputs from prompts and video data, processing the video frames, and tokenizing the inputs.
- Generating outputs from the VideoLLaMA model and decoding them back to text.

## Main Function
The main function orchestrates the sampling process by initializing the necessary configurations, creating a Sampler instance, and processing the provided prompts to generate responses.

## Video Processing

The script processes video inputs (handling both image files and video formats) by resizing and cropping frames to a consistent size, encoding them using the VQGAN model, and tokenizing the encoded frames for input to the VideoLLaMA model.

## Text Processing

Prompts and other textual inputs are tokenized using the specified tokenizer configurations. Special tokens are added as needed to mark the beginning and end of vision inputs and to structure the overall input sequence for the model.

## Model Sampling

The script uses pjit to define a parallelized forward generation function that leverages the JAX mesh for distributed computation. This function generates sequences from the VideoLLaMA model based on the constructed inputs.

## Output Decoding

Generated sequences are decoded back to text, with special handling to trim outputs at the end-of-sequence token and compile the final responses.

## Usage

The script is designed to be run with command-line arguments corresponding to the defined flags, allowing users to specify the prompt, input video or image file, and various model and sampling parameters.

It is a complex integration of multiple components (video processing, tokenization, model sampling) into a cohesive pipeline for generative tasks with video and text inputs.
28 changes: 28 additions & 0 deletions docs/vision_llama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# The FlaxVideoLLaMAForCausalLM

A VideoLLaMA model architecture, specifically designed for causal language modeling tasks. This module is built to handle and generate sequences where each token is predicted based on the preceding tokens, making it suitable for tasks like text generation. Additionally, it extends these capabilities to multimodal inputs, allowing it to work with both text and visual data, which is particularly useful in scenarios where the model needs to understand and generate content based on a combination of textual and visual cues.

## Causal Language Modeling

It is tailored for generating sequences in a causal manner, meaning each token is predicted based on the previous tokens in the sequence. This is essential for generative tasks like story generation, where the narrative flows logically from one sentence to the next.

## Multimodal Input Handling
The module can process both text and visual inputs, making it versatile for a range of applications, from generating descriptive captions for images to creating content that seamlessly integrates textual and visual information.

## Configurable Generation

It offers a variety of settings for sequence generation, such as controlling the maximum length of the generated sequences, specifying the end-of-sequence token, and adjusting the randomness of generation through temperature and top-k sampling parameters.

## Efficient Generation with Caching

The module uses a caching mechanism to speed up the generation process, especially for autoregressive generation where each token's prediction can benefit from the computations done for the previous tokens.

## Flexible Output Formats

It can provide outputs in different formats, catering to various downstream needs. For example, it can return just the last hidden state, all hidden states, and attention scores depending on the configuration.

## Generation Strategies Support

The module supports different generation strategies, including greedy decoding and sampling with temperature, allowing users to balance between the diversity and accuracy of the generated sequences.

This module is a part of the broader VideoLLaMA framework with handling large-scale models and data. The FlaxVideoLLaMAForCausalLM is particularly noteworthy for its ability to bridge the gap between traditional NLP tasks and the emerging field of multimodal AI.
Loading