Skip to content

Commit

Permalink
Pull request: 2846 cover aghnet vol.2
Browse files Browse the repository at this point in the history
Merge in DNS/adguard-home from 2846-cover-aghnet-vol.2 to master

Updates AdguardTeam#2846.
Closes AdguardTeam#4408.

Squashed commit of the following:

commit 8d62b29
Author: Eugene Burkov <[email protected]>
Date:   Wed Mar 23 20:42:04 2022 +0300

    home: recover panic

commit 1d98109
Merge: ac11d75 9ce2a0f
Author: Eugene Burkov <[email protected]>
Date:   Wed Mar 23 20:32:05 2022 +0300

    Merge branch 'master' into 2846-cover-aghnet-vol.2

commit ac11d75
Author: Eugene Burkov <[email protected]>
Date:   Wed Mar 23 20:29:41 2022 +0300

    aghnet: use iotest

commit 7c923df
Author: Eugene Burkov <[email protected]>
Date:   Wed Mar 23 20:17:19 2022 +0300

    aghnet: cover more

commit 3bfd4d5
Author: Eugene Burkov <[email protected]>
Date:   Wed Mar 23 14:13:59 2022 +0300

    aghnet: cover arpdb more

commit cd5cf7b
Author: Eugene Burkov <[email protected]>
Date:   Wed Mar 23 13:05:35 2022 +0300

    all: rm arpdb initial refresh

commit 0fb8d9e
Author: Eugene Burkov <[email protected]>
Date:   Tue Mar 22 21:13:16 2022 +0300

    aghnet: cover arpdb
  • Loading branch information
EugeneOne1 authored and heyxkhoa committed Mar 17, 2023
1 parent ada429e commit f7dfa16
Show file tree
Hide file tree
Showing 17 changed files with 143 additions and 88 deletions.
11 changes: 2 additions & 9 deletions internal/aghnet/arpdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,8 @@ type ARPDB interface {
}

// NewARPDB returns the ARPDB properly initialized for the OS.
func NewARPDB() (arp ARPDB, err error) {
arp = newARPDB()

err = arp.Refresh()
if err != nil {
return nil, fmt.Errorf("arpdb initial refresh: %w", err)
}

return arp, nil
func NewARPDB() (arp ARPDB) {
return newARPDB()
}

// Empty ARPDB implementation
Expand Down
8 changes: 8 additions & 0 deletions internal/aghnet/arpdb_bsd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (
)

const arpAOutput = `
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
hostname.one (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ethernet]
? (::1234) at aa:bb:cc:dd:ee:ff on ej0 expires in 1918 seconds [ethernet]
`

var wantNeighs = []Neighbor{{
Expand All @@ -20,4 +24,8 @@ var wantNeighs = []Neighbor{{
Name: "hostname.two",
IP: net.ParseIP("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}, {
Name: "",
IP: net.ParseIP("::1234"),
MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
}}
7 changes: 7 additions & 0 deletions internal/aghnet/arpdb_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@ import (

const arpAOutputWrt = `
IP address HW type Flags HW address Mask Device
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
::ffff:ffff 0x1 0x2 ef:cd:ab:ef:cd:ab * wan`

const arpAOutput = `
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
? (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
? (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 100 seconds [ethernet]`

const ipNeighOutput = `
1.2.3.4.5 dev enp0s3 lladdr aa:bb:cc:dd:ee:ff DELAY
1.2.3.4 dev enp0s3 lladdr 12:34:56:78:910 DELAY
192.168.1.2 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef DELAY
::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE`

Expand Down
2 changes: 2 additions & 0 deletions internal/aghnet/arpdb_openbsd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (

const arpAOutput = `
Host Ethernet Address Netif Expire Flags
1.2.3.4.5 aa:bb:cc:dd:ee:ff em0 permanent
1.2.3.4 12:34:56:78:910 em0 permanent
192.168.1.2 ab:cd:ef:ab:cd:ef em0 19m56s
::ffff:ffff ef:cd:ab:ef:cd:ab em0 permanent l
`
Expand Down
53 changes: 53 additions & 0 deletions internal/aghnet/arpdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,23 @@ import (
"strings"
"sync"
"testing"
"testing/iotest"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewARPDB(t *testing.T) {
var a ARPDB
require.NotPanics(t, func() {
a = NewARPDB()
})

assert.NotNil(t, a)
}

// TestARPDB is the mock implementation of ARPDB to use in tests.
type TestARPDB struct {
OnRefresh func() (err error)
Expand Down Expand Up @@ -166,3 +176,46 @@ func TestCmdARPDB_arpa(t *testing.T) {
testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err)
})
}

func TestCmdARPDB_errors(t *testing.T) {
const errRead errors.Error = "can't read"

badReaderRunCmd := runCmdFunc(func() (r io.Reader, err error) {
return iotest.ErrReader(errRead), nil
})

a := &cmdARPDB{
runcmd: badReaderRunCmd,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}

const wantErrMsg string = "cmd arpdb: scanning the output: " + string(errRead)

testutil.AssertErrorMsg(t, wantErrMsg, a.Refresh())
}

func TestEmptyARPDB(t *testing.T) {
a := EmptyARPDB{}

t.Run("refresh", func(t *testing.T) {
var err error
require.NotPanics(t, func() {
err = a.Refresh()
})

assert.NoError(t, err)
})

t.Run("neighbors", func(t *testing.T) {
var ns []Neighbor
require.NotPanics(t, func() {
ns = a.Neighbors()
})

assert.Empty(t, ns)
})
}
4 changes: 2 additions & 2 deletions internal/aghnet/hostscontainer.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
}
}

// writeRules writes the actual rule for the qtype and the PTR for the
// host-ip pair into internal builders.
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
// pair into internal builders.
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
arpa, err := netutil.IPToReversedAddr(ip)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions internal/aghnet/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
}
if len(ifaces) == 0 {
} else if len(ifaces) == 0 {
return nil, errors.Error("couldn't find any legible interface")
}

Expand Down
25 changes: 25 additions & 0 deletions internal/aghnet/net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -130,3 +131,27 @@ func TestCheckPort(t *testing.T) {
assert.NoError(t, err)
})
}

