Skip to content

Commit

Permalink
feat: add azure models and fix dashscope models
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Dec 29, 2023
1 parent 61c9177 commit 7ab1629
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 27 deletions.
25 changes: 18 additions & 7 deletions adapter/dashscope/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,31 @@ func (c *ChatInstance) GetHeader() map[string]string {
}
}

func (c *ChatInstance) FormatMessages(message []globals.Message) []Message {
var messages []Message
for _, v := range message {
if v.Role == globals.Tool {
continue
}

messages = append(messages, Message{
Role: v.Role,
Content: v.Content,
})
}

return messages
}

func (c *ChatInstance) GetChatBody(props *ChatProps) ChatRequest {
if props.Token <= 0 || props.Token > 1500 {
props.Token = 1500
}

return ChatRequest{
Model: strings.TrimSuffix(props.Model, "-net"),
Input: ChatInput{
Messages: utils.EachNotNil(props.Message, func(message globals.Message) *globals.Message {
if message.Role == globals.Tool {
return nil
}

return &message
}),
Messages: c.FormatMessages(props.Message),
},
Parameters: ChatParam{
MaxTokens: props.Token,
Expand Down
10 changes: 6 additions & 4 deletions adapter/dashscope/types.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package dashscope

import "chat/globals"

// ChatRequest is the request body for dashscope
type ChatRequest struct {
Model string `json:"model"`
Input ChatInput `json:"input"`
Parameters ChatParam `json:"parameters"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type ChatInput struct {
Prompt string `json:"prompt"`
Messages []globals.Message `json:"messages"`
Messages []Message `json:"messages"`
}

type ChatParam struct {
Expand Down
56 changes: 56 additions & 0 deletions app/src/conf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,49 @@ export const supportModels: Model[] = [
tag: ["official", "unstable", "image-generation"],
},

{
id: "azure-gpt-3.5-turbo",
name: "Azure GPT-3.5",
free: false,
auth: true,
tag: ["official"],
},
{
id: "azure-gpt-3.5-turbo-16k",
name: "Azure GPT-3.5 16K",
free: false,
auth: true,
tag: ["official"],
},
{
id: "azure-gpt-4",
name: "Azure GPT-4",
free: false,
auth: true,
tag: ["official", "high-quality"],
},
{
id: "azure-gpt-4-1106-preview",
name: "Azure GPT-4 Turbo 128k",
free: false,
auth: true,
tag: ["official", "high-context", "unstable"],
},
{
id: "azure-gpt-4-vision-preview",
name: "Azure GPT-4 Vision 128k",
free: false,
auth: true,
tag: ["official", "high-context", "multi-modal"],
},
{
id: "azure-gpt-4-32k",
name: "Azure GPT-4 32k",
free: false,
auth: true,
tag: ["official", "multi-modal"],
},

// spark desk
{
id: "spark-desk-v3",
Expand Down Expand Up @@ -352,6 +395,13 @@ export const defaultModels = [
"gpt-4-v",
"gpt-4-dalle",

"azure-gpt-3.5-turbo",
"azure-gpt-3.5-turbo-16k",
"azure-gpt-4",
"azure-gpt-4-1106-preview",
"azure-gpt-4-vision-preview",
"azure-gpt-4-32k",

"claude-1-100k",
"claude-2",
"claude-2.1",
Expand Down Expand Up @@ -414,6 +464,12 @@ export const modelAvatars: Record<string, string> = {
"gpt-4-32k-0613": "gpt432k.webp",
"gpt-4-v": "gpt4v.png",
"gpt-4-dalle": "gpt4dalle.png",
"azure-gpt-3.5-turbo": "gpt35turbo.png",
"azure-gpt-3.5-turbo-16k": "gpt35turbo16k.webp",
"azure-gpt-4": "gpt4.png",
"azure-gpt-4-1106-preview": "gpt432k.webp",
"azure-gpt-4-vision-preview": "gpt4v.png",
"azure-gpt-4-32k": "gpt432k.webp",
"claude-1-100k": "claude.png",
"claude-2": "claude100k.png",
"claude-2.1": "claude100k.png",
Expand Down
18 changes: 9 additions & 9 deletions manager/chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"time"
)

func TranshipmentAPI(c *gin.Context) {
func ChatRelayAPI(c *gin.Context) {
username := utils.GetUserFromContext(c)
if username == "" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error")
Expand All @@ -28,7 +28,7 @@ func TranshipmentAPI(c *gin.Context) {
return
}

var form TranshipmentForm
var form RelayForm
if err := c.ShouldBindJSON(&form); err != nil {
abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error")
return
Expand Down Expand Up @@ -67,7 +67,7 @@ func TranshipmentAPI(c *gin.Context) {
}
}

func GetChatProps(form TranshipmentForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adapter.ChatProps {
func GetChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adapter.ChatProps {
return &adapter.ChatProps{
Model: form.Model,
Message: messages,
Expand All @@ -85,7 +85,7 @@ func GetChatProps(form TranshipmentForm, messages []globals.Message, buffer *uti
}
}

func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)

Expand All @@ -105,7 +105,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []
}

CollectQuota(c, user, buffer, plan, err)
c.JSON(http.StatusOK, TranshipmentResponse{
c.JSON(http.StatusOK, RelayResponse{
Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion",
Created: created,
Expand All @@ -126,8 +126,8 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []
})
}

func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool, err error) TranshipmentStreamResponse {
return TranshipmentStreamResponse{
func getStreamTranshipmentForm(id string, created int64, form RelayForm, data string, buffer *utils.Buffer, end bool, err error) RelayStreamResponse {
return RelayStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion.chunk",
Created: created,
Expand All @@ -152,8 +152,8 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
}
}

func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
partial := make(chan TranshipmentStreamResponse)
func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
partial := make(chan RelayStreamResponse)
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)

Expand Down
56 changes: 56 additions & 0 deletions manager/images.go
Original file line number Diff line number Diff line change
@@ -1 +1,57 @@
package manager

import (
"chat/auth"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"strings"
"time"
)

func ImagesRelayAPI(c *gin.Context) {
username := utils.GetUserFromContext(c)
if username == "" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error")
return
}

if utils.GetAgentFromContext(c) != "api" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid agent"), "authentication_error")
return
}

var form RelayImageForm
if err := c.ShouldBindJSON(&form); err != nil {
abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error")
return
}

prompt := strings.TrimSpace(form.Prompt)
if prompt == "" {
sendErrorResponse(c, fmt.Errorf("prompt is required"), "invalid_request_error")
}

db := utils.GetDBFromContext(c)
user := &auth.User{
Username: username,
}

created := time.Now().Unix()

if strings.HasSuffix(form.Model, "-official") {
form.Model = strings.TrimSuffix(form.Model, "-official")
}

check := auth.CanEnableModel(db, user, form.Model)
if !check {
sendErrorResponse(c, fmt.Errorf("quota exceeded"), "quota_exceeded_error")
return
}

createRelayImageObject(c, form, prompt, created, user, false)
}

func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string, created int64, user *auth.User, plan bool) {

}
2 changes: 1 addition & 1 deletion manager/transhipment.go → manager/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func sendErrorResponse(c *gin.Context, err error, types ...string) {
errType = "chatnio_api_error"
}

c.JSON(http.StatusServiceUnavailable, TranshipmentErrorResponse{
c.JSON(http.StatusServiceUnavailable, RelayErrorResponse{
Error: TranshipmentError{
Message: err.Error(),
Type: errType,
Expand Down
2 changes: 1 addition & 1 deletion manager/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func Register(app *gin.RouterGroup) {
app.GET("/v1/charge", ChargeAPI)
app.GET("/dashboard/billing/usage", GetBillingUsage)
app.GET("/dashboard/billing/subscription", GetSubscription)
app.POST("/v1/chat/completions", TranshipmentAPI)
app.POST("/v1/chat/completions", ChatRelayAPI)

broadcast.Register(app)
}
23 changes: 18 additions & 5 deletions manager/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type MessageContent struct {

type MessageContents []MessageContent

type TranshipmentForm struct {
type RelayForm struct {
Model string `json:"model" binding:"required"`
Messages []Message `json:"messages" binding:"required"`
Stream bool `json:"stream"`
Expand Down Expand Up @@ -54,7 +54,7 @@ type Usage struct {
TotalTokens int `json:"total_tokens"`
}

type TranshipmentResponse struct {
type RelayResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Expand All @@ -70,7 +70,7 @@ type ChoiceDelta struct {
FinishReason interface{} `json:"finish_reason"`
}

type TranshipmentStreamResponse struct {
type RelayStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Expand All @@ -81,7 +81,7 @@ type TranshipmentStreamResponse struct {
Error error `json:"error,omitempty"`
}

type TranshipmentErrorResponse struct {
type RelayErrorResponse struct {
Error TranshipmentError `json:"error"`
}

Expand All @@ -90,6 +90,19 @@ type TranshipmentError struct {
Type string `json:"type"`
}

type RelayImageForm struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N *int `json:"n,omitempty"`
}

type RelayImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
} `json:"data"`
}

func transformContent(content interface{}) string {
switch v := content.(type) {
case string:
Expand All @@ -100,7 +113,7 @@ func transformContent(content interface{}) string {
if data == nil || len(*data) == 0 {
return ""
}

for _, v := range *data {
if v.Text != nil {
result += *v.Text
Expand Down

0 comments on commit 7ab1629

Please sign in to comment.