Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement optional io.Reader in AudioRequest (#303) (#265) #331

Merged
merged 5 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"

Expand All @@ -27,8 +28,14 @@ const (
// AudioRequest represents a request structure for audio API.
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
type AudioRequest struct {
Model string
FilePath string
Model string

// FilePath is either an existing file in your filesystem or a filename representing the contents of Reader.
FilePath string

// Reader is an optional io.Reader when you do not want to use an existing file.
Reader io.Reader

Prompt string // For translation, it should be in English
Temperature float32
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
Expand Down Expand Up @@ -95,15 +102,9 @@ func (r AudioRequest) HasJSONResponse() bool {
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing.
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()

err = b.CreateFormFile("file", f)
err := createFileField(request, b)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
return err
}

err = b.WriteField("model", request.Model)
Expand Down Expand Up @@ -146,3 +147,27 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
// Close the multipart writer
return b.Close()
}

// createFileField creates the "file" form field from either an existing file or by using the reader.
func createFileField(request AudioRequest, b utils.FormBuilder) error {
if request.Reader != nil {
err := b.CreateFormFileReader("file", request.Reader, request.FilePath)
if err != nil {
return fmt.Errorf("creating form using reader: %w", err)
}
return nil
}

f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()

err = b.CreateFormFile("file", f)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}

return nil
}
66 changes: 63 additions & 3 deletions audio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand All @@ -11,12 +12,10 @@ import (
"os"
"path/filepath"
"strings"
"testing"

"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"testing"
)

// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
Expand Down Expand Up @@ -65,6 +64,16 @@ func TestAudio(t *testing.T) {
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})

t.Run(tc.name+" (with reader)", func(t *testing.T) {
req := AudioRequest{
FilePath: "fake.webm",
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}

Expand Down Expand Up @@ -213,3 +222,54 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
}
}

func TestCreateFileField(t *testing.T) {
t.Run("createFileField failing file", func(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)

req := AudioRequest{
FilePath: path,
}

mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder := &mockFormBuilder{
mockCreateFormFile: func(string, *os.File) error {
return mockFailedErr
},
}

err := createFileField(req, mockBuilder)
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails")
})

t.Run("createFileField failing reader", func(t *testing.T) {
req := AudioRequest{
FilePath: "test.wav",
Reader: bytes.NewBuffer([]byte(`wav test contents`)),
}

mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder := &mockFormBuilder{
mockCreateFormFileReader: func(string, io.Reader, string) error {
return mockFailedErr
},
}

err := createFileField(req, mockBuilder)
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails")
})

t.Run("createFileField failing open", func(t *testing.T) {
req := AudioRequest{
FilePath: "non_existing_file.wav",
}

mockBuilder := &mockFormBuilder{}

err := createFileField(req, mockBuilder)
checks.HasError(t, err, "createFileField using file should return error when open file fails")
})
}
11 changes: 8 additions & 3 deletions image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,20 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
}

type mockFormBuilder struct {
mockCreateFormFile func(string, *os.File) error
mockWriteField func(string, string) error
mockClose func() error
mockCreateFormFile func(string, *os.File) error
mockCreateFormFileReader func(string, io.Reader, string) error
mockWriteField func(string, string) error
mockClose func() error
}

func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
return fb.mockCreateFormFile(fieldname, file)
}

func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.mockCreateFormFileReader(fieldname, r, filename)
}

func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
return fb.mockWriteField(fieldname, value)
}
Expand Down
20 changes: 18 additions & 2 deletions internal/form_builder.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package openai

import (
"fmt"
"io"
"mime/multipart"
"os"
"path"
)

type FormBuilder interface {
CreateFormFile(fieldname string, file *os.File) error
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
WriteField(fieldname, value string) error
Close() error
FormDataContentType() string
Expand All @@ -24,15 +27,28 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
}

func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
return fb.createFormFile(fieldname, file, file.Name())
}

func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.createFormFile(fieldname, r, path.Base(filename))
}

func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
if filename == "" {
return fmt.Errorf("filename cannot be empty")
}

fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
if err != nil {
return err
}

_, err = io.Copy(fieldWriter, file)
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}

return nil
}

Expand Down