Skip to content

Commit

Permalink
allow custom voice and speech models (#691)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti authored Jun 30, 2024
1 parent e311859 commit 03851d2
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 48 deletions.
31 changes: 0 additions & 31 deletions speech.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package openai

import (
"context"
"errors"
"net/http"
)

Expand Down Expand Up @@ -36,11 +35,6 @@ const (
SpeechResponseFormatPcm SpeechResponseFormat = "pcm"
)

var (
ErrInvalidSpeechModel = errors.New("invalid speech model")
ErrInvalidVoice = errors.New("invalid voice")
)

type CreateSpeechRequest struct {
Model SpeechModel `json:"model"`
Input string `json:"input"`
Expand All @@ -49,32 +43,7 @@ type CreateSpeechRequest struct {
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
}

func contains[T comparable](s []T, e T) bool {
for _, v := range s {
if v == e {
return true
}
}
return false
}

func isValidSpeechModel(model SpeechModel) bool {
return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model)
}

func isValidVoice(voice SpeechVoice) bool {
return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice)
}

func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) {
if !isValidSpeechModel(request.Model) {
err = ErrInvalidSpeechModel
return
}
if !isValidVoice(request.Voice) {
err = ErrInvalidVoice
return
}
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)),
withBody(request),
withContentType("application/json"),
Expand Down
17 changes: 0 additions & 17 deletions speech_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,4 @@ func TestSpeechIntegration(t *testing.T) {
err = os.WriteFile("test.mp3", buf, 0644)
checks.NoError(t, err, "Create error")
})
t.Run("invalid model", func(t *testing.T) {
_, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{
Model: "invalid_model",
Input: "Hello!",
Voice: openai.VoiceAlloy,
})
checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error")
})

t.Run("invalid voice", func(t *testing.T) {
_, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{
Model: openai.TTSModel1,
Input: "Hello!",
Voice: "invalid_voice",
})
checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error")
})
}

0 comments on commit 03851d2

Please sign in to comment.