diff --git a/env.go b/env.go index 5876579..1c118e3 100644 --- a/env.go +++ b/env.go @@ -3,11 +3,14 @@ package ejson2env import ( "errors" "fmt" + "regexp" ) var errNoEnv = errors.New("environment is not set in ejson") var errEnvNotMap = errors.New("environment is not a map[string]interface{}") +var validIdentifierPattern = regexp.MustCompile(`\A[a-zA-Z_][a-zA-Z0-9_]*\z`) + // ExtractEnv extracts the environment values from the map[string]interface{} // containing all secrets, and returns a map[string]string containing the // key value pairs. If there's an issue (the environment key doesn't exist, for @@ -26,6 +29,12 @@ func ExtractEnv(secrets map[string]interface{}) (map[string]string, error) { envSecrets := make(map[string]string, len(envMap)) for key, rawValue := range envMap { + // Reject keys that would be invalid environment variable identifiers + if !validIdentifierPattern.MatchString(key) { + err := fmt.Errorf("invalid identifier as key in environment: %q", key) + + return nil, err + } // Only export values that convert to strings properly. if value, ok := rawValue.(string); ok { diff --git a/env_test.go b/env_test.go index 208db2e..a0f6384 100644 --- a/env_test.go +++ b/env_test.go @@ -94,6 +94,12 @@ func TestInvalidEnvironments(t *testing.T) { "environment": "bad", } + testBadInvalidKey := map[string]interface{}{ + "environment": map[string]interface{}{ + "invalid key": "test_value", + }, + } + var testNoEnv map[string]interface{} _, err := ExtractEnv(testBadNonMap) @@ -103,6 +109,13 @@ func TestInvalidEnvironments(t *testing.T) { t.Errorf("wrong error when passed a non-map environment: %s", err) } + _, err = ExtractEnv(testBadInvalidKey) + if nil == err { + t.Errorf("no error when passed an environment with invalid key") + } else if `invalid identifier as key in environment: "invalid key"` != err.Error() { + t.Errorf("wrong error when passed an environment with invalid key: %s", err) + } + _, err = ExtractEnv(testNoEnv) if nil == err { t.Errorf("no error when passed a non-existent environment") @@ -133,3 +146,45 @@ func TestEscaping(t *testing.T) { } } + +func TestIdentifierPattern(t *testing.T) { + key := "ALL_CAPS123" + if !validIdentifierPattern.MatchString(key) { + t.Errorf("key should match pattern %q: %q", validIdentifierPattern, key) + } + + key = "lowercase" + if !validIdentifierPattern.MatchString(key) { + t.Errorf("key should match pattern %q: %q", validIdentifierPattern, key) + } + + key = "a" + if !validIdentifierPattern.MatchString(key) { + t.Errorf("key should match pattern %q: %q", validIdentifierPattern, key) + } + + key = "_leading_underscore" + if !validIdentifierPattern.MatchString(key) { + t.Errorf("key should match pattern %q: %q", validIdentifierPattern, key) + } + + key = "1_leading_digit" + if validIdentifierPattern.MatchString(key) { + t.Errorf("key should not match pattern %q: %q", validIdentifierPattern, key) + } + + key = "contains whitespace" + if validIdentifierPattern.MatchString(key) { + t.Errorf("key should not match pattern %q: %q", validIdentifierPattern, key) + } + + key = "contains-dash" + if validIdentifierPattern.MatchString(key) { + t.Errorf("key should not match pattern %q: %q", validIdentifierPattern, key) + } + + key = "contains_special_character;" + if validIdentifierPattern.MatchString(key) { + t.Errorf("key should not match pattern %q: %q", validIdentifierPattern, key) + } +}