From 6eaa82cf9b33a3c860907779ffa7c739cb6f9b79 Mon Sep 17 00:00:00 2001 From: Garrett Graves <35381942+GRVYDEV@users.noreply.github.com> Date: Wed, 13 Mar 2024 16:37:50 -0700 Subject: [PATCH] llms/anthropic: Implement WithBaseURL and WithHTTPClient for Anthropic LLM (#671) * implement WithBaseUrl and WithHttpClient for anthropic model * fix build errors * URL not url * fix more errors --- llms/anthropic/anthropicllm.go | 7 +++-- llms/anthropic/anthropicllm_option.go | 26 +++++++++++++++++-- .../anthropicclient/anthropicclient.go | 9 ++++--- .../internal/anthropicclient/completions.go | 2 +- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index 3a8ba2f55..7d69f8063 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -3,6 +3,7 @@ package anthropic import ( "context" "errors" + "net/http" "os" "github.com/tmc/langchaingo/callbacks" @@ -34,7 +35,9 @@ func New(opts ...Option) (*LLM, error) { func newClient(opts ...Option) (*anthropicclient.Client, error) { options := &options{ - token: os.Getenv(tokenEnvVarName), + token: os.Getenv(tokenEnvVarName), + baseURL: anthropicclient.DefaultBaseURL, + httpClient: http.DefaultClient, } for _, opt := range opts { @@ -45,7 +48,7 @@ func newClient(opts ...Option) (*anthropicclient.Client, error) { return nil, ErrMissingToken } - return anthropicclient.New(options.token, options.model) + return anthropicclient.New(options.token, options.model, options.baseURL, options.httpClient) } // Call requests a completion for the given prompt. diff --git a/llms/anthropic/anthropicllm_option.go b/llms/anthropic/anthropicllm_option.go index de543c3b8..cb3d18e57 100644 --- a/llms/anthropic/anthropicllm_option.go +++ b/llms/anthropic/anthropicllm_option.go @@ -1,12 +1,18 @@ package anthropic +import ( + "github.com/tmc/langchaingo/llms/anthropic/internal/anthropicclient" +) + const ( tokenEnvVarName = "ANTHROPIC_API_KEY" //nolint:gosec ) type options struct { - token string - model string + token string + model string + baseURL string + httpClient anthropicclient.Doer } type Option func(*options) @@ -25,3 +31,19 @@ func WithModel(model string) Option { opts.model = model } } + +// WithBaseUrl passes the Anthropic base URL to the client. +// If not set, the default base URL is used. +func WithBaseURL(baseURL string) Option { + return func(opts *options) { + opts.baseURL = baseURL + } +} + +// WithHTTPClient allows setting a custom HTTP client. If not set, the default value +// is http.DefaultClient. +func WithHTTPClient(client anthropicclient.Doer) Option { + return func(opts *options) { + opts.httpClient = client + } +} diff --git a/llms/anthropic/internal/anthropicclient/anthropicclient.go b/llms/anthropic/internal/anthropicclient/anthropicclient.go index 74b74f805..386226349 100644 --- a/llms/anthropic/internal/anthropicclient/anthropicclient.go +++ b/llms/anthropic/internal/anthropicclient/anthropicclient.go @@ -4,10 +4,11 @@ import ( "context" "errors" "net/http" + "strings" ) const ( - defaultBaseURL = "https://api.anthropic.com/v1" + DefaultBaseURL = "https://api.anthropic.com/v1" ) // ErrEmptyResponse is returned when the Anthropic API returns an empty response. @@ -40,12 +41,12 @@ func WithHTTPClient(client Doer) Option { } // New returns a new Anthropic client. -func New(token string, model string, opts ...Option) (*Client, error) { +func New(token string, model string, baseURL string, httpClient Doer, opts ...Option) (*Client, error) { c := &Client{ Model: model, token: token, - baseURL: defaultBaseURL, - httpClient: http.DefaultClient, + baseURL: strings.TrimSuffix(baseURL, "/"), + httpClient: httpClient, } for _, opt := range opts { diff --git a/llms/anthropic/internal/anthropicclient/completions.go b/llms/anthropic/internal/anthropicclient/completions.go index b0aaa7398..5765563fa 100644 --- a/llms/anthropic/internal/anthropicclient/completions.go +++ b/llms/anthropic/internal/anthropicclient/completions.go @@ -80,7 +80,7 @@ func (c *Client) createCompletion(ctx context.Context, payload *completionPayloa } if c.baseURL == "" { - c.baseURL = defaultBaseURL + c.baseURL = DefaultBaseURL } url := fmt.Sprintf("%s/complete", c.baseURL)