Skip to content

Commit

Permalink
Merge branch 'oobabooga:main' into stt-extension
Browse files Browse the repository at this point in the history
  • Loading branch information
EliasVincent authored Mar 12, 2023
2 parents 1c0bda3 + 3375eae commit 3b41459
Show file tree
Hide file tree
Showing 15 changed files with 460 additions and 183 deletions.
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Text generation web UI

A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, GPT-Neo, and Pygmalion.
A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, LLaMA, and Pygmalion.

Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.

Expand All @@ -27,6 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* Supports softprompts.
* [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
Expand All @@ -53,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU.
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
```

* If you are running in CPU mode, replace the third command with this one:
* If you are running it in CPU mode, replace the third command with this one:

```
conda install pytorch torchvision torchaudio git -c pytorch
Expand Down Expand Up @@ -137,6 +138,8 @@ Optionally, you can use the following command-line flags:
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.|
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
Expand Down Expand Up @@ -176,14 +179,10 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-

Pull requests, suggestions, and issue reports are welcome.

Before reporting a bug, make sure that you have created a conda environment and installed the dependencies exactly as in the *Installation* section above.
Before reporting a bug, make sure that you have:

These issues are known:

* 8-bit doesn't work properly on Windows or older GPUs.
* DeepSpeed doesn't work properly on Windows.

For these two, please try commenting on an existing issue instead of creating a new one.
1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
2. [Searched](https://github.com/oobabooga/text-generation-webui/issues) to see if an issue already exists for the issue you encountered.

## Credits

Expand Down
20 changes: 14 additions & 6 deletions download-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
python download-model.py facebook/opt-1.3b
'''

import argparse
import base64
import json
import multiprocessing
import re
Expand Down Expand Up @@ -93,23 +95,28 @@ def select_model_from_default_options():
def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b""

links = []
classifications = []
has_pytorch = False
has_safetensors = False
while page is not None:
content = requests.get(f"{base}{page}").content
while True:
content = requests.get(f"{base}{page}{cursor.decode()}").content

dict = json.loads(content)
if len(dict) == 0:
break

for i in range(len(dict)):
fname = dict[i]['path']

is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_text = re.match(".*\.(txt|json)", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer

if is_text or is_safetensors or is_pytorch:
if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text')
Expand All @@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True
classifications.append('pytorch')

#page = dict['nextUrl']
page = None
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')

# If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors:
Expand Down
18 changes: 18 additions & 0 deletions extensions/llama_prompts/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import gradio as gr
import modules.shared as shared
import pandas as pd

df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")

def get_prompt_by_name(name):
if name == 'None':
return ''
else:
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')

def ui():
if not shared.args.chat or shared.args.cai_chat:
choices = ['None'] + list(df['Prompt name'])

prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox'])
130 changes: 121 additions & 9 deletions extensions/silero_tts/script.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,45 @@
import re
import time
from pathlib import Path

import gradio as gr
import torch

import modules.chat as chat
import modules.shared as shared

torch._C._jit_set_profiling_mode(False)

params = {
'activate': True,
'speaker': 'en_56',
'speaker': 'en_5',
'language': 'en',
'model_id': 'v3_en',
'sample_rate': 48000,
'device': 'cpu',
'show_text': False,
'autoplay': True,
'voice_pitch': 'medium',
'voice_speed': 'medium',
}

current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
wav_idx = 0
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
last_msg_id = 0

# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"'": "&apos;",
'"': "&quot;",
})

def xmlesc(txt):
return txt.translate(table)

def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
Expand All @@ -33,20 +57,67 @@ def remove_surrounded_chars(string):
new_string += char
return new_string

def remove_tts_from_history():
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
for i, entry in enumerate(shared.history['internal']):
reply = entry[1]
reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
if shared.args.chat:
reply = reply.replace('\n', '<br>')
shared.history['visible'][i][1] = reply

if shared.args.cai_chat:
return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
else:
return shared.history['visible']

def toggle_text_in_history():
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
audio_str='\n\n' # The '\n\n' used after </audio>
if shared.args.chat:
audio_str='<br><br>'

if params['show_text']==True:
#for i, entry in enumerate(shared.history['internal']):
for i, entry in enumerate(shared.history['visible']):
vis_reply = entry[1]
if vis_reply.startswith('<audio'):
reply = shared.history['internal'][i][1]
reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
if shared.args.chat:
reply = reply.replace('\n', '<br>')
shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str+reply
else:
for i, entry in enumerate(shared.history['visible']):
vis_reply = entry[1]
if vis_reply.startswith('<audio'):
shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str

if shared.args.cai_chat:
return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
else:
return shared.history['visible']

def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""

# Remove autoplay from previous chat history
if (shared.args.chat or shared.args.cai_chat)and len(shared.history['internal'])>0:
[visible_text, visible_reply] = shared.history['visible'][-1]
vis_rep_clean = visible_reply.replace('controls autoplay>','controls>')
shared.history['visible'][-1] = [visible_text, vis_rep_clean]

return string

def output_modifier(string):
"""
This function is applied to the model outputs.
"""

global wav_idx, model, current_params
global model, current_params

for i in params:
if params[i] != current_params[i]:
Expand All @@ -57,20 +128,34 @@ def output_modifier(string):
if params['activate'] == False:
return string

orig_string = string
string = remove_surrounded_chars(string)
string = string.replace('"', '')
string = string.replace('“', '')
string = string.replace('\n', ' ')
string = string.strip()

silent_string = False # Used to prevent unnecessary audio file generation
if string == '':
string = 'empty reply, try regenerating'
silent_string = True

pitch = params['voice_pitch']
speed = params['voice_speed']
prosody=f'<prosody rate="{speed}" pitch="{pitch}">'
string = '<speak>'+prosody+xmlesc(string)+'</prosody></speak>'

output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
if not shared.still_streaming and not silent_string:
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
autoplay_str = ' autoplay' if params['autoplay'] else ''
string = f'<audio src="file/{output_file.as_posix()}" controls{autoplay_str}></audio>\n\n'
else:
# Placeholder so text doesn't shift around so much
string = '<audio controls></audio>\n\n'

string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
wav_idx += 1
if params['show_text']:
string += orig_string

return string

Expand All @@ -85,9 +170,36 @@ def bot_prefix_modifier(string):

def ui():
# Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Accordion("Silero TTS"):
with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Row():
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
with gr.Row():
convert = gr.Button('Permanently replace chat history audio with message text')
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
convert_cancel = gr.Button('Cancel', visible=False)

# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [], shared.gradio['display'])
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)

# Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(toggle_text_in_history, [], shared.gradio['display'])
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)

# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None)
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
48 changes: 6 additions & 42 deletions modules/RWKV.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from pathlib import Path
from queue import Queue
from threading import Thread

import numpy as np
from tokenizers import Tokenizer

import modules.shared as shared
from modules.callbacks import Iteratorize

np.set_printoptions(precision=4, suppress=True, linewidth=200)

Expand Down Expand Up @@ -49,11 +48,11 @@ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50,
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)

def generate_with_streaming(self, **kwargs):
iterable = Iteratorize(self.generate, kwargs, callback=None)
reply = kwargs['context']
for token in iterable:
reply += token
yield reply
with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = kwargs['context']
for token in generator:
reply += token
yield reply

class RWKVTokenizer:
def __init__(self):
Expand All @@ -73,38 +72,3 @@ def encode(self, prompt):

def decode(self, ids):
return self.tokenizer.decode(ids)

class Iteratorize:

"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""

def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue(maxsize=1)
self.sentinel = object()
self.kwargs = kwargs

def _callback(val):
self.q.put(val)

def gentask():
ret = self.mfunc(callback=_callback, **self.kwargs)
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)

Thread(target=gentask).start()

def __iter__(self):
return self

def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
Loading

0 comments on commit 3b41459

Please sign in to comment.