diff --git a/rds/client/client.go b/rds/client/client.go index 96e15c09..44cda47e 100644 --- a/rds/client/client.go +++ b/rds/client/client.go @@ -1,4 +1,4 @@ -// Copyright 2018-2019 Google Inc. +// Copyright 2018-2020 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import ( "fmt" "math/rand" "net" + "strings" "sync" "time" @@ -45,6 +46,9 @@ type cacheRecord struct { labels map[string]string } +// Default RDS port +const defaultRDSPort = "9314" + // Client represents an RDS based client instance. type Client struct { mu sync.Mutex @@ -142,6 +146,16 @@ func (client *Client) Resolve(name string, ipVer int) (net.IP, error) { return nil, fmt.Errorf("no IPv4 address (IP: %s) for %s", ip.String(), name) } +func (client *Client) connect(serverAddr string) (*grpc.ClientConn, error) { + client.l.Infof("rds.client: using RDS servers at: %s", serverAddr) + + if strings.HasPrefix(serverAddr, "srvlist:///") { + client.dialOpts = append(client.dialOpts, grpc.WithResolvers(&srvListBuilder{defaultPort: defaultRDSPort})) + } + + return grpc.Dial(client.serverOpts.GetServerAddress(), client.dialOpts...) +} + // initListResourcesFunc uses server options to establish a connection with the // given RDS server. func (client *Client) initListResourcesFunc() error { @@ -173,8 +187,7 @@ func (client *Client) initListResourcesFunc() error { client.dialOpts = append(client.dialOpts, grpc.WithPerRPCCredentials(grpcoauth.TokenSource{oauthTS})) } - client.l.Infof("rds.client: using RDS server at: %s", client.serverOpts.GetServerAddress()) - conn, err := grpc.Dial(client.serverOpts.GetServerAddress(), client.dialOpts...) + conn, err := client.connect(client.serverOpts.GetServerAddress()) if err != nil { return err } diff --git a/rds/client/srvlist.go b/rds/client/srvlist.go new file mode 100644 index 00000000..9a39a0a8 --- /dev/null +++ b/rds/client/srvlist.go @@ -0,0 +1,168 @@ +// Copyright 2020 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file implements a client-side load balancing resolver for gRPC clients. +// This resolver takes a comma separated list of addresses and sets client +// connection to use those addresses in a round-robin manner. It implements +// the APIs defined in google.golang.org/grpc/resolver. + +package client + +import ( + "math/rand" + "net" + "strings" + + cpRes "github.com/google/cloudprober/targets/resolver" + "google.golang.org/grpc/resolver" +) + +// srvListResolver implements the resolver.Resolver interface. +type srvListResolver struct { + hostList []string + portList []string + r *cpRes.Resolver + cc resolver.ClientConn + defaultPort string +} + +func parseAddr(addr, defaultPort string) (host, port string, err error) { + if ipStr, ok := formatIP(addr); ok { + return ipStr, defaultPort, nil + } + + host, port, err = net.SplitHostPort(addr) + if err != nil { + return "", "", err + } + + if port == "" { + port = defaultPort + } + + // target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port + if host == "" { + // Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed. + host = "localhost" + } + + return +} + +// formatIP returns ok = false if addr is not a valid textual representation of an IP address. +// If addr is an IPv4 address, return the addr and ok = true. +// If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true. +func formatIP(addr string) (addrIP string, ok bool) { + ip := net.ParseIP(addr) + if ip == nil { + return "", false + } + if ip.To4() != nil { + return addr, true + } + return "[" + addr + "]", true +} + +func (res *srvListResolver) resolve() (*resolver.State, error) { + state := &resolver.State{} + + for i, host := range res.hostList { + if ipStr, ok := formatIP(host); ok { + state.Addresses = append(state.Addresses, resolver.Address{ + Addr: ipStr + ":" + res.portList[i], + }) + continue + } + + ip, err := res.r.Resolve(host, 0) + if err != nil { + return nil, err + } + state.Addresses = append(state.Addresses, resolver.Address{ + Addr: ip.String() + ":" + res.portList[i], + }) + } + + // Set round robin policy. + state.ServiceConfig = res.cc.ParseServiceConfig("{\"loadBalancingPolicy\": \"round_robin\"}") + return state, nil +} + +func (res *srvListResolver) ResolveNow(_ resolver.ResolveNowOptions) { + state, err := res.resolve() + if err != nil { + res.cc.ReportError(err) + return + } + + res.cc.UpdateState(*state) +} + +func (res *srvListResolver) Close() { +} + +func newSrvListResolver(target, defaultPort string) (*srvListResolver, error) { + res := &srvListResolver{ + r: cpRes.New(), + defaultPort: defaultPort, + } + + addrs := strings.Split(target, ",") + + // Shuffle addresses to create variance in what order different clients start + // connecting to these addresses. Note that round-robin load balancing policy + // takes care of distributing load evenly over time. + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + + for _, addr := range addrs { + host, port, err := parseAddr(addr, defaultPort) + if err != nil { + return nil, err + } + + res.hostList = append(res.hostList, host) + res.portList = append(res.portList, port) + } + + return res, nil +} + +type srvListBuilder struct { + defaultPort string +} + +// Scheme returns the naming scheme of this resolver builder, which is "srvlist". +func (slb *srvListBuilder) Scheme() string { + return "srvlist" +} + +func (slb *srvListBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + res, err := newSrvListResolver(target.Endpoint, slb.defaultPort) + if err != nil { + return nil, err + } + + res.cc = cc + + state, err := res.resolve() + if err != nil { + res.cc.ReportError(err) + } else { + res.cc.UpdateState(*state) + } + + return res, nil +} diff --git a/rds/client/srvlist_test.go b/rds/client/srvlist_test.go new file mode 100644 index 00000000..592520f0 --- /dev/null +++ b/rds/client/srvlist_test.go @@ -0,0 +1,124 @@ +// Copyright 2020 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "fmt" + "reflect" + "sort" + "testing" +) + +var testDefaultPort = "9999" + +func TestParseAddr(t *testing.T) { + var tests = []struct { + addr, host, port string + err error + }{ + { + addr: "rds-service:443", + host: "rds-service", + port: "443", + err: nil, + }, + { + addr: "192.168.1.2:4430", + host: "192.168.1.2", + port: "4430", + err: nil, + }, + { + addr: "192.168.1.4", + host: "192.168.1.4", + port: testDefaultPort, + err: nil, + }, + { + addr: "1620:15c:2c4:201::ff", + host: "[1620:15c:2c4:201::ff]", + port: testDefaultPort, + err: nil, + }, + { + addr: "rds-service:", + host: "rds-service", + port: testDefaultPort, + err: nil, + }, + { + addr: ":9314", + host: "localhost", + port: "9314", + err: nil, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("parsing %s", test.addr), func(t *testing.T) { + host, port, err := parseAddr(test.addr, testDefaultPort) + + if host != test.host || port != test.port || err != test.err { + t.Errorf("parseAddr(%s)=%s, %s, %v, want=%s, %s, %v", test.addr, host, port, err, test.host, test.port, test.err) + } + }) + } +} + +func TestNewResolver(t *testing.T) { + var tests = []struct { + target string + hosts []string + ports []string + }{ + { + target: "rds-service-a:443,rds-service-b:9314", + hosts: []string{"rds-service-a", "rds-service-b"}, + ports: []string{"443", "9314"}, + }, + { + target: "35.14.14.1:443,rds-service-b:9314", + hosts: []string{"35.14.14.1", "rds-service-b"}, + ports: []string{"443", "9314"}, + }, + { + target: "35.14.14.1,35.14.14.2", + hosts: []string{"35.14.14.1", "35.14.14.2"}, + ports: []string{testDefaultPort, testDefaultPort}, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("parsing %s", test.target), func(t *testing.T) { + res, err := newSrvListResolver(test.target, testDefaultPort) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + sort.Strings(res.hostList) + sort.Strings(test.hosts) + if !reflect.DeepEqual(res.hostList, test.hosts) { + t.Errorf("res.hostList, got=%v, want=%v", res.hostList, test.hosts) + } + + sort.Strings(res.portList) + sort.Strings(test.ports) + if !reflect.DeepEqual(res.portList, test.ports) { + t.Errorf("res.portList, got=%v, want=%v", res.portList, test.ports) + } + }) + } + +}