From e31d69ae18f882c6c4ec69001f11d4540ce4eeb4 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Sat, 21 Oct 2023 22:59:17 +0800 Subject: [PATCH] update model selector --- app/src/assets/ui.less | 24 +++++ app/src/components/Message.tsx | 34 ++++--- app/src/components/SelectGroup.tsx | 118 +++++++++++++++++----- app/src/components/home/ModelSelector.tsx | 56 ++++++++++ app/src/conf.ts | 55 +++++++--- app/src/conversation/connection.ts | 9 +- app/src/conversation/conversation.ts | 12 ++- app/src/conversation/types.ts | 7 ++ app/src/events/connection.ts | 2 +- app/src/routes/Generation.tsx | 29 +----- app/src/routes/Home.tsx | 55 ++++------ app/src/store/chat.ts | 11 +- 12 files changed, 282 insertions(+), 130 deletions(-) create mode 100644 app/src/components/home/ModelSelector.tsx diff --git a/app/src/assets/ui.less b/app/src/assets/ui.less index 70912c56..e4439340 100644 --- a/app/src/assets/ui.less +++ b/app/src/assets/ui.less @@ -17,6 +17,9 @@ } .select-group-item { + display: flex; + flex-direction: row; + align-items: center; padding: 0.35rem 0.8rem; border: 1px solid hsl(var(--border)); border-radius: 4px; @@ -34,6 +37,27 @@ background: hsl(var(--text)); border-color: hsl(var(--border-hover)); color: hsl(var(--background)); + + .badge { + background: hsl(var(--background)) !important; + color: hsl(var(--text)); + + &:hover { + background: hsl(var(--background)) !important; + } + } + } + + .badge { + user-select: none; + transition: .2s; + padding-left: 0.45rem; + padding-right: 0.45rem; + background: hsl(var(--primary)) !important; + + &:hover { + background: hsl(var(--primary)) !important; + } } } } diff --git a/app/src/components/Message.tsx b/app/src/components/Message.tsx index 75518209..2a119764 100644 --- a/app/src/components/Message.tsx +++ b/app/src/components/Message.tsx @@ -6,7 +6,9 @@ import { Copy, File, Loader2, - MousePointerSquare, Power, RotateCcw, + MousePointerSquare, + Power, + RotateCcw, } from "lucide-react"; import { ContextMenu, @@ -165,21 +167,21 @@ function MessageContent({ message, end, onEvent }: MessageProps) { )} - { - (message.role === "assistant" && end === true) && ( -
- { - (message.end !== false) ? - ( - onEvent && onEvent("restart") - )} /> : - ( - onEvent && onEvent("stop") - )} /> - } -
- ) - } + {message.role === "assistant" && end === true && ( +
+ {message.end !== false ? ( + onEvent && onEvent("restart")} + /> + ) : ( + onEvent && onEvent("stop")} + /> + )} +
+ )} ); } diff --git a/app/src/components/SelectGroup.tsx b/app/src/components/SelectGroup.tsx index c0c7fef4..7900872a 100644 --- a/app/src/components/SelectGroup.tsx +++ b/app/src/components/SelectGroup.tsx @@ -7,51 +7,113 @@ import { } from "./ui/select"; import { mobile } from "../utils.ts"; import { useEffect, useState } from "react"; +import { Badge } from "./ui/badge.tsx"; + +export type SelectItemProps = { + name: string; + value: string; + badge?: string; + tag?: any; +}; type SelectGroupProps = { - current: string; - list: string[]; + current: SelectItemProps; + list: SelectItemProps[]; onChange?: (select: string) => void; + maxElements?: number; }; -function SelectGroup(props: SelectGroupProps) { - const [state, setState] = useState(mobile); - useEffect(() => { - window.addEventListener("resize", () => { - setState(mobile); - }); - }, []); +function GroupSelectItem(props: SelectItemProps) { + return ( + <> + {props.value} + {props.badge && {props.badge}} + + ); +} - return state ? ( +function SelectGroupDesktop(props: SelectGroupProps) { + const max: number = props.maxElements || 5; + const range = props.list.length > max ? max : props.list.length; + const display = props.list.slice(0, range); + const hidden = props.list.slice(range); + + return ( +
+ {display.map((select: SelectItemProps, idx: number) => ( +
props.onChange?.(select.name)} + className={`select-group-item ${ + select == props.current ? "active" : "" + }`} + > + +
+ ))} + + {props.list.length > max && ( + + )} +
+ ); +} + +function SelectGroupMobile(props: SelectGroupProps) { + return ( + ); +} + +function SelectGroup(props: SelectGroupProps) { + const [state, setState] = useState(mobile); + useEffect(() => { + window.addEventListener("resize", () => { + setState(mobile); + }); + }, []); + + return state ? ( + ) : ( -
- {props.list.map((select: string, idx: number) => ( -
props.onChange?.(select)} - className={`select-group-item ${ - select == props.current ? "active" : "" - }`} - > - {select} -
- ))} -
+ ); } diff --git a/app/src/components/home/ModelSelector.tsx b/app/src/components/home/ModelSelector.tsx new file mode 100644 index 00000000..9aee97b3 --- /dev/null +++ b/app/src/components/home/ModelSelector.tsx @@ -0,0 +1,56 @@ +import SelectGroup, { SelectItemProps } from "../SelectGroup.tsx"; +import { supportModels } from "../../conf.ts"; +import { selectModel, setModel } from "../../store/chat.ts"; +import { useTranslation } from "react-i18next"; +import { useDispatch, useSelector } from "react-redux"; +import { selectAuthenticated } from "../../store/auth.ts"; +import { useToast } from "../ui/use-toast.ts"; +import { useEffect } from "react"; +import { Model } from "../../conversation/types.ts"; + +function GetModel(name: string): Model { + return supportModels.find((model) => model.id === name) as Model; +} + +function ModelSelector() { + const { t } = useTranslation(); + const dispatch = useDispatch(); + const { toast } = useToast(); + + const model = useSelector(selectModel); + const auth = useSelector(selectAuthenticated); + + useEffect(() => { + if (auth && model === "GPT-3.5") dispatch(setModel("GPT-3.5-16k")); + }, [auth]); + + const list = supportModels.map( + (model: Model): SelectItemProps => ({ + name: model.id, + value: model.name, + badge: model.free ? "free" : undefined, + }), + ); + + return ( + item.name === model) as SelectItemProps} + list={list} + maxElements={6} + onChange={(value: string) => { + const model = GetModel(value); + console.debug(`[model] select model: ${model.name} (id: ${model.id})`); + + if (!auth && model.auth) { + toast({ + title: t("login-require"), + }); + return; + } + dispatch(setModel(value)); + }} + /> + ); +} + +export default ModelSelector; diff --git a/app/src/conf.ts b/app/src/conf.ts index b267b41a..ebae6b6d 100644 --- a/app/src/conf.ts +++ b/app/src/conf.ts @@ -1,6 +1,7 @@ import axios from "axios"; +import { Model } from "./conversation/types.ts"; -export const version = "3.4.5"; +export const version = "3.4.6"; export const deploy: boolean = true; export let rest_api: string = "http://localhost:8094"; export let ws_api: string = "ws://localhost:8094"; @@ -11,19 +12,45 @@ if (deploy) { } export const tokenField = deploy ? "token" : "token-dev"; -export const supportModels: string[] = [ - "GPT-3.5", - "GPT-3.5-16k", - "GPT-4", - "GPT-4-32k", - "Claude-2", - "Claude-2-100k", - "SparkDesk 讯飞星火", - "Palm2", - "New Bing", - "智谱 ChatGLM Pro", - "智谱 ChatGLM Std", - "智谱 ChatGLM Lite", +export const supportModels: Model[] = [ + // openai models + { id: "gpt-3.5-turbo", name: "GPT-3.5", free: true, auth: false }, + { id: "gpt-3.5-turbo-16k", name: "GPT-3.5-16k", free: true, auth: true }, + { id: "gpt-4", name: "GPT-4", free: false, auth: true }, + { id: "gpt-4-32k", name: "GPT-4-32k", free: false, auth: true }, + + // anthropic models + { id: "claude-1", name: "Claude-2", free: true, auth: false }, + { id: "claude-2", name: "Claude-2-100k", free: false, auth: true }, // not claude-2-100k + + // spark desk + { id: "spark-desk", name: "SparkDesk 讯飞星火", free: false, auth: true }, + + // google palm2 + { id: "chat-bison-001", name: "Palm2", free: true, auth: true }, + + // new bing + { id: "bing-creative", name: "New Bing", free: true, auth: true }, + + // zhipu models + { + id: "zhipu-chatglm-pro", + name: "智谱 ChatGLM Pro", + free: false, + auth: true, + }, + { + id: "zhipu-chatglm-std", + name: "智谱 ChatGLM Std", + free: false, + auth: true, + }, + { + id: "zhipu-chatglm-lite", + name: "智谱 ChatGLM Lite", + free: true, + auth: true, + }, ]; export const supportModelConvertor: Record = { diff --git a/app/src/conversation/connection.ts b/app/src/conversation/connection.ts index c72caf24..fdbda9b6 100644 --- a/app/src/conversation/connection.ts +++ b/app/src/conversation/connection.ts @@ -71,10 +71,11 @@ export class Connection { }, 500); } } catch { - if (t !== undefined) this.triggerCallback({ - message: t("request-failed"), - end: true, - }); + if (t !== undefined) + this.triggerCallback({ + message: t("request-failed"), + end: true, + }); } } diff --git a/app/src/conversation/conversation.ts b/app/src/conversation/conversation.ts index edf55a78..dc16ce34 100644 --- a/app/src/conversation/conversation.ts +++ b/app/src/conversation/conversation.ts @@ -1,7 +1,7 @@ import { ChatProps, Connection, StreamMessage } from "./connection.ts"; import { Message } from "./types.ts"; import { sharingEvent } from "../events/sharing.ts"; -import {connectionEvent} from "../events/connection.ts"; +import { connectionEvent } from "../events/connection.ts"; type ConversationCallback = (idx: number, message: Message[]) => void; @@ -34,7 +34,9 @@ export class Conversation { connectionEvent.addEventListener((ev) => { if (ev.id === this.id) { - console.debug(`[conversation] connection event (id: ${this.id}, event: ${ev.event})`); + console.debug( + `[conversation] connection event (id: ${this.id}, event: ${ev.event})`, + ); switch (ev.event) { case "stop": @@ -52,10 +54,12 @@ export class Conversation { break; default: - console.debug(`[conversation] unknown event: ${ev.event} (from: ${ev.id})`); + console.debug( + `[conversation] unknown event: ${ev.event} (from: ${ev.id})`, + ); } } - }) + }); } protected sendEvent(event: string, data?: string) { diff --git a/app/src/conversation/types.ts b/app/src/conversation/types.ts index b8e52be9..1a273b87 100644 --- a/app/src/conversation/types.ts +++ b/app/src/conversation/types.ts @@ -8,6 +8,13 @@ export type Message = { end?: boolean; }; +export type Model = { + id: string; + name: string; + free: boolean; + auth: boolean; +}; + export type Id = number; export type ConversationInstance = { diff --git a/app/src/events/connection.ts b/app/src/events/connection.ts index 95d1fc2e..c7c3141a 100644 --- a/app/src/events/connection.ts +++ b/app/src/events/connection.ts @@ -1,4 +1,4 @@ -import {EventCommitter} from "./struct.ts"; +import { EventCommitter } from "./struct.ts"; export type ConnectionEvent = { id: number; diff --git a/app/src/routes/Generation.tsx b/app/src/routes/Generation.tsx index be09baca..1755762f 100644 --- a/app/src/routes/Generation.tsx +++ b/app/src/routes/Generation.tsx @@ -1,18 +1,17 @@ import "../assets/generation.less"; -import { useDispatch, useSelector } from "react-redux"; -import { selectAuthenticated } from "../store/auth.ts"; +import { useSelector } from "react-redux"; import { useTranslation } from "react-i18next"; import { Button } from "../components/ui/button.tsx"; import { ChevronLeft, Cloud, FileDown, Send } from "lucide-react"; -import { rest_api, supportModelConvertor, supportModels } from "../conf.ts"; +import { rest_api, supportModelConvertor } from "../conf.ts"; import router from "../router.ts"; import { Input } from "../components/ui/input.tsx"; import { useEffect, useRef, useState } from "react"; -import SelectGroup from "../components/SelectGroup.tsx"; import { manager } from "../conversation/generation.ts"; import { useToast } from "../components/ui/use-toast.ts"; import { handleGenerationData } from "../utils.ts"; -import { selectModel, setModel } from "../store/chat.ts"; +import { selectModel } from "../store/chat.ts"; +import ModelSelector from "../components/home/ModelSelector.tsx"; type WrapperProps = { onSend?: (value: string, model: string) => boolean; @@ -20,7 +19,6 @@ type WrapperProps = { function Wrapper({ onSend }: WrapperProps) { const { t } = useTranslation(); - const dispatch = useDispatch(); const ref = useRef(null); const [stayed, setStayed] = useState(false); const [hash, setHash] = useState(""); @@ -28,14 +26,9 @@ function Wrapper({ onSend }: WrapperProps) { const [quota, setQuota] = useState(0); const model = useSelector(selectModel); const modelRef = useRef(model); - const auth = useSelector(selectAuthenticated); const { toast } = useToast(); - useEffect(() => { - if (auth && model === "GPT-3.5") dispatch(setModel("GPT-3.5-16k")); - }, [auth]); - function clear() { setData(""); setQuota(0); @@ -154,19 +147,7 @@ function Wrapper({ onSend }: WrapperProps) {
- { - if (!auth && value !== "GPT-3.5") { - toast({ - title: t("login-require"), - }); - return; - } - dispatch(setModel(value)); - }} - /> +
); diff --git a/app/src/routes/Home.tsx b/app/src/routes/Home.tsx index 414c4e47..674e77ae 100644 --- a/app/src/routes/Home.tsx +++ b/app/src/routes/Home.tsx @@ -22,7 +22,7 @@ import { import { useDispatch, useSelector } from "react-redux"; import type { RootState } from "../store"; import { selectAuthenticated, selectInit } from "../store/auth.ts"; -import { login, supportModels } from "../conf.ts"; +import { login } from "../conf.ts"; import { deleteConversation, toggleConversation, @@ -39,7 +39,7 @@ import { useEffectAsync, copyClipboard, } from "../utils.ts"; -import { toast, useToast } from "../components/ui/use-toast.ts"; +import { useToast } from "../components/ui/use-toast.ts"; import { ConversationInstance, Message } from "../conversation/types.ts"; import { selectCurrent, @@ -47,7 +47,6 @@ import { selectHistory, selectMessages, selectWeb, - setModel, setWeb, } from "../store/chat.ts"; import { @@ -66,10 +65,10 @@ import MessageSegment from "../components/Message.tsx"; import { setMenu } from "../store/menu.ts"; import FileProvider, { FileObject } from "../components/FileProvider.tsx"; import router from "../router.ts"; -import SelectGroup from "../components/SelectGroup.tsx"; import EditorProvider from "../components/EditorProvider.tsx"; import ConversationSegment from "../components/home/ConversationSegment.tsx"; -import {connectionEvent} from "../events/connection.ts"; +import { connectionEvent } from "../events/connection.ts"; +import ModelSelector from "../components/home/ModelSelector.tsx"; function SideBar() { const { t } = useTranslation(); @@ -353,21 +352,19 @@ function ChatInterface() { - { - messages.map((message, i) => - { - connectionEvent.emit({ - id: current, - event: e, - }); - }} - key={i} - /> - ) - } + {messages.map((message, i) => ( + { + connectionEvent.emit({ + id: current, + event: e, + }); + }} + key={i} + /> + ))} ); @@ -390,10 +387,6 @@ function ChatWrapper() { const target = useRef(null); manager.setDispatch(dispatch); - useEffect(() => { - if (auth && model === "GPT-3.5") dispatch(setModel("GPT-3.5-16k")); - }, [auth]); - function clearFile() { clearEvent?.(); } @@ -522,19 +515,7 @@ function ChatWrapper() {
- { - if (!auth && model !== "GPT-3.5") { - toast({ - title: t("login-require"), - }); - return; - } - dispatch(setModel(model)); - }} - /> +
diff --git a/app/src/store/chat.ts b/app/src/store/chat.ts index 49493a14..d4267363 100644 --- a/app/src/store/chat.ts +++ b/app/src/store/chat.ts @@ -1,8 +1,9 @@ import { createSlice } from "@reduxjs/toolkit"; -import { ConversationInstance } from "../conversation/types.ts"; +import {ConversationInstance, Model} from "../conversation/types.ts"; import { Message } from "../conversation/types.ts"; import { insertStart } from "../utils.ts"; import { RootState } from "./index.ts"; +import { supportModels } from "../conf.ts"; type initialStateType = { history: ConversationInstance[]; @@ -12,12 +13,17 @@ type initialStateType = { current: number; }; +function GetModel(model: string | undefined | null): string { + return model && supportModels.filter((item: Model) => item.id === model).length + ? model : supportModels[0].id; +} + const chatSlice = createSlice({ name: "chat", initialState: { history: [], messages: [], - model: "GPT-3.5", + model: GetModel(localStorage.getItem("model")), web: false, current: -1, } as initialStateType, @@ -44,6 +50,7 @@ const chatSlice = createSlice({ state.messages = action.payload as Message[]; }, setModel: (state, action) => { + localStorage.setItem("model", action.payload as string); state.model = action.payload as string; }, setWeb: (state, action) => {