Skip to content

Commit

Permalink
fix(audio): fix audioTextResponse decode (#638)
Browse files Browse the repository at this point in the history
* fix(audio): fix audioTextResponse decode

* test(audio): add audioTextResponse decode test

* test(audio): simplify code
  • Loading branch information
WqyJh authored Jan 17, 2024
1 parent 4ce03a9 commit eff8dc1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 11 deletions.
10 changes: 7 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,14 @@ func decodeResponse(body io.Reader, v any) error {
return nil
}

if result, ok := v.(*string); ok {
return decodeString(body, result)
switch o := v.(type) {
case *string:
return decodeString(body, o)
case *audioTextResponse:
return decodeString(body, &o.Text)
default:
return json.NewDecoder(body).Decode(v)
}
return json.NewDecoder(body).Decode(v)
}

func decodeString(body io.Reader, output *string) error {
Expand Down
48 changes: 40 additions & 8 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"fmt"
"io"
"net/http"
"reflect"
"testing"

"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

var errTestRequestBuilderFailed = errors.New("test request builder failed")
Expand Down Expand Up @@ -43,38 +45,68 @@ func TestDecodeResponse(t *testing.T) {
testCases := []struct {
name string
value interface{}
expected interface{}
body io.Reader
hasError bool
}{
{
name: "nil input",
value: nil,
body: bytes.NewReader([]byte("")),
name: "nil input",
value: nil,
body: bytes.NewReader([]byte("")),
expected: nil,
},
{
name: "string input",
value: &stringInput,
body: bytes.NewReader([]byte("test")),
name: "string input",
value: &stringInput,
body: bytes.NewReader([]byte("test")),
expected: "test",
},
{
name: "map input",
value: &map[string]interface{}{},
body: bytes.NewReader([]byte(`{"test": "test"}`)),
expected: map[string]interface{}{
"test": "test",
},
},
{
name: "reader return error",
value: &stringInput,
body: &errorReader{err: errors.New("dummy")},
hasError: true,
},
{
name: "audio text input",
value: &audioTextResponse{},
body: bytes.NewReader([]byte("test")),
expected: audioTextResponse{
Text: "test",
},
},
}

assertEqual := func(t *testing.T, expected, actual interface{}) {
t.Helper()
if expected == actual {
return
}
v := reflect.ValueOf(actual).Elem().Interface()
if !reflect.DeepEqual(v, expected) {
t.Fatalf("Unexpected value: %v, expected: %v", v, expected)
}
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := decodeResponse(tc.body, tc.value)
if (err != nil) != tc.hasError {
t.Errorf("Unexpected error: %v", err)
if tc.hasError {
checks.HasError(t, err, "Unexpected nil error")
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
assertEqual(t, tc.expected, tc.value)
})
}
}
Expand Down

0 comments on commit eff8dc1

Please sign in to comment.