Skip to content

Commit

Permalink
feat: first working version of shitty language model from scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
eschmidt42 committed Mar 4, 2024
1 parent d470a8e commit aa17926
Show file tree
Hide file tree
Showing 6 changed files with 1,455 additions and 21 deletions.
396 changes: 396 additions & 0 deletions nbs/attention.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,396 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"References:\n",
"* fastai [nb 27](https://github.com/fastai/course22p2/blob/master/nbs/27_attention.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Cross attention:\n",
"- https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html\n",
"- https://arxiv.org/abs/2112.10752 -> inserting class / text embedding into K and V of attention"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ResBlock probably needs a cross-attention call that takes the label (K,V) and the processed image (Q) and returns something of the shape of the processed image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"paper: U-Net: Convolutional Networks for Biomedical Image Segmentation\n",
"\n",
"unet data: https://forum.image.sc/t/isbi-2012-site-down/57867"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"minbpe\n",
"* https://www.youtube.com/watch?v=zduSFxRajkE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"\n",
"import random_neural_net_models.learner as rnnm_learner\n",
"import random_neural_net_models.text as rnnm_text\n",
"import random_neural_net_models.tokenization as rnnm_tok\n",
"import random_neural_net_models.transformer as rnnm_trans\n",
"import random_neural_net_models.utils as utils"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"utils.make_deterministic(42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = utils.get_device()\n",
"device"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"tom lehrer's songs: https://tomlehrersongs.com/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = Path(\"../data/tom-lehrer\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"files = rnnm_text.find_files(path, \"*.txt\")\n",
"files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"body_for_tokenizer = rnnm_text.concat_files(files, \"\\n\")\n",
"body_for_tokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vocab_size = 200\n",
"tokenizer = rnnm_tok.TokenizerRegex()\n",
"tokenizer.fit(\n",
" body_for_tokenizer,\n",
" vocab_size=vocab_size,\n",
" pattern=rnnm_tok.GPT4_SPLIT_PATTERN,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"special_token2id_map = {\n",
" \"<|endoftext|>\": 100257,\n",
" \"<|fim_prefix|>\": 100258,\n",
" \"<|fim_middle|>\": 100259,\n",
" \"<|fim_suffix|>\": 100260,\n",
" \"<|endofprompt|>\": 100276,\n",
"}\n",
"tokenizer.register_special_tokens(special_token2id_map)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"block_size = 128\n",
"ds_train = rnnm_text.TextDataset(\n",
" path=path,\n",
" suffix=\"*.txt\",\n",
" tokenizer=tokenizer,\n",
" block_size=block_size,\n",
" end_of_text_token=\"<|endoftext|>\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds_train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# from torch.utils.data import Dataset, RandomSampler\n",
"# RandomSampler(\n",
"# self.train_dataset,\n",
"# replacement=True,\n",
"# num_samples=int(1e10),\n",
"# generator=torch.manual_seed(3407),\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bs_train = 10\n",
"dl_train = DataLoader(\n",
" ds_train,\n",
" batch_size=bs_train,\n",
" collate_fn=rnnm_text.collate_text_dataset_to_block,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"next(iter(dl_train))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_blocks = 2\n",
"emb_dim = 10\n",
"n_tokens = block_size\n",
"latent_dim = 40\n",
"num_heads = 4\n",
"\n",
"model = rnnm_trans.LanguageModelWithTensordict(\n",
" vocab_size=ds_train.vocab_size,\n",
" emb_dim=emb_dim,\n",
" n_tokens=n_tokens,\n",
" latent_dim=latent_dim,\n",
" num_heads=num_heads,\n",
" num_blocks=num_blocks,\n",
")\n",
"# model = rnnm_trans.EncoderWithTensordict(\n",
"# num_blocks=num_blocks,\n",
"# enc_emb_dim=enc_emb_dim,\n",
"# enc_n_tokens=enc_n_tokens,\n",
"# latent_dim=latent_dim,\n",
"# num_heads=num_heads,\n",
"# causal=True,\n",
"# vocab_size=len(tokenizer.vocab)\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learning_rate = 0.1\n",
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
"loss = rnnm_trans.CrossEntropyLoss()\n",
"loss_callback = rnnm_learner.TrainLossCallback()\n",
"\n",
"save_dir = Path(\"./models\")\n",
"\n",
"callbacks = [loss_callback]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learner = rnnm_learner.Learner(\n",
" model,\n",
" optimizer,\n",
" loss,\n",
" callbacks=callbacks,\n",
" save_dir=save_dir,\n",
" device=device,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)\n",
"\n",
"learner.find_learning_rate(\n",
" dl_train, n_epochs=2, lr_find_callback=lr_find_callback\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lr_find_callback.plot(yscale=\"log\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learning_rate = 3e-2\n",
"n_epochs = 5\n",
"\n",
"scheduler = optim.lr_scheduler.OneCycleLR(\n",
" optimizer=optimizer,\n",
" max_lr=learning_rate,\n",
" epochs=n_epochs,\n",
" steps_per_epoch=len(dl_train),\n",
")\n",
"scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)\n",
"learner.update_callback(scheduler_callback)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learner.fit(dl_train, n_epochs=n_epochs) # , dataloader_valid=dl_val"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss_callback.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inp = next(iter(dl_train))\n",
"inp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"out_ids_dense = model.generate(inp.to(device), max_new_tokens=20)\n",
"out_ids_dense"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds_train.dense_ids_to_strings(out_ids_dense.cpu())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit aa17926

Please sign in to comment.