This repository has been archived by the owner on Nov 25, 2024. It is now read-only.
forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_model_builders.py
181 lines (171 loc) · 6.86 KB
/
_model_builders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from functools import partial
from torchtune.models.llama3_2._component_builders import llama3_2, lora_llama3_2
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
"""
Model builders build specific instantiations using component builders. For example
the llama3_2_1b model builder uses the llama3_2 component builder to create the
Llama3.2 1B model.
"""
def llama3_2_1b() -> TransformerDecoder:
"""
Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values.
Returns:
TransformerDecoder: Instantiation of Llama3.2 1B model
"""
return llama3_2(
vocab_size=128_256,
num_layers=16,
num_heads=32,
num_kv_heads=8,
embed_dim=2048,
max_seq_len=131072,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
def llama3_2_3b() -> TransformerDecoder:
"""
Builder for creating a Llama3.2 model initialized w/ the default 3b parameter values.
Returns:
TransformerDecoder: Instantiation of Llama3.2 3B model
"""
return llama3_2(
vocab_size=128_256,
num_layers=28,
num_heads=24,
num_kv_heads=8,
embed_dim=3072,
max_seq_len=131072,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
def lora_llama3_2_1b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> TransformerDecoder:
"""
Builder for creating a Llama3.2 1B model with LoRA enabled.
The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_1b`,
while LoRA default params are based on
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): dropout probability for the low-rank approximation
use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
quantize_base (bool): Whether to quantize base model weights
Returns:
TransformerDecoder: Instantiation of Llama3.2 1B model with LoRA applied
"""
return lora_llama3_2(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
vocab_size=128_256,
num_layers=16,
num_heads=32,
num_kv_heads=8,
embed_dim=2048,
max_seq_len=131072,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
def lora_llama3_2_3b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> TransformerDecoder:
"""
Builder for creating a Llama3.2 3B model with LoRA enabled.
The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_3b`,
while LoRA default params are based on
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): dropout probability for the low-rank approximation
use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
quantize_base (bool): Whether to quantize base model weights
Returns:
TransformerDecoder: Instantiation of Llama3.2 3B model with LoRA applied
"""
return lora_llama3_2(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
vocab_size=128_256,
num_layers=28,
num_heads=24,
num_kv_heads=8,
embed_dim=3072,
max_seq_len=131072,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
qlora_llama3_2_1b = partial(lora_llama3_2_1b, quantize_base=True)
qlora_llama3_2_1b.__doc__ = """
Builder for creating a Llama3.2 1B model with QLoRA enabled. Base model weights in linear layers
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
Please see `lora_llama3_2_1b` for full API arguments.
"""
qlora_llama3_2_3b = partial(lora_llama3_2_3b, quantize_base=True)
qlora_llama3_2_3b.__doc__ = """
Builder for creating a Llama3.2 3B model with QLoRA enabled. Base model weights in linear layers
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
Please see `lora_llama3_2_3b` for full API arguments.
"""