-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: first working version of shitty language model from scratch
- Loading branch information
1 parent
d470a8e
commit aa17926
Showing
6 changed files
with
1,455 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,396 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"References:\n", | ||
"* fastai [nb 27](https://github.com/fastai/course22p2/blob/master/nbs/27_attention.ipynb)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Cross attention:\n", | ||
"- https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html\n", | ||
"- https://arxiv.org/abs/2112.10752 -> inserting class / text embedding into K and V of attention" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"ResBlock probably needs a cross-attention call that takes the label (K,V) and the processed image (Q) and returns something of the shape of the processed image" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"paper: U-Net: Convolutional Networks for Biomedical Image Segmentation\n", | ||
"\n", | ||
"unet data: https://forum.image.sc/t/isbi-2012-site-down/57867" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"minbpe\n", | ||
"* https://www.youtube.com/watch?v=zduSFxRajkE" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.optim as optim\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"\n", | ||
"import random_neural_net_models.learner as rnnm_learner\n", | ||
"import random_neural_net_models.text as rnnm_text\n", | ||
"import random_neural_net_models.tokenization as rnnm_tok\n", | ||
"import random_neural_net_models.transformer as rnnm_trans\n", | ||
"import random_neural_net_models.utils as utils" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"utils.make_deterministic(42)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"device = utils.get_device()\n", | ||
"device" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"tom lehrer's songs: https://tomlehrersongs.com/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"path = Path(\"../data/tom-lehrer\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"files = rnnm_text.find_files(path, \"*.txt\")\n", | ||
"files" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"body_for_tokenizer = rnnm_text.concat_files(files, \"\\n\")\n", | ||
"body_for_tokenizer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"vocab_size = 200\n", | ||
"tokenizer = rnnm_tok.TokenizerRegex()\n", | ||
"tokenizer.fit(\n", | ||
" body_for_tokenizer,\n", | ||
" vocab_size=vocab_size,\n", | ||
" pattern=rnnm_tok.GPT4_SPLIT_PATTERN,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"special_token2id_map = {\n", | ||
" \"<|endoftext|>\": 100257,\n", | ||
" \"<|fim_prefix|>\": 100258,\n", | ||
" \"<|fim_middle|>\": 100259,\n", | ||
" \"<|fim_suffix|>\": 100260,\n", | ||
" \"<|endofprompt|>\": 100276,\n", | ||
"}\n", | ||
"tokenizer.register_special_tokens(special_token2id_map)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"block_size = 128\n", | ||
"ds_train = rnnm_text.TextDataset(\n", | ||
" path=path,\n", | ||
" suffix=\"*.txt\",\n", | ||
" tokenizer=tokenizer,\n", | ||
" block_size=block_size,\n", | ||
" end_of_text_token=\"<|endoftext|>\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ds_train[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# from torch.utils.data import Dataset, RandomSampler\n", | ||
"# RandomSampler(\n", | ||
"# self.train_dataset,\n", | ||
"# replacement=True,\n", | ||
"# num_samples=int(1e10),\n", | ||
"# generator=torch.manual_seed(3407),\n", | ||
"# )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"bs_train = 10\n", | ||
"dl_train = DataLoader(\n", | ||
" ds_train,\n", | ||
" batch_size=bs_train,\n", | ||
" collate_fn=rnnm_text.collate_text_dataset_to_block,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"next(iter(dl_train))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"num_blocks = 2\n", | ||
"emb_dim = 10\n", | ||
"n_tokens = block_size\n", | ||
"latent_dim = 40\n", | ||
"num_heads = 4\n", | ||
"\n", | ||
"model = rnnm_trans.LanguageModelWithTensordict(\n", | ||
" vocab_size=ds_train.vocab_size,\n", | ||
" emb_dim=emb_dim,\n", | ||
" n_tokens=n_tokens,\n", | ||
" latent_dim=latent_dim,\n", | ||
" num_heads=num_heads,\n", | ||
" num_blocks=num_blocks,\n", | ||
")\n", | ||
"# model = rnnm_trans.EncoderWithTensordict(\n", | ||
"# num_blocks=num_blocks,\n", | ||
"# enc_emb_dim=enc_emb_dim,\n", | ||
"# enc_n_tokens=enc_n_tokens,\n", | ||
"# latent_dim=latent_dim,\n", | ||
"# num_heads=num_heads,\n", | ||
"# causal=True,\n", | ||
"# vocab_size=len(tokenizer.vocab)\n", | ||
"# )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"learning_rate = 0.1\n", | ||
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n", | ||
"loss = rnnm_trans.CrossEntropyLoss()\n", | ||
"loss_callback = rnnm_learner.TrainLossCallback()\n", | ||
"\n", | ||
"save_dir = Path(\"./models\")\n", | ||
"\n", | ||
"callbacks = [loss_callback]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"learner = rnnm_learner.Learner(\n", | ||
" model,\n", | ||
" optimizer,\n", | ||
" loss,\n", | ||
" callbacks=callbacks,\n", | ||
" save_dir=save_dir,\n", | ||
" device=device,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)\n", | ||
"\n", | ||
"learner.find_learning_rate(\n", | ||
" dl_train, n_epochs=2, lr_find_callback=lr_find_callback\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"lr_find_callback.plot(yscale=\"log\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"learning_rate = 3e-2\n", | ||
"n_epochs = 5\n", | ||
"\n", | ||
"scheduler = optim.lr_scheduler.OneCycleLR(\n", | ||
" optimizer=optimizer,\n", | ||
" max_lr=learning_rate,\n", | ||
" epochs=n_epochs,\n", | ||
" steps_per_epoch=len(dl_train),\n", | ||
")\n", | ||
"scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)\n", | ||
"learner.update_callback(scheduler_callback)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"learner.fit(dl_train, n_epochs=n_epochs) # , dataloader_valid=dl_val" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"loss_callback.plot()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"inp = next(iter(dl_train))\n", | ||
"inp" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"out_ids_dense = model.generate(inp.to(device), max_new_tokens=20)\n", | ||
"out_ids_dense" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ds_train.dense_ids_to_strings(out_ids_dense.cpu())" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.