func TestCollectAllIfacesAddrs(t *testing.T) {
addrs, err := CollectAllIfacesAddrs()
require.NoError(t, err)

assert.NotEmpty(t, addrs)
}

func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)

_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})

t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"

assert.False(t, IsAddrInUse(anotherErr))
})
}
40 changes: 5 additions & 35 deletions internal/aghnet/systemresolvers.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
package aghnet

import (
"time"

"github.com/AdguardTeam/golibs/log"
)

// DefaultRefreshIvl is the default period of time between refreshing cached
// addresses.
// const DefaultRefreshIvl = 5 * time.Minute
Expand All @@ -16,51 +10,27 @@ type HostGenFunc func() (host string)

// SystemResolvers helps to work with local resolvers' addresses provided by OS.
type SystemResolvers interface {
// Get returns the slice of local resolvers' addresses. It should be
// safe for concurrent use.
// Get returns the slice of local resolvers' addresses. It must be safe for
// concurrent use.
Get() (rs []string)
// refresh refreshes the local resolvers' addresses cache. It should be
// safe for concurrent use.
// refresh refreshes the local resolvers' addresses cache. It must be safe
// for concurrent use.
refresh() (err error)
}

// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
defer log.OnPanic("systemResolvers")

// TODO(e.burkov): Implement a functionality to stop ticker.
for range tickCh {
err := sr.refresh()
if err != nil {
log.Error("systemResolvers: error in refreshing goroutine: %s", err)

continue
}

log.Debug("systemResolvers: local addresses cache is refreshed")
}
}

// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
// nil is passed for hostGenFunc, the default generator will be used.
func NewSystemResolvers(
refreshIvl time.Duration,
hostGenFunc HostGenFunc,
) (sr SystemResolvers, err error) {
sr = newSystemResolvers(refreshIvl, hostGenFunc)
sr = newSystemResolvers(hostGenFunc)

// Fill cache.
err = sr.refresh()
if err != nil {
return nil, err
}

if refreshIvl > 0 {
ticker := time.NewTicker(refreshIvl)

go refreshWithTicker(sr, ticker.C)
}

return sr, nil
}
23 changes: 13 additions & 10 deletions internal/aghnet/systemresolvers_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ func defaultHostGen() (host string) {

// systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct {
resolver *net.Resolver
hostGenFunc HostGenFunc

// addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set
// addrsLock protects addrs.
addrsLock sync.RWMutex
// addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set

// resolver is used to fetch the resolvers' addresses.
resolver *net.Resolver
// hostGenFunc generates hosts to resolve.
hostGenFunc HostGenFunc
}

const (
Expand All @@ -44,6 +47,7 @@ const (
errUnexpectedHostFormat errors.Error = "unexpected host format"
)

// refresh implements the SystemResolvers interface for *systemResolvers.
func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()

Expand All @@ -56,7 +60,7 @@ func (sr *systemResolvers) refresh() (err error) {
return err
}

func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) {
func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
if hostGenFunc == nil {
hostGenFunc = defaultHostGen
}
Expand All @@ -76,19 +80,18 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
func validateDialedHost(host string) (err error) {
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()

var ipStr string
parts := strings.Split(host, "%")
switch len(parts) {
case 1:
ipStr = host
// host
case 2:
// Remove the zone and check the IP address part.
ipStr = parts[0]
host = parts[0]
default:
return errUnexpectedHostFormat
}

if net.ParseIP(ipStr) == nil {
if _, err = netutil.ParseIP(host); err != nil {
return errBadAddrPassed
}

Expand Down
21 changes: 8 additions & 13 deletions internal/aghnet/systemresolvers_others_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,53 +6,48 @@ package aghnet
import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func createTestSystemResolversImp(
func createTestSystemResolversImpl(
t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc,
) (imp *systemResolvers) {
t.Helper()

sr := createTestSystemResolvers(t, refreshDur, hostGenFunc)
sr := createTestSystemResolvers(t, hostGenFunc)
require.IsType(t, (*systemResolvers)(nil), sr)

var ok bool
imp, ok = sr.(*systemResolvers)
require.True(t, ok)

return imp
return sr.(*systemResolvers)
}

func TestSystemResolvers_Refresh(t *testing.T) {
t.Run("expected_error", func(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil)
sr := createTestSystemResolvers(t, nil)

assert.NoError(t, sr.refresh())
})

t.Run("unexpected_error", func(t *testing.T) {
_, err := NewSystemResolvers(0, func() string {
_, err := NewSystemResolvers(func() string {
return "127.0.0.1::123"
})
assert.Error(t, err)
})
}

func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, 0, nil)
imp := createTestSystemResolversImpl(t, nil)

testCases := []struct {
want error
name string
address string
}{{
want: errFakeDial,
name: "valid",
name: "valid_ipv4",
address: "127.0.0.1",
}, {
want: errFakeDial,
Expand Down
Loading

0 comments on commit f7dfa16

Please sign in to comment.