Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement new fine tuning job API #479

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
{"ListFineTuneEvents", func() (any, error) {
return client.ListFineTuneEvents(ctx, "")
}},
{"CreateFineTuningJob", func() (any, error) {
return client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
}},
{"CancelFineTuningJob", func() (any, error) {
return client.CancelFineTuningJob(ctx, "")
}},
{"RetrieveFineTuningJob", func() (any, error) {
return client.RetrieveFineTuningJob(ctx, "")
}},
{"ListFineTuningJobEvents", func() (any, error) {
return client.ListFineTuningJobEvents(ctx, "")
}},
{"Moderations", func() (any, error) {
return client.Moderations(ctx, ModerationRequest{})
}},
Expand Down
153 changes: 153 additions & 0 deletions fine_tuning_job.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package openai

import (
"context"
"fmt"
"net/http"
"net/url"
)

type FineTuningJob struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
FinishedAt int64 `json:"finished_at"`
Model string `json:"model"`
FineTunedModel string `json:"fine_tuned_model,omitempty"`
OrganizationID string `json:"organization_id"`
Status string `json:"status"`
Hyperparameters Hyperparameters `json:"hyperparameters"`
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
ResultFiles []string `json:"result_files"`
TrainedTokens int `json:"trained_tokens"`
}

type Hyperparameters struct {
Epochs int `json:"n_epochs"`
}

type FineTuningJobRequest struct {
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
Model string `json:"model,omitempty"`
Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"`
Suffix string `json:"suffix,omitempty"`
}

type FineTuningJobEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
HasMore bool `json:"has_more"`
}

type FineTuningJobEvent struct {
Object string `json:"object"`
ID string `json:"id"`
CreatedAt int `json:"created_at"`
Level string `json:"level"`
Message string `json:"message"`
Data any `json:"data"`
Type string `json:"type"`
}

// CreateFineTuningJob create a fine tuning job.
func (c *Client) CreateFineTuningJob(
ctx context.Context,
request FineTuningJobRequest,
) (response FineTuningJob, err error) {
urlSuffix := "/fine_tuning/jobs"
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request))
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

// CancelFineTuningJob cancel a fine tuning job.
func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel"))
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

// RetrieveFineTuningJob retrieve a fine tuning job.
func (c *Client) RetrieveFineTuningJob(
ctx context.Context,
fineTuningJobID string,
) (response FineTuningJob, err error) {
urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

type listFineTuningJobEventsParameters struct {
after *string
limit *int
}

type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters)

func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter {
return func(args *listFineTuningJobEventsParameters) {
args.after = &after
}
}

func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter {
return func(args *listFineTuningJobEventsParameters) {
args.limit = &limit
}
}

// ListFineTuningJobs list fine tuning jobs events.
func (c *Client) ListFineTuningJobEvents(
ctx context.Context,
fineTuningJobID string,
setters ...ListFineTuningJobEventsParameter,
) (response FineTuningJobEventList, err error) {
parameters := &listFineTuningJobEventsParameters{
after: nil,
limit: nil,
}

for _, setter := range setters {
setter(parameters)
}

urlValues := url.Values{}
if parameters.after != nil {
urlValues.Add("after", *parameters.after)
}
if parameters.limit != nil {
urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit))
}

encodedValues := ""
if len(urlValues) > 0 {
encodedValues = "?" + urlValues.Encode()
}

req, err := c.newRequest(
ctx,
http.MethodGet,
c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues),
)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}
90 changes: 90 additions & 0 deletions fine_tuning_job_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package openai_test

import (
"context"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"encoding/json"
"fmt"
"net/http"
"testing"
)

const testFineTuninigJobID = "fine-tuning-job-id"

// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server.
func TestFineTuningJob(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler(
"/v1/fine_tuning/jobs",
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
resBytes, _ = json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine_tuning/jobs/"+testFineTuninigJobID,
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
resBytes, _ = json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTuningJobEventList{})
fmt.Fprintln(w, string(resBytes))
},
)

ctx := context.Background()

_, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
checks.NoError(t, err, "CreateFineTuningJob error")

_, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID)
checks.NoError(t, err, "CancelFineTuningJob error")

_, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID)
checks.NoError(t, err, "RetrieveFineTuningJob error")

_, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID)
checks.NoError(t, err, "ListFineTuningJobEvents error")

_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithAfter("last-event-id"),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")

_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithLimit(10),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")

_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithAfter("last-event-id"),
ListFineTuningJobEventsWithLimit(10),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")
}