From e246d4f4f24e15b1c62181c070b158af33d5c29d Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 25 Aug 2023 01:11:06 +0200 Subject: [PATCH 1/8] feat: implement new fine tuning job API --- fine_tuning_job.go | 145 ++++++++++++++++++++++++++++++++++++++++ fine_tuning_job_test.go | 68 +++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 fine_tuning_job.go create mode 100644 fine_tuning_job_test.go diff --git a/fine_tuning_job.go b/fine_tuning_job.go new file mode 100644 index 000000000..7e39baf28 --- /dev/null +++ b/fine_tuning_job.go @@ -0,0 +1,145 @@ +package openai + +import ( + "context" + "fmt" + "net/http" +) + +type FineTuningJob struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at"` + Model string `json:"model"` + FineTunedModel string `json:"fine_tuned_model,omitempty"` + OrganizationID string `json:"organization_id"` + Status string `json:"status"` + Hyperparameters Hyperparameters `json:"hyperparameters"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + ResultFiles []string `json:"result_files"` + TrainedTokens int `json:"trained_tokens"` +} + +type Hyperparameters struct { + Epochs int `json:"n_epochs"` +} + +type FineTuningJobRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTuningJobEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` + HasMore bool `json:"has_more"` +} + +type FineTuningJobEvent struct { + Object string `json:"object"` + ID string `json:"id"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + Data interface{} `json:"data"` + Type string `json:"type"` +} + +// CreateFineTuningJob create a fine tuning job. +func (c *Client) CreateFineTuningJob(ctx context.Context, request FineTuningJobRequest) (response FineTuningJob, err error) { + urlSuffix := "/fine_tuning/jobs" + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelFineTuningJob cancel a fine tuning job. +func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel")) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveFineTuningJob retrieve a fine tuning job. +func (c *Client) RetrieveFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { + urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type listFineTuningJobEventsParameters struct { + after *string + limit *int +} + +type listFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) + +func ListFineTuningJobEventsWithAfter(after string) listFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.after = &after + } +} + +func ListFineTuningJobEventsWithLimit(limit int) listFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.limit = &limit + } +} + +// ListFineTuningJobs list fine tuning jobs events. +func (c *Client) ListFineTuningJobEvents( + ctx context.Context, + fineTuningJobID string, + setters ...listFineTuningJobEventsParameter, +) (response FineTuningJobEventList, err error) { + + parameters := &listFineTuningJobEventsParameters{ + after: nil, + limit: nil, + } + + for _, setter := range setters { + setter(parameters) + } + + requestParamters := "" + if parameters.after != nil { + requestParamters += fmt.Sprintf("after=%s", *parameters.after) + } + if parameters.limit != nil { + if requestParamters != "" { + requestParamters += "&" + } + requestParamters += fmt.Sprintf("limit=%d", *parameters.limit) + } + + if requestParamters != "" { + requestParamters = "?" + requestParamters + } + + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+requestParamters)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go new file mode 100644 index 000000000..e9d8b0532 --- /dev/null +++ b/fine_tuning_job_test.go @@ -0,0 +1,68 @@ +package openai_test + +import ( + "context" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +const testFineTuninigJobID = "fine-tuning-job-id" + +// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. +func TestFineTuningJob(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler( + "/v1/fine_tuning/jobs", + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID, + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJobEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + ctx := context.Background() + + _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + checks.NoError(t, err, "CreateFineTuningJob error") + + _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "CancelFineTuningJob error") + + _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "RetrieveFineTuningJob error") + + _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID, nil, nil) + checks.NoError(t, err, "ListFineTuningJobEvents error") +} From 2b99af19426c69383fd57459c38169447df9ee72 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 25 Aug 2023 01:19:51 +0200 Subject: [PATCH 2/8] fix: export ListFineTuningJobEventsParameter --- fine_tuning_job.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 7e39baf28..cefaccf1f 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -90,15 +90,15 @@ type listFineTuningJobEventsParameters struct { limit *int } -type listFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) +type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) -func ListFineTuningJobEventsWithAfter(after string) listFineTuningJobEventsParameter { +func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter { return func(args *listFineTuningJobEventsParameters) { args.after = &after } } -func ListFineTuningJobEventsWithLimit(limit int) listFineTuningJobEventsParameter { +func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter { return func(args *listFineTuningJobEventsParameters) { args.limit = &limit } @@ -108,7 +108,7 @@ func ListFineTuningJobEventsWithLimit(limit int) listFineTuningJobEventsParamete func (c *Client) ListFineTuningJobEvents( ctx context.Context, fineTuningJobID string, - setters ...listFineTuningJobEventsParameter, + setters ...ListFineTuningJobEventsParameter, ) (response FineTuningJobEventList, err error) { parameters := &listFineTuningJobEventsParameters{ From d7b71184b586982579832fc4484d6c77c1c7ca47 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 25 Aug 2023 01:23:33 +0200 Subject: [PATCH 3/8] fix: lint errors --- fine_tuning_job.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index cefaccf1f..7588d9520 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -51,7 +51,10 @@ type FineTuningJobEvent struct { } // CreateFineTuningJob create a fine tuning job. -func (c *Client) CreateFineTuningJob(ctx context.Context, request FineTuningJobRequest) (response FineTuningJob, err error) { +func (c *Client) CreateFineTuningJob( + ctx context.Context, + request FineTuningJobRequest, +) (response FineTuningJob, err error) { urlSuffix := "/fine_tuning/jobs" req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) if err != nil { @@ -74,7 +77,10 @@ func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string } // RetrieveFineTuningJob retrieve a fine tuning job. -func (c *Client) RetrieveFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { +func (c *Client) RetrieveFineTuningJob( + ctx context.Context, + fineTuningJobID string, +) (response FineTuningJob, err error) { urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { @@ -110,7 +116,6 @@ func (c *Client) ListFineTuningJobEvents( fineTuningJobID string, setters ...ListFineTuningJobEventsParameter, ) (response FineTuningJobEventList, err error) { - parameters := &listFineTuningJobEventsParameters{ after: nil, limit: nil, @@ -135,7 +140,11 @@ func (c *Client) ListFineTuningJobEvents( requestParamters = "?" + requestParamters } - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+requestParamters)) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+requestParamters), + ) if err != nil { return } From 132a697e995ef7bfd36169168cb4ed1e8bc7eb0f Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 25 Aug 2023 01:25:40 +0200 Subject: [PATCH 4/8] fix: test errors --- fine_tuning_job_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index e9d8b0532..97487dba7 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -63,6 +63,6 @@ func TestFineTuningJob(t *testing.T) { _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) checks.NoError(t, err, "RetrieveFineTuningJob error") - _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID, nil, nil) + _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) checks.NoError(t, err, "ListFineTuningJobEvents error") } From fa038f386dba9b3d537debd5e03fccdad83f302b Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 25 Aug 2023 01:42:12 +0200 Subject: [PATCH 5/8] fix: code test coverage --- fine_tuning_job_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index 97487dba7..519c6cd2d 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -65,4 +65,26 @@ func TestFineTuningJob(t *testing.T) { _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") } From ee63a55a33c29ad881f55c884f20de6486110d00 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 25 Aug 2023 01:57:48 +0200 Subject: [PATCH 6/8] fix: code test coverage --- client_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/client_test.go b/client_test.go index 29d84edfa..9b5046899 100644 --- a/client_test.go +++ b/client_test.go @@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListFineTuneEvents", func() (any, error) { return client.ListFineTuneEvents(ctx, "") }}, + {"CreateFineTuningJob", func() (any, error) { + return client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + }}, + {"CancelFineTuningJob", func() (any, error) { + return client.CancelFineTuningJob(ctx, "") + }}, + {"RetrieveFineTuningJob", func() (any, error) { + return client.RetrieveFineTuningJob(ctx, "") + }}, + {"ListFineTuningJobEvents", func() (any, error) { + return client.ListFineTuningJobEvents(ctx, "") + }}, {"Moderations", func() (any, error) { return client.Moderations(ctx, ModerationRequest{}) }}, From 80cd9eaa3bd5b0111dda1a30059488f5e8f8c6eb Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 25 Aug 2023 06:48:29 +0200 Subject: [PATCH 7/8] fix: use any --- fine_tuning_job.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 7588d9520..b2780453b 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -41,13 +41,13 @@ type FineTuningJobEventList struct { } type FineTuningJobEvent struct { - Object string `json:"object"` - ID string `json:"id"` - CreatedAt int `json:"created_at"` - Level string `json:"level"` - Message string `json:"message"` - Data interface{} `json:"data"` - Type string `json:"type"` + Object string `json:"object"` + ID string `json:"id"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + Data any `json:"data"` + Type string `json:"type"` } // CreateFineTuningJob create a fine tuning job. From 057559a3062eacc0a48bc92e18a1d016d36895aa Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 28 Aug 2023 22:41:27 +0200 Subject: [PATCH 8/8] chore: use url.Values --- fine_tuning_job.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index b2780453b..a840b7ec3 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" ) type FineTuningJob struct { @@ -125,25 +126,23 @@ func (c *Client) ListFineTuningJobEvents( setter(parameters) } - requestParamters := "" + urlValues := url.Values{} if parameters.after != nil { - requestParamters += fmt.Sprintf("after=%s", *parameters.after) + urlValues.Add("after", *parameters.after) } if parameters.limit != nil { - if requestParamters != "" { - requestParamters += "&" - } - requestParamters += fmt.Sprintf("limit=%d", *parameters.limit) + urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit)) } - if requestParamters != "" { - requestParamters = "?" + requestParamters + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() } req, err := c.newRequest( ctx, http.MethodGet, - c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+requestParamters), + c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues), ) if err != nil { return