diff --git a/.github/ISSUE_TEMPLATE/bug_report_template.yml b/.github/ISSUE_TEMPLATE/bug_report_template.yml new file mode 100644 index 0000000000..bd30a0c9c1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report_template.yml @@ -0,0 +1,53 @@ +name: "Bug report" +description: Report a bug +labels: [ "bug" ] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + - type: textarea + id: bug-description + attributes: + label: Describe the bug + description: A clear and concise description of what the bug is. + placeholder: Bug description + validations: + required: true + - type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the issue you encountered. + options: + - label: I have searched the existing issues + required: true + - type: textarea + id: reproduction + attributes: + label: Reproduction + description: Please provide the steps necessary to reproduce your issue. + placeholder: Reproduction + validations: + required: true + - type: textarea + id: screenshot + attributes: + label: Screenshot + description: "If possible, please include screenshot(s) so that we can understand what the issue is." + - type: textarea + id: logs + attributes: + label: Logs + description: "Please include the full stacktrace of the errors you get in the command-line (if any)." + render: shell + validations: + required: true + - type: textarea + id: system-info + attributes: + label: System Info + description: "Please share your system info with us: operating system, GPU brand, and GPU model. If you are using a Google Colab notebook, mention that instead." + render: shell + placeholder: + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000..b94974f865 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,16 @@ +--- +name: Feature request +about: Suggest an improvement or new feature for the web UI +title: '' +labels: 'enhancement' +assignees: '' + +--- + +**Description** + +A clear and concise description of what you want to be implemented. + +**Additional Context** + +If applicable, please provide any extra information, external links, or screenshots that could be useful. diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..91abb11fdf --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000..82cd170189 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,22 @@ +name: Close inactive issues +on: + schedule: + - cron: "10 23 * * *" + +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v5 + with: + stale-issue-message: "" + close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, you can reopen it (if you are the author) or leave a comment below." + days-before-issue-stale: 30 + days-before-issue-close: 0 + stale-issue-label: "stale" + days-before-pr-stale: -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index dbc8c59cc8..c9834558bb 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,9 @@ pip3 install torch torchvision torchaudio --extra-index-url https://download.pyt conda install pytorch torchvision torchaudio git -c pytorch ``` -See also: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings). +> **Note** +> 1. If you are on Windows, it may be easier to run the commands above in a WSL environment. The performance may also be better. +> 2. For a more detailed, user-contributed guide, see: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings). ## Installation option 2: one-click installers @@ -140,8 +142,9 @@ Optionally, you can use the following command-line flags: | `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | | `--cpu` | Use the CPU to generate text.| | `--load-in-8bit` | Load the model with 8-bit precision.| -| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.| -| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. | +| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. | +| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. | +| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.| | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | diff --git a/api-example-stream.py b/api-example-stream.py index a5ed420252..add1df41c4 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -26,6 +26,7 @@ async def run(context): 'top_p': 0.9, 'typical_p': 1, 'repetition_penalty': 1.05, + 'encoder_repetition_penalty': 1.0, 'top_k': 0, 'min_length': 0, 'no_repeat_ngram_size': 0, @@ -59,6 +60,7 @@ async def run(context): params['top_p'], params['typical_p'], params['repetition_penalty'], + params['encoder_repetition_penalty'], params['top_k'], params['min_length'], params['no_repeat_ngram_size'], diff --git a/api-example.py b/api-example.py index 0306b7ab8a..a6f0c10e77 100644 --- a/api-example.py +++ b/api-example.py @@ -24,6 +24,7 @@ 'top_p': 0.9, 'typical_p': 1, 'repetition_penalty': 1.05, + 'encoder_repetition_penalty': 1.0, 'top_k': 0, 'min_length': 0, 'no_repeat_ngram_size': 0, @@ -45,6 +46,7 @@ params['top_p'], params['typical_p'], params['repetition_penalty'], + params['encoder_repetition_penalty'], params['top_k'], params['min_length'], params['no_repeat_ngram_size'], diff --git a/css/chat.css b/css/chat.css new file mode 100644 index 0000000000..8d9d88a6ee --- /dev/null +++ b/css/chat.css @@ -0,0 +1,25 @@ +.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx { + height: 66.67vh +} + +.gradio-container { + margin-left: auto !important; + margin-right: auto !important; +} + +.w-screen { + width: unset +} + +div.svelte-362y77>*, div.svelte-362y77>.form>* { + flex-wrap: nowrap +} + +/* fixes the API documentation in chat mode */ +.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h { + display: grid; +} + +.pending.svelte-1ed2p3z { + opacity: 1; +} diff --git a/css/chat.js b/css/chat.js new file mode 100644 index 0000000000..e304f12547 --- /dev/null +++ b/css/chat.js @@ -0,0 +1,4 @@ +document.getElementById("main").childNodes[0].style = "max-width: 800px; margin-left: auto; margin-right: auto"; +document.getElementById("extensions").style.setProperty("max-width", "800px"); +document.getElementById("extensions").style.setProperty("margin-left", "auto"); +document.getElementById("extensions").style.setProperty("margin-right", "auto"); diff --git a/css/html_4chan_style.css b/css/html_4chan_style.css new file mode 100644 index 0000000000..843e8a97fe --- /dev/null +++ b/css/html_4chan_style.css @@ -0,0 +1,103 @@ +#parent #container { + background-color: #eef2ff; + padding: 17px; +} +#parent #container .reply { + background-color: rgb(214, 218, 240); + border-bottom-color: rgb(183, 197, 217); + border-bottom-style: solid; + border-bottom-width: 1px; + border-image-outset: 0; + border-image-repeat: stretch; + border-image-slice: 100%; + border-image-source: none; + border-image-width: 1; + border-left-color: rgb(0, 0, 0); + border-left-style: none; + border-left-width: 0px; + border-right-color: rgb(183, 197, 217); + border-right-style: solid; + border-right-width: 1px; + border-top-color: rgb(0, 0, 0); + border-top-style: none; + border-top-width: 0px; + color: rgb(0, 0, 0); + display: table; + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + margin-bottom: 4px; + margin-left: 0px; + margin-right: 0px; + margin-top: 4px; + overflow-x: hidden; + overflow-y: hidden; + padding-bottom: 4px; + padding-left: 2px; + padding-right: 2px; + padding-top: 4px; +} + +#parent #container .number { + color: rgb(0, 0, 0); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + width: 342.65px; + margin-right: 7px; +} + +#parent #container .op { + color: rgb(0, 0, 0); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + margin-bottom: 8px; + margin-left: 0px; + margin-right: 0px; + margin-top: 4px; + overflow-x: hidden; + overflow-y: hidden; +} + +#parent #container .op blockquote { + margin-left: 0px !important; +} + +#parent #container .name { + color: rgb(17, 119, 67); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + font-weight: 700; + margin-left: 7px; +} + +#parent #container .quote { + color: rgb(221, 0, 0); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + text-decoration-color: rgb(221, 0, 0); + text-decoration-line: underline; + text-decoration-style: solid; + text-decoration-thickness: auto; +} + +#parent #container .greentext { + color: rgb(120, 153, 34); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; +} + +#parent #container blockquote { + margin: 0px !important; + margin-block-start: 1em; + margin-block-end: 1em; + margin-inline-start: 40px; + margin-inline-end: 40px; + margin-top: 13.33px !important; + margin-bottom: 13.33px !important; + margin-left: 40px !important; + margin-right: 40px !important; +} + +#parent #container .message { + color: black; + border: none; +} \ No newline at end of file diff --git a/css/html_cai_style.css b/css/html_cai_style.css new file mode 100644 index 0000000000..3190b3d1cc --- /dev/null +++ b/css/html_cai_style.css @@ -0,0 +1,73 @@ +.chat { + margin-left: auto; + margin-right: auto; + max-width: 800px; + height: 66.67vh; + overflow-y: auto; + padding-right: 20px; + display: flex; + flex-direction: column-reverse; +} + +.message { + display: grid; + grid-template-columns: 60px 1fr; + padding-bottom: 25px; + font-size: 15px; + font-family: Helvetica, Arial, sans-serif; + line-height: 1.428571429; +} + +.circle-you { + width: 50px; + height: 50px; + background-color: rgb(238, 78, 59); + border-radius: 50%; +} + +.circle-bot { + width: 50px; + height: 50px; + background-color: rgb(59, 78, 244); + border-radius: 50%; +} + +.circle-bot img, +.circle-you img { + border-radius: 50%; + width: 100%; + height: 100%; + object-fit: cover; +} + +.text {} + +.text p { + margin-top: 5px; +} + +.username { + font-weight: bold; +} + +.message-body {} + +.message-body img { + max-width: 300px; + max-height: 300px; + border-radius: 20px; +} + +.message-body p { + margin-bottom: 0 !important; + font-size: 15px !important; + line-height: 1.428571429 !important; +} + +.dark .message-body p em { + color: rgb(138, 138, 138) !important; +} + +.message-body p em { + color: rgb(110, 110, 110) !important; +} \ No newline at end of file diff --git a/css/html_readable_style.css b/css/html_readable_style.css new file mode 100644 index 0000000000..d3f580a53a --- /dev/null +++ b/css/html_readable_style.css @@ -0,0 +1,14 @@ +.container { + max-width: 600px; + margin-left: auto; + margin-right: auto; + background-color: rgb(31, 41, 55); + padding:3em; +} + +.container p { + font-size: 16px !important; + color: white !important; + margin-bottom: 22px; + line-height: 1.4 !important; +} diff --git a/css/main.css b/css/main.css new file mode 100644 index 0000000000..8c5f156eb0 --- /dev/null +++ b/css/main.css @@ -0,0 +1,38 @@ +.tabs.svelte-710i53 { + margin-top: 0 +} +.py-6 { + padding-top: 2.5rem +} +.dark #refresh-button { + background-color: #ffffff1f; +} +#refresh-button { + flex: none; + margin: 0; + padding: 0; + min-width: 50px; + border: none; + box-shadow: none; + border-radius: 10px; + background-color: #0000000d; +} +#download-label, #upload-label { + min-height: 0 +} +#accordion { +} +.dark svg { + fill: white; +} +svg { + display: unset !important; + vertical-align: middle !important; + margin: 5px; +} +ol li p, ul li p { + display: inline-block; +} +#main, #settings, #chat-settings { + border: 0; +} diff --git a/css/main.js b/css/main.js new file mode 100644 index 0000000000..9db3fe8b68 --- /dev/null +++ b/css/main.js @@ -0,0 +1,18 @@ +document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px"; +document.getElementById("main").parentNode.style = "padding: 0; margin: 0"; +document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0"; + +// Get references to the elements +let main = document.getElementById('main'); +let main_parent = main.parentNode; +let extensions = document.getElementById('extensions'); + +// Add an event listener to the main element +main_parent.addEventListener('click', function(e) { + // Check if the main element is visible + if (main.offsetHeight > 0 && main.offsetWidth > 0) { + extensions.style.display = 'block'; + } else { + extensions.style.display = 'none'; + } +}); diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index 709a4e640a..c0d943e51f 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -76,7 +76,7 @@ def generate_html(): return container_html def ui(): - with gr.Accordion("Character gallery"): + with gr.Accordion("Character gallery", open=False): update = gr.Button("Refresh") gallery = gr.HTML(value=generate_html()) update.click(generate_html, [], gallery) diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index 1d06822931..f611dc27b7 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -81,6 +81,7 @@ def input_modifier(string): if (shared.args.chat or shared.args.cai_chat) and len(shared.history['internal']) > 0: shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')] + shared.processing_message = "*Is recording a voice message...*" return string def output_modifier(string): @@ -119,6 +120,7 @@ def output_modifier(string): if params['show_text']: string += f'\n\n{original_string}' + shared.processing_message = "*Is typing...*" return string def bot_prefix_modifier(string): diff --git a/modules/quantized_LLaMA.py b/modules/GPTQ_loader.py similarity index 54% rename from modules/quantized_LLaMA.py rename to modules/GPTQ_loader.py index a5757c682b..c2723490bb 100644 --- a/modules/quantized_LLaMA.py +++ b/modules/GPTQ_loader.py @@ -7,28 +7,40 @@ import modules.shared as shared sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) -from llama import load_quant +import llama +import opt -# 4-bit LLaMA -def load_quantized_LLaMA(model_name): - if shared.args.load_in_4bit: - bits = 4 +def load_quantized(model_name): + if not shared.args.gptq_model_type: + # Try to determine model type from model name + model_type = model_name.split('-')[0].lower() + if model_type not in ('llama', 'opt'): + print("Can't determine model type from model name. Please specify it manually using --gptq-model-type " + "argument") + exit() else: - bits = shared.args.gptq_bits + model_type = shared.args.gptq_model_type.lower() + + if model_type == 'llama': + load_quant = llama.load_quant + elif model_type == 'opt': + load_quant = opt.load_quant + else: + print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") + exit() path_to_model = Path(f'models/{model_name}') - pt_model = '' if path_to_model.name.lower().startswith('llama-7b'): - pt_model = f'llama-7b-{bits}bit.pt' + pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt' elif path_to_model.name.lower().startswith('llama-13b'): - pt_model = f'llama-13b-{bits}bit.pt' + pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt' elif path_to_model.name.lower().startswith('llama-30b'): - pt_model = f'llama-30b-{bits}bit.pt' + pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt' elif path_to_model.name.lower().startswith('llama-65b'): - pt_model = f'llama-65b-{bits}bit.pt' + pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt' else: - pt_model = f'{model_name}-{bits}bit.pt' + pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt' # Try to find the .pt both in models/ and in the subfolder pt_path = None @@ -40,7 +52,7 @@ def load_quantized_LLaMA(model_name): print(f"Could not find {pt_model}, exiting...") exit() - model = load_quant(str(path_to_model), str(pt_path), bits) + model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits) # Multiple GPUs or GPU+CPU if shared.args.gpu_memory: diff --git a/modules/chat.py b/modules/chat.py index b3824a3a70..dfe8addc06 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -98,7 +98,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate def stop_everything_event(): shared.stop_everything = True -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): shared.stop_everything = False just_started = True eos_token = '\n' if check else None @@ -127,13 +127,14 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical else: prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + # Yield *Is typing...* if not regenerate: - yield shared.history['visible']+[[visible_text, '*Is typing...*']] + yield shared.history['visible']+[[visible_text, shared.processing_message]] # Generate reply = '' for i in range(chat_generation_attempts): - for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): + for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): # Extracting the reply reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check) @@ -160,7 +161,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical yield shared.history['visible'] -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): +def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): eos_token = '\n' if check else None if 'pygmalion' in shared.model_name.lower(): @@ -169,28 +170,29 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) reply = '' - yield '*Is typing...*' + # Yield *Is typing...* + yield shared.processing_message for i in range(chat_generation_attempts): - for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): + for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True) yield reply if next_character_found: break yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): - for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): +def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): + for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): yield generate_chat_html(_history, name1, name2, shared.character) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): +def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) else: last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() - - yield generate_chat_output(shared.history['visible']+[[last_visible[0], '*Is typing...*']], name1, name2, shared.character) - for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): + # Yield '*Is typing...*' + yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character) + for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] else: diff --git a/modules/extensions.py b/modules/extensions.py index c8de8a7bc9..89dfcda417 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,3 +1,5 @@ +import gradio as gr + import extensions import modules.shared as shared @@ -40,6 +42,8 @@ def create_extensions_block(): extension.params[param] = shared.settings[_id] # Creating the extension ui elements - for extension, name in iterator(): - if hasattr(extension, "ui"): - extension.ui() + with gr.Box(elem_id="extensions"): + gr.Markdown("Extensions") + for extension, name in iterator(): + if hasattr(extension, "ui"): + extension.ui() diff --git a/modules/html_generator.py b/modules/html_generator.py index 162040bac6..9942e6c948 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -8,29 +8,22 @@ import re from pathlib import Path +import markdown from PIL import Image # This is to store the paths to the thumbnails of the profile pictures image_cache = {} +with open(Path(__file__).resolve().parent / '../css/html_readable_style.css', 'r') as f: + readable_css = f.read() +with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') as css_f: + _4chan_css = css_f.read() +with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f: + cai_css = f.read() + def generate_basic_html(s): - css = """ - .container { - max-width: 600px; - margin-left: auto; - margin-right: auto; - background-color: rgb(31, 41, 55); - padding:3em; - } - .container p { - font-size: 16px !important; - color: white !important; - margin-bottom: 22px; - line-height: 1.4 !important; - } - """ s = '\n'.join([f'

{line}

' for line in s.split('\n')]) - s = f'
{s}
' + s = f'
{s}
' return s def process_post(post, c): @@ -48,113 +41,6 @@ def process_post(post, c): return src def generate_4chan_html(f): - css = """ - - #parent #container { - background-color: #eef2ff; - padding: 17px; - } - #parent #container .reply { - background-color: rgb(214, 218, 240); - border-bottom-color: rgb(183, 197, 217); - border-bottom-style: solid; - border-bottom-width: 1px; - border-image-outset: 0; - border-image-repeat: stretch; - border-image-slice: 100%; - border-image-source: none; - border-image-width: 1; - border-left-color: rgb(0, 0, 0); - border-left-style: none; - border-left-width: 0px; - border-right-color: rgb(183, 197, 217); - border-right-style: solid; - border-right-width: 1px; - border-top-color: rgb(0, 0, 0); - border-top-style: none; - border-top-width: 0px; - color: rgb(0, 0, 0); - display: table; - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - margin-bottom: 4px; - margin-left: 0px; - margin-right: 0px; - margin-top: 4px; - overflow-x: hidden; - overflow-y: hidden; - padding-bottom: 4px; - padding-left: 2px; - padding-right: 2px; - padding-top: 4px; - } - - #parent #container .number { - color: rgb(0, 0, 0); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - width: 342.65px; - margin-right: 7px; - } - - #parent #container .op { - color: rgb(0, 0, 0); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - margin-bottom: 8px; - margin-left: 0px; - margin-right: 0px; - margin-top: 4px; - overflow-x: hidden; - overflow-y: hidden; - } - - #parent #container .op blockquote { - margin-left: 0px !important; - } - - #parent #container .name { - color: rgb(17, 119, 67); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - font-weight: 700; - margin-left: 7px; - } - - #parent #container .quote { - color: rgb(221, 0, 0); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - text-decoration-color: rgb(221, 0, 0); - text-decoration-line: underline; - text-decoration-style: solid; - text-decoration-thickness: auto; - } - - #parent #container .greentext { - color: rgb(120, 153, 34); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - } - - #parent #container blockquote { - margin: 0px !important; - margin-block-start: 1em; - margin-block-end: 1em; - margin-inline-start: 40px; - margin-inline-end: 40px; - margin-top: 13.33px !important; - margin-bottom: 13.33px !important; - margin-left: 40px !important; - margin-right: 40px !important; - } - - #parent #container .message { - color: black; - border: none; - } - """ - posts = [] post = '' c = -2 @@ -181,7 +67,7 @@ def generate_4chan_html(f): posts[i] = f'
{posts[i]}
\n' output = '' - output += f'
' + output += f'
' for post in posts: output += post output += '
' @@ -208,135 +94,39 @@ def get_image_cache(path): return image_cache[path][1] -def generate_chat_html(history, name1, name2, character): - css = """ - .chat { - margin-left: auto; - margin-right: auto; - max-width: 800px; - height: 66.67vh; - overflow-y: auto; - padding-right: 20px; - display: flex; - flex-direction: column-reverse; - } - - .message { - display: grid; - grid-template-columns: 60px 1fr; - padding-bottom: 25px; - font-size: 15px; - font-family: Helvetica, Arial, sans-serif; - line-height: 1.428571429; - } - - .circle-you { - width: 50px; - height: 50px; - background-color: rgb(238, 78, 59); - border-radius: 50%; - } - - .circle-bot { - width: 50px; - height: 50px; - background-color: rgb(59, 78, 244); - border-radius: 50%; - } - - .circle-bot img, .circle-you img { - border-radius: 50%; - width: 100%; - height: 100%; - object-fit: cover; - } - - .text { - } +def load_html_image(paths): + for str_path in paths: + path = Path(str_path) + if path.exists(): + return f'' + return '' - .text p { - margin-top: 5px; - } - - .username { - font-weight: bold; - } - - .message-body { - } - - .message-body img { - max-width: 300px; - max-height: 300px; - border-radius: 20px; - } - - .message-body p { - margin-bottom: 0 !important; - font-size: 15px !important; - line-height: 1.428571429 !important; - } - - .dark .message-body p em { - color: rgb(138, 138, 138) !important; - } - - .message-body p em { - color: rgb(110, 110, 110) !important; - } - - """ - - output = '' - output += f'
' - img = '' - - for i in [ - f"characters/{character}.png", - f"characters/{character}.jpg", - f"characters/{character}.jpeg", - "img_bot.png", - "img_bot.jpg", - "img_bot.jpeg" - ]: - - path = Path(i) - if path.exists(): - img = f'' - break - - img_me = '' - for i in ["img_me.png", "img_me.jpg", "img_me.jpeg"]: - path = Path(i) - if path.exists(): - img_me = f'' - break +def generate_chat_html(history, name1, name2, character): + output = f'
' + + img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"]) + img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"]) for i,_row in enumerate(history[::-1]): - row = _row.copy() - row[0] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"\2", row[0]) - row[1] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"\2", row[1]) - row[0] = re.sub(r"(\*)([^\*\n]*)(\*)", r"\2", row[0]) - row[1] = re.sub(r"(\*)([^\*\n]*)(\*)", r"\2", row[1]) - p = '\n'.join([f"

{x}

" for x in row[1].split('\n')]) + row = [markdown.markdown(re.sub(r"(.)```", r"\1\n```", entry), extensions=['fenced_code']) for entry in _row] + output += f"""
- {img} + {img_bot}
{name2}
- {p} + {row[1]}
""" if not (i == len(history)-1 and len(row[0]) == 0): - p = '\n'.join([f"

{x}

" for x in row[0].split('\n')]) output += f"""
@@ -347,7 +137,7 @@ def generate_chat_html(history, name1, name2, character): {name1}
- {p} + {row[0]}
diff --git a/modules/models.py b/modules/models.py index 7d094ed566..f4bb11fd3f 100644 --- a/modules/models.py +++ b/modules/models.py @@ -1,6 +1,5 @@ import json import os -import sys import time import zipfile from pathlib import Path @@ -35,6 +34,7 @@ ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir) dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration + def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() @@ -42,7 +42,7 @@ def load_model(model_name): shared.is_RWKV = model_name.lower().startswith('rwkv-') # Default settings - if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): + if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: @@ -87,11 +87,11 @@ def load_model(model_name): return model, tokenizer - # 4-bit LLaMA - elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit: - from modules.quantized_LLaMA import load_quantized_LLaMA + # Quantized model + elif shared.args.gptq_bits > 0: + from modules.GPTQ_loader import load_quantized - model = load_quantized_LLaMA(model_name) + model = load_quantized(model_name) # Custom else: diff --git a/modules/shared.py b/modules/shared.py index 8fcd4745c8..ea2eb50b7f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,6 +11,7 @@ history = {'internal': [], 'visible': []} character = 'None' stop_everything = False +processing_message = '*Is typing...*' # UI elements (buttons, sliders, HTML, etc) gradio = {} @@ -68,8 +69,9 @@ def str2bool(v): parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') -parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision. Currently only works with LLaMA.') -parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA.') +parser.add_argument('--load-in-4bit', action='store_true', help='DEPRECATED: use --gptq-bits 4 instead.') +parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA and OPT.') +parser.add_argument('--gptq-model-type', type=str, help='Model type of pre-quantized model. Currently only LLaMa and OPT are supported.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') @@ -94,3 +96,8 @@ def str2bool(v): parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') args = parser.parse_args() + +# Provisional, this will be deleted later +if args.load_in_4bit: + print("Warning: --load-in-4bit is deprecated and will be removed. Use --gptq-bits 4 instead.\n") + args.gptq_bits = 4 diff --git a/modules/text_generation.py b/modules/text_generation.py index d64481b24e..a29b987f4e 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -89,7 +89,7 @@ def clear_torch_cache(): if not shared.args.cpu: torch.cuda.empty_cache() -def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): +def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): clear_torch_cache() t0 = time.time() @@ -101,7 +101,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) yield formatted_outputs(reply, shared.model_name) else: - yield formatted_outputs(question, shared.model_name) + if not (shared.args.chat or shared.args.cai_chat): + yield formatted_outputs(question, shared.model_name) # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): @@ -122,7 +123,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = encode(question, max_new_tokens) original_input_ids = input_ids output = input_ids[0] - cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" + cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] if eos_token is not None: eos_token_ids.append(int(encode(eos_token)[0][-1])) @@ -132,45 +133,49 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi t = encode(stopping_string, 0, add_special_tokens=False) stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) + generate_params = {} if not shared.args.flexgen: - generate_params = [ - f"max_new_tokens=max_new_tokens", - f"eos_token_id={eos_token_ids}", - f"stopping_criteria=stopping_criteria_list", - f"do_sample={do_sample}", - f"temperature={temperature}", - f"top_p={top_p}", - f"typical_p={typical_p}", - f"repetition_penalty={repetition_penalty}", - f"top_k={top_k}", - f"min_length={min_length if shared.args.no_stream else 0}", - f"no_repeat_ngram_size={no_repeat_ngram_size}", - f"num_beams={num_beams}", - f"penalty_alpha={penalty_alpha}", - f"length_penalty={length_penalty}", - f"early_stopping={early_stopping}", - ] + generate_params.update({ + "max_new_tokens": max_new_tokens, + "eos_token_id": eos_token_ids, + "stopping_criteria": stopping_criteria_list, + "do_sample": do_sample, + "temperature": temperature, + "top_p": top_p, + "typical_p": typical_p, + "repetition_penalty": repetition_penalty, + "encoder_repetition_penalty": encoder_repetition_penalty, + "top_k": top_k, + "min_length": min_length if shared.args.no_stream else 0, + "no_repeat_ngram_size": no_repeat_ngram_size, + "num_beams": num_beams, + "penalty_alpha": penalty_alpha, + "length_penalty": length_penalty, + "early_stopping": early_stopping, + }) else: - generate_params = [ - f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}", - f"do_sample={do_sample}", - f"temperature={temperature}", - f"stop={eos_token_ids[-1]}", - ] + generate_params.update({ + "max_new_tokens": max_new_tokens if shared.args.no_stream else 8, + "do_sample": do_sample, + "temperature": temperature, + "stop": eos_token_ids[-1], + }) if shared.args.deepspeed: - generate_params.append("synced_gpus=True") + generate_params.update({"synced_gpus": True}) if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - generate_params.insert(0, "inputs_embeds=inputs_embeds") - generate_params.insert(0, "inputs=filler_input_ids") + generate_params.update({"inputs_embeds": inputs_embeds}) + generate_params.update({"inputs": filler_input_ids}) else: - generate_params.insert(0, "inputs=input_ids") + generate_params.update({"inputs": input_ids}) try: # Generate the entire reply at once. if shared.args.no_stream: with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] + output = shared.model.generate(**generate_params)[0] + if cuda: + output = output.cuda() if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) @@ -193,8 +198,9 @@ def generate_with_callback(callback=None, **kwargs): def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) - yield formatted_outputs(original_question, shared.model_name) - with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator: + if not (shared.args.chat or shared.args.cai_chat): + yield formatted_outputs(original_question, shared.model_name) + with generate_with_streaming(**generate_params) as generator: for output in generator: if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) @@ -214,7 +220,7 @@ def generate_with_streaming(**kwargs): for i in range(max_new_tokens//8+1): clear_torch_cache() with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)})")[0] + output = shared.model.generate(**generate_params)[0] if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) reply = decode(output) diff --git a/modules/ui.py b/modules/ui.py index bb193e35c1..80bd7c1cde 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,68 +1,17 @@ +from pathlib import Path + import gradio as gr refresh_symbol = '\U0001f504' # 🔄 -css = """ -.tabs.svelte-710i53 { - margin-top: 0 -} -.py-6 { - padding-top: 2.5rem -} -.dark #refresh-button { - background-color: #ffffff1f; -} -#refresh-button { - flex: none; - margin: 0; - padding: 0; - min-width: 50px; - border: none; - box-shadow: none; - border-radius: 10px; - background-color: #0000000d; -} -#download-label, #upload-label { - min-height: 0 -} -#accordion { -} -.dark svg { - fill: white; -} -svg { - display: unset !important; - vertical-align: middle !important; - margin: 5px; -} -ol li p, ul li p { - display: inline-block; -} -""" - -chat_css = """ -.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx { - height: 66.67vh -} -.gradio-container { - max-width: 800px !important; - margin-left: auto !important; - margin-right: auto !important; -} -.w-screen { - width: unset -} -div.svelte-362y77>*, div.svelte-362y77>.form>* { - flex-wrap: nowrap -} -/* fixes the API documentation in chat mode */ -.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h { - display: grid; -} -.pending.svelte-1ed2p3z { - opacity: 1; -} -""" +with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: + css = f.read() +with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f: + chat_css = f.read() +with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f: + main_js = f.read() +with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: + chat_js = f.read() class ToolButton(gr.Button, gr.components.FormComponent): """Small button with single emoji as text, fits inside gradio forms""" diff --git a/requirements.txt b/requirements.txt index 8534be29a7..014d987ba7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ -accelerate==0.17.0 -bitsandbytes==0.37.0 +accelerate==0.17.1 +bitsandbytes==0.37.1 flexgen==0.1.7 gradio==3.18.0 +markdown numpy requests -rwkv==0.3.1 +rwkv==0.4.2 safetensors==0.3.0 sentencepiece tqdm diff --git a/server.py b/server.py index 6ae6d9f373..f0187e88e0 100644 --- a/server.py +++ b/server.py @@ -34,7 +34,7 @@ def get_available_models(): if shared.args.flexgen: return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower) + return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) @@ -67,6 +67,7 @@ def load_preset_values(preset_menu, return_dict=False): 'top_p': 1, 'typical_p': 1, 'repetition_penalty': 1, + 'encoder_repetition_penalty': 1, 'top_k': 50, 'num_beams': 1, 'penalty_alpha': 0, @@ -87,7 +88,7 @@ def load_preset_values(preset_menu, return_dict=False): if return_dict: return generate_params else: - return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] + return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] def upload_soft_prompt(file): with zipfile.ZipFile(io.BytesIO(file)) as zf: @@ -101,9 +102,7 @@ def upload_soft_prompt(file): return name -def create_settings_menus(default_preset): - generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) - +def create_model_and_preset_menus(): with gr.Row(): with gr.Column(): with gr.Row(): @@ -114,31 +113,40 @@ def create_settings_menus(default_preset): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') - with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'): - with gr.Row(): - with gr.Column(): - shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') - shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty') - shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k') - shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p') - with gr.Column(): - shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') - shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p') - shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') - shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream) +def create_settings_menus(default_preset): + generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) - gr.Markdown('Contrastive search:') - shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') + with gr.Row(): + with gr.Column(): + with gr.Box(): + gr.Markdown('Custom generation parameters') + with gr.Row(): + with gr.Column(): + shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') + shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p') + shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k') + shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p') + with gr.Column(): + shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty') + shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty') + shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') + shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream) + shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') + with gr.Column(): + with gr.Box(): + gr.Markdown('Contrastive search') + shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') - gr.Markdown('Beam search (uses a lot of VRAM):') - with gr.Row(): - with gr.Column(): - shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') - with gr.Column(): - shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') - shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') + with gr.Box(): + gr.Markdown('Beam search (uses a lot of VRAM)') + with gr.Row(): + with gr.Column(): + shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') + with gr.Column(): + shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') + shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - with gr.Accordion('Soft prompt', open=False, elem_id='accordion'): + with gr.Accordion('Soft prompt', open=False): with gr.Row(): shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button') @@ -148,7 +156,7 @@ def create_settings_menus(default_preset): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) - shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']]) + shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']]) shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']]) @@ -201,26 +209,39 @@ def create_settings_menus(default_preset): if shared.args.chat or shared.args.cai_chat: with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: - if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) - else: - shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) - shared.gradio['textbox'] = gr.Textbox(label='Input') - with gr.Row(): - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['Generate'] = gr.Button('Generate') - with gr.Row(): - shared.gradio['Impersonate'] = gr.Button('Impersonate') - shared.gradio['Regenerate'] = gr.Button('Regenerate') - with gr.Row(): - shared.gradio['Copy last reply'] = gr.Button('Copy last reply') - shared.gradio['Replace last reply'] = gr.Button('Replace last reply') - shared.gradio['Remove last'] = gr.Button('Remove last') - - shared.gradio['Clear history'] = gr.Button('Clear history') - shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) - shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) - with gr.Tab('Chat settings'): + with gr.Tab("Text generation", elem_id="main"): + if shared.args.cai_chat: + shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) + else: + shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) + shared.gradio['textbox'] = gr.Textbox(label='Input') + with gr.Row(): + shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") + shared.gradio['Generate'] = gr.Button('Generate') + with gr.Row(): + shared.gradio['Impersonate'] = gr.Button('Impersonate') + shared.gradio['Regenerate'] = gr.Button('Regenerate') + with gr.Row(): + shared.gradio['Copy last reply'] = gr.Button('Copy last reply') + shared.gradio['Replace last reply'] = gr.Button('Replace last reply') + shared.gradio['Remove last'] = gr.Button('Remove last') + + shared.gradio['Clear history'] = gr.Button('Clear history') + shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) + shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) + + create_model_and_preset_menus() + + with gr.Box(): + with gr.Row(): + with gr.Column(): + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) + with gr.Column(): + shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') + shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') + + with gr.Tab("Chat settings", elem_id="chat-settings"): shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context') @@ -228,8 +249,6 @@ def create_settings_menus(default_preset): shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') - with gr.Row(): - shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') with gr.Row(): with gr.Tab('Chat history'): with gr.Row(): @@ -254,23 +273,16 @@ def create_settings_menus(default_preset): with gr.Tab('Upload TavernAI Character Card'): shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) - with gr.Tab('Generation settings'): - with gr.Row(): - with gr.Column(): - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - with gr.Column(): - shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) - shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') + with gr.Tab("Settings", elem_id="settings"): create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] if shared.args.extensions is not None: - with gr.Tab('Extensions'): - extensions_module.create_extensions_block() + extensions_module.create_extensions_block() function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] - gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) @@ -309,65 +321,75 @@ def create_settings_menus(default_preset): shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']: - gr.Markdown(description) - with gr.Tab('Raw'): - shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23) - with gr.Tab('Markdown'): - shared.gradio['markdown'] = gr.Markdown() - with gr.Tab('HTML'): - shared.gradio['html'] = gr.HTML() - - shared.gradio['Generate'] = gr.Button('Generate') - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - - create_settings_menus(default_preset) + with gr.Tab("Text generation", elem_id="main"): + with gr.Tab('Raw'): + shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25) + with gr.Tab('Markdown'): + shared.gradio['markdown'] = gr.Markdown() + with gr.Tab('HTML'): + shared.gradio['html'] = gr.HTML() + + with gr.Row(): + shared.gradio['Stop'] = gr.Button('Stop') + shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + + create_model_and_preset_menus() + with gr.Tab("Settings", elem_id="settings"): + create_settings_menus(default_preset) + if shared.args.extensions is not None: extensions_module.create_extensions_block() - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") else: with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']: - gr.Markdown(description) - with gr.Row(): - with gr.Column(): - shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - shared.gradio['Generate'] = gr.Button('Generate') - with gr.Row(): - with gr.Column(): - shared.gradio['Continue'] = gr.Button('Continue') - with gr.Column(): - shared.gradio['Stop'] = gr.Button('Stop') + with gr.Tab("Text generation", elem_id="main"): + with gr.Row(): + with gr.Column(): + shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['Generate'] = gr.Button('Generate') + with gr.Row(): + with gr.Column(): + shared.gradio['Continue'] = gr.Button('Continue') + with gr.Column(): + shared.gradio['Stop'] = gr.Button('Stop') - create_settings_menus(default_preset) - if shared.args.extensions is not None: - extensions_module.create_extensions_block() + create_model_and_preset_menus() - with gr.Column(): - with gr.Tab('Raw'): - shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output') - with gr.Tab('Markdown'): - shared.gradio['markdown'] = gr.Markdown() - with gr.Tab('HTML'): - shared.gradio['html'] = gr.HTML() + with gr.Column(): + with gr.Tab('Raw'): + shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output') + with gr.Tab('Markdown'): + shared.gradio['markdown'] = gr.Markdown() + with gr.Tab('HTML'): + shared.gradio['html'] = gr.HTML() + with gr.Tab("Settings", elem_id="settings"): + create_settings_menus(default_preset) + + if shared.args.extensions is not None: + extensions_module.create_extensions_block() - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].queue() if shared.args.listen: