Skip to content
This repository has been archived by the owner on Nov 5, 2021. It is now read-only.

Add load balancing support to RDS client. #362

Merged
merged 1 commit into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions rds/client/client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2019 Google Inc.
// Copyright 2018-2020 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,7 @@ import (
"fmt"
"math/rand"
"net"
"strings"
"sync"
"time"

Expand All @@ -45,6 +46,9 @@ type cacheRecord struct {
labels map[string]string
}

// Default RDS port
const defaultRDSPort = "9314"

// Client represents an RDS based client instance.
type Client struct {
mu sync.Mutex
Expand Down Expand Up @@ -142,6 +146,16 @@ func (client *Client) Resolve(name string, ipVer int) (net.IP, error) {
return nil, fmt.Errorf("no IPv4 address (IP: %s) for %s", ip.String(), name)
}

func (client *Client) connect(serverAddr string) (*grpc.ClientConn, error) {
client.l.Infof("rds.client: using RDS servers at: %s", serverAddr)

if strings.HasPrefix(serverAddr, "srvlist:///") {
client.dialOpts = append(client.dialOpts, grpc.WithResolvers(&srvListBuilder{defaultPort: defaultRDSPort}))
}

return grpc.Dial(client.serverOpts.GetServerAddress(), client.dialOpts...)
}

// initListResourcesFunc uses server options to establish a connection with the
// given RDS server.
func (client *Client) initListResourcesFunc() error {
Expand Down Expand Up @@ -173,8 +187,7 @@ func (client *Client) initListResourcesFunc() error {
client.dialOpts = append(client.dialOpts, grpc.WithPerRPCCredentials(grpcoauth.TokenSource{oauthTS}))
}

client.l.Infof("rds.client: using RDS server at: %s", client.serverOpts.GetServerAddress())
conn, err := grpc.Dial(client.serverOpts.GetServerAddress(), client.dialOpts...)
conn, err := client.connect(client.serverOpts.GetServerAddress())
if err != nil {
return err
}
Expand Down
168 changes: 168 additions & 0 deletions rds/client/srvlist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright 2020 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file implements a client-side load balancing resolver for gRPC clients.
// This resolver takes a comma separated list of addresses and sets client
// connection to use those addresses in a round-robin manner. It implements
// the APIs defined in google.golang.org/grpc/resolver.

package client

import (
"math/rand"
"net"
"strings"

cpRes "github.com/google/cloudprober/targets/resolver"
"google.golang.org/grpc/resolver"
)

// srvListResolver implements the resolver.Resolver interface.
type srvListResolver struct {
hostList []string
portList []string
r *cpRes.Resolver
cc resolver.ClientConn
defaultPort string
}

func parseAddr(addr, defaultPort string) (host, port string, err error) {
if ipStr, ok := formatIP(addr); ok {
return ipStr, defaultPort, nil
}

host, port, err = net.SplitHostPort(addr)
if err != nil {
return "", "", err
}

if port == "" {
port = defaultPort
}

// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
if host == "" {
// Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed.
host = "localhost"
}

return
}

// formatIP returns ok = false if addr is not a valid textual representation of an IP address.
// If addr is an IPv4 address, return the addr and ok = true.
// If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true.
func formatIP(addr string) (addrIP string, ok bool) {
ip := net.ParseIP(addr)
if ip == nil {
return "", false
}
if ip.To4() != nil {
return addr, true
}
return "[" + addr + "]", true
}

func (res *srvListResolver) resolve() (*resolver.State, error) {
state := &resolver.State{}

for i, host := range res.hostList {
if ipStr, ok := formatIP(host); ok {
state.Addresses = append(state.Addresses, resolver.Address{
Addr: ipStr + ":" + res.portList[i],
})
continue
}

ip, err := res.r.Resolve(host, 0)
if err != nil {
return nil, err
}
state.Addresses = append(state.Addresses, resolver.Address{
Addr: ip.String() + ":" + res.portList[i],
})
}

// Set round robin policy.
state.ServiceConfig = res.cc.ParseServiceConfig("{\"loadBalancingPolicy\": \"round_robin\"}")
return state, nil
}

func (res *srvListResolver) ResolveNow(_ resolver.ResolveNowOptions) {
state, err := res.resolve()
if err != nil {
res.cc.ReportError(err)
return
}

res.cc.UpdateState(*state)
}

func (res *srvListResolver) Close() {
}

func newSrvListResolver(target, defaultPort string) (*srvListResolver, error) {
res := &srvListResolver{
r: cpRes.New(),
defaultPort: defaultPort,
}

addrs := strings.Split(target, ",")

// Shuffle addresses to create variance in what order different clients start
// connecting to these addresses. Note that round-robin load balancing policy
// takes care of distributing load evenly over time.
rand.Shuffle(len(addrs), func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})

for _, addr := range addrs {
host, port, err := parseAddr(addr, defaultPort)
if err != nil {
return nil, err
}

res.hostList = append(res.hostList, host)
res.portList = append(res.portList, port)
}

return res, nil
}

type srvListBuilder struct {
defaultPort string
}

// Scheme returns the naming scheme of this resolver builder, which is "srvlist".
func (slb *srvListBuilder) Scheme() string {
return "srvlist"
}

func (slb *srvListBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
res, err := newSrvListResolver(target.Endpoint, slb.defaultPort)
if err != nil {
return nil, err
}

res.cc = cc

state, err := res.resolve()
if err != nil {
res.cc.ReportError(err)
} else {
res.cc.UpdateState(*state)
}

return res, nil
}
124 changes: 124 additions & 0 deletions rds/client/srvlist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2020 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import (
"fmt"
"reflect"
"sort"
"testing"
)

var testDefaultPort = "9999"

func TestParseAddr(t *testing.T) {
var tests = []struct {
addr, host, port string
err error
}{
{
addr: "rds-service:443",
host: "rds-service",
port: "443",
err: nil,
},
{
addr: "192.168.1.2:4430",
host: "192.168.1.2",
port: "4430",
err: nil,
},
{
addr: "192.168.1.4",
host: "192.168.1.4",
port: testDefaultPort,
err: nil,
},
{
addr: "1620:15c:2c4:201::ff",
host: "[1620:15c:2c4:201::ff]",
port: testDefaultPort,
err: nil,
},
{
addr: "rds-service:",
host: "rds-service",
port: testDefaultPort,
err: nil,
},
{
addr: ":9314",
host: "localhost",
port: "9314",
err: nil,
},
}

for _, test := range tests {
t.Run(fmt.Sprintf("parsing %s", test.addr), func(t *testing.T) {
host, port, err := parseAddr(test.addr, testDefaultPort)

if host != test.host || port != test.port || err != test.err {
t.Errorf("parseAddr(%s)=%s, %s, %v, want=%s, %s, %v", test.addr, host, port, err, test.host, test.port, test.err)
}
})
}
}

func TestNewResolver(t *testing.T) {
var tests = []struct {
target string
hosts []string
ports []string
}{
{
target: "rds-service-a:443,rds-service-b:9314",
hosts: []string{"rds-service-a", "rds-service-b"},
ports: []string{"443", "9314"},
},
{
target: "35.14.14.1:443,rds-service-b:9314",
hosts: []string{"35.14.14.1", "rds-service-b"},
ports: []string{"443", "9314"},
},
{
target: "35.14.14.1,35.14.14.2",
hosts: []string{"35.14.14.1", "35.14.14.2"},
ports: []string{testDefaultPort, testDefaultPort},
},
}

for _, test := range tests {
t.Run(fmt.Sprintf("parsing %s", test.target), func(t *testing.T) {
res, err := newSrvListResolver(test.target, testDefaultPort)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

sort.Strings(res.hostList)
sort.Strings(test.hosts)
if !reflect.DeepEqual(res.hostList, test.hosts) {
t.Errorf("res.hostList, got=%v, want=%v", res.hostList, test.hosts)
}

sort.Strings(res.portList)
sort.Strings(test.ports)
if !reflect.DeepEqual(res.portList, test.ports) {
t.Errorf("res.portList, got=%v, want=%v", res.portList, test.ports)
}
})
}

}