From daf7ac05987898c752c6d1b1205c63cf61bbed7f Mon Sep 17 00:00:00 2001 From: Manu Garg Date: Mon, 6 Jul 2020 12:31:46 -0700 Subject: [PATCH] rds.client: If RDS resource's IP is not a valid IP, try to resolve it using system's DNS resolver. This is to support cases where resources have a hostname but not an IP, for example, AWS ELB: https://github.com/google/cloudprober/issues/418 PiperOrigin-RevId: 319832714 --- rds/client/client.go | 38 +++++++++++++------- rds/client/client_test.go | 69 ++++++++++++++++++++++++++++++++---- targets/resolver/resolver.go | 12 +++++-- 3 files changed, 97 insertions(+), 22 deletions(-) diff --git a/rds/client/client.go b/rds/client/client.go index 9710cb12..bed36b50 100644 --- a/rds/client/client.go +++ b/rds/client/client.go @@ -36,13 +36,21 @@ import ( pb "github.com/google/cloudprober/rds/proto" spb "github.com/google/cloudprober/rds/proto" "github.com/google/cloudprober/targets/endpoint" + dnsRes "github.com/google/cloudprober/targets/resolver" "google.golang.org/grpc" "google.golang.org/grpc/credentials" grpcoauth "google.golang.org/grpc/credentials/oauth" ) +// globalResolver is a singleton DNS resolver that is used as the default +// resolver by targets. It is a singleton because dnsRes.Resolver provides a +// cache layer that is best shared by all probes. +var ( + globalResolver *dnsRes.Resolver +) + type cacheRecord struct { - ip net.IP + ip string port int labels map[string]string } @@ -59,6 +67,7 @@ type Client struct { cache map[string]*cacheRecord names []string listResources func(context.Context, *pb.ListResourcesRequest) (*pb.ListResourcesResponse, error) + resolver *dnsRes.Resolver l *logger.Logger } @@ -85,16 +94,7 @@ func (client *Client) updateState(response *pb.ListResourcesResponse) { client.names = make([]string, len(response.GetResources())) for i, res := range response.GetResources() { - var ip net.IP - - if res.GetIp() != "" { - ip = net.ParseIP(res.GetIp()) - if ip == nil { - client.l.Errorf("rds.client: errors parsing IP address for %s, IP string: %s", res.GetName(), res.GetIp()) - continue - } - } - client.cache[res.GetName()] = &cacheRecord{ip, int(res.GetPort()), res.Labels} + client.cache[res.GetName()] = &cacheRecord{res.GetIp(), int(res.GetPort()), res.Labels} client.names[i] = res.GetName() } } @@ -115,11 +115,17 @@ func (client *Client) ListEndpoints() []endpoint.Endpoint { func (client *Client) Resolve(name string, ipVer int) (net.IP, error) { client.mu.RLock() defer client.mu.RUnlock() + cr, ok := client.cache[name] - if !ok || cr.ip == nil { + if !ok || cr.ip == "" { return nil, fmt.Errorf("no IP address for the resource: %s", name) } - ip := cr.ip + + ip := net.ParseIP(cr.ip) + // If not a valid IP, use DNS resolver to resolve it. + if ip == nil { + return client.resolver.Resolve(cr.ip, ipVer) + } if ipVer == 0 || iputils.IPVersion(ip) == ipVer { return ip, nil @@ -189,6 +195,7 @@ func New(c *configpb.ClientConf, listResources ListResourcesFunc, l *logger.Logg serverOpts: c.GetServerOptions(), cache: make(map[string]*cacheRecord), listResources: listResources, + resolver: globalResolver, l: l, } @@ -213,3 +220,8 @@ func New(c *configpb.ClientConf, listResources ListResourcesFunc, l *logger.Logg return client, nil } + +// init initializes the package by creating a new global resolver. +func init() { + globalResolver = dnsRes.New() +} diff --git a/rds/client/client_test.go b/rds/client/client_test.go index 9f688e6f..2c708ca6 100644 --- a/rds/client/client_test.go +++ b/rds/client/client_test.go @@ -16,6 +16,8 @@ package client import ( "context" + "fmt" + "net" "reflect" "testing" "time" @@ -26,6 +28,7 @@ import ( pb "github.com/google/cloudprober/rds/proto" "github.com/google/cloudprober/rds/server" serverpb "github.com/google/cloudprober/rds/server/proto" + dnsRes "github.com/google/cloudprober/targets/resolver" ) type testProvider struct { @@ -46,9 +49,48 @@ var testResourcesMap = map[string][]*pb.Resource{ Port: proto.Int32(8080), Labels: map[string]string{"zone": "us-central1-a"}, }, + { + Name: proto.String("testR22v6"), + Ip: proto.String("::1"), + Port: proto.Int32(8080), + Labels: map[string]string{"zone": "us-central1-a"}, + }, + { + Name: proto.String("testR3"), + Ip: proto.String("testR3.test.com"), + Port: proto.Int32(80), + Labels: map[string]string{"zone": "us-central1-c"}, + }, }, } +var expectedIPByVersion = map[string]map[int]string{ + "testR21": map[int]string{ + 0: "10.0.2.1", + 4: "10.0.2.1", + 6: "err", + }, + "testR22": map[int]string{ + 0: "10.0.2.2", + 4: "10.0.2.2", + 6: "err", + }, + "testR22v6": map[int]string{ + 0: "::1", + 4: "err", + 6: "::1", + }, + "testR3": map[int]string{ + 0: "10.1.1.2", + 4: "10.1.1.2", + 6: "::2", + }, +} + +var testNameToIP = map[string][]net.IP{ + "testR3.test.com": []net.IP{net.ParseIP("10.1.1.2"), net.ParseIP("::2")}, +} + func (tp *testProvider) ListResources(*pb.ListResourcesRequest) (*pb.ListResourcesResponse, error) { return &pb.ListResourcesResponse{ Resources: tp.resources, @@ -88,6 +130,10 @@ func TestListAndResolve(t *testing.T) { if err != nil { t.Fatalf("Got error initializing RDS client: %v", err) } + client.resolver = dnsRes.NewWithResolve(func(name string) ([]net.IP, error) { + return testNameToIP[name], nil + }) + client.refreshState(time.Second) // Test ListEndpoint() @@ -108,12 +154,23 @@ func TestListAndResolve(t *testing.T) { // Test Resolve() for _, res := range testResources { - ip, err := client.Resolve(res.GetName(), 4) - if err != nil { - t.Errorf("Error resolving %s, err: %v", res.GetName(), err) - } - if ip.String() != res.GetIp() { - t.Errorf("Didn't get expected IP for %s. Got: %s, Want: %s", res.GetName(), ip.String(), res.GetIp()) + for _, ipVer := range []int{0, 4, 6} { + t.Run(fmt.Sprintf("resolve_%s_for_IPv%d", res.GetName(), ipVer), func(t *testing.T) { + expectedIP := expectedIPByVersion[res.GetName()][ipVer] + + var gotIP string + ip, err := client.Resolve(res.GetName(), ipVer) + if err != nil { + t.Logf("Error resolving %s, err: %v", res.GetName(), err) + gotIP = "err" + } else { + gotIP = ip.String() + } + + if gotIP != expectedIP { + t.Errorf("Didn't get expected IP for %s. Got: %s, Want: %s", res.GetName(), gotIP, expectedIP) + } + }) } } } diff --git a/targets/resolver/resolver.go b/targets/resolver/resolver.go index ee2324da..8474afea 100644 --- a/targets/resolver/resolver.go +++ b/targets/resolver/resolver.go @@ -188,11 +188,17 @@ func (cr *cacheRecord) refreshIfRequired(name string, resolve func(string) ([]ne } } -// New returns a new Resolver. -func New() *Resolver { +// NewWithResolve returns a new Resolver with the given backend resolver. +// This is useful for testing. +func NewWithResolve(resolveFunc func(string) ([]net.IP, error)) *Resolver { return &Resolver{ cache: make(map[string]*cacheRecord), - resolve: func(name string) ([]net.IP, error) { return net.LookupIP(name) }, + resolve: resolveFunc, DefaultMaxAge: defaultMaxAge, } } + +// New returns a new Resolver. +func New() *Resolver { + return NewWithResolve(net.LookupIP) +}