diff --git a/internal/agherr/agherr.go b/internal/agherr/agherr.go index aedf2a8ba06..4cef546925f 100644 --- a/internal/agherr/agherr.go +++ b/internal/agherr/agherr.go @@ -1,5 +1,4 @@ -// Package agherr contains the extended error type, and the function for -// wrapping several errors. +// Package agherr contains AdGuard Home's error handling helpers. package agherr import ( @@ -23,8 +22,8 @@ type manyError struct { } // Many wraps several errors and returns a single error. -func Many(message string, underlying ...error) error { - err := &manyError{ +func Many(message string, underlying ...error) (err error) { + err = &manyError{ message: message, underlying: underlying, } @@ -33,7 +32,7 @@ func Many(message string, underlying ...error) error { } // Error implements the error interface for *manyError. -func (e *manyError) Error() string { +func (e *manyError) Error() (msg string) { switch len(e.underlying) { case 0: return e.message @@ -58,7 +57,7 @@ func (e *manyError) Error() string { } // Unwrap implements the hidden errors.wrapper interface for *manyError. -func (e *manyError) Unwrap() error { +func (e *manyError) Unwrap() (err error) { if len(e.underlying) == 0 { return nil } @@ -71,3 +70,37 @@ func (e *manyError) Unwrap() error { type wrapper interface { Unwrap() error } + +// Annotate annotates the error with the message, unless the error is nil. This +// is a helper function to simplify code like this: +// +// func (f *foo) doStuff(s string) (err error) { +// defer func() { +// if err != nil { +// err = fmt.Errorf("bad foo string %q: %w", s, err) +// } +// }() +// +// // … +// } +// +// Instead, write: +// +// func (f *foo) doStuff(s string) (err error) { +// defer agherr.Annotate("bad foo string %q: %w", &err, s) +// +// // … +// } +// +func Annotate(msg string, errPtr *error, args ...interface{}) { + if errPtr == nil { + return + } + + err := *errPtr + if err != nil { + args = append(args, err) + + *errPtr = fmt.Errorf(msg, args...) + } +} diff --git a/internal/agherr/agherr_test.go b/internal/agherr/agherr_test.go index 3ac5aeab426..b9a1ff900c0 100644 --- a/internal/agherr/agherr_test.go +++ b/internal/agherr/agherr_test.go @@ -6,30 +6,32 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestError_Error(t *testing.T) { testCases := []struct { + err error name string want string - err error }{{ + err: Many("a"), name: "simple", want: "a", - err: Many("a"), }, { + err: Many("a", errors.New("b")), name: "wrapping", want: "a: b", - err: Many("a", errors.New("b")), }, { + err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")), name: "wrapping several", want: "a: b (hidden: c, d)", - err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")), }, { + err: Many("a", Many("b", errors.New("c"), errors.New("d"))), name: "wrapping wrapper", want: "a: b: c (hidden: d)", - err: Many("a", Many("b", errors.New("c"), errors.New("d"))), }} + for _, tc := range testCases { assert.Equal(t, tc.want, tc.err.Error(), tc.name) } @@ -43,33 +45,78 @@ func TestError_Unwrap(t *testing.T) { errWrapped errNil ) + errs := []error{ errSimple: errors.New("a"), errWrapped: fmt.Errorf("err: %w", errors.New("nested")), errNil: nil, } + testCases := []struct { - name string want error wrapped error + name string }{{ - name: "simple", want: errs[errSimple], wrapped: Many("a", errs[errSimple]), + name: "simple", }, { - name: "nested", want: errs[errWrapped], wrapped: Many("b", errs[errWrapped]), + name: "nested", }, { - name: "nil passed", want: errs[errNil], wrapped: Many("c", errs[errNil]), + name: "nil passed", }, { - name: "nil not passed", want: nil, wrapped: Many("d"), + name: "nil not passed", }} + for _, tc := range testCases { assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name) } } + +func TestAnnotate(t *testing.T) { + const s = "1234" + const wantMsg = `bad string "1234": test` + + // Don't use const, because we can't take a pointer of a contant. + var errTest error = Error("test") + + t.Run("nil", func(t *testing.T) { + var errPtr *error + assert.NotPanics(t, func() { + Annotate("bad string %q: %w", errPtr, s) + }) + }) + + t.Run("non_nil", func(t *testing.T) { + errPtr := &errTest + assert.NotPanics(t, func() { + Annotate("bad string %q: %w", errPtr, s) + }) + + require.NotNil(t, errPtr) + + err := *errPtr + require.NotNil(t, err) + + assert.Equal(t, wantMsg, err.Error()) + }) + + t.Run("defer", func(t *testing.T) { + f := func() (err error) { + defer Annotate("bad string %q: %w", &errTest, s) + + return errTest + } + + err := f() + require.NotNil(t, err) + + assert.Equal(t, wantMsg, err.Error()) + }) +} diff --git a/internal/dhcpd/iprange.go b/internal/dhcpd/iprange.go index fa70c8f6ab8..50242255dff 100644 --- a/internal/dhcpd/iprange.go +++ b/internal/dhcpd/iprange.go @@ -5,6 +5,8 @@ import ( "math" "math/big" "net" + + "github.com/AdguardTeam/AdGuardHome/internal/agherr" ) // ipRange is an inclusive range of IP addresses. @@ -18,17 +20,14 @@ type ipRange struct { end *big.Int } -// maxRangeLen is the maximum IP range length. +// maxRangeLen is the maximum IP range length. The bitsets used in servers only +// accept uints, which can have the size of 32 bit. const maxRangeLen = math.MaxUint32 // newIPRange creates a new IP address range. start must be less than end. The // resulting range must not be greater than maxRangeLen. func newIPRange(start, end net.IP) (r *ipRange, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("invalid ip range: %w", err) - } - }() + defer agherr.Annotate("invalid ip range: %w", &err) // Make sure that both are 16 bytes long to simplify handling in // methods. @@ -85,7 +84,7 @@ func (r *ipRange) find(p ipPredicate) (ip net.IP) { // offset returns the offset of ip from the beginning of r. It returns 0 and // false if ip is not in r. -func (r *ipRange) offset(ip net.IP) (offset uint64, ok bool) { +func (r *ipRange) offset(ip net.IP) (offset uint, ok bool) { ip = ip.To16() ipInt := (&big.Int{}).SetBytes(ip) if !r.containsInt(ipInt) { @@ -96,5 +95,5 @@ func (r *ipRange) offset(ip net.IP) (offset uint64, ok bool) { // Assume that the range was checked against maxRangeLen during // construction. - return offsetInt.Uint64(), true + return uint(offsetInt.Uint64()), true } diff --git a/internal/dhcpd/iprange_test.go b/internal/dhcpd/iprange_test.go index 2a95e691836..01991532a39 100644 --- a/internal/dhcpd/iprange_test.go +++ b/internal/dhcpd/iprange_test.go @@ -115,7 +115,7 @@ func TestIPRange_Offset(t *testing.T) { testCases := []struct { name string in net.IP - wantOffset uint64 + wantOffset uint wantOK bool }{{ name: "in", diff --git a/internal/dhcpd/options.go b/internal/dhcpd/options.go index 9992764efc0..780eeeab132 100644 --- a/internal/dhcpd/options.go +++ b/internal/dhcpd/options.go @@ -100,11 +100,7 @@ func newDHCPOptionParser() (p *dhcpOptionParser) { // parse parses an option. See the handlers' documentation for more info. func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("invalid option string %q: %w", s, err) - } - }() + defer agherr.Annotate("invalid option string %q: %w", &err, s) s = strings.TrimSpace(s) parts := strings.SplitN(s, " ", 3) diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index c243157a061..a4d7c5c82e9 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -20,15 +20,18 @@ import ( // // TODO(a.garipov): Think about unifying this and v6Server. type v4Server struct { - srv *server4.Server - leasesLock sync.Mutex - leases []*Lease + conf V4ServerConf + srv *server4.Server // leasedOffsets contains offsets from conf.ipRange.start that have been // leased. leasedOffsets *bitset.BitSet - conf V4ServerConf + // leases contains all dynamic and static leases. + leases []*Lease + + // leasesLock protects leases and leasedOffsets. + leasesLock sync.Mutex } // WriteDiskConfig4 - write configuration @@ -116,11 +119,16 @@ func (s *v4Server) blacklistLease(lease *Lease) { lease.Expiry = time.Now().Add(s.conf.leaseTime) } -// removeLease removes a lease by its offset from the beginning of the IP range. -func (s *v4Server) removeLease(offset uint) { - l := s.leases[offset] - s.leases = append(s.leases[:offset], s.leases[offset+1:]...) - s.leasedOffsets.Clear(offset) +// rmLeaseByIndex removes a lease by its index in the leases slice. +func (s *v4Server) rmLeaseByIndex(i int) { + l := s.leases[i] + s.leases = append(s.leases[:i], s.leases[i+1:]...) + + r := s.conf.ipRange + offset, ok := r.offset(l.IP) + if ok { + s.leasedOffsets.Clear(offset) + } log.Debug("dhcpv4: removed lease %s (%s)", l.IP, l.HWAddr) } @@ -136,7 +144,7 @@ func (s *v4Server) rmDynamicLease(lease Lease) error { return fmt.Errorf("static lease already exists") } - s.removeLease(uint(i)) + s.rmLeaseByIndex(i) if i == len(s.leases) { break } @@ -149,7 +157,7 @@ func (s *v4Server) rmDynamicLease(lease Lease) error { return fmt.Errorf("static lease already exists") } - s.removeLease(uint(i)) + s.rmLeaseByIndex(i) } } return nil @@ -175,14 +183,13 @@ func (s *v4Server) addLease(l *Lease) { // Remove a lease with the same properties func (s *v4Server) rmLease(lease Lease) error { for i, l := range s.leases { - if net.IP.Equal(l.IP, lease.IP) { - + if l.IP.Equal(lease.IP) { if !bytes.Equal(l.HWAddr, lease.HWAddr) || l.Hostname != lease.Hostname { return fmt.Errorf("lease not found") } - s.removeLease(uint(i)) + s.rmLeaseByIndex(i) return nil } @@ -255,7 +262,7 @@ func (s *v4Server) addrAvailable(target net.IP) bool { pinger.Timeout = time.Duration(s.conf.ICMPTimeout) * time.Millisecond pinger.Count = 1 reply := false - pinger.OnRecv = func(pkt *ping.Packet) { + pinger.OnRecv = func(_ *ping.Packet) { reply = true } log.Debug("dhcpv4: Sending ICMP Echo to %v", target) diff --git a/staticcheck.conf b/staticcheck.conf index 4dd931764d5..b997f6a9477 100644 --- a/staticcheck.conf +++ b/staticcheck.conf @@ -1,6 +1,8 @@ checks = ["all"] initialisms = [ # See https://github.com/dominikh/go-tools/blob/master/config/config.go. + # + # Do not add "PTR" since we use "Ptr" as a suffix. "inherit" , "DHCP" , "DOH" @@ -8,7 +10,6 @@ initialisms = [ , "DOT" , "EDNS" , "MX" -, "PTR" , "QUIC" , "RA" , "SDNS"