Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA support #366

Merged
merged 7 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion css/main.css
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
.tabs.svelte-710i53 {
margin-top: 0
}

.py-6 {
padding-top: 2.5rem
}

.dark #refresh-button {
background-color: #ffffff1f;
}

#refresh-button {
flex: none;
margin: 0;
Expand All @@ -17,22 +20,28 @@
border-radius: 10px;
background-color: #0000000d;
}

#download-label, #upload-label {
min-height: 0
}

#accordion {
}

.dark svg {
fill: white;
}

svg {
display: unset !important;
vertical-align: middle !important;
margin: 5px;
}

ol li p, ul li p {
display: inline-block;
}
#main, #parameters, #chat-settings, #interface-mode {

#main, #parameters, #chat-settings, #interface-mode, #lora {
border: 0;
}
17 changes: 11 additions & 6 deletions download-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
classifications = []
has_pytorch = False
has_safetensors = False
is_lora = False
while True:
content = requests.get(f"{base}{page}{cursor.decode()}").content

Expand All @@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):

for i in range(len(dict)):
fname = dict[i]['path']
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True

is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
Expand All @@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True
classifications.append('pytorch')


cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
Expand All @@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
if classifications[i] == 'pytorch':
links.pop(i)

return links
return links, is_lora

if __name__ == '__main__':
model = args.MODEL
Expand All @@ -159,15 +163,16 @@ def get_download_links_from_huggingface(model, branch):
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()

links, is_lora = get_download_links_from_huggingface(model, branch)
base_folder = 'models' if not is_lora else 'loras'
if branch != 'main':
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
else:
output_folder = Path("models") / model.split('/')[-1]
output_folder = Path(base_folder) / model.split('/')[-1]
if not output_folder.exists():
output_folder.mkdir()

links = get_download_links_from_huggingface(model, branch)

# Downloading the files
print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads)
Expand Down
Empty file added loras/place-your-loras-here.txt
Empty file.
17 changes: 17 additions & 0 deletions modules/LoRA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pathlib import Path

from peft import PeftModel

import modules.shared as shared
from modules.models import load_model


def add_lora_to_model(lora_name):

# Is there a more efficient way of returning to the base model?
if lora_name == "None":
print("Reloading the model to remove the LoRA...")
shared.model, shared.tokenizer = load_model(shared.model_name)
else:
print(f"Adding the LoRA {lora_name} to the model...")
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"))
1 change: 1 addition & 0 deletions modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import modules.shared as shared


# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):

Expand Down
3 changes: 2 additions & 1 deletion modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html
from modules.text_generation import encode, generate_reply, get_max_prompt_length
from modules.text_generation import (encode, generate_reply,
get_max_prompt_length)


# This gets the new line characters right.
Expand Down
8 changes: 7 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

model = None
tokenizer = None
model_name = ""
model_name = "None"
lora_name = "None"
soft_prompt_tensor = None
soft_prompt = False
is_RWKV = False
Expand Down Expand Up @@ -52,6 +53,10 @@
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'(rosey|chip|joi)_.*_instruct.*': 'User: \n',
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
},
'lora_prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'alpaca-lora-7b': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
}
}

Expand All @@ -67,6 +72,7 @@ def str2bool(v):

parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ flexgen==0.1.7
gradio==3.18.0
markdown
numpy
peft==0.2.0
requests
rwkv==0.4.2
safetensors==0.3.0
Expand Down
29 changes: 28 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import modules.shared as shared
import modules.ui as ui
from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply

Expand Down Expand Up @@ -48,6 +49,9 @@ def get_available_extensions():
def get_available_softprompts():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)

def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)

def load_model_wrapper(selected_model):
if selected_model != shared.model_name:
shared.model_name = selected_model
Expand All @@ -59,6 +63,17 @@ def load_model_wrapper(selected_model):

return selected_model

def load_lora_wrapper(selected_lora):
shared.lora_name = selected_lora
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]

if not shared.args.cpu:
gc.collect()
torch.cuda.empty_cache()
add_lora_to_model(selected_lora)

return selected_lora, default_text

def load_preset_values(preset_menu, return_dict=False):
generate_params = {
'do_sample': True,
Expand Down Expand Up @@ -145,6 +160,10 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')

with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')

with gr.Accordion('Soft prompt', open=False):
with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
Expand All @@ -156,6 +175,7 @@ def create_settings_menus(default_preset):

shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])

Expand All @@ -181,6 +201,7 @@ def set_interface_arguments(interface_mode, extensions, cmd_active):
available_presets = get_available_presets()
available_characters = get_available_characters()
available_softprompts = get_available_softprompts()
available_loras = get_available_loras()

# Default extensions
extensions_module.available_extensions = get_available_extensions()
Expand Down Expand Up @@ -213,10 +234,16 @@ def set_interface_arguments(interface_mode, extensions, cmd_active):
print()
shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
print(shared.args.lora)
shared.lora_name = shared.args.lora
add_lora_to_model(shared.lora_name)

# Default UI settings
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
if default_text == '':
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
title ='Text generation web UI'
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
Expand Down
7 changes: 5 additions & 2 deletions settings-template.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
"presets": {
"default": "NovelAI-Sphinx Moth",
"pygmalion-*": "Pygmalion",
"RWKV-*": "Naive",
"(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)"
"RWKV-*": "Naive"
},
"prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
},
"lora_prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
}
}