Skip to content

Commit

Permalink
feat: exclude models
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Dec 12, 2023
1 parent 208e18c commit 58a493e
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 47 deletions.
9 changes: 6 additions & 3 deletions app/src/i18n.ts
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ const resources = {
"mapper-tip": "模型名转换,实现非对称的模型请求",
"mapper-placeholder":
"请输入模型映射,一行一个,格式: model>model\n" +
"前者为请求的模型,后者为映射的模型(需要在模型中存在),中间用 > 分隔",
"前者为请求的模型,后者为映射的模型(需要在模型中存在),中间用 > 分隔\n" +
"格式前加!表示原模型不包含在此渠道的可用范围内,如: !gpt-4-slow>gpt-4,那么 gpt-4 将不会被涵盖在此渠道的可请求模型中",
state: "状态",
action: "操作",
edit: "编辑渠道",
Expand Down Expand Up @@ -829,7 +830,8 @@ const resources = {
"Model name conversion to achieve asymmetric model request",
"mapper-placeholder":
"Please enter the model mapper, one line each, format: model>model\n" +
"The former is the requested model, and the latter is the mapped model (which needs to exist in the model), separated by > in the middle",
"The former is the requested model, and the latter is the mapped model (which needs to exist in the model), separated by > in the middle\n" +
"The format is preceded by! Indicates that the original model is not included in the available range of this channel, such as: !gpt-4-slow>gpt-4, then gpt-4 will not be covered in the available models that can be requested in this channel",
state: "State",
action: "Action",
edit: "Edit Channel",
Expand Down Expand Up @@ -1273,7 +1275,8 @@ const resources = {
"Преобразование имени модели для достижения асимметричного запроса модели",
"mapper-placeholder":
"Введите модельный маппер, по одной строке, формат: model>model\n" +
"Первая модель - запрошенная модель, вторая модель - отображаемая модель (которая должна существовать в модели), разделенная > посередине",
"Первая модель - запрошенная модель, вторая модель - отображаемая модель (которая должна существовать в модели), разделенная > посередине\n" +
"Формат предшествует! Означает, что исходная модель не включена в доступный диапазон этого канала, например: !gpt-4-slow>gpt-4, тогда gpt-4 не будет охвачен в доступных моделях, которые можно запросить в этом канале",
state: "Статус",
action: "Действие",
edit: "Редактировать канал",
Expand Down
72 changes: 42 additions & 30 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,55 @@ func (c *Channel) GetMapper() string {
return c.Mapper
}

func (c *Channel) GetReflect() map[string]string {
if c.Reflect == nil {
reflect := make(map[string]string)
arr := strings.Split(c.GetMapper(), "\n")
for _, item := range arr {
pair := strings.Split(item, ">")
if len(pair) == 2 {
reflect[pair[0]] = pair[1]
}
func (c *Channel) Load() {
reflect := make(map[string]string)
exclude := make([]string, 0)
models := c.GetModels()

arr := strings.Split(c.GetMapper(), "\n")
for _, item := range arr {
pair := strings.Split(item, ">")
if len(pair) != 2 {
continue
}

from, to := pair[0], pair[1]
if strings.HasPrefix(from, "!") {
from = strings.TrimPrefix(from, "!")
exclude = append(exclude, to)
}

reflect[from] = to
}

c.Reflect = &reflect
c.ExcludeModels = &exclude

var hits []string

for _, model := range models {
if !utils.Contains(model, hits) && !utils.Contains(model, exclude) {
hits = append(hits, model)
}
}

c.Reflect = &reflect
for model := range reflect {
if !utils.Contains(model, hits) && !utils.Contains(model, exclude) {
hits = append(hits, model)
}
}

c.HitModels = &hits
}

func (c *Channel) GetReflect() map[string]string {
return *c.Reflect
}

func (c *Channel) GetExcludeModels() []string {
return *c.ExcludeModels
}

// GetModelReflect returns the reflection model name if it exists, otherwise returns the original model name
func (c *Channel) GetModelReflect(model string) string {
ref := c.GetReflect()
Expand All @@ -115,26 +147,6 @@ func (c *Channel) GetModelReflect(model string) string {
}

func (c *Channel) GetHitModels() []string {
if c.HitModels == nil {
var res []string

models := c.GetModels()
ref := c.GetReflect()

for _, model := range models {
if !utils.Contains(model, res) {
res = append(res, model)
}
}

for model := range ref {
if !utils.Contains(model, res) {
res = append(res, model)
}
}

c.HitModels = &res
}
return *c.HitModels
}

Expand Down
7 changes: 7 additions & 0 deletions channel/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ func NewManager() *Manager {
}

func (m *Manager) Load() {
// load channels
for _, channel := range m.Sequence {
if channel != nil {
channel.Load()
}
}

// init support models
m.Models = []string{}
for _, channel := range m.GetActiveSequence() {
Expand Down
27 changes: 14 additions & 13 deletions channel/types.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package channel

type Channel struct {
Id int `json:"id" mapstructure:"id"`
Name string `json:"name" mapstructure:"name"`
Type string `json:"type" mapstructure:"type"`
Priority int `json:"priority" mapstructure:"priority"`
Weight int `json:"weight" mapstructure:"weight"`
Models []string `json:"models" mapstructure:"models"`
Retry int `json:"retry" mapstructure:"retry"`
Secret string `json:"secret" mapstructure:"secret"`
Endpoint string `json:"endpoint" mapstructure:"endpoint"`
Mapper string `json:"mapper" mapstructure:"mapper"`
State bool `json:"state" mapstructure:"state"`
Reflect *map[string]string `json:"-"`
HitModels *[]string `json:"-"`
Id int `json:"id" mapstructure:"id"`
Name string `json:"name" mapstructure:"name"`
Type string `json:"type" mapstructure:"type"`
Priority int `json:"priority" mapstructure:"priority"`
Weight int `json:"weight" mapstructure:"weight"`
Models []string `json:"models" mapstructure:"models"`
Retry int `json:"retry" mapstructure:"retry"`
Secret string `json:"secret" mapstructure:"secret"`
Endpoint string `json:"endpoint" mapstructure:"endpoint"`
Mapper string `json:"mapper" mapstructure:"mapper"`
State bool `json:"state" mapstructure:"state"`
Reflect *map[string]string `json:"-"`
HitModels *[]string `json:"-"`
ExcludeModels *[]string `json:"-"`
}

type Sequence []*Channel
Expand Down
16 changes: 15 additions & 1 deletion channel/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"chat/globals"
"chat/utils"
"fmt"
"github.com/spf13/viper"
)

func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error {
Expand All @@ -14,16 +15,29 @@ func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error {

ticker := ManagerInstance.GetTicker(props.Model)

debug := viper.GetBool("debug")

var err error
for !ticker.IsDone() {
if channel := ticker.Next(); channel != nil {
if debug {
fmt.Println(fmt.Sprintf("[channel] try channel %s for model %s", channel.GetName(), props.Model))
}

props.MaxRetries = utils.ToPtr(channel.GetRetry())
if err = adapter.NewChatRequest(channel, props, hook); err == nil {
if debug {
fmt.Println(fmt.Sprintf("[channel] hit channel %s for model %s", channel.GetName(), props.Model))
}

return nil
}
fmt.Println(fmt.Sprintf("[channel] hit error %s for model %s, goto next channel", err.Error(), props.Model))
fmt.Println(fmt.Sprintf("[channel] caught error %s for model %s at channel %s -> goto next channel", err.Error(), props.Model, channel.GetName()))
}
}

if debug {
fmt.Println(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model))
}
return err
}
38 changes: 38 additions & 0 deletions utils/char.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,41 @@ func DecodeUnicode(data string) string {
return string(rune(unicode))
})
}

func SortString(arr []string) []string {
// sort string array by first char
// e.g. ["a", "b", "c", "ab", "ac", "bc"] => ["a", "ab", "ac", "b", "bc", "c"]

if len(arr) <= 1 {
return arr
}

var result []string
var temp []string
var first string

for _, item := range arr {
if first == "" {
first = item
continue
}

if strings.HasPrefix(item, first) {
temp = append(temp, item)
} else {
result = append(result, first)
result = append(result, SortString(temp)...)
first = item
temp = []string{}
}
}

if len(temp) > 0 {
result = append(result, first)
result = append(result, SortString(temp)...)
} else {
result = append(result, first)
}

return result
}
8 changes: 8 additions & 0 deletions utils/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/tls"
"fmt"
"github.com/goccy/go-json"
"github.com/spf13/viper"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -132,6 +133,13 @@ func EventSource(method string, uri string, headers map[string]string, body inte
defer res.Body.Close()

if res.StatusCode >= 400 {
// print body
if viper.GetBool("debug") {
if content, err := io.ReadAll(res.Body); err == nil {
fmt.Println(fmt.Sprintf("request failed with status: %s, body: %s", res.Status, string(content)))
}
}

return fmt.Errorf("request failed with status: %s", res.Status)
}

Expand Down

0 comments on commit 58a493e

Please sign in to comment.