From 0f2b0549f95217cf720c01ed7a9844ae73948fa1 Mon Sep 17 00:00:00 2001 From: Igor Lazarev Date: Tue, 7 May 2024 23:12:15 +0300 Subject: [PATCH] support basic types --- .github/workflows/release.yaml | 2 +- examples/basic/di/internal/container.go | 36 +++++++ .../di/internal/definitions/container.go | 8 ++ .../basic/di/internal/factories/params.go | 20 ++++ .../basic/di/internal/factories/server.go | 3 +- .../basic/di/internal/lookup/container.go | 8 ++ go.mod | 8 +- go.sum | 19 ++-- pkg/di/factories.go | 73 +++++++++++++- pkg/di/file.go | 4 + pkg/di/file_generation_test.go | 48 ++++++++++ pkg/di/internal_container_generator.go | 15 +-- pkg/di/parsing_test.go | 94 ++++++++++++------- pkg/di/templates.go | 19 ++-- .../single_container_with_basic_types.txt | 81 ++++++++++++++++ 15 files changed, 378 insertions(+), 60 deletions(-) create mode 100644 examples/basic/di/internal/factories/params.go create mode 100644 pkg/di/testdata/single_container_with_basic_types.txt diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index fa0ff15..70759f7 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -21,7 +21,7 @@ jobs: name: Set up Go uses: actions/setup-go@v4 with: - go-version: '1.20' + go-version: '1.21' - name: Run GoReleaser uses: goreleaser/goreleaser-action@v5 diff --git a/examples/basic/di/internal/container.go b/examples/basic/di/internal/container.go index e026fa0..6e86ba8 100644 --- a/examples/basic/di/internal/container.go +++ b/examples/basic/di/internal/container.go @@ -15,6 +15,7 @@ import ( "database/sql" "log" "net/http" + "time" ) type Container struct { @@ -25,6 +26,7 @@ type Container struct { db *sql.DB server *http.Server + params *ParamsContainer api *APIContainer useCases *UseCaseContainer repositories *RepositoryContainer @@ -32,6 +34,7 @@ type Container struct { func NewContainer() *Container { c := &Container{} + c.params = &ParamsContainer{Container: c} c.api = &APIContainer{Container: c} c.useCases = &UseCaseContainer{Container: c} c.repositories = &RepositoryContainer{Container: c} @@ -51,6 +54,14 @@ func (c *Container) SetError(err error) { } } +type ParamsContainer struct { + *Container + + serverPort int + serverHost string + requestTimeout time.Duration +} + type APIContainer struct { *Container @@ -94,6 +105,31 @@ func (c *Container) Server(ctx context.Context) *http.Server { return c.server } +func (c *Container) Params() lookup.ParamsContainer { + return c.params +} + +func (c *ParamsContainer) ServerPort(ctx context.Context) int { + if c.serverPort == 0 && c.err == nil { + c.serverPort = factories.CreateParamsServerPort(ctx, c) + } + return c.serverPort +} + +func (c *ParamsContainer) ServerHost(ctx context.Context) string { + if c.serverHost == "" && c.err == nil { + c.serverHost = factories.CreateParamsServerHost(ctx, c) + } + return c.serverHost +} + +func (c *ParamsContainer) RequestTimeout(ctx context.Context) time.Duration { + if c.requestTimeout == 0 && c.err == nil { + c.requestTimeout = factories.CreateParamsRequestTimeout(ctx, c) + } + return c.requestTimeout +} + func (c *Container) API() lookup.APIContainer { return c.api } diff --git a/examples/basic/di/internal/definitions/container.go b/examples/basic/di/internal/definitions/container.go index 434bbda..9d18bc4 100644 --- a/examples/basic/di/internal/definitions/container.go +++ b/examples/basic/di/internal/definitions/container.go @@ -4,6 +4,7 @@ import ( "database/sql" "log" "net/http" + "time" "basic/app/config" "basic/app/domain" @@ -20,11 +21,18 @@ type Container struct { Server *http.Server `di:"public,close" factory-file:"server"` + Params ParamsContainer API APIContainer UseCases UseCaseContainer Repositories RepositoryContainer } +type ParamsContainer struct { + ServerPort int + ServerHost string + RequestTimeout time.Duration +} + type APIContainer struct { FindEntityHandler *httphandler.FindEntity `di:"public"` } diff --git a/examples/basic/di/internal/factories/params.go b/examples/basic/di/internal/factories/params.go new file mode 100644 index 0000000..4dcc34f --- /dev/null +++ b/examples/basic/di/internal/factories/params.go @@ -0,0 +1,20 @@ +package factories + +import ( + "context" + "time" + + "basic/di/internal/lookup" +) + +func CreateParamsServerPort(ctx context.Context, c lookup.Container) int { + return 3000 +} + +func CreateParamsServerHost(ctx context.Context, c lookup.Container) string { + return "127.0.0.1" +} + +func CreateParamsRequestTimeout(ctx context.Context, c lookup.Container) time.Duration { + return time.Second +} diff --git a/examples/basic/di/internal/factories/server.go b/examples/basic/di/internal/factories/server.go index 55af64f..757af22 100644 --- a/examples/basic/di/internal/factories/server.go +++ b/examples/basic/di/internal/factories/server.go @@ -9,6 +9,7 @@ import ( func CreateServer(ctx context.Context, c lookup.Container) *http.Server { return &http.Server{ - Handler: c.API().FindEntityHandler(ctx), + Handler: c.API().FindEntityHandler(ctx), + IdleTimeout: c.Params().RequestTimeout(ctx), } } diff --git a/examples/basic/di/internal/lookup/container.go b/examples/basic/di/internal/lookup/container.go index a2a390a..812bab5 100644 --- a/examples/basic/di/internal/lookup/container.go +++ b/examples/basic/di/internal/lookup/container.go @@ -13,6 +13,7 @@ import ( "database/sql" "log" "net/http" + "time" ) type Container interface { @@ -24,11 +25,18 @@ type Container interface { DB(ctx context.Context) *sql.DB Server(ctx context.Context) *http.Server + Params() ParamsContainer API() APIContainer UseCases() UseCaseContainer Repositories() RepositoryContainer } +type ParamsContainer interface { + ServerPort(ctx context.Context) int + ServerHost(ctx context.Context) string + RequestTimeout(ctx context.Context) time.Duration +} + type APIContainer interface { FindEntityHandler(ctx context.Context) *httphandler.FindEntity } diff --git a/go.mod b/go.mod index 0e98a97..e63169a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/strider2038/digen -go 1.20 +go 1.21 require ( github.com/iancoleman/strcase v0.3.0 @@ -10,7 +10,7 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/viper v1.8.1 github.com/stretchr/testify v1.7.0 - golang.org/x/mod v0.13.0 + golang.org/x/mod v0.17.0 ) require ( @@ -21,19 +21,21 @@ require ( github.com/gookit/color v1.4.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/magiconair/properties v1.8.5 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect github.com/mitchellh/mapstructure v1.4.1 // indirect github.com/pelletier/go-toml v1.9.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/spf13/afero v1.6.0 // indirect github.com/spf13/cast v1.4.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect - golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect + golang.org/x/sys v0.20.0 // indirect golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b // indirect golang.org/x/text v0.3.7 // indirect gopkg.in/ini.v1 v1.62.0 // indirect diff --git a/go.sum b/go.sum index b80af46..37d8e8c 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,7 @@ github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnht github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -183,11 +184,13 @@ github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfV github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= @@ -214,6 +217,7 @@ github.com/muonsoft/errors v0.4.1/go.mod h1:+vu8wBT7mW3vPKTfPQF+un6Vw3uG3EX3Cyj4 github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3 h1:zeC5b1GviRUyKYd6OJPvBU/mcVDVoL1OhT17FCt5dSQ= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -227,6 +231,9 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= @@ -319,8 +326,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= -golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -423,8 +430,8 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg= -golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b h1:9zKuko04nR4gjZ4+DNjHqRlAJqbJETHwiNKDqTfOjfE= diff --git a/pkg/di/factories.go b/pkg/di/factories.go index 00d44c2..7c946a4 100644 --- a/pkg/di/factories.go +++ b/pkg/di/factories.go @@ -2,6 +2,7 @@ package di import ( "go/ast" + "slices" "strings" ) @@ -99,6 +100,39 @@ type TypeDefinition struct { Name string } +var basicTypes = []string{ + "string", + "int", + "uint", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + "bool", +} + +func (d TypeDefinition) IsBasicType() bool { + return d.Package == "" && slices.Contains(basicTypes, d.Name) +} + +func (d TypeDefinition) IsTime() bool { + return d.Package == "time" && d.Name == "Time" +} + +func (d TypeDefinition) IsDuration() bool { + return d.Package == "time" && d.Name == "Duration" +} + +func (d TypeDefinition) IsURL() bool { + return d.Package == "url" && d.Name == "URL" +} + func (d TypeDefinition) String() string { var s strings.Builder @@ -106,12 +140,49 @@ func (d TypeDefinition) String() string { s.WriteString("*") } s.WriteString(d.Package) - s.WriteString(".") + if d.Package != "" { + s.WriteString(".") + } s.WriteString(d.Name) return s.String() } +func (d TypeDefinition) ZeroComparison() string { + if d.IsPointer { + return " == nil" + } + if d.IsBasicType() { + return d.basicZeroComparison() + } + if d.IsTime() { + return ".IsZero()" + } + if d.IsDuration() { + return " == 0" + } + if d.IsURL() { + return " == url.URL{}" + } + + return " == nil" +} + +func (d TypeDefinition) basicZeroComparison() string { + switch d.Name { + case "bool": + return " == false" + case "string": + return " == \"\"" + case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64": + return " == 0" + case "float32", "float64": + return " == 0.0" + default: + return " == nil" + } +} + type FactoryFile struct { Imports map[string]*ImportDefinition Services []string diff --git a/pkg/di/file.go b/pkg/di/file.go index 5e18873..7a79a86 100644 --- a/pkg/di/file.go +++ b/pkg/di/file.go @@ -56,6 +56,10 @@ func NewFileBuilder(filename, packageName string, packageType PackageType) *File } func (b *FileBuilder) AddImport(imp string) { + if imp == "" { + return + } + for _, existingImport := range b.imports { if existingImport == imp { return diff --git a/pkg/di/file_generation_test.go b/pkg/di/file_generation_test.go index 486137f..b850504 100644 --- a/pkg/di/file_generation_test.go +++ b/pkg/di/file_generation_test.go @@ -1,6 +1,7 @@ package di_test import ( + _ "embed" "testing" "github.com/stretchr/testify/assert" @@ -166,6 +167,50 @@ func TestGenerate(t *testing.T) { assert.Equal(t, singleContainerWithStaticTypeInternalContainer, string(files[0].Content)) }, }, + { + name: "single container with basic types", + container: &di.RootContainerDefinition{ + Name: "Container", + Package: "testpkg", + Imports: map[string]*di.ImportDefinition{ + "time": {Path: `"time"`}, + "url": {Path: `"net/url"`}, + }, + Services: []*di.ServiceDefinition{ + { + Name: "StringOption", + Type: di.TypeDefinition{Name: "string"}, + }, + { + Name: "StringPointer", + Type: di.TypeDefinition{IsPointer: true, Name: "string"}, + }, + { + Name: "IntOption", + Type: di.TypeDefinition{Name: "int"}, + }, + { + Name: "TimeOption", + Type: di.TypeDefinition{Package: "time", Name: "Time"}, + }, + { + Name: "DurationOption", + Type: di.TypeDefinition{Package: "time", Name: "Duration"}, + }, + { + Name: "URLOption", + Type: di.TypeDefinition{Package: "url", Name: "URL"}, + }, + }, + }, + assert: func(t *testing.T, files []*di.File) { + t.Helper() + + require.GreaterOrEqual(t, len(files), 1) + got := string(files[0].Content) + assert.Equal(t, singleContainerWithBasicTypes, got) + }, + }, { name: "single container with closer", container: &di.RootContainerDefinition{ @@ -677,6 +722,9 @@ func (c *Container) SetConfiguration(s config.Configuration) { func (c *Container) Close() {} ` +//go:embed testdata/single_container_with_basic_types.txt +var singleContainerWithBasicTypes string + const singleContainerWithCloserInternalContainer = `package internal import ( diff --git a/pkg/di/internal_container_generator.go b/pkg/di/internal_container_generator.go index 8ae9ad5..5100c63 100644 --- a/pkg/di/internal_container_generator.go +++ b/pkg/di/internal_container_generator.go @@ -139,13 +139,14 @@ func (g *InternalContainerGenerator) writeServiceGetters(services []*ServiceDefi g.importService(service) parameters := templateParameters{ - ContainerName: containerName, - ServicePrefix: strings.Title(service.Prefix), - ServiceName: strcase.ToLowerCamel(service.Name), - ServiceTitle: service.Title(), - ServiceType: service.Type.String(), - HasDefinition: !service.IsRequired, - PanicOnNil: service.IsExternal, + ContainerName: containerName, + ServicePrefix: strings.Title(service.Prefix), + ServiceName: strcase.ToLowerCamel(service.Name), + ServiceTitle: service.Title(), + ServiceType: service.Type.String(), + ServiceZeroComparison: service.Type.ZeroComparison(), + HasDefinition: !service.IsRequired, + PanicOnNil: service.IsExternal, } err := g.writeGetter(parameters, service) if err != nil { diff --git a/pkg/di/parsing_test.go b/pkg/di/parsing_test.go index 8131104..c69bb1d 100644 --- a/pkg/di/parsing_test.go +++ b/pkg/di/parsing_test.go @@ -23,6 +23,11 @@ type Container struct { EntityRepository domain.EntityRepository ` + "`di:\"required,set,close,public,external\"`" + ` Handler *httpadapter.GetEntityHandler ` + "`factory-file:\"http_handler\"`" + ` + StringOption string + IntOption int + DurationOption time.Duration + StringPointer *string + UseCase UseCaseContainer } @@ -107,38 +112,63 @@ func assertExpectedContainerImports(t *testing.T, imports map[string]*di.ImportD } func assertExpectedContainerServices(t *testing.T, services []*di.ServiceDefinition) { - if assert.Len(t, services, 3) { - assert.Equal(t, "Configuration", services[0].Name) - assert.Equal(t, "config", services[0].Type.Package) - assert.Equal(t, "Configuration", services[0].Type.Name) - assert.False(t, services[0].Type.IsPointer) - assert.False(t, services[0].HasSetter) - assert.False(t, services[0].HasCloser) - assert.False(t, services[0].IsRequired) - assert.False(t, services[0].IsPublic) - assert.False(t, services[0].IsExternal) - - assert.Equal(t, "EntityRepository", services[1].Name) - assert.Equal(t, "domain", services[1].Type.Package) - assert.Equal(t, "EntityRepository", services[1].Type.Name) - assert.False(t, services[1].Type.IsPointer) - assert.True(t, services[1].HasSetter) - assert.True(t, services[1].HasCloser) - assert.True(t, services[1].IsRequired) - assert.True(t, services[1].IsPublic) - assert.True(t, services[1].IsExternal) - - assert.Equal(t, "Handler", services[2].Name) - assert.Equal(t, "httpadapter", services[2].Type.Package) - assert.Equal(t, "GetEntityHandler", services[2].Type.Name) - assert.Equal(t, "http_handler.go", services[2].FactoryFileName) - assert.True(t, services[2].Type.IsPointer) - assert.False(t, services[2].HasSetter) - assert.False(t, services[2].HasCloser) - assert.False(t, services[2].IsRequired) - assert.False(t, services[2].IsPublic) - assert.False(t, services[2].IsExternal) - } + require.Len(t, services, 7) + + assert.Equal(t, "Configuration", services[0].Name) + assert.Equal(t, "config", services[0].Type.Package) + assert.Equal(t, "Configuration", services[0].Type.Name) + assert.False(t, services[0].Type.IsPointer) + assert.False(t, services[0].HasSetter) + assert.False(t, services[0].HasCloser) + assert.False(t, services[0].IsRequired) + assert.False(t, services[0].IsPublic) + assert.False(t, services[0].IsExternal) + + assert.Equal(t, "EntityRepository", services[1].Name) + assert.Equal(t, "domain", services[1].Type.Package) + assert.Equal(t, "EntityRepository", services[1].Type.Name) + assert.False(t, services[1].Type.IsPointer) + assert.True(t, services[1].HasSetter) + assert.True(t, services[1].HasCloser) + assert.True(t, services[1].IsRequired) + assert.True(t, services[1].IsPublic) + assert.True(t, services[1].IsExternal) + + assert.Equal(t, "Handler", services[2].Name) + assert.Equal(t, "httpadapter", services[2].Type.Package) + assert.Equal(t, "GetEntityHandler", services[2].Type.Name) + assert.Equal(t, "http_handler.go", services[2].FactoryFileName) + assert.True(t, services[2].Type.IsPointer) + assert.False(t, services[2].HasSetter) + assert.False(t, services[2].HasCloser) + assert.False(t, services[2].IsRequired) + assert.False(t, services[2].IsPublic) + assert.False(t, services[2].IsExternal) + + assert.Equal(t, "StringOption", services[3].Name) + assert.True(t, services[3].Type.IsBasicType()) + assert.False(t, services[3].Type.IsPointer) + assert.Equal(t, "", services[3].Type.Package) + assert.Equal(t, "string", services[3].Type.Name) + + assert.Equal(t, "IntOption", services[4].Name) + assert.True(t, services[4].Type.IsBasicType()) + assert.False(t, services[4].Type.IsPointer) + assert.Equal(t, "", services[4].Type.Package) + assert.Equal(t, "int", services[4].Type.Name) + + assert.Equal(t, "DurationOption", services[5].Name) + assert.True(t, services[5].Type.IsDuration()) + assert.False(t, services[5].Type.IsPointer) + assert.Equal(t, "time", services[5].Type.Package) + assert.Equal(t, "Duration", services[5].Type.Name) + + assert.Equal(t, "StringPointer", services[6].Name) + assert.True(t, services[6].Type.IsBasicType()) + assert.True(t, services[6].Type.IsPointer) + assert.Equal(t, "", services[6].Type.Package) + assert.Equal(t, "string", services[6].Type.Name) + } func assertExpectedInternalContainers(t *testing.T, containers []*di.ContainerDefinition) { diff --git a/pkg/di/templates.go b/pkg/di/templates.go index d087ea1..e295da0 100644 --- a/pkg/di/templates.go +++ b/pkg/di/templates.go @@ -9,14 +9,15 @@ import ( var readmeTemplate string type templateParameters struct { - ContainerName string - ServicePrefix string - ServicePath string - ServiceName string - ServiceTitle string - ServiceType string - HasDefinition bool - PanicOnNil bool + ContainerName string + ServicePrefix string + ServicePath string + ServiceName string + ServiceTitle string + ServiceType string + ServiceZeroComparison string + HasDefinition bool + PanicOnNil bool } type internalContainerTemplateParameters struct { @@ -70,7 +71,7 @@ func (c *Container) SetError(err error) { var getterTemplate = template.Must(template.New("getter").Parse(` func (c *{{.ContainerName}}) {{.ServiceTitle}}(ctx context.Context) {{.ServiceType}} { -{{ if .HasDefinition }} if c.{{.ServiceName}} == nil && c.err == nil { +{{ if .HasDefinition }} if c.{{.ServiceName}}{{.ServiceZeroComparison}} && c.err == nil { {{ if .PanicOnNil }}panic("missing {{.ServiceTitle}}"){{ else }}c.{{.ServiceName}} = factories.Create{{.ServicePrefix}}{{.ServiceTitle}}(ctx, c){{ end }} } {{ end }} return c.{{.ServiceName}} diff --git a/pkg/di/testdata/single_container_with_basic_types.txt b/pkg/di/testdata/single_container_with_basic_types.txt new file mode 100644 index 0000000..96dc183 --- /dev/null +++ b/pkg/di/testdata/single_container_with_basic_types.txt @@ -0,0 +1,81 @@ +package internal + +import ( + "context" + "time" + "net/url" + "example.com/test/di/internal/factories" +) + +type Container struct { + err error + + stringOption string + stringPointer *string + intOption int + timeOption time.Time + durationOption time.Duration + urloption url.URL +} + +func NewContainer() *Container { + c := &Container{} + + return c +} + +// Error returns the first initialization error, which can be set via SetError in a service definition. +func (c *Container) Error() error { + return c.err +} + +// SetError sets the first error into container. The error is used in the public container to return an initialization error. +func (c *Container) SetError(err error) { + if err != nil && c.err == nil { + c.err = err + } +} + +func (c *Container) StringOption(ctx context.Context) string { + if c.stringOption == "" && c.err == nil { + c.stringOption = factories.CreateStringOption(ctx, c) + } + return c.stringOption +} + +func (c *Container) StringPointer(ctx context.Context) *string { + if c.stringPointer == nil && c.err == nil { + c.stringPointer = factories.CreateStringPointer(ctx, c) + } + return c.stringPointer +} + +func (c *Container) IntOption(ctx context.Context) int { + if c.intOption == 0 && c.err == nil { + c.intOption = factories.CreateIntOption(ctx, c) + } + return c.intOption +} + +func (c *Container) TimeOption(ctx context.Context) time.Time { + if c.timeOption.IsZero() && c.err == nil { + c.timeOption = factories.CreateTimeOption(ctx, c) + } + return c.timeOption +} + +func (c *Container) DurationOption(ctx context.Context) time.Duration { + if c.durationOption == 0 && c.err == nil { + c.durationOption = factories.CreateDurationOption(ctx, c) + } + return c.durationOption +} + +func (c *Container) URLOption(ctx context.Context) url.URL { + if c.urloption == url.URL{} && c.err == nil { + c.urloption = factories.CreateURLOption(ctx, c) + } + return c.urloption +} + +func (c *Container) Close() {}