diff --git a/jsonschema/json.go b/jsonschema/json.go index c02d250aa..e4eef98e7 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,6 +4,8 @@ // and/or pass in the schema in []byte format. package jsonschema +import "encoding/json" + type DataType string const ( @@ -17,7 +19,7 @@ const ( ) // Definition is a struct for describing a JSON Schema. -// It is fairly limited and you may have better luck using a third-party library. +// It is fairly limited, and you may have better luck using a third-party library. type Definition struct { // Type specifies the data type of the schema. Type DataType `json:"type,omitempty"` @@ -33,3 +35,24 @@ type Definition struct { // Items specifies which data type an array contains, if the schema type is Array. Items *Definition `json:"items,omitempty"` } + +func (d *Definition) MarshalJSON() ([]byte, error) { + d.initializeProperties() + return json.Marshal(*d) +} + +func (d *Definition) initializeProperties() { + if d.Properties == nil { + d.Properties = make(map[string]Definition) + return + } + + for k, v := range d.Properties { + if v.Properties == nil { + v.Properties = make(map[string]Definition) + } else { + v.initializeProperties() + } + d.Properties[k] = v + } +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go new file mode 100644 index 000000000..0dc31a58a --- /dev/null +++ b/jsonschema/json_test.go @@ -0,0 +1,201 @@ +package jsonschema_test + +import ( + "encoding/json" + "reflect" + "testing" + + . "github.com/sashabaranov/go-openai/jsonschema" +) + +func TestDefinition_MarshalJSON(t *testing.T) { + tests := []struct { + name string + def Definition + want string + }{ + { + name: "Test with empty Definition", + def: Definition{}, + want: `{"properties":{}}`, + }, + { + name: "Test with Definition properties set", + def: Definition{ + Type: String, + Description: "A string type", + Properties: map[string]Definition{ + "name": { + Type: String, + }, + }, + }, + want: `{ + "type":"string", + "description":"A string type", + "properties":{ + "name":{ + "type":"string", + "properties":{} + } + } +}`, + }, + { + name: "Test with nested Definition properties", + def: Definition{ + Type: Object, + Properties: map[string]Definition{ + "user": { + Type: Object, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + "age": { + Type: Integer, + }, + }, + }, + }, + }, + want: `{ + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "properties":{} + }, + "age":{ + "type":"integer", + "properties":{} + } + } + } + } +}`, + }, + { + name: "Test with complex nested Definition", + def: Definition{ + Type: Object, + Properties: map[string]Definition{ + "user": { + Type: Object, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + "age": { + Type: Integer, + }, + "address": { + Type: Object, + Properties: map[string]Definition{ + "city": { + Type: String, + }, + "country": { + Type: String, + }, + }, + }, + }, + }, + }, + }, + want: `{ + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "properties":{} + }, + "age":{ + "type":"integer", + "properties":{} + }, + "address":{ + "type":"object", + "properties":{ + "city":{ + "type":"string", + "properties":{} + }, + "country":{ + "type":"string", + "properties":{} + } + } + } + } + } + } +}`, + }, + { + name: "Test with Array type Definition", + def: Definition{ + Type: Array, + Items: &Definition{ + Type: String, + }, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + }, + }, + want: `{ + "type":"array", + "items":{ + "type":"string", + "properties":{ + + } + }, + "properties":{ + "name":{ + "type":"string", + "properties":{} + } + } +}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotBytes, err := json.Marshal(&tt.def) + if err != nil { + t.Errorf("Failed to Marshal JSON: error = %v", err) + return + } + + var got map[string]interface{} + err = json.Unmarshal(gotBytes, &got) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + wantBytes := []byte(tt.want) + var want map[string]interface{} + err = json.Unmarshal(wantBytes, &want) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, want) + } + }) + } +}