diff --git a/lib/netext/dialer.go b/lib/netext/dialer.go index a1a686ff86d0..62d4ffc0bab9 100644 --- a/lib/netext/dialer.go +++ b/lib/netext/dialer.go @@ -65,47 +65,56 @@ func (b BlackListedIPError) Error() string { return fmt.Sprintf("IP (%s) is in a blacklisted range (%s)", b.ip, b.net) } -func resolveHost(host string, hosts map[string]net.IP, resolver *dnscache.Resolver) (net.IP, error) { +// Resolver is implemented by dnscache.Resolver and used by tests to +// pass a mock resolver. +type Resolver interface { + FetchOne(string) (net.IP, error) +} + +// DialContext wraps the net.Dialer.DialContext and handles the k6 specifics +func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, error) { + address, err := d.checkAndResolveAddress(addr, d.Resolver) + if err != nil { + return nil, err + } + + var conn net.Conn + conn, err = d.Dialer.DialContext(ctx, proto, address) + if err != nil { + return nil, err + } + conn = &Conn{conn, &d.BytesRead, &d.BytesWritten} + return conn, err +} + +func (d *Dialer) checkAndResolveAddress(addr string, resolver Resolver) (string, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return "", err + } + ip := net.ParseIP(host) if ip == nil { // It's not an IP address, so lookup the hostname in the Hosts // option before trying to resolve DNS. var ok bool - ip, ok = hosts[host] + ip, ok = d.Hosts[host] if !ok { var dnsErr error ip, dnsErr = resolver.FetchOne(host) if dnsErr != nil { - return nil, dnsErr + return "", dnsErr } } } - return ip, nil -} - -// DialContext wraps the net.Dialer.DialContext and handles the k6 specifics -func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - ip, resErr := resolveHost(host, d.Hosts, d.Resolver) - if resErr != nil { - return nil, resErr - } for _, ipnet := range d.Blacklist { if (*net.IPNet)(ipnet).Contains(ip) { - return nil, BlackListedIPError{ip: ip, net: ipnet} + return "", BlackListedIPError{ip: ip, net: ipnet} } } - conn, err := d.Dialer.DialContext(ctx, proto, net.JoinHostPort(ip.String(), port)) - if err != nil { - return nil, err - } - conn = &Conn{conn, &d.BytesRead, &d.BytesWritten} - return conn, err + + return net.JoinHostPort(ip.String(), port), nil } // GetTrail creates a new NetTrail instance with the Dialer diff --git a/lib/netext/dialer_test.go b/lib/netext/dialer_test.go index 8a80c86759ad..978c481c55cc 100644 --- a/lib/netext/dialer_test.go +++ b/lib/netext/dialer_test.go @@ -1,46 +1,104 @@ +/* + * + * k6 - a next-generation load testing tool + * Copyright (C) 2020 Load Impact + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + package netext import ( + "fmt" "net" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/viki-org/dnscache" + + "github.com/loadimpact/k6/lib" ) -func TestDialerResolveHost(t *testing.T) { +type mockResolver struct { + hosts map[string]net.IP +} + +func (r *mockResolver) FetchOne(h string) (net.IP, error) { + var ( + ip net.IP + ok bool + ) + if ip, ok = r.hosts[h]; !ok { + return nil, fmt.Errorf("mock lookup %s: no such host", h) + } + return ip, nil +} + +func TestDialerCheckAndResolveAddress(t *testing.T) { t.Parallel() testCases := []struct { - host string - hosts map[string]net.IP - ipVer int - expErr string + address, expAddress, expErr string }{ - {"1.2.3.4", nil, 4, ""}, - {"256.1.1.1", nil, 4, "lookup 256.1.1.1: no such host"}, - {"example.com", nil, 4, ""}, - {"::1", nil, 6, ""}, - {"::1.2.3.4", nil, 6, ""}, - {"abcd:ef01:2345:6789", nil, 6, "lookup abcd:ef01:2345:6789: no such host"}, - {"2001:db8:aaaa:1::100", nil, 6, ""}, - {"ipv6.google.com", nil, 6, ""}, - {"mycustomv4.host", map[string]net.IP{"mycustomv4.host": net.ParseIP("1.2.3.4")}, 4, ""}, - {"mycustomv6.host", map[string]net.IP{"mycustomv6.host": net.ParseIP("::1")}, 6, ""}, + // IPv4 + {"1.2.3.4:80", "1.2.3.4:80", ""}, + {"example.com:443", "93.184.216.34:443", ""}, + {"mycustomv4.host:443", "1.2.3.4:443", ""}, + {"1.2.3.4", "", "address 1.2.3.4: missing port in address"}, + {"256.1.1.1:80", "", "mock lookup 256.1.1.1: no such host"}, + {"blockedv4.host:443", "", "IP (10.0.0.10) is in a blacklisted range (10.0.0.0/8)"}, + + // IPv6 + {"::1", "", "address ::1: too many colons in address"}, + {"[::1.2.3.4]", "", "address [::1.2.3.4]: missing port in address"}, + {"[::1.2.3.4]:443", "[::102:304]:443", ""}, + {"[abcd:ef01:2345:6789]:443", "", "mock lookup abcd:ef01:2345:6789: no such host"}, + {"[2001:db8:aaaa:1::100]:443", "[2001:db8:aaaa:1::100]:443", ""}, + {"ipv6.google.com:443", "[2a00:1450:4007:812::200e]:443", ""}, + {"blockedv6.host:443", "", "IP (2600::1) is in a blacklisted range (2600::/64)"}, + } + + block4, err := lib.ParseCIDR("10.0.0.0/8") + require.NoError(t, err) + block6, err := lib.ParseCIDR("2600::/64") + require.NoError(t, err) + + dialer := &Dialer{ + Blacklist: []*lib.IPNet{block4, block6}, + Hosts: map[string]net.IP{ + "mycustomv4.host": net.ParseIP("1.2.3.4"), + "mycustomv6.host": net.ParseIP("::1"), + }, } + resolver := &mockResolver{hosts: map[string]net.IP{ + "example.com": net.ParseIP("93.184.216.34"), + "ipv6.google.com": net.ParseIP("2a00:1450:4007:812::200e"), + "blockedv4.host": net.ParseIP("10.0.0.10"), + "blockedv6.host": net.ParseIP("2600::1"), + }} - resolver := dnscache.New(0) for _, tc := range testCases { tc := tc - t.Run(tc.host, func(t *testing.T) { - ip, err := resolveHost(tc.host, tc.hosts, resolver) + t.Run(tc.address, func(t *testing.T) { + address, err := dialer.checkAndResolveAddress(tc.address, resolver) if tc.expErr != "" { assert.EqualError(t, err, tc.expErr) return } require.NoError(t, err) - assert.Equal(t, tc.ipVer == 6, ip.To4() == nil) + assert.Equal(t, tc.expAddress, address) }) } }