Skip to content
This repository has been archived by the owner on Nov 5, 2021. It is now read-only.

Commit

Permalink
rds.client: If RDS resource's IP is not a valid IP, try to resolve it…
Browse files Browse the repository at this point in the history
… using system's DNS resolver.

This is to support cases where resources have a hostname but not an IP, for example, AWS ELB:
#418

PiperOrigin-RevId: 319832714
  • Loading branch information
manugarg committed Jul 6, 2020
1 parent c165da1 commit daf7ac0
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 22 deletions.
38 changes: 25 additions & 13 deletions rds/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,21 @@ import (
pb "github.com/google/cloudprober/rds/proto"
spb "github.com/google/cloudprober/rds/proto"
"github.com/google/cloudprober/targets/endpoint"
dnsRes "github.com/google/cloudprober/targets/resolver"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
grpcoauth "google.golang.org/grpc/credentials/oauth"
)

// globalResolver is a singleton DNS resolver that is used as the default
// resolver by targets. It is a singleton because dnsRes.Resolver provides a
// cache layer that is best shared by all probes.
var (
globalResolver *dnsRes.Resolver
)

type cacheRecord struct {
ip net.IP
ip string
port int
labels map[string]string
}
Expand All @@ -59,6 +67,7 @@ type Client struct {
cache map[string]*cacheRecord
names []string
listResources func(context.Context, *pb.ListResourcesRequest) (*pb.ListResourcesResponse, error)
resolver *dnsRes.Resolver
l *logger.Logger
}

Expand All @@ -85,16 +94,7 @@ func (client *Client) updateState(response *pb.ListResourcesResponse) {

client.names = make([]string, len(response.GetResources()))
for i, res := range response.GetResources() {
var ip net.IP

if res.GetIp() != "" {
ip = net.ParseIP(res.GetIp())
if ip == nil {
client.l.Errorf("rds.client: errors parsing IP address for %s, IP string: %s", res.GetName(), res.GetIp())
continue
}
}
client.cache[res.GetName()] = &cacheRecord{ip, int(res.GetPort()), res.Labels}
client.cache[res.GetName()] = &cacheRecord{res.GetIp(), int(res.GetPort()), res.Labels}
client.names[i] = res.GetName()
}
}
Expand All @@ -115,11 +115,17 @@ func (client *Client) ListEndpoints() []endpoint.Endpoint {
func (client *Client) Resolve(name string, ipVer int) (net.IP, error) {
client.mu.RLock()
defer client.mu.RUnlock()

cr, ok := client.cache[name]
if !ok || cr.ip == nil {
if !ok || cr.ip == "" {
return nil, fmt.Errorf("no IP address for the resource: %s", name)
}
ip := cr.ip

ip := net.ParseIP(cr.ip)
// If not a valid IP, use DNS resolver to resolve it.
if ip == nil {
return client.resolver.Resolve(cr.ip, ipVer)
}

if ipVer == 0 || iputils.IPVersion(ip) == ipVer {
return ip, nil
Expand Down Expand Up @@ -189,6 +195,7 @@ func New(c *configpb.ClientConf, listResources ListResourcesFunc, l *logger.Logg
serverOpts: c.GetServerOptions(),
cache: make(map[string]*cacheRecord),
listResources: listResources,
resolver: globalResolver,
l: l,
}

Expand All @@ -213,3 +220,8 @@ func New(c *configpb.ClientConf, listResources ListResourcesFunc, l *logger.Logg

return client, nil
}

// init initializes the package by creating a new global resolver.
func init() {
globalResolver = dnsRes.New()
}
69 changes: 63 additions & 6 deletions rds/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package client

import (
"context"
"fmt"
"net"
"reflect"
"testing"
"time"
Expand All @@ -26,6 +28,7 @@ import (
pb "github.com/google/cloudprober/rds/proto"
"github.com/google/cloudprober/rds/server"
serverpb "github.com/google/cloudprober/rds/server/proto"
dnsRes "github.com/google/cloudprober/targets/resolver"
)

type testProvider struct {
Expand All @@ -46,9 +49,48 @@ var testResourcesMap = map[string][]*pb.Resource{
Port: proto.Int32(8080),
Labels: map[string]string{"zone": "us-central1-a"},
},
{
Name: proto.String("testR22v6"),
Ip: proto.String("::1"),
Port: proto.Int32(8080),
Labels: map[string]string{"zone": "us-central1-a"},
},
{
Name: proto.String("testR3"),
Ip: proto.String("testR3.test.com"),
Port: proto.Int32(80),
Labels: map[string]string{"zone": "us-central1-c"},
},
},
}

var expectedIPByVersion = map[string]map[int]string{
"testR21": map[int]string{
0: "10.0.2.1",
4: "10.0.2.1",
6: "err",
},
"testR22": map[int]string{
0: "10.0.2.2",
4: "10.0.2.2",
6: "err",
},
"testR22v6": map[int]string{
0: "::1",
4: "err",
6: "::1",
},
"testR3": map[int]string{
0: "10.1.1.2",
4: "10.1.1.2",
6: "::2",
},
}

var testNameToIP = map[string][]net.IP{
"testR3.test.com": []net.IP{net.ParseIP("10.1.1.2"), net.ParseIP("::2")},
}

func (tp *testProvider) ListResources(*pb.ListResourcesRequest) (*pb.ListResourcesResponse, error) {
return &pb.ListResourcesResponse{
Resources: tp.resources,
Expand Down Expand Up @@ -88,6 +130,10 @@ func TestListAndResolve(t *testing.T) {
if err != nil {
t.Fatalf("Got error initializing RDS client: %v", err)
}
client.resolver = dnsRes.NewWithResolve(func(name string) ([]net.IP, error) {
return testNameToIP[name], nil
})

client.refreshState(time.Second)

// Test ListEndpoint()
Expand All @@ -108,12 +154,23 @@ func TestListAndResolve(t *testing.T) {

// Test Resolve()
for _, res := range testResources {
ip, err := client.Resolve(res.GetName(), 4)
if err != nil {
t.Errorf("Error resolving %s, err: %v", res.GetName(), err)
}
if ip.String() != res.GetIp() {
t.Errorf("Didn't get expected IP for %s. Got: %s, Want: %s", res.GetName(), ip.String(), res.GetIp())
for _, ipVer := range []int{0, 4, 6} {
t.Run(fmt.Sprintf("resolve_%s_for_IPv%d", res.GetName(), ipVer), func(t *testing.T) {
expectedIP := expectedIPByVersion[res.GetName()][ipVer]

var gotIP string
ip, err := client.Resolve(res.GetName(), ipVer)
if err != nil {
t.Logf("Error resolving %s, err: %v", res.GetName(), err)
gotIP = "err"
} else {
gotIP = ip.String()
}

if gotIP != expectedIP {
t.Errorf("Didn't get expected IP for %s. Got: %s, Want: %s", res.GetName(), gotIP, expectedIP)
}
})
}
}
}
Expand Down
12 changes: 9 additions & 3 deletions targets/resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,17 @@ func (cr *cacheRecord) refreshIfRequired(name string, resolve func(string) ([]ne
}
}

// New returns a new Resolver.
func New() *Resolver {
// NewWithResolve returns a new Resolver with the given backend resolver.
// This is useful for testing.
func NewWithResolve(resolveFunc func(string) ([]net.IP, error)) *Resolver {
return &Resolver{
cache: make(map[string]*cacheRecord),
resolve: func(name string) ([]net.IP, error) { return net.LookupIP(name) },
resolve: resolveFunc,
DefaultMaxAge: defaultMaxAge,
}
}

// New returns a new Resolver.
func New() *Resolver {
return NewWithResolve(net.LookupIP)
}

0 comments on commit daf7ac0

Please sign in to comment.