Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Moved vendors into their own directory + homogized their implementation
Browse files Browse the repository at this point in the history
baalimago committed Apr 2, 2024
1 parent 6e9edac commit fc568f1
Showing 12 changed files with 114 additions and 104 deletions.
4 changes: 2 additions & 2 deletions internal/create_queriers.go
Original file line number Diff line number Diff line change
@@ -5,12 +5,12 @@ import (
"os"
"strings"

"github.com/baalimago/clai/internal/anthropic"
"github.com/baalimago/clai/internal/chat"
"github.com/baalimago/clai/internal/models"
"github.com/baalimago/clai/internal/openai"
"github.com/baalimago/clai/internal/photo"
"github.com/baalimago/clai/internal/text"
"github.com/baalimago/clai/internal/vendors/anthropic"
"github.com/baalimago/clai/internal/vendors/openai"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/misc"
)
2 changes: 1 addition & 1 deletion internal/setup_config_migrations.go
Original file line number Diff line number Diff line change
@@ -5,10 +5,10 @@ import (
"os"
"path"

"github.com/baalimago/clai/internal/openai"
"github.com/baalimago/clai/internal/photo"
"github.com/baalimago/clai/internal/text"
"github.com/baalimago/clai/internal/tools"
"github.com/baalimago/clai/internal/vendors/openai"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/misc"
)
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ type Claude struct {
debug bool `json:"-"`
}

type ClaudeReq struct {
type claudeReq struct {
Model string `json:"model"`
Messages []models.Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package anthropic

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"os"
@@ -23,40 +20,6 @@ var defaultClaude = Claude{
MaxTokens: 1024,
}

func (c *Claude) constructRequest(ctx context.Context, chat models.Chat) (*http.Request, error) {
// ignored for now as error is not used
sysMsg, _ := chat.SystemMessage()
if c.debug {
ancli.PrintOK(fmt.Sprintf("pre-claudified messages: %+v\n", chat.Messages))
}
claudifiedMsgs := claudifyMessages(chat.Messages)
if c.debug {
ancli.PrintOK(fmt.Sprintf("claudified messages: %+v\n", claudifiedMsgs))
}
reqData := ClaudeReq{
Model: c.Model,
Messages: claudifiedMsgs,
MaxTokens: c.MaxTokens,
Stream: true,
System: sysMsg.Content,
}
jsonData, err := json.Marshal(reqData)
if err != nil {
return nil, fmt.Errorf("failed to marshal ClaudeReq: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", c.apiKey)
req.Header.Set("anthropic-version", c.AnthropicVersion)
if c.debug {
ancli.PrintOK(fmt.Sprintf("Request: %+v\n", req))
}
return req, nil
}

func loadQuerier(loadFrom, model string) (*Claude, error) {
apiKey := os.Getenv("ANTHROPIC_API_KEY")
if apiKey == "" {
@@ -80,6 +43,8 @@ func loadQuerier(loadFrom, model string) (*Claude, error) {
return &claudeQuerier, nil
}

// NewTextQuerier returns a new Claude querier using the textconfigurations to load
// the correct model. API key is fetched via environment variable
func NewTextQuerier(conf text.Configurations) (models.ChatQuerier, error) {
home, _ := os.UserConfigDir()
querier, err := loadQuerier(home, conf.Model)
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ package anthropic

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
@@ -177,3 +178,37 @@ func (c *Claude) clearAndPrettyPrint(termWidth, lineCount int, fullMessage model
fmt.Print(fullMessage.Content)
}
}

func (c *Claude) constructRequest(ctx context.Context, chat models.Chat) (*http.Request, error) {
// ignored for now as error is not used
sysMsg, _ := chat.SystemMessage()
if c.debug {
ancli.PrintOK(fmt.Sprintf("pre-claudified messages: %+v\n", chat.Messages))
}
claudifiedMsgs := claudifyMessages(chat.Messages)
if c.debug {
ancli.PrintOK(fmt.Sprintf("claudified messages: %+v\n", claudifiedMsgs))
}
reqData := claudeReq{
Model: c.Model,
Messages: claudifiedMsgs,
MaxTokens: c.MaxTokens,
Stream: true,
System: sysMsg.Content,
}
jsonData, err := json.Marshal(reqData)
if err != nil {
return nil, fmt.Errorf("failed to marshal ClaudeReq: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", c.apiKey)
req.Header.Set("anthropic-version", c.AnthropicVersion)
if c.debug {
ancli.PrintOK(fmt.Sprintf("Request: %+v\n", req))
}
return req, nil
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
55 changes: 12 additions & 43 deletions internal/openai/gpt.go → internal/vendors/openai/gpt.go
Original file line number Diff line number Diff line change
@@ -4,14 +4,9 @@ import (
"context"
"fmt"
"net/http"
"os"

"github.com/baalimago/clai/internal/models"
"github.com/baalimago/clai/internal/reply"
"github.com/baalimago/clai/internal/text"
"github.com/baalimago/clai/internal/tools"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/misc"
)

type ChatGPT struct {
@@ -30,6 +25,18 @@ type ChatGPT struct {
debug bool `json:"-"`
}

type gptReq struct {
Model string `json:"model"`
ResponseFormat responseFormat `json:"response_format"`
Messages []models.Message `json:"messages"`
Stream bool `json:"stream"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
}

// Query performs a streamCompletion and appends the returned message to it's internal chat.
// Then it stores the internal chat as prevQuery.json, so that it may be used n upcoming queries
func (q *ChatGPT) Query(ctx context.Context) error {
@@ -56,41 +63,3 @@ func (q *ChatGPT) TextQuery(ctx context.Context, chat models.Chat) (models.Chat,
chat.Messages = append(chat.Messages, nextMsg)
return chat, nil
}

func loadQuerier(loadFrom, model string) (*ChatGPT, error) {
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
return nil, fmt.Errorf("environment variable 'OPENAI_API_KEY' not set")
}
defaultCpy := defaultGpt
defaultCpy.Model = model
defaultCpy.Url = ChatURL
// Load config based on model, allowing for different configs for each model
gptQuerier, err := tools.LoadConfigFromFile[ChatGPT](loadFrom, fmt.Sprintf("openai_gpt_%v.json", model), nil, &defaultCpy)
if misc.Truthy(os.Getenv("DEBUG")) {
ancli.PrintOK(fmt.Sprintf("ChatGPT config: %+v\n", gptQuerier))
}
if err != nil {
ancli.PrintWarn(fmt.Sprintf("failed to load config for model: %v, error: %v\n", model, err))
}
gptQuerier.client = &http.Client{}
gptQuerier.apiKey = apiKey
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
return &gptQuerier, nil
}

func NewTextQuerier(conf text.Configurations) (models.ChatQuerier, error) {
home, _ := os.UserConfigDir()
querier, err := loadQuerier(home, conf.Model)
if err != nil {
return nil, fmt.Errorf("failed to load querier of model: %v, error: %w", conf.Model, err)
}
if misc.Truthy(os.Getenv("DEBUG")) {
querier.debug = true
}
querier.chat = conf.InitialPrompt
querier.Raw = conf.Raw
return querier, nil
}
60 changes: 60 additions & 0 deletions internal/vendors/openai/gpt_setup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package openai

import (
"fmt"
"net/http"
"os"

"github.com/baalimago/clai/internal/models"
"github.com/baalimago/clai/internal/text"
"github.com/baalimago/clai/internal/tools"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/misc"
)

var defaultGpt = ChatGPT{
Model: "gpt-4-turbo-preview",
Temperature: 1.0,
TopP: 1.0,
Url: ChatURL,
}

func loadQuerier(loadFrom, model string) (*ChatGPT, error) {
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
return nil, fmt.Errorf("environment variable 'OPENAI_API_KEY' not set")
}
defaultCpy := defaultGpt
defaultCpy.Model = model
defaultCpy.Url = ChatURL
// Load config based on model, allowing for different configs for each model
gptQuerier, err := tools.LoadConfigFromFile[ChatGPT](loadFrom, fmt.Sprintf("openai_gpt_%v.json", model), nil, &defaultCpy)
if misc.Truthy(os.Getenv("DEBUG")) {
ancli.PrintOK(fmt.Sprintf("ChatGPT config: %+v\n", gptQuerier))
}
if err != nil {
ancli.PrintWarn(fmt.Sprintf("failed to load config for model: %v, error: %v\n", model, err))
}
gptQuerier.client = &http.Client{}
gptQuerier.apiKey = apiKey
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
return &gptQuerier, nil
}

// NewTextQuerier returns a new ChatGPT querier using the textconfigurations to load
// the correct model. API key is fetched via environment variable
func NewTextQuerier(conf text.Configurations) (models.ChatQuerier, error) {
home, _ := os.UserConfigDir()
querier, err := loadQuerier(home, conf.Model)
if err != nil {
return nil, fmt.Errorf("failed to load querier of model: %v, error: %w", conf.Model, err)
}
if misc.Truthy(os.Getenv("DEBUG")) {
querier.debug = true
}
querier.chat = conf.InitialPrompt
querier.Raw = conf.Raw
return querier, nil
}
Original file line number Diff line number Diff line change
@@ -20,25 +20,6 @@ type responseFormat struct {
Type string `json:"type"`
}

type request struct {
Model string `json:"model"`
ResponseFormat responseFormat `json:"response_format"`
Messages []models.Message `json:"messages"`
Stream bool `json:"stream"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
}

var defaultGpt = ChatGPT{
Model: "gpt-4-turbo-preview",
Temperature: 1.0,
TopP: 1.0,
Url: ChatURL,
}

type chatCompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
@@ -57,7 +38,7 @@ var dataPrefix = []byte("data: ")

// streamCompletions taking the messages as prompt conversation. Returns the messages from the chat model.
func (q *ChatGPT) streamCompletions(ctx context.Context, API_KEY string, messages []models.Message) (models.Message, error) {
reqData := request{
reqData := gptReq{
Model: q.Model,
FrequencyPenalty: q.FrequencyPenalty,
MaxTokens: q.MaxTokens,
File renamed without changes.

0 comments on commit fc568f1

Please sign in to comment.