mugpt is a fork of seqax by MatX modified to run on Google's TPU v4-32s and uses the hyperparameter transfer described by Everett et al. Previously, I experimented with a modified form of attention "SharedKV" which resides in the sharedkv branch as well as mu-parameter transfer described by Yang et al. For more information on these investigations, take a look here:
- Benchmarking Learning Rate Transfer
- Minimizing HBM usage through SharedKV
- Exploring the best approaches for implementing muP
The installation procedure is identical to that described in seqax.
-
Install
graphviz
from your system package manager: e.g.brew install graphviz
orapt install graphviz
. -
Install Python dependencies, typically inside a virtualenv:
python -m pip install -r requirements-cpu.txt
.NOTE: the
requirements-cpu.txt
is configured for CPU-based installation. For GPU or TPU installation, you may need a different install of JAX and jaxlib. Consult the JAX install documentation. If your GPU environment has a Torch-GPU installation, you may need to switch it to a Torch-CPU installation to avoid conflicts with JAX-GPU.
For development and testing you can run on CPU. Typically you'd use our synthetic dataset (which is checked into this repository) or the Huggingface data loader and you'd set XLA flags to simulate multiple devices so as to test that parallelism is working as intended:
XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m train --config-name=local_test_synthetic +paths.model_name=synthetic_000
The paths.model_name
flag specifies which subdirectory on disk (inside /tmp
) to write model checkpoints to. You'll typically want to change this when starting a new model run.
Thanks to the MatX team for their implementation of GPT in seqax which I used to implement muP and SharedKV attention.
Thanks to the Google TPU Research Cloud, which has supported my investigations.