diff --git a/conn/bind_linux.go b/conn/bind_linux.go index 7b970e65a..975a6ab60 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -14,6 +14,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "golang.zx2c4.com/go118/netip" ) type ipv4Source struct { @@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil) func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { var end LinuxSocketEndpoint - addr, err := parseEndpoint(s) + e, err := netip.ParseAddrPort(s) if err != nil { return nil, err } - ipv4 := addr.IP.To4() - if ipv4 != nil { + if e.Addr().Is4() { dst := end.dst4() end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) + dst.Port = int(e.Port()) + dst.Addr = e.Addr().As4() end.ClearSrc() return &end, nil } - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) + if e.Addr().Is6() { + zone, err := zoneToUint32(e.Addr().Zone()) if err != nil { return nil, err } dst := end.dst6() end.isV6 = true - dst.Port = addr.Port + dst.Port = int(e.Port()) dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) + dst.Addr = e.Addr().As16() end.ClearSrc() return &end, nil } @@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { } } -func (end *LinuxSocketEndpoint) SrcIP() net.IP { +func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { if !end.isV6 { - return net.IPv4( - end.src4().Src[0], - end.src4().Src[1], - end.src4().Src[2], - end.src4().Src[3], - ) + return netip.AddrFrom4(end.src4().Src) } else { - return end.src6().src[:] + return netip.AddrFrom16(end.src6().src) } } -func (end *LinuxSocketEndpoint) DstIP() net.IP { +func (end *LinuxSocketEndpoint) DstIP() netip.Addr { if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) + return netip.AddrFrom4(end.src4().Src) } else { - return end.dst6().Addr[:] + return netip.AddrFrom16(end.dst6().Addr) } } @@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string { } func (end *LinuxSocketEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() + var port int if !end.isV6 { - udpAddr.Port = end.dst4().Port + port = end.dst4().Port } else { - udpAddr.Port = end.dst6().Port + port = end.dst6().Port } - return udpAddr.String() + return netip.AddrPortFrom(end.DstIP(), uint16(port)).String() } func (end *LinuxSocketEndpoint) ClearDst() { diff --git a/conn/bind_std.go b/conn/bind_std.go index cb85cfdaf..219c7194f 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -10,6 +10,8 @@ import ( "net" "sync" "syscall" + + "golang.zx2c4.com/go118/netip" ) // StdNetBind is meant to be a temporary solution on platforms for which @@ -32,18 +34,22 @@ var _ Bind = (*StdNetBind)(nil) var _ Endpoint = (*StdNetEndpoint)(nil) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*StdNetEndpoint)(addr), err + e, err := netip.ParseAddrPort(s) + return (*StdNetEndpoint)(&net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + }), err } func (*StdNetEndpoint) ClearSrc() {} -func (e *StdNetEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP +func (e *StdNetEndpoint) DstIP() netip.Addr { + return netip.AddrFromSlice((*net.UDPAddr)(e).IP) } -func (e *StdNetEndpoint) SrcIP() net.IP { - return nil // not supported +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported } func (e *StdNetEndpoint) DstToBytes() []byte { diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 42e06adbc..26a3af851 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -15,6 +15,7 @@ import ( "unsafe" "golang.org/x/sys/windows" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn/winrio" ) @@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { func (*WinRingEndpoint) ClearSrc() {} -func (e *WinRingEndpoint) DstIP() net.IP { +func (e *WinRingEndpoint) DstIP() netip.Addr { switch e.family { case windows.AF_INET: - return append([]byte{}, e.data[2:6]...) + return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) case windows.AF_INET6: - return append([]byte{}, e.data[6:22]...) + return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) } - return nil + return netip.Addr{} } -func (e *WinRingEndpoint) SrcIP() net.IP { - return nil // not supported +func (e *WinRingEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported } func (e *WinRingEndpoint) DstToBytes() []byte { @@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToString() string { switch e.family { case windows.AF_INET: - addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} - return addr.String() + netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() case windows.AF_INET6: var zone string if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { zone = strconv.FormatUint(uint64(scope), 10) } - addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} - return addr.String() + return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() } return "" } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 7d43fb30b..6a4589613 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -10,8 +10,8 @@ import ( "math/rand" "net" "os" - "strconv" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" ) @@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } -func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } +func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } -func (c ChannelEndpoint) SrcIP() net.IP { return nil } +func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { c.closeSignal = make(chan bool) @@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { } func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { - _, port, err := net.SplitHostPort(s) + addr, err := netip.ParseAddrPort(s) if err != nil { return nil, err } - i, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return nil, err - } - return ChannelEndpoint(i), nil + return ChannelEndpoint(addr.Port()), nil } diff --git a/conn/conn.go b/conn/conn.go index 9cce9adea..35fb6b1b0 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -9,10 +9,11 @@ package conn import ( "errors" "fmt" - "net" "reflect" "runtime" "strings" + + "golang.zx2c4.com/go118/netip" ) // A ReceiveFunc receives a single inbound packet from the network. @@ -68,8 +69,8 @@ type Endpoint interface { SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP + DstIP() netip.Addr + SrcIP() netip.Addr } var ( @@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string { } return name } - -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} diff --git a/device/allowedips.go b/device/allowedips.go index c08399bbf..a6534b194 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -12,6 +12,8 @@ import ( "net" "sync" "unsafe" + + "golang.zx2c4.com/go118/netip" ) type parentIndirection struct { @@ -26,7 +28,7 @@ type trieEntry struct { cidr uint8 bitAtByte uint8 bitAtShift uint8 - bits net.IP + bits []byte perPeerElem *list.Element } @@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 { return bits.ReverseBytes64(i) } -func commonBits(ip1 net.IP, ip2 net.IP) uint8 { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { a := (*uint32)(unsafe.Pointer(&ip1[0])) @@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) choose(ip net.IP) byte { +func (node *trieEntry) choose(ip []byte) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() { node.parent.parentBit = nil } -func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { parent = node if parent.cidr == cidr { @@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, return } -func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { if *trie.parentBit == nil { node := &trieEntry{ peer: peer, @@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { } } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { @@ -229,13 +231,13 @@ type AllowedIPs struct { mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { node := elem.Value.(*trieEntry) - if !cb(node.bits, node.cidr) { + if !cb(netip.PrefixFrom(netip.AddrFromSlice(node.bits), int(node.cidr))) { return } } @@ -283,28 +285,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) - case net.IPv4len: - parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) Lookup(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - switch len(address) { + switch len(ip) { case net.IPv6len: - return table.IPv6.lookup(address) + return table.IPv6.lookup(ip) case net.IPv4len: - return table.IPv4.lookup(address) + return table.IPv4.lookup(ip) default: panic(errors.New("looking up unknown address type")) } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 16de1704a..ff56fe6a9 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -10,6 +10,8 @@ import ( "net" "sort" "testing" + + "golang.zx2c4.com/go118/netip" ) const ( @@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) { rand.Read(addr4[:]) cidr := uint8(rand.Intn(32) + 1) index := rand.Intn(NumberOfPeers) - allowedIPs.Insert(addr4[:], cidr, peers[index]) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index]) var addr6 [16]byte rand.Read(addr6[:]) cidr = uint8(rand.Intn(128) + 1) index = rand.Intn(NumberOfPeers) - allowedIPs.Insert(addr6[:], cidr, peers[index]) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 2059a8836..a27499782 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -9,6 +9,8 @@ import ( "math/rand" "net" "testing" + + "golang.zx2c4.com/go118/netip" ) type testPairCommonBits struct { @@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) { var allowedIPs AllowedIPs insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { - allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { @@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - allowedIPs.Insert(addr, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { diff --git a/device/device_test.go b/device/device_test.go index 29daeb9c3..84221bed1 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -11,7 +11,6 @@ import ( "fmt" "io" "math/rand" - "net" "runtime" "runtime/pprof" "sync" @@ -19,6 +18,7 @@ import ( "testing" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/tun/tuntest" @@ -96,7 +96,7 @@ type testPair [2]testPeer type testPeer struct { tun *tuntest.ChannelTUN dev *Device - ip net.IP + ip netip.Addr } type SendDirection bool @@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { for i := range pair { p := &pair[i] p.tun = tuntest.NewChannelTUN() - p.ip = net.IPv4(1, 0, 0, byte(i+1)) + p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) level := LogLevelVerbose if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError diff --git a/device/endpoint_test.go b/device/endpoint_test.go index 57c361ca6..f1ae47e34 100644 --- a/device/endpoint_test.go +++ b/device/endpoint_test.go @@ -7,47 +7,44 @@ package device import ( "math/rand" - "net" + + "golang.zx2c4.com/go118/netip" ) type DummyEndpoint struct { - src [16]byte - dst [16]byte + src, dst netip.Addr } func CreateDummyEndpoint() (*DummyEndpoint, error) { - var end DummyEndpoint - if _, err := rand.Read(end.src[:]); err != nil { + var src, dst [16]byte + if _, err := rand.Read(src[:]); err != nil { return nil, err } - _, err := rand.Read(end.dst[:]) - return &end, err + _, err := rand.Read(dst[:]) + return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err } func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) SrcToString() string { - var addr net.UDPAddr - addr.IP = e.SrcIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.SrcIP(), 1000).String() } func (e *DummyEndpoint) DstToString() string { - var addr net.UDPAddr - addr.IP = e.DstIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.DstIP(), 1000).String() } -func (e *DummyEndpoint) SrcToBytes() []byte { - return e.src[:] +func (e *DummyEndpoint) DstToBytes() []byte { + out := e.DstIP().AsSlice() + out = append(out, byte(1000&0xff)) + out = append(out, byte((1000>>8)&0xff)) + return out } -func (e *DummyEndpoint) DstIP() net.IP { - return e.dst[:] +func (e *DummyEndpoint) DstIP() netip.Addr { + return e.dst } -func (e *DummyEndpoint) SrcIP() net.IP { - return e.src[:] +func (e *DummyEndpoint) SrcIP() netip.Addr { + return e.src } diff --git a/device/receive.go b/device/receive.go index 58574810f..cc3449801 100644 --- a/device/receive.go +++ b/device/receive.go @@ -17,7 +17,6 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/conn" ) diff --git a/device/uapi.go b/device/uapi.go index 66ecd484f..9572fb808 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/ipc" ) @@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) - device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool { - sendf("allowed_ip=%s/%d", ip.String(), cidr) + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) return true }) } @@ -370,16 +371,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error case "allowed_ip": device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) - - _, network, err := net.ParseCIDR(value) + prefix, err := netip.ParsePrefix(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) } if peer.dummy { return nil } - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint8(ones), peer.Peer) + device.allowedips.Insert(prefix, peer.Peer) case "protocol_version": if value != "1" { diff --git a/go.mod b/go.mod index 856bb6c21..be7116500 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,8 @@ go 1.17 require ( golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 + golang.org/x/net v0.0.0-20211104170005-ce137452f963 golang.org/x/sys v0.0.0-20211103235746-7861aae1554b + golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 ) diff --git a/go.sum b/go.sum index 37fe06796..4246a0e22 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU= -golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211104170005-ce137452f963 h1:8gJUadZl+kWvZBqG/LautX0X6qe5qTC2VI/3V3NBRAY= +golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -12,5 +12,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 2f7aa2ae0..8e78d5e46 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "sync" "time" + + "golang.zx2c4.com/go118/netip" ) const ( @@ -30,8 +31,7 @@ type Ratelimiter struct { timeNow func() time.Time stopReset chan struct{} // send to reset, close to stop - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry + table map[netip.Addr]*RatelimiterEntry } func (rate *Ratelimiter) Close() { @@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() { } rate.stopReset = make(chan struct{}) - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + rate.table = make(map[netip.Addr]*RatelimiterEntry) stopReset := rate.stopReset // store in case Init is called again. @@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) { rate.mu.Lock() defer rate.mu.Unlock() - for key, entry := range rate.tableIPv4 { + for key, entry := range rate.table { entry.mu.Lock() if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) + delete(rate.table, key) } entry.mu.Unlock() } - for key, entry := range rate.tableIPv6 { - entry.mu.Lock() - if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mu.Unlock() - } - - return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 + return len(rate.table) == 0 } -func (rate *Ratelimiter) Allow(ip net.IP) bool { +func (rate *Ratelimiter) Allow(ip netip.Addr) bool { var entry *RatelimiterEntry - var keyIPv4 [net.IPv4len]byte - var keyIPv6 [net.IPv6len]byte - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - rate.mu.RLock() - - if IPv4 != nil { - copy(keyIPv4[:], IPv4) - entry = rate.tableIPv4[keyIPv4] - } else { - copy(keyIPv6[:], IPv6) - entry = rate.tableIPv6[keyIPv6] - } - + entry = rate.table[ip] rate.mu.RUnlock() // make new entry if not found - if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost entry.lastTime = rate.timeNow() rate.mu.Lock() - if IPv4 != nil { - rate.tableIPv4[keyIPv4] = entry - if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { - rate.stopReset <- struct{}{} - } - } else { - rate.tableIPv6[keyIPv6] = entry - if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { - rate.stopReset <- struct{}{} - } + rate.table[ip] = entry + if len(rate.table) == 1 { + rate.stopReset <- struct{}{} } rate.mu.Unlock() return true } // add tokens to entry - entry.mu.Lock() now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() @@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { } // subtract cost of packet - if entry.tokens > packetCost { entry.tokens -= packetCost entry.mu.Unlock() diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index f231fe5b7..3e06ff778 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "testing" "time" + + "golang.zx2c4.com/go118/netip" ) type result struct { @@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) { text: "packet following 2 packet burst", }) - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("192.168.1.1"), - net.ParseIP("172.167.2.3"), - net.ParseIP("97.231.252.215"), - net.ParseIP("248.97.91.167"), - net.ParseIP("188.208.233.47"), - net.ParseIP("104.2.183.179"), - net.ParseIP("72.129.46.120"), - net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), - net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), - net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), - net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), - net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), - net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), + ips := []netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("172.167.2.3"), + netip.MustParseAddr("97.231.252.215"), + netip.MustParseAddr("248.97.91.167"), + netip.MustParseAddr("188.208.233.47"), + netip.MustParseAddr("104.2.183.179"), + netip.MustParseAddr("72.129.46.120"), + netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), + netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), + netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), + netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), + netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), + netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), } now := time.Now() diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go index 6ac28596b..b39b45316 100644 --- a/tun/netstack/examples/http_client.go +++ b/tun/netstack/examples/http_client.go @@ -1,4 +1,5 @@ //go:build ignore +// +build ignore /* SPDX-License-Identifier: MIT * @@ -10,9 +11,9 @@ package main import ( "io" "log" - "net" "net/http" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -20,8 +21,8 @@ import ( func main() { tun, tnet, err := netstack.CreateNetTUN( - []net.IP{net.ParseIP("192.168.4.29")}, - []net.IP{net.ParseIP("8.8.8.8")}, + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 1420) if err != nil { log.Panic(err) diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go index 577c6ea4c..40f780447 100644 --- a/tun/netstack/examples/http_server.go +++ b/tun/netstack/examples/http_server.go @@ -1,4 +1,5 @@ //go:build ignore +// +build ignore /* SPDX-License-Identifier: MIT * @@ -13,6 +14,7 @@ import ( "net" "net/http" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -20,8 +22,8 @@ import ( func main() { tun, tnet, err := netstack.CreateNetTUN( - []net.IP{net.ParseIP("192.168.4.29")}, - []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 1420, ) if err != nil { diff --git a/tun/netstack/go.mod b/tun/netstack/go.mod index 8db9f4bcd..46b57bab1 100644 --- a/tun/netstack/go.mod +++ b/tun/netstack/go.mod @@ -6,6 +6,7 @@ require ( golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect + golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6 ) diff --git a/tun/netstack/go.sum b/tun/netstack/go.sum index 78c025c54..01bfbc70c 100644 --- a/tun/netstack/go.sum +++ b/tun/netstack/go.sum @@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg= +golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 24d0835dd..f1c03f48c 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" @@ -38,7 +39,7 @@ type netTun struct { events chan tun.Event incomingPacket chan buffer.VectorisedView mtu int - dnsServers []net.IP + dnsServers []netip.Addr hasV4, hasV6 bool } type endpoint netTun @@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } -func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) { +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, @@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } for _, ip := range localAddresses { - if ip4 := ip.To4(); ip4 != nil { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip4).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr) - } + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { dev.hasV4 = true - } else { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) - } + } else if ip.Is6() { dev.hasV6 = true } } @@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } -func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - if ip4 := ip.To4(); ip4 != nil { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip4), - Port: uint16(port), - }, ipv4.ProtocolNumber +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber } else { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip), - Port: uint16(port), - }, ipv6.ProtocolNumber + protoNumber = ipv6.ProtocolNumber } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) } func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialContextTCP(ctx, net.stack, fa, pn) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) } func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialTCP(net.stack, fa, pn) + return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) } func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.ListenTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.ListenTCP(net.stack, fa, pn) + return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) } -func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { var lfa, rfa *tcpip.FullAddress var pn tcpip.NetworkProtocolNumber - if laddr != nil { + if laddr.IsValid() || laddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(laddr.IP, laddr.Port) + addr, pn = convertToFullAddr(laddr) lfa = &addr } - if raddr != nil { + if raddr.IsValid() || raddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(raddr.IP, raddr.Port) + addr, pn = convertToFullAddr(raddr) rfa = &addr } return gonet.DialUDP(net.stack, lfa, rfa, pn) } +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port)) + } + if raddr != nil { + ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") @@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by return p, h, nil } -func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { q.Class = dnsmessage.ClassINET id, udpReq, tcpReq, err := newRequest(q) if err != nil { @@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest var c net.Conn var err error if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) } else { - c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) } if err != nil { @@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, zlen = zidx } } - if ip := net.ParseIP(host[:zlen]); ip != nil { - return []string{host[:zlen]}, nil + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil } if !isDomainName(host) { @@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, server string error } - var addrsV4, addrsV6 []net.IP + var addrsV4, addrsV6 []netip.Addr lanes := 0 if tnet.hasV4 { lanes++ @@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV4 = append(addrsV4, net.IP(a.A[:])) + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) case dnsmessage.TypeAAAA: aaaa, err := result.p.AAAAResource() @@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) default: if err := result.p.SkipAnswer(); err != nil { @@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } } // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled - var addrs []net.IP + var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) } else { @@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } - var addrs []net.IP + var addrs []netip.AddrPort for _, addr := range allAddr { - if strings.IndexByte(addr, ':') != -1 && acceptV6 { - addrs = append(addrs, net.ParseIP(addr)) - } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { - addrs = append(addrs, net.ParseIP(addr)) + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) } } if len(addrs) == 0 && len(allAddr) != 0 { @@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. var c net.Conn if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) } else { - c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port}) + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) } if err == nil { return c, nil diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index d89db7187..bdf0467b4 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -8,13 +8,13 @@ package tuntest import ( "encoding/binary" "io" - "net" "os" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" ) -func Ping(dst, src net.IP) []byte { +func Ping(dst, src netip.Addr) []byte { localPort := uint16(1337) seq := uint16(0) @@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 { return ^uint16(v) } -func genICMPv4(payload []byte, dst, src net.IP) []byte { +func genICMPv4(payload []byte, dst, src netip.Addr) []byte { const ( icmpv4ProtocolNumber = 1 icmpv4Echo = 8 @@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) ip[8] = ttl ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) + copy(ip[12:], src.AsSlice()) + copy(ip[16:], dst.AsSlice()) chksum = ^checksum(ip[:], 0) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)