diff --git a/lib/netext/dialer.go b/lib/netext/dialer.go index ecf65c8fef29..a1a686ff86d0 100644 --- a/lib/netext/dialer.go +++ b/lib/netext/dialer.go @@ -65,27 +65,35 @@ func (b BlackListedIPError) Error() string { return fmt.Sprintf("IP (%s) is in a blacklisted range (%s)", b.ip, b.net) } -// DialContext wraps the net.Dialer.DialContext and handles the k6 specifics -func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - +func resolveHost(host string, hosts map[string]net.IP, resolver *dnscache.Resolver) (net.IP, error) { ip := net.ParseIP(host) if ip == nil { // It's not an IP address, so lookup the hostname in the Hosts // option before trying to resolve DNS. var ok bool - ip, ok = d.Hosts[host] + ip, ok = hosts[host] if !ok { var dnsErr error - ip, dnsErr = d.Resolver.FetchOne(host) + ip, dnsErr = resolver.FetchOne(host) if dnsErr != nil { return nil, dnsErr } } } + return ip, nil +} + +// DialContext wraps the net.Dialer.DialContext and handles the k6 specifics +func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + ip, resErr := resolveHost(host, d.Hosts, d.Resolver) + if resErr != nil { + return nil, resErr + } for _, ipnet := range d.Blacklist { if (*net.IPNet)(ipnet).Contains(ip) { diff --git a/lib/netext/dialer_test.go b/lib/netext/dialer_test.go new file mode 100644 index 000000000000..8a80c86759ad --- /dev/null +++ b/lib/netext/dialer_test.go @@ -0,0 +1,46 @@ +package netext + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viki-org/dnscache" +) + +func TestDialerResolveHost(t *testing.T) { + t.Parallel() + + testCases := []struct { + host string + hosts map[string]net.IP + ipVer int + expErr string + }{ + {"1.2.3.4", nil, 4, ""}, + {"256.1.1.1", nil, 4, "lookup 256.1.1.1: no such host"}, + {"example.com", nil, 4, ""}, + {"::1", nil, 6, ""}, + {"::1.2.3.4", nil, 6, ""}, + {"abcd:ef01:2345:6789", nil, 6, "lookup abcd:ef01:2345:6789: no such host"}, + {"2001:db8:aaaa:1::100", nil, 6, ""}, + {"ipv6.google.com", nil, 6, ""}, + {"mycustomv4.host", map[string]net.IP{"mycustomv4.host": net.ParseIP("1.2.3.4")}, 4, ""}, + {"mycustomv6.host", map[string]net.IP{"mycustomv6.host": net.ParseIP("::1")}, 6, ""}, + } + + resolver := dnscache.New(0) + for _, tc := range testCases { + tc := tc + t.Run(tc.host, func(t *testing.T) { + ip, err := resolveHost(tc.host, tc.hosts, resolver) + if tc.expErr != "" { + assert.EqualError(t, err, tc.expErr) + return + } + require.NoError(t, err) + assert.Equal(t, tc.ipVer == 6, ip.To4() == nil) + }) + } +}