diff --git a/cmd/lekko/gen/ts.go b/cmd/lekko/gen/ts.go index 7055b4b2..761877e8 100644 --- a/cmd/lekko/gen/ts.go +++ b/cmd/lekko/gen/ts.go @@ -305,14 +305,26 @@ export function {{$.FuncName}}({{$.Parameters}}): {{$.RetType}} { } } + optionalVariables := make(map[string]string) + translateFeatureToOptionalVariables(f, optionalVariables) + usedVariables := make(map[string]string) - code := translateFeatureTS(f, usedVariables) + code := translateFeatureTS(f, usedVariables, optionalVariables) if len(parameters) == 0 && len(usedVariables) > 0 { var keys []string var keyAndTypes []string for k, t := range usedVariables { + if _, exists := optionalVariables[k]; !exists { + keys = append(keys, k) + keyAndTypes = append(keyAndTypes, fmt.Sprintf("%s: %s", k, t)) + } + } + for k, t := range optionalVariables { keys = append(keys, k) - keyAndTypes = append(keyAndTypes, fmt.Sprintf("%s: %s", k, t)) + if paramType, exists := usedVariables[k]; exists { + t = paramType + } + keyAndTypes = append(keyAndTypes, fmt.Sprintf("%s?: %s", k, t)) } parameters = fmt.Sprintf("{%s}: {%s}", strings.Join(keys, ","), strings.Join(keyAndTypes, ",")) } @@ -345,14 +357,20 @@ export function {{$.FuncName}}({{$.Parameters}}): {{$.RetType}} { return ret.String(), nil } -func translateFeatureTS(f *featurev1beta1.Feature, usedVariables map[string]string) []string { +func translateFeatureToOptionalVariables(f *featurev1beta1.Feature, optionalVariables map[string]string) { + for _, constraint := range f.Tree.Constraints { + getOptionalVariables(constraint.GetRuleAstNew(), optionalVariables) + } +} + +func translateFeatureTS(f *featurev1beta1.Feature, usedVariables map[string]string, optionalVariables map[string]string) []string { var buffer []string for i, constraint := range f.Tree.Constraints { ifToken := "} else if" if i == 0 { ifToken = "if" } - rule := translateRuleTS(constraint.GetRuleAstNew(), usedVariables) + rule := translateRuleTS(constraint.GetRuleAstNew(), usedVariables, optionalVariables) buffer = append(buffer, fmt.Sprintf("\t%s %s {", ifToken, rule)) // TODO this doesn't work for proto, but let's try @@ -378,7 +396,42 @@ func structpbValueToKindString(v *structpb.Value) string { return "unknown" } -func translateRuleTS(rule *rulesv1beta3.Rule, usedVariables map[string]string) string { +func getOptionalVariables(rule *rulesv1beta3.Rule, optionalVariables map[string]string) { + if rule == nil { + return + } + switch v := rule.GetRule().(type) { + case *rulesv1beta3.Rule_Atom: + switch v.Atom.GetComparisonOperator() { + case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_PRESENT: + optionalVariables[v.Atom.ContextKey] = "string | boolean | number" + } + case *rulesv1beta3.Rule_LogicalExpression: + for _, rule := range v.LogicalExpression.Rules { + getOptionalVariables(rule, optionalVariables) + } + } +} + +func createUndefinedSafeJSExpression(contextKey string, comparisonValue *structpb.Value, operator string, optionalVariables map[string]string, marshalOptions protojson.MarshalOptions, usedVariables map[string]string) string { + usedVariables[contextKey] = structpbValueToKindString(comparisonValue) + + marshaledValue, err := marshalOptions.Marshal(comparisonValue) + if err != nil { + return fmt.Sprintf("Error marshaling comparison value: %v", err) + } + + marshaledValueStr := string(marshaledValue) + + methodCall := fmt.Sprintf("%s.%s(%s)", contextKey, operator, marshaledValueStr) + if _, ok := optionalVariables[contextKey]; ok { + methodCall = fmt.Sprintf("%s?.%s(%s)", contextKey, operator, marshaledValueStr) + } + + return fmt.Sprintf("(%s)", methodCall) +} + +func translateRuleTS(rule *rulesv1beta3.Rule, usedVariables map[string]string, optionalVariables map[string]string) string { marshalOptions := protojson.MarshalOptions{ UseProtoNames: true, } @@ -415,14 +468,16 @@ func translateRuleTS(rule *rulesv1beta3.Rule, usedVariables map[string]string) s usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) return fmt.Sprintf("(%s >= %s)", v.Atom.ContextKey, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_CONTAINS: - usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s.includes(%s))", v.Atom.ContextKey, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return createUndefinedSafeJSExpression(v.Atom.ContextKey, v.Atom.ComparisonValue, "includes", optionalVariables, marshalOptions, usedVariables) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_STARTS_WITH: - usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s.startsWith(%s))", v.Atom.ContextKey, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return createUndefinedSafeJSExpression(v.Atom.ContextKey, v.Atom.ComparisonValue, "startsWith", optionalVariables, marshalOptions, usedVariables) case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_ENDS_WITH: - usedVariables[v.Atom.ContextKey] = structpbValueToKindString(v.Atom.ComparisonValue) - return fmt.Sprintf("(%s.endsWith(%s))", v.Atom.ContextKey, try.To1(marshalOptions.Marshal(v.Atom.ComparisonValue))) + return createUndefinedSafeJSExpression(v.Atom.ContextKey, v.Atom.ComparisonValue, "endsWith", optionalVariables, marshalOptions, usedVariables) + case rulesv1beta3.ComparisonOperator_COMPARISON_OPERATOR_PRESENT: + if _, ok := usedVariables[v.Atom.ContextKey]; !ok { + usedVariables[v.Atom.ContextKey] = "string | boolean | number" + } + return fmt.Sprintf("(%s !== undefined)", v.Atom.ContextKey) } case *rulesv1beta3.Rule_LogicalExpression: operator := " && " @@ -433,7 +488,7 @@ func translateRuleTS(rule *rulesv1beta3.Rule, usedVariables map[string]string) s var result []string for _, rule := range v.LogicalExpression.Rules { // worry about inner parens later - result = append(result, translateRuleTS(rule, usedVariables)) + result = append(result, translateRuleTS(rule, usedVariables, optionalVariables)) } return "(" + strings.Join(result, operator) + ")" }