Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
OPChips committed Nov 6, 2024
2 parents 3086a2f + 6ded4e9 commit 6667ee1
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 73 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,14 @@ iflytek Api Key.

iflytek Api Secret.

### `CHATGLM_API_KEY` (optional)

ChatGLM Api Key.

### `CHATGLM_URL` (optional)

ChatGLM Api Url.

### `HIDE_USER_API_KEY` (optional)

> Default: Empty
Expand Down
7 changes: 7 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ ByteDance Api Url.

讯飞星火Api Secret.

### `CHATGLM_API_KEY` (可选)

ChatGLM Api Key.

### `CHATGLM_URL` (可选)

ChatGLM Api Url.


### `HIDE_USER_API_KEY` (可选)
Expand Down
4 changes: 2 additions & 2 deletions app/api/common.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "../config/server";
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
import { isModelAvailableInServer } from "../utils/model";
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
import { getModelProvider, isModelAvailableInServer } from "../utils/model";

const serverConfig = getServerSideConfig();

Expand Down Expand Up @@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
.filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
.forEach((m) => {
const [fullName, displayName] = m.split("=");
const [_, providerName] = fullName.split("@");
const [_, providerName] = getModelProvider(fullName);
if (providerName === "azure" && !displayName) {
const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
"deployments/",
Expand Down
85 changes: 48 additions & 37 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";

import { isEmpty } from "lodash-es";
import { getModelProvider } from "../utils/model";

const localStorage = safeLocalStorage();

Expand Down Expand Up @@ -148,7 +149,8 @@ export function SessionConfigModel(props: { onClose: () => void }) {
text={Locale.Chat.Config.Reset}
onClick={async () => {
if (await showConfirm(Locale.Memory.ResetConfirm)) {
chatStore.updateCurrentSession(
chatStore.updateTargetSession(
session,
(session) => (session.memoryPrompt = ""),
);
}
Expand All @@ -173,7 +175,10 @@ export function SessionConfigModel(props: { onClose: () => void }) {
updateMask={(updater) => {
const mask = { ...session.mask };
updater(mask);
chatStore.updateCurrentSession((session) => (session.mask = mask));
chatStore.updateTargetSession(
session,
(session) => (session.mask = mask),
);
}}
shouldSyncFromGlobal
extraListItems={
Expand Down Expand Up @@ -345,12 +350,14 @@ export function PromptHints(props: {

function ClearContextDivider() {
const chatStore = useChatStore();
const session = chatStore.currentSession();

return (
<div
className={styles["clear-context"]}
onClick={() =>
chatStore.updateCurrentSession(
chatStore.updateTargetSession(
session,
(session) => (session.clearContextIndex = undefined),
)
}
Expand Down Expand Up @@ -460,6 +467,7 @@ export function ChatActions(props: {
const navigate = useNavigate();
const chatStore = useChatStore();
const pluginStore = usePluginStore();
const session = chatStore.currentSession();

// switch themes
const theme = config.theme;
Expand All @@ -476,10 +484,9 @@ export function ChatActions(props: {
const stopAll = () => ChatControllerPool.stopAll();

// switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model;
const currentModel = session.mask.modelConfig.model;
const currentProviderName =
chatStore.currentSession().mask.modelConfig?.providerName ||
ServiceProvider.OpenAI;
session.mask.modelConfig?.providerName || ServiceProvider.OpenAI;
const allModels = useAllModels();
const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available);
Expand Down Expand Up @@ -513,12 +520,9 @@ export function ChatActions(props: {
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
const currentSize =
chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024";
const currentQuality =
chatStore.currentSession().mask.modelConfig?.quality ?? "standard";
const currentStyle =
chatStore.currentSession().mask.modelConfig?.style ?? "vivid";
const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
const currentQuality = session.mask.modelConfig?.quality ?? "standard";
const currentStyle = session.mask.modelConfig?.style ?? "vivid";

const isMobileScreen = useMobileScreen();

Expand All @@ -536,7 +540,7 @@ export function ChatActions(props: {
if (isUnavailableModel && models.length > 0) {
// show next model to default model if exist
let nextModel = models.find((model) => model.isDefault) || models[0];
chatStore.updateCurrentSession((session) => {
chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.model = nextModel.name;
session.mask.modelConfig.providerName = nextModel?.provider
?.providerName as ServiceProvider;
Expand All @@ -547,7 +551,7 @@ export function ChatActions(props: {
: nextModel.name,
);
}
}, [chatStore, currentModel, models]);
}, [chatStore, currentModel, models, session]);

return (
<div className={styles["chat-input-actions"]}>
Expand Down Expand Up @@ -614,7 +618,7 @@ export function ChatActions(props: {
text={Locale.Chat.InputActions.Clear}
icon={<BreakIcon />}
onClick={() => {
chatStore.updateCurrentSession((session) => {
chatStore.updateTargetSession(session, (session) => {
if (session.clearContextIndex === session.messages.length) {
session.clearContextIndex = undefined;
} else {
Expand Down Expand Up @@ -645,8 +649,8 @@ export function ChatActions(props: {
onClose={() => setShowModelSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
const [model, providerName] = s[0].split("@");
chatStore.updateCurrentSession((session) => {
const [model, providerName] = getModelProvider(s[0]);
chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.model = model as ModelType;
session.mask.modelConfig.providerName =
providerName as ServiceProvider;
Expand Down Expand Up @@ -684,7 +688,7 @@ export function ChatActions(props: {
onSelection={(s) => {
if (s.length === 0) return;
const size = s[0];
chatStore.updateCurrentSession((session) => {
chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.size = size;
});
showToast(size);
Expand All @@ -711,7 +715,7 @@ export function ChatActions(props: {
onSelection={(q) => {
if (q.length === 0) return;
const quality = q[0];
chatStore.updateCurrentSession((session) => {
chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.quality = quality;
});
showToast(quality);
Expand All @@ -738,7 +742,7 @@ export function ChatActions(props: {
onSelection={(s) => {
if (s.length === 0) return;
const style = s[0];
chatStore.updateCurrentSession((session) => {
chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.style = style;
});
showToast(style);
Expand Down Expand Up @@ -769,7 +773,7 @@ export function ChatActions(props: {
}))}
onClose={() => setShowPluginSelector(false)}
onSelection={(s) => {
chatStore.updateCurrentSession((session) => {
chatStore.updateTargetSession(session, (session) => {
session.mask.plugin = s as string[];
});
}}
Expand Down Expand Up @@ -812,7 +816,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
icon={<ConfirmIcon />}
key="ok"
onClick={() => {
chatStore.updateCurrentSession(
chatStore.updateTargetSession(
session,
(session) => (session.messages = messages),
);
props.onClose();
Expand All @@ -829,7 +834,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
type="text"
value={session.topic}
onInput={(e) =>
chatStore.updateCurrentSession(
chatStore.updateTargetSession(
session,
(session) => (session.topic = e.currentTarget.value),
)
}
Expand Down Expand Up @@ -990,7 +996,8 @@ function _Chat() {
prev: () => chatStore.nextSession(-1),
next: () => chatStore.nextSession(1),
clear: () =>
chatStore.updateCurrentSession(
chatStore.updateTargetSession(
session,
(session) => (session.clearContextIndex = session.messages.length),
),
fork: () => chatStore.forkSession(),
Expand Down Expand Up @@ -1061,7 +1068,7 @@ function _Chat() {
};

useEffect(() => {
chatStore.updateCurrentSession((session) => {
chatStore.updateTargetSession(session, (session) => {
const stopTiming = Date.now() - REQUEST_TIMEOUT_MS;
session.messages.forEach((m) => {
// check if should stop all stale messages
Expand All @@ -1087,7 +1094,7 @@ function _Chat() {
}
});
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
}, [session]);

// check if should send message
const onInputKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
Expand Down Expand Up @@ -1118,7 +1125,8 @@ function _Chat() {
};

const deleteMessage = (msgId?: string) => {
chatStore.updateCurrentSession(
chatStore.updateTargetSession(
session,
(session) =>
(session.messages = session.messages.filter((m) => m.id !== msgId)),
);
Expand Down Expand Up @@ -1185,7 +1193,7 @@ function _Chat() {
};

const onPinMessage = (message: ChatMessage) => {
chatStore.updateCurrentSession((session) =>
chatStore.updateTargetSession(session, (session) =>
session.mask.context.push(message),
);

Expand Down Expand Up @@ -1607,7 +1615,7 @@ function _Chat() {
title={Locale.Chat.Actions.RefreshTitle}
onClick={() => {
showToast(Locale.Chat.Actions.RefreshToast);
chatStore.summarizeSession(true);
chatStore.summarizeSession(true, session);
}}
/>
</div>
Expand Down Expand Up @@ -1711,14 +1719,17 @@ function _Chat() {
});
}
}
chatStore.updateCurrentSession((session) => {
const m = session.mask.context
.concat(session.messages)
.find((m) => m.id === message.id);
if (m) {
m.content = newContent;
}
});
chatStore.updateTargetSession(
session,
(session) => {
const m = session.mask.context
.concat(session.messages)
.find((m) => m.id === message.id);
if (m) {
m.content = newContent;
}
},
);
}}
></IconButton>
</div>
Expand Down
9 changes: 7 additions & 2 deletions app/components/model-config.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
import { useAllModels } from "../utils/hooks";
import { groupBy } from "lodash-es";
import styles from "./model-config.module.scss";
import { getModelProvider } from "../utils/model";

export function ModelConfigList(props: {
modelConfig: ModelConfig;
Expand All @@ -28,7 +29,9 @@ export function ModelConfigList(props: {
value={value}
align="left"
onChange={(e) => {
const [model, providerName] = e.currentTarget.value.split("@");
const [model, providerName] = getModelProvider(
e.currentTarget.value,
);
props.updateConfig((config) => {
config.model = ModalConfigValidator.model(model);
config.providerName = providerName as ServiceProvider;
Expand Down Expand Up @@ -247,7 +250,9 @@ export function ModelConfigList(props: {
aria-label={Locale.Settings.CompressModel.Title}
value={compressModelValue}
onChange={(e) => {
const [model, providerName] = e.currentTarget.value.split("@");
const [model, providerName] = getModelProvider(
e.currentTarget.value,
);
props.updateConfig((config) => {
config.compressModel = ModalConfigValidator.model(model);
config.compressProviderName = providerName as ServiceProvider;
Expand Down
5 changes: 3 additions & 2 deletions app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ export const XAI = {

export const ChatGLM = {
ExampleEndpoint: CHATGLM_BASE_URL,
ChatPath: "/api/paas/v4/chat/completions",
ChatPath: "api/paas/v4/chat/completions",
};

export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
Expand Down Expand Up @@ -327,12 +327,13 @@ const anthropicModels = [
"claude-2.1",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-opus-latest",
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest",
"claude-3-opus-latest",
"claude-3-5-haiku-latest",
];

const baiduModels = [
Expand Down
5 changes: 3 additions & 2 deletions app/store/access.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
import { createPersistStore } from "../utils/store";
import { ensure } from "../utils/clone";
import { DEFAULT_CONFIG } from "./config";
import { getModelProvider } from "../utils/model";

let fetchState = 0; // 0 not fetch, 1 fetching, 2 done

Expand Down Expand Up @@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
.then((res) => {
const defaultModel = res.defaultModel ?? "";
if (defaultModel !== "") {
const [model, providerName] = defaultModel.split("@");
const [model, providerName] = getModelProvider(defaultModel);
DEFAULT_CONFIG.modelConfig.model = model;
DEFAULT_CONFIG.modelConfig.providerName = providerName;
DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
}

return res;
Expand Down
Loading

0 comments on commit 6667ee1

Please sign in to comment.