Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactoring http request creation and sending #395

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) {
az.OrgID = c.OrgID

cli := NewClientWithConfig(az)
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "")
req, err := cli.newRequest(context.Background(), "POST", "/chat/completions")
if err != nil {
t.Errorf("Failed to create request: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ func (c *Client) callAudioAPI(
}

urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
withBody(&formBody), withContentType(builder.FormDataContentType()))
if err != nil {
return AudioResponse{}, err
}
req.Header.Add("Content-Type", builder.FormDataContentType())

if request.HasJSONResponse() {
err = c.sendRequest(req, &response)
Expand Down
2 changes: 1 addition & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (c *Client) CreateChatCompletion(
return
}

req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
if err != nil {
return
}
Expand Down
22 changes: 5 additions & 17 deletions chat_stream.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package openai

import (
"bufio"
"context"

utils "github.com/sashabaranov/go-openai/internal"
"net/http"
)

type ChatCompletionStreamChoiceDelta struct {
Expand Down Expand Up @@ -47,27 +45,17 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
if err != nil {
return
return nil, err
}

resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req)
if err != nil {
return
}
if isFailureStatusCode(resp) {
return nil, c.handleErrorResp(resp)
}

stream = &ChatCompletionStream{
streamReader: &streamReader[ChatCompletionStreamResponse]{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
},
streamReader: resp,
}
return
}
94 changes: 72 additions & 22 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai

import (
"bufio"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -45,6 +46,42 @@ func NewOrgClient(authToken, org string) *Client {
return NewClientWithConfig(config)
}

type requestOptions struct {
body any
header http.Header
}

type requestOption func(*requestOptions)

func withBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}

func withContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}

func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) {
// Default Options
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
if err != nil {
return nil, err
}
c.setCommonHeaders(req)
return req, nil
}

func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Accept", "application/json; charset=utf-8")

Expand All @@ -55,8 +92,6 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

c.setCommonHeaders(req)

res, err := c.config.HTTPClient.Do(req)
if err != nil {
return err
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to cover the above line.
#395 (comment)

Expand All @@ -71,6 +106,41 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
return decodeResponse(res.Body, v)
}

func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
resp, err := c.config.HTTPClient.Do(req)
if err != nil {
return
}

if isFailureStatusCode(resp) {
err = c.handleErrorResp(resp)
return
}
return resp.Body, nil
}

func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
if err != nil {
return new(streamReader[T]), err
}
if isFailureStatusCode(resp) {
return new(streamReader[T]), client.handleErrorResp(resp)
}
return &streamReader[T]{
emptyMessagesLimit: client.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
}, nil
}

func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
Expand Down Expand Up @@ -138,26 +208,6 @@ func (c *Client) fullURL(suffix string, args ...any) string {
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}

func (c *Client) newStreamRequest(
ctx context.Context,
method string,
urlSuffix string,
body any,
model string) (*http.Request, error) {
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body)
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

c.setCommonHeaders(req)
return req, nil
}

func (c *Client) handleErrorResp(resp *http.Response) error {
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
Expand Down
25 changes: 20 additions & 5 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed")

type failingRequestBuilder struct{}

func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}

Expand All @@ -41,9 +41,10 @@ func TestDecodeResponse(t *testing.T) {
stringInput := ""

testCases := []struct {
name string
value interface{}
body io.Reader
name string
value interface{}
body io.Reader
hasError bool
}{
{
name: "nil input",
Expand All @@ -60,18 +61,32 @@ func TestDecodeResponse(t *testing.T) {
value: &map[string]interface{}{},
body: bytes.NewReader([]byte(`{"test": "test"}`)),
},
{
name: "reader return error",
value: &stringInput,
body: &errorReader{err: errors.New("dummy")},
hasError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := decodeResponse(tc.body, tc.value)
if err != nil {
if (err != nil) != tc.hasError {
t.Errorf("Unexpected error: %v", err)
}
})
}
}

type errorReader struct {
err error
}

func (e *errorReader) Read(_ []byte) (n int, err error) {
return 0, e.err
}

func TestHandleErrorResp(t *testing.T) {
// var errRes *ErrorResponse
var errRes ErrorResponse
Expand Down
2 changes: 1 addition & 1 deletion completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (c *Client) CreateCompletion(
return
}

req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type EditsResponse struct {

// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request))
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request))
if err != nil {
return
}
Expand Down
4 changes: 2 additions & 2 deletions engines.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type EnginesList struct {
// ListEngines Lists the currently available engines, and provides basic
// information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines"))
if err != nil {
return
}
Expand All @@ -38,7 +38,7 @@ func (c *Client) GetEngine(
engineID string,
) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}
Expand Down
28 changes: 7 additions & 21 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,19 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"),
withBody(&b), withContentType(builder.FormDataContentType()))
if err != nil {
return
}

req.Header.Set("Content-Type", builder.FormDataContentType())

err = c.sendRequest(req, &file)

return
}

// DeleteFile deletes an existing file.
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/files/"+fileID))
if err != nil {
return
}
Expand All @@ -83,7 +81,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
// ListFiles Lists the currently available files,
// and provides basic information about each file such as the file name and purpose.
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/files"))
if err != nil {
return
}
Expand All @@ -96,7 +94,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
// such as the file name and purpose.
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
urlSuffix := fmt.Sprintf("/files/%s", fileID)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}
Expand All @@ -107,23 +105,11 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err

func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) {
urlSuffix := fmt.Sprintf("/files/%s/content", fileID)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}

c.setCommonHeaders(req)

res, err := c.config.HTTPClient.Do(req)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}

if isFailureStatusCode(res) {
err = c.handleErrorResp(res)
return
}

content = res.Body
content, err = c.sendRequestRaw(req)
return
}
Loading