From e2c9a95c424b10bdb5b00bb87c20f7a16c9da6b9 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 7 Mar 2024 15:18:30 -0500 Subject: [PATCH] chore(internal/protoveneer): support custom field converters (#9456) Add the ability to specify conversion functions for specific fields. --- internal/protoveneer/cmd/protoveneer/config.go | 2 ++ internal/protoveneer/cmd/protoveneer/protoveneer.go | 10 ++++++++++ .../cmd/protoveneer/testdata/basic/basic.pb.go | 1 + .../cmd/protoveneer/testdata/basic/config.yaml | 3 +++ .../protoveneer/cmd/protoveneer/testdata/basic/golden | 3 +++ 5 files changed, 19 insertions(+) diff --git a/internal/protoveneer/cmd/protoveneer/config.go b/internal/protoveneer/cmd/protoveneer/config.go index c1fc9350a8e4..08322ea3ae36 100644 --- a/internal/protoveneer/cmd/protoveneer/config.go +++ b/internal/protoveneer/cmd/protoveneer/config.go @@ -69,6 +69,8 @@ type fieldConfig struct { Type string // veneer type // Omit from output. Omit bool + // Custom conversion functions: "tofunc, fromfunc" + ConvertToFrom string `yaml:"convertToFrom"` } func (c *config) init() { diff --git a/internal/protoveneer/cmd/protoveneer/protoveneer.go b/internal/protoveneer/cmd/protoveneer/protoveneer.go index dccab87c6294..d34ba8b01a88 100644 --- a/internal/protoveneer/cmd/protoveneer/protoveneer.go +++ b/internal/protoveneer/cmd/protoveneer/protoveneer.go @@ -220,6 +220,9 @@ func generate(conf *config, pkg *ast.Package, fset *token.FileSet) (src []byte, // Use the converters map to give every field a converter. for _, ti := range toWrite { for _, f := range ti.fields { + if f.converter != nil { + continue + } f.converter, err = makeConverter(f.af.Type, f.protoType, converters) if err != nil { return nil, fmt.Errorf("%s.%s: %w", ti.protoName, f.protoName, err) @@ -482,6 +485,13 @@ func processField(af *ast.Field, tc *typeConfig, typeInfos map[string]*typeInfo) } af.Type = expr } + if fc.ConvertToFrom != "" { + c, err := parseCustomConverter(id.Name, fc.ConvertToFrom) + if err != nil { + return nil, err + } + fi.converter = c + } } } af.Type = veneerType(af.Type, typeInfos) diff --git a/internal/protoveneer/cmd/protoveneer/testdata/basic/basic.pb.go b/internal/protoveneer/cmd/protoveneer/testdata/basic/basic.pb.go index fef209a7c405..f87738971cac 100755 --- a/internal/protoveneer/cmd/protoveneer/testdata/basic/basic.pb.go +++ b/internal/protoveneer/cmd/protoveneer/testdata/basic/basic.pb.go @@ -95,6 +95,7 @@ type GenerationConfig struct { HarmCat HarmCategory FinishReason Candidate_FinishReason CitMet *CitationMetadata + TopK *float32 } // A collection of source attributions for a piece of content. diff --git a/internal/protoveneer/cmd/protoveneer/testdata/basic/config.yaml b/internal/protoveneer/cmd/protoveneer/testdata/basic/config.yaml index 9a54cbe88242..1946480b065f 100644 --- a/internal/protoveneer/cmd/protoveneer/testdata/basic/config.yaml +++ b/internal/protoveneer/cmd/protoveneer/testdata/basic/config.yaml @@ -37,6 +37,9 @@ types: type: float32 CandidateCount: type: int32 + TopK: + type: '*int32' + convertToFrom: int32pToFloat32p, float32pToInt32p Citation: docVerb: contains diff --git a/internal/protoveneer/cmd/protoveneer/testdata/basic/golden b/internal/protoveneer/cmd/protoveneer/testdata/basic/golden index 65179acae38b..3de88757abd0 100644 --- a/internal/protoveneer/cmd/protoveneer/testdata/basic/golden +++ b/internal/protoveneer/cmd/protoveneer/testdata/basic/golden @@ -147,6 +147,7 @@ type GenerationConfig struct { HarmCat HarmCategory FinishReason FinishReason CitMet *CitationMetadata + TopK *int32 } func (v *GenerationConfig) toProto() *pb.GenerationConfig { @@ -160,6 +161,7 @@ func (v *GenerationConfig) toProto() *pb.GenerationConfig { HarmCat: pb.HarmCategory(v.HarmCat), FinishReason: pb.Candidate_FinishReason(v.FinishReason), CitMet: v.CitMet.toProto(), + TopK: int32pToFloat32p(v.TopK), } } @@ -174,6 +176,7 @@ func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig { HarmCat: HarmCategory(p.HarmCat), FinishReason: FinishReason(p.FinishReason), CitMet: (CitationMetadata{}).fromProto(p.CitMet), + TopK: float32pToInt32p(p.TopK), } }