diff --git a/network/domain.go b/network/domain.go index a205b4eece..ee571e0bce 100644 --- a/network/domain.go +++ b/network/domain.go @@ -26,8 +26,9 @@ import ( ) const ( - resolverFileName = "/etc/resolv.conf" - defaultDomainName = "cluster.local" + resolverFileName = "/etc/resolv.conf" + clusterDomainEnvKey = "CLUSTER_DOMAIN" + defaultDomainName = "cluster.local" ) var ( @@ -55,6 +56,7 @@ func GetClusterDomainName() string { } func getClusterDomainName(r io.Reader) string { + // First look in the conf file. for scanner := bufio.NewScanner(r); scanner.Scan(); { elements := strings.Split(scanner.Text(), " ") if elements[0] != "search" { @@ -66,6 +68,12 @@ func getClusterDomainName(r io.Reader) string { } } } + + // Then look in the ENV. + if domain := os.Getenv(clusterDomainEnvKey); len(domain) > 0 { + return domain + } + // For all abnormal cases return default domain name. return defaultDomainName } diff --git a/network/domain_test.go b/network/domain_test.go index 6ce3ba3e63..e6a61d2358 100644 --- a/network/domain_test.go +++ b/network/domain_test.go @@ -17,6 +17,7 @@ limitations under the License. package network import ( + "os" "strings" "testing" ) @@ -24,46 +25,71 @@ import ( func TestGetDomainName(t *testing.T) { tests := []struct { name string + env string resolvConf string want string - }{ - { - name: "all good", - resolvConf: ` + }{{ + name: "all good", + resolvConf: ` nameserver 1.1.1.1 search default.svc.abc.com svc.abc.com abc.com options ndots:5 `, - want: "abc.com", - }, - { - name: "all good with trailing dot", - resolvConf: ` + want: "abc.com", + }, { + name: "all good with env set but ignored", + resolvConf: ` +nameserver 1.1.1.1 +search default.svc.abc.com svc.abc.com abc.com +options ndots:5 +`, + env: "ignored.com", + want: "abc.com", + }, { + name: "all good from env", + resolvConf: ``, + env: "abc.com", + want: "abc.com", + }, { + name: "all good with trailing dot", + resolvConf: ` nameserver 1.1.1.1 search default.svc.abc.com. svc.abc.com. abc.com. options ndots:5 `, - want: "abc.com", - }, - { - name: "missing search line", - resolvConf: ` + want: "abc.com", + }, { + name: "missing search line", + resolvConf: ` nameserver 1.1.1.1 options ndots:5 `, - want: defaultDomainName, - }, - { - name: "non k8s resolv.conf format", - resolvConf: ` + want: defaultDomainName, + }, { + name: "non k8s resolv.conf format", + resolvConf: ` nameserver 1.1.1.1 search abc.com xyz.com options ndots:5 `, - want: defaultDomainName, - }, - } + want: defaultDomainName, + }} + + domainWas := os.Getenv(clusterDomainEnvKey) + t.Cleanup(func() { + if len(domainWas) > 0 { + _ = os.Setenv(clusterDomainEnvKey, domainWas) + } else { + _ = os.Unsetenv(clusterDomainEnvKey) + } + }) + for _, tt := range tests { + if len(tt.env) > 0 { + _ = os.Setenv(clusterDomainEnvKey, tt.env) + } else { + _ = os.Unsetenv(clusterDomainEnvKey) + } got := getClusterDomainName(strings.NewReader(tt.resolvConf)) if got != tt.want { t.Errorf("Test %s failed expected: %s but got: %s", tt.name, tt.want, got)