Skip to content

Commit

Permalink
update sparkdesk
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Oct 31, 2023
1 parent 14b04d9 commit e74902a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 9 deletions.
33 changes: 30 additions & 3 deletions adapter/sparkdesk/struct.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sparkdesk

import (
"chat/globals"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
Expand All @@ -19,13 +20,39 @@ type ChatInstance struct {
Endpoint string
}

func NewChatInstance() *ChatInstance {
func TransformAddr(model string) string {
switch model {
case globals.SparkDesk:
return "v1.1"
case globals.SparkDeskV2:
return "v2.1"
case globals.SparkDeskV3:
return "v3.1"
default:
return "v1.1"
}
}

func TransformModel(model string) string {
switch model {
case globals.SparkDesk:
return "general"
case globals.SparkDeskV2:
return "generalv2"
case globals.SparkDeskV3:
return "generalv3"
default:
return "general"
}
}

func NewChatInstance(model string) *ChatInstance {
return &ChatInstance{
AppId: viper.GetString("sparkdesk.app_id"),
ApiSecret: viper.GetString("sparkdesk.api_secret"),
ApiKey: viper.GetString("sparkdesk.api_key"),
Model: viper.GetString("sparkdesk.model"),
Endpoint: viper.GetString("sparkdesk.endpoint"),
Model: TransformModel(model),
Endpoint: fmt.Sprintf("%s/%s/chat", viper.GetString("sparkdesk.endpoint"), TransformAddr(model)),
}
}

Expand Down
2 changes: 1 addition & 1 deletion auth/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
return user != nil && user.GetQuota(db) >= 5
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
return user != nil && user.GetQuota(db) >= 50
case globals.SparkDesk:
case globals.SparkDesk, globals.SparkDeskV2, globals.SparkDeskV3:
return user != nil && user.GetQuota(db) >= 1
case globals.Claude2100k:
return user != nil && user.GetQuota(db) >= 1
Expand Down
14 changes: 12 additions & 2 deletions globals/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ const (
Claude2 = "claude-1" // claude v1.3
Claude2100k = "claude-2"
ClaudeSlack = "claude-slack"
SparkDesk = "spark-desk"
SparkDesk = "spark-desk-v1.5"
SparkDeskV2 = "spark-desk-v2"
SparkDeskV3 = "spark-desk-v3"
ChatBison001 = "chat-bison-001"
BingCreative = "bing-creative"
BingBalanced = "bing-balanced"
Expand Down Expand Up @@ -105,6 +107,12 @@ var ZhiPuModelArray = []string{
ZhiPuChatGLMLite,
}

var SparkDeskModelArray = []string{
SparkDesk,
SparkDeskV2,
SparkDeskV3,
}

var LongContextModelArray = []string{
GPT3Turbo16k,
GPT3Turbo16k0613,
Expand Down Expand Up @@ -149,6 +157,8 @@ var AllModels = []string{
Claude2100k,
ClaudeSlack,
SparkDesk,
SparkDeskV2,
SparkDeskV3,
ChatBison001,
BingCreative,
BingBalanced,
Expand Down Expand Up @@ -196,7 +206,7 @@ func IsSlackModel(model string) bool {
}

func IsSparkDeskModel(model string) bool {
return model == SparkDesk
return in(model, SparkDeskModelArray)
}

func IsPalm2Model(model string) bool {
Expand Down
12 changes: 9 additions & 3 deletions utils/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ func GetWeightByModel(model string) int {
globals.GPT4,
globals.GPT40314,
globals.GPT40613,
globals.SparkDesk:
globals.SparkDesk,
globals.SparkDeskV2,
globals.SparkDeskV3:
return 3
case globals.GPT3Turbo0301,
globals.GPT3Turbo16k0301,
Expand Down Expand Up @@ -109,7 +111,9 @@ func CountInputToken(model string, v []globals.Message) float32 {
case globals.GPT432k:
return float32(CountTokenPrice(v, model)) / 1000 * 4.2
case globals.SparkDesk:
return float32(CountTokenPrice(v, model)) / 1000 * 0.36
return float32(CountTokenPrice(v, model)) / 1000 * 0.15
case globals.SparkDeskV2, globals.SparkDeskV3:
return float32(CountTokenPrice(v, model)) / 1000 * 0.3
case globals.Claude2:
return 0
case globals.Claude2100k:
Expand All @@ -134,7 +138,9 @@ func CountOutputToken(model string, t int) float32 {
case globals.GPT432k:
return float32(t*GetWeightByModel(model)) / 1000 * 8.6
case globals.SparkDesk:
return float32(t*GetWeightByModel(model)) / 1000 * 0.36
return float32(t*GetWeightByModel(model)) / 1000 * 0.15
case globals.SparkDeskV2, globals.SparkDeskV3:
return float32(t*GetWeightByModel(model)) / 1000 * 0.3
case globals.Claude2:
return 0
case globals.Claude2100k:
Expand Down

0 comments on commit e74902a

Please sign in to comment.