Skip to content

Commit

Permalink
Refactor Dialer test to mock resolver, increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Mirić committed Jul 20, 2020
1 parent 437d91f commit c8ab370
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 44 deletions.
57 changes: 33 additions & 24 deletions lib/netext/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,47 +65,56 @@ func (b BlackListedIPError) Error() string {
return fmt.Sprintf("IP (%s) is in a blacklisted range (%s)", b.ip, b.net)
}

func resolveHost(host string, hosts map[string]net.IP, resolver *dnscache.Resolver) (net.IP, error) {
// Resolver is implemented by dnscache.Resolver and used by tests to
// pass a mock resolver.
type Resolver interface {
FetchOne(string) (net.IP, error)
}

// DialContext wraps the net.Dialer.DialContext and handles the k6 specifics
func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, error) {
address, err := d.checkAndResolveAddress(addr, d.Resolver)
if err != nil {
return nil, err
}

var conn net.Conn
conn, err = d.Dialer.DialContext(ctx, proto, address)
if err != nil {
return nil, err
}
conn = &Conn{conn, &d.BytesRead, &d.BytesWritten}
return conn, err
}

func (d *Dialer) checkAndResolveAddress(addr string, resolver Resolver) (string, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return "", err
}

ip := net.ParseIP(host)
if ip == nil {
// It's not an IP address, so lookup the hostname in the Hosts
// option before trying to resolve DNS.
var ok bool
ip, ok = hosts[host]
ip, ok = d.Hosts[host]
if !ok {
var dnsErr error
ip, dnsErr = resolver.FetchOne(host)
if dnsErr != nil {
return nil, dnsErr
return "", dnsErr
}
}
}
return ip, nil
}

// DialContext wraps the net.Dialer.DialContext and handles the k6 specifics
func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}

ip, resErr := resolveHost(host, d.Hosts, d.Resolver)
if resErr != nil {
return nil, resErr
}

for _, ipnet := range d.Blacklist {
if (*net.IPNet)(ipnet).Contains(ip) {
return nil, BlackListedIPError{ip: ip, net: ipnet}
return "", BlackListedIPError{ip: ip, net: ipnet}
}
}
conn, err := d.Dialer.DialContext(ctx, proto, net.JoinHostPort(ip.String(), port))
if err != nil {
return nil, err
}
conn = &Conn{conn, &d.BytesRead, &d.BytesWritten}
return conn, err

return net.JoinHostPort(ip.String(), port), nil
}

// GetTrail creates a new NetTrail instance with the Dialer
Expand Down
98 changes: 78 additions & 20 deletions lib/netext/dialer_test.go
Original file line number Diff line number Diff line change
@@ -1,46 +1,104 @@
/*
*
* k6 - a next-generation load testing tool
* Copyright (C) 2020 Load Impact
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/

package netext

import (
"fmt"
"net"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/viki-org/dnscache"

"github.com/loadimpact/k6/lib"
)

func TestDialerResolveHost(t *testing.T) {
type mockResolver struct {
hosts map[string]net.IP
}

func (r *mockResolver) FetchOne(h string) (net.IP, error) {
var (
ip net.IP
ok bool
)
if ip, ok = r.hosts[h]; !ok {
return nil, fmt.Errorf("mock lookup %s: no such host", h)
}
return ip, nil
}

func TestDialerCheckAndResolveAddress(t *testing.T) {
t.Parallel()

testCases := []struct {
host string
hosts map[string]net.IP
ipVer int
expErr string
address, expAddress, expErr string
}{
{"1.2.3.4", nil, 4, ""},
{"256.1.1.1", nil, 4, "lookup 256.1.1.1: no such host"},
{"example.com", nil, 4, ""},
{"::1", nil, 6, ""},
{"::1.2.3.4", nil, 6, ""},
{"abcd:ef01:2345:6789", nil, 6, "lookup abcd:ef01:2345:6789: no such host"},
{"2001:db8:aaaa:1::100", nil, 6, ""},
{"ipv6.google.com", nil, 6, ""},
{"mycustomv4.host", map[string]net.IP{"mycustomv4.host": net.ParseIP("1.2.3.4")}, 4, ""},
{"mycustomv6.host", map[string]net.IP{"mycustomv6.host": net.ParseIP("::1")}, 6, ""},
// IPv4
{"1.2.3.4:80", "1.2.3.4:80", ""},
{"example.com:443", "93.184.216.34:443", ""},
{"mycustomv4.host:443", "1.2.3.4:443", ""},
{"1.2.3.4", "", "address 1.2.3.4: missing port in address"},
{"256.1.1.1:80", "", "mock lookup 256.1.1.1: no such host"},
{"blockedv4.host:443", "", "IP (10.0.0.10) is in a blacklisted range (10.0.0.0/8)"},

// IPv6
{"::1", "", "address ::1: too many colons in address"},
{"[::1.2.3.4]", "", "address [::1.2.3.4]: missing port in address"},
{"[::1.2.3.4]:443", "[::102:304]:443", ""},
{"[abcd:ef01:2345:6789]:443", "", "mock lookup abcd:ef01:2345:6789: no such host"},
{"[2001:db8:aaaa:1::100]:443", "[2001:db8:aaaa:1::100]:443", ""},
{"ipv6.google.com:443", "[2a00:1450:4007:812::200e]:443", ""},
{"blockedv6.host:443", "", "IP (2600::1) is in a blacklisted range (2600::/64)"},
}

block4, err := lib.ParseCIDR("10.0.0.0/8")
require.NoError(t, err)
block6, err := lib.ParseCIDR("2600::/64")
require.NoError(t, err)

dialer := &Dialer{
Blacklist: []*lib.IPNet{block4, block6},
Hosts: map[string]net.IP{
"mycustomv4.host": net.ParseIP("1.2.3.4"),
"mycustomv6.host": net.ParseIP("::1"),
},
}
resolver := &mockResolver{hosts: map[string]net.IP{
"example.com": net.ParseIP("93.184.216.34"),
"ipv6.google.com": net.ParseIP("2a00:1450:4007:812::200e"),
"blockedv4.host": net.ParseIP("10.0.0.10"),
"blockedv6.host": net.ParseIP("2600::1"),
}}

resolver := dnscache.New(0)
for _, tc := range testCases {
tc := tc
t.Run(tc.host, func(t *testing.T) {
ip, err := resolveHost(tc.host, tc.hosts, resolver)
t.Run(tc.address, func(t *testing.T) {
address, err := dialer.checkAndResolveAddress(tc.address, resolver)
if tc.expErr != "" {
assert.EqualError(t, err, tc.expErr)
return
}
require.NoError(t, err)
assert.Equal(t, tc.ipVer == 6, ip.To4() == nil)
assert.Equal(t, tc.expAddress, address)
})
}
}

0 comments on commit c8ab370

Please sign in to comment.