From d49f4e9fe36f344ff1fa75523bdac18bdcb31945 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Sun, 1 Mar 2020 00:39:24 -0800 Subject: [PATCH 01/12] device: make Peer fields safe for atomic access on 32-bit. All atomic access must be aligned to 64 bits, even on 32-bit platforms. Go promises that the start of allocated structs is aligned to 64 bits. So, place the atomically-accessed things first in the struct so that they benefit from that alignment. As a side bonus, it cleanly separates fields that are accessed by atomic ops, and those that should be accessed under mu. Also adds a test that will fail consistently on 32-bit platforms if the struct ever changes again to violate the rules. This is likely not needed because unaligned access crashes reliably, but this will reliably fail even if tests accidentally pass due to lucky alignment. Signed-Off-By: David Anderson --- device/peer.go | 25 ++++++++++++++++--------- device/peer_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) create mode 100644 device/peer_test.go diff --git a/device/peer.go b/device/peer.go index 8a8224c66..65581d524 100644 --- a/device/peer.go +++ b/device/peer.go @@ -19,20 +19,27 @@ const ( ) type Peer struct { - isRunning AtomicBool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer - keypairs Keypairs - handshake Handshake - device *Device - endpoint Endpoint - persistentKeepaliveInterval uint16 - - // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly + // These fields are accessed with atomic operations, which must be + // 64-bit aligned even on 32-bit platforms. Go guarantees that an + // allocated struct will be 64-bit aligned. So we place + // atomically-accessed fields up front, so that they can share in + // this alignment before smaller fields throw it off. stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch } + // This field is only 32 bits wide, but is still aligned to 64 + // bits. Don't place other atomic fields after this one. + isRunning AtomicBool + + // Mostly protects endpoint, but is generally taken whenever we modify peer + sync.RWMutex + keypairs Keypairs + handshake Handshake + device *Device + endpoint Endpoint + persistentKeepaliveInterval uint16 timers struct { retransmitHandshake *Timer diff --git a/device/peer_test.go b/device/peer_test.go new file mode 100644 index 000000000..de87ab638 --- /dev/null +++ b/device/peer_test.go @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "testing" + "unsafe" +) + +func checkAlignment(t *testing.T, name string, offset uintptr) { + t.Helper() + if offset%8 != 0 { + t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) + } +} + +// TestPeerAlignment checks that atomically-accessed fields are +// aligned to 64-bit boundaries, as required by the atomic package. +// +// Unfortunately, violating this rule on 32-bit platforms results in a +// hard segfault at runtime. +func TestPeerAlignment(t *testing.T) { + var p Peer + checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) + checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) +} From 66793239d1fb15e6b3dfe5c7beaa68552d2f5bc5 Mon Sep 17 00:00:00 2001 From: Avery Pennarun Date: Wed, 23 Oct 2019 00:08:52 -0400 Subject: [PATCH 02/12] wintun: split error message for create vs open namespace. Signed-off-by: Avery Pennarun --- tun/wintun/namespace_windows.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tun/wintun/namespace_windows.go b/tun/wintun/namespace_windows.go index f4316fe84..5f8a04138 100644 --- a/tun/wintun/namespace_windows.go +++ b/tun/wintun/namespace_windows.go @@ -59,9 +59,12 @@ func initializeNamespace() error { if err == windows.ERROR_PATH_NOT_FOUND { continue } + if err != nil { + return fmt.Errorf("OpenPrivateNamespace failed: %v", err) + } } if err != nil { - return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err) + return fmt.Errorf("CreatePrivateNamespace failed: %v", err) } break } From c4a8eab3ddc37021433a4d08075a02b355aca21c Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Thu, 7 Nov 2019 11:13:05 -0500 Subject: [PATCH 03/12] conn: new package that splits out the Bind and Endpoint types The sticky socket code stays in the device package for now, as it reaches deeply into the peer list. This is the first step in an effort to split some code out of the very busy device package. Signed-off-by: David Crawshaw --- {device => conn}/boundif_windows.go | 19 +-- conn/conn.go | 101 +++++++++++ {device => conn}/conn_default.go | 13 +- {device => conn}/conn_linux.go | 249 +++------------------------- {device => conn}/mark_default.go | 2 +- {device => conn}/mark_unix.go | 2 +- device/bind_test.go | 14 +- device/bindsocketshim.go | 36 ++++ device/conn.go | 187 --------------------- device/device.go | 146 ++++++++++++++-- device/peer.go | 6 +- device/receive.go | 9 +- device/sticky_default.go | 12 ++ device/sticky_linux.go | 215 ++++++++++++++++++++++++ device/uapi.go | 3 +- 15 files changed, 562 insertions(+), 452 deletions(-) rename {device => conn}/boundif_windows.go (66%) create mode 100644 conn/conn.go rename {device => conn}/conn_default.go (94%) rename {device => conn}/conn_linux.go (63%) rename {device => conn}/mark_default.go (93%) rename {device => conn}/mark_unix.go (98%) create mode 100644 device/bindsocketshim.go delete mode 100644 device/conn.go create mode 100644 device/sticky_default.go create mode 100644 device/sticky_linux.go diff --git a/device/boundif_windows.go b/conn/boundif_windows.go similarity index 66% rename from device/boundif_windows.go rename to conn/boundif_windows.go index 690841528..fe38d05f5 100644 --- a/device/boundif_windows.go +++ b/conn/boundif_windows.go @@ -3,11 +3,10 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "encoding/binary" - "errors" "unsafe" "golang.org/x/sys/windows" @@ -18,17 +17,13 @@ const ( sockoptIPV6_UNICAST_IF = 31 ) -func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { +func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ bytes := make([]byte, 4) binary.BigEndian.PutUint32(bytes, interfaceIndex) interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn() + sysconn, err := bind.ipv4.SyscallConn() if err != nil { return err } @@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo if err != nil { return err } - device.net.bind.(*nativeBind).blackhole4 = blackhole + bind.blackhole4 = blackhole return nil } -func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() +func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + sysconn, err := bind.ipv6.SyscallConn() if err != nil { return err } @@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo if err != nil { return err } - device.net.bind.(*nativeBind).blackhole6 = blackhole + bind.blackhole6 = blackhole return nil } diff --git a/conn/conn.go b/conn/conn.go new file mode 100644 index 000000000..6b7db12ab --- /dev/null +++ b/conn/conn.go @@ -0,0 +1,101 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +// Package conn implements WireGuard's network connections. +package conn + +import ( + "errors" + "net" + "strings" +) + +// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. +type Bind interface { + // LastMark reports the last mark set for this Bind. + LastMark() uint32 + + // SetMark sets the mark for each packet sent through this Bind. + // This mark is passed to the kernel as the socket option SO_MARK. + SetMark(mark uint32) error + + // ReceiveIPv6 reads an IPv6 UDP packet into b. + // + // It reports the number of bytes read, n, + // the packet source address ep, + // and any error. + ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error) + + // ReceiveIPv4 reads an IPv4 UDP packet into b. + // + // It reports the number of bytes read, n, + // the packet source address ep, + // and any error. + ReceiveIPv4(b []byte) (n int, ep Endpoint, err error) + + // Send writes a packet b to address ep. + Send(b []byte, ep Endpoint) error + + // Close closes the Bind connection. + Close() error +} + +// CreateBind creates a Bind bound to a port. +// +// The value actualPort reports the actual port number the Bind +// object gets bound to. +func CreateBind(port uint16) (b Bind, actualPort uint16, err error) { + return createBind(port) +} + +// BindToInterface is implemented by Bind objects that support being +// tied to a single network interface. +type BindToInterface interface { + BindToInterface4(interfaceIndex uint32, blackhole bool) error + BindToInterface6(interfaceIndex uint32, blackhole bool) error +} + +// An Endpoint maintains the source/destination caching for a peer. +// +// dst : the remote address of a peer ("endpoint" in uapi terminology) +// src : the local address from which datagrams originate going to the peer +type Endpoint interface { + ClearSrc() // clears the source address + SrcToString() string // returns the local source address (ip:port) + DstToString() string // returns the destination address (ip:port) + DstToBytes() []byte // used for mac2 cookie calculations + DstIP() net.IP + SrcIP() net.IP +} + +func parseEndpoint(s string) (*net.UDPAddr, error) { + // ensure that the host is an IP address + + host, _, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { + // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just + // trying to make sure with a small sanity test that this is a real IP address and + // not something that's likely to incur DNS lookups. + host = host[:i] + } + if ip := net.ParseIP(host); ip == nil { + return nil, errors.New("Failed to parse IP address: " + host) + } + + // parse address and port + + addr, err := net.ResolveUDPAddr("udp", s) + if err != nil { + return nil, err + } + ip4 := addr.IP.To4() + if ip4 != nil { + addr.IP = ip4 + } + return addr, err +} diff --git a/device/conn_default.go b/conn/conn_default.go similarity index 94% rename from device/conn_default.go rename to conn/conn_default.go index 661f57d97..bad9d4df8 100644 --- a/device/conn_default.go +++ b/conn/conn_default.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "net" @@ -67,16 +67,13 @@ func (e *NativeEndpoint) SrcToString() string { } func listenNet(network string, port int) (*net.UDPConn, int, error) { - - // listen - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) if err != nil { return nil, 0, err } - // retrieve port - + // Retrieve port. + // TODO(crawshaw): under what circumstances is this necessary? laddr := conn.LocalAddr() uaddr, err := net.ResolveUDPAddr( laddr.Network(), @@ -100,7 +97,7 @@ func extractErrno(err error) error { return syscallErr.Err } -func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { +func createBind(uport uint16) (Bind, uint16, error) { var err error var bind nativeBind @@ -135,6 +132,8 @@ func (bind *nativeBind) Close() error { return err2 } +func (bind *nativeBind) LastMark() uint32 { return 0 } + func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { if bind.ipv4 == nil { return 0, nil, syscall.EAFNOSUPPORT diff --git a/device/conn_linux.go b/conn/conn_linux.go similarity index 63% rename from device/conn_linux.go rename to conn/conn_linux.go index e90b0e35b..523da4a45 100644 --- a/device/conn_linux.go +++ b/conn/conn_linux.go @@ -3,18 +3,9 @@ /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - * - * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. This is more or less a straight port - * of the sticky-sockets.c example code: - * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. */ -package device +package conn import ( "errors" @@ -25,7 +16,6 @@ import ( "unsafe" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" ) const ( @@ -33,8 +23,8 @@ const ( ) type IPv4Source struct { - src [4]byte - ifindex int32 + Src [4]byte + Ifindex int32 } type IPv6Source struct { @@ -49,6 +39,10 @@ type NativeEndpoint struct { isV6 bool } +func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() } +func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } +func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 } + func (endpoint *NativeEndpoint) src4() *IPv4Source { return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) } @@ -66,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { } type nativeBind struct { - sock4 int - sock6 int - netlinkSock int - netlinkCancel *rwcancel.RWCancel - lastMark uint32 + sock4 int + sock6 int + lastMark uint32 } var _ Endpoint = (*NativeEndpoint)(nil) @@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) { return nil, errors.New("Invalid IP address") } -func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - return -1, err - } - saddr := &unix.SockaddrNetlink{ - Family: unix.AF_NETLINK, - Groups: unix.RTMGRP_IPV4_ROUTE, - } - err = unix.Bind(sock, saddr) - if err != nil { - unix.Close(sock) - return -1, err - } - return sock, nil - -} - -func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) { +func createBind(port uint16) (Bind, uint16, error) { var err error var bind nativeBind var newPort uint16 - bind.netlinkSock, err = createNetlinkRouteSocket() - if err != nil { - return nil, 0, err - } - bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) - if err != nil { - unix.Close(bind.netlinkSock) - return nil, 0, err - } - - go bind.routineRouteListener(device) - - // attempt ipv6 bind, update port if successful - + // Attempt ipv6 bind, update port if successful. bind.sock6, newPort, err = create6(port) if err != nil { if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() return nil, 0, err } } else { port = newPort } - // attempt ipv4 bind, update port if successful - + // Attempt ipv4 bind, update port if successful. bind.sock4, newPort, err = create4(port) if err != nil { if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() unix.Close(bind.sock6) return nil, 0, err } @@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) { return &bind, port, nil } +func (bind *nativeBind) LastMark() uint32 { + return bind.lastMark +} + func (bind *nativeBind) SetMark(value uint32) error { if bind.sock6 != -1 { err := unix.SetsockoptInt( @@ -216,22 +178,18 @@ func closeUnblock(fd int) error { } func (bind *nativeBind) Close() error { - var err1, err2, err3 error + var err1, err2 error if bind.sock6 != -1 { err1 = closeUnblock(bind.sock6) } if bind.sock4 != -1 { err2 = closeUnblock(bind.sock4) } - err3 = bind.netlinkCancel.Cancel() if err1 != nil { return err1 } - if err2 != nil { - return err2 - } - return err3 + return err2 } func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { @@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error { func (end *NativeEndpoint) SrcIP() net.IP { if !end.isV6 { return net.IPv4( - end.src4().src[0], - end.src4().src[1], - end.src4().src[2], - end.src4().src[3], + end.src4().Src[0], + end.src4().Src[1], + end.src4().Src[2], + end.src4().Src[3], ) } else { return end.src6().src[:] @@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error { Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, }, unix.Inet4Pktinfo{ - Spec_dst: end.src4().src, - Ifindex: end.src4().ifindex, + Spec_dst: end.src4().Src, + Ifindex: end.src4().Ifindex, }, } @@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().src = cmsg.pktinfo.Spec_dst - end.src4().ifindex = cmsg.pktinfo.Ifindex + end.src4().Src = cmsg.pktinfo.Spec_dst + end.src4().Ifindex = cmsg.pktinfo.Ifindex } return size, nil @@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { return size, nil } - -func (bind *nativeBind) routineRouteListener(device *Device) { - type peerEndpointPtr struct { - peer *Peer - endpoint *Endpoint - } - var reqPeer map[uint32]peerEndpointPtr - var reqPeerLock sync.Mutex - - defer unix.Close(bind.netlinkSock) - - for msg := make([]byte, 1<<16); ; { - var err error - var msgn int - for { - msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) - if err == nil || !rwcancel.RetryAfterError(err) { - break - } - if !bind.netlinkCancel.ReadyRead() { - return - } - } - if err != nil { - return - } - - for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { - - hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) - - if uint(hdr.Len) > uint(len(remain)) { - break - } - - switch hdr.Type { - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - if hdr.Seq <= MaxPeers && hdr.Seq > 0 { - if uint(len(remain)) < uint(hdr.Len) { - break - } - if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { - attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] - for { - if uint(len(attr)) < uint(unix.SizeofRtAttr) { - break - } - attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) - if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { - break - } - if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { - ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) - reqPeerLock.Lock() - if reqPeer == nil { - reqPeerLock.Unlock() - break - } - pePtr, ok := reqPeer[hdr.Seq] - reqPeerLock.Unlock() - if !ok { - break - } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() - break - } - if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { - pePtr.peer.Unlock() - break - } - pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() - } - attr = attr[attrhdr.Len:] - } - } - break - } - reqPeerLock.Lock() - reqPeer = make(map[uint32]peerEndpointPtr) - reqPeerLock.Unlock() - go func() { - device.peers.RLock() - i := uint32(1) - for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { - peer.RUnlock() - continue - } - if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { - peer.RUnlock() - break - } - nlmsg := struct { - hdr unix.NlMsghdr - msg unix.RtMsg - dsthdr unix.RtAttr - dst [4]byte - srchdr unix.RtAttr - src [4]byte - markhdr unix.RtAttr - mark uint32 - }{ - unix.NlMsghdr{ - Type: uint16(unix.RTM_GETROUTE), - Flags: unix.NLM_F_REQUEST, - Seq: i, - }, - unix.RtMsg{ - Family: unix.AF_INET, - Dst_len: 32, - Src_len: 32, - }, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_DST, - }, - peer.endpoint.(*NativeEndpoint).dst4().Addr, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_SRC, - }, - peer.endpoint.(*NativeEndpoint).src4().src, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_MARK, - }, - uint32(bind.lastMark), - } - nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) - reqPeerLock.Lock() - reqPeer[i] = peerEndpointPtr{ - peer: peer, - endpoint: &peer.endpoint, - } - reqPeerLock.Unlock() - peer.RUnlock() - i++ - _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) - if err != nil { - break - } - } - device.peers.RUnlock() - }() - } - remain = remain[hdr.Len:] - } - } -} diff --git a/device/mark_default.go b/conn/mark_default.go similarity index 93% rename from device/mark_default.go rename to conn/mark_default.go index 7de2524c0..fc41ba993 100644 --- a/device/mark_default.go +++ b/conn/mark_default.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn func (bind *nativeBind) SetMark(mark uint32) error { return nil diff --git a/device/mark_unix.go b/conn/mark_unix.go similarity index 98% rename from device/mark_unix.go rename to conn/mark_unix.go index 669b32814..5334582e9 100644 --- a/device/mark_unix.go +++ b/conn/mark_unix.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "runtime" diff --git a/device/bind_test.go b/device/bind_test.go index 0c2e2cfd2..c5f7f68a9 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -5,11 +5,15 @@ package device -import "errors" +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) type DummyDatagram struct { msg []byte - endpoint Endpoint + endpoint conn.Endpoint world bool // better type } @@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error { return nil } -func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in6 if !ok { return 0, nil, errors.New("closed") @@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { return len(datagram.msg), datagram.endpoint, nil } -func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in4 if !ok { return 0, nil, errors.New("closed") @@ -50,6 +54,6 @@ func (b *DummyBind) Close() error { return nil } -func (b *DummyBind) Send(buff []byte, end Endpoint) error { +func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { return nil } diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go new file mode 100644 index 000000000..c4dd4effb --- /dev/null +++ b/device/bindsocketshim.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) + +// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. +func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + if device.net.bind == nil { + return errors.New("Bind is not yet initialized") + } + + if iface, ok := device.net.bind.(conn.BindToInterface); ok { + return iface.BindToInterface4(interfaceIndex, blackhole) + } + return nil +} + +// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. +func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + if device.net.bind == nil { + return errors.New("Bind is not yet initialized") + } + + if iface, ok := device.net.bind.(conn.BindToInterface); ok { + return iface.BindToInterface6(interfaceIndex, blackhole) + } + return nil +} diff --git a/device/conn.go b/device/conn.go deleted file mode 100644 index 7b341f6b2..000000000 --- a/device/conn.go +++ /dev/null @@ -1,187 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - "net" - "strings" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -const ( - ConnRoutineNumber = 2 -) - -/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic - */ -type Bind interface { - SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error - Close() error -} - -/* An Endpoint maintains the source/destination caching for a peer - * - * dst : the remote address of a peer ("endpoint" in uapi terminology) - * src : the local address from which datagrams originate going to the peer - */ -type Endpoint interface { - ClearSrc() // clears the source address - SrcToString() string // returns the local source address (ip:port) - DstToString() string // returns the destination address (ip:port) - DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP -} - -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} - -func unsafeCloseBind(device *Device) error { - var err error - netc := &device.net - if netc.bind != nil { - err = netc.bind.Close() - netc.bind = nil - } - netc.stopping.Wait() - return err -} - -func (device *Device) BindSetMark(mark uint32) error { - - device.net.Lock() - defer device.net.Unlock() - - // check if modified - - if device.net.fwmark == mark { - return nil - } - - // update fwmark on existing bind - - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { - if err := device.net.bind.SetMark(mark); err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - return nil -} - -func (device *Device) BindUpdate() error { - - device.net.Lock() - defer device.net.Unlock() - - // close existing sockets - - if err := unsafeCloseBind(device); err != nil { - return err - } - - // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port, device) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.starting.Add(ConnRoutineNumber) - device.net.stopping.Add(ConnRoutineNumber) - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - device.net.starting.Wait() - - device.log.Debug.Println("UDP bind has been updated") - } - - return nil -} - -func (device *Device) BindClose() error { - device.net.Lock() - err := unsafeCloseBind(device) - device.net.Unlock() - return err -} diff --git a/device/device.go b/device/device.go index 8c08f1c34..a9fedea86 100644 --- a/device/device.go +++ b/device/device.go @@ -11,15 +11,14 @@ import ( "sync/atomic" "time" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" ) -const ( - DeviceRoutineNumberPerCPU = 3 - DeviceRoutineNumberAdditional = 2 -) - type Device struct { isUp AtomicBool // device is (going) up isClosed AtomicBool // device is closed? (acting as guard) @@ -39,9 +38,10 @@ type Device struct { starting sync.WaitGroup stopping sync.WaitGroup sync.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind conn.Bind // bind interface + netlinkCancel *rwcancel.RWCancel + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } staticIdentity struct { @@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { cpus := runtime.NumCPU() device.state.starting.Wait() device.state.stopping.Wait() - device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) for i := 0; i < cpus; i += 1 { + device.state.starting.Add(3) + device.state.stopping.Add(3) go device.RoutineEncryption() go device.RoutineDecryption() go device.RoutineHandshake() } + device.state.starting.Add(2) + device.state.stopping.Add(2) go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() @@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { } device.peers.RUnlock() } + +func unsafeCloseBind(device *Device) error { + var err error + netc := &device.net + if netc.netlinkCancel != nil { + netc.netlinkCancel.Cancel() + } + if netc.bind != nil { + err = netc.bind.Close() + netc.bind = nil + } + netc.stopping.Wait() + return err +} + +func (device *Device) BindSetMark(mark uint32) error { + + device.net.Lock() + defer device.net.Unlock() + + // check if modified + + if device.net.fwmark == mark { + return nil + } + + // update fwmark on existing bind + + device.net.fwmark = mark + if device.isUp.Get() && device.net.bind != nil { + if err := device.net.bind.SetMark(mark); err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + return nil +} + +func (device *Device) BindUpdate() error { + + device.net.Lock() + defer device.net.Unlock() + + // close existing sockets + + if err := unsafeCloseBind(device); err != nil { + return err + } + + // open new sockets + + if device.isUp.Get() { + + // bind to new port + + var err error + netc := &device.net + netc.bind, netc.port, err = conn.CreateBind(netc.port) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.bind = nil + netc.port = 0 + return err + } + + // set fwmark + + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + // start receiving routines + + device.net.starting.Add(2) + device.net.stopping.Add(2) + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.starting.Wait() + + device.log.Debug.Println("UDP bind has been updated") + } + + return nil +} + +func (device *Device) BindClose() error { + device.net.Lock() + err := unsafeCloseBind(device) + device.net.Unlock() + return err +} diff --git a/device/peer.go b/device/peer.go index 65581d524..a96f2612a 100644 --- a/device/peer.go +++ b/device/peer.go @@ -12,6 +12,8 @@ import ( "sync" "sync/atomic" "time" + + "golang.zx2c4.com/wireguard/conn" ) const ( @@ -38,7 +40,7 @@ type Peer struct { keypairs Keypairs handshake Handshake device *Device - endpoint Endpoint + endpoint conn.Endpoint persistentKeepaliveInterval uint16 timers struct { @@ -293,7 +295,7 @@ func (peer *Peer) Stop() { var RoamingDisabled bool -func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { +func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { if RoamingDisabled { return } diff --git a/device/receive.go b/device/receive.go index 7d0693e1f..4818d649e 100644 --- a/device/receive.go +++ b/device/receive.go @@ -17,12 +17,13 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { msgType uint32 packet []byte - endpoint Endpoint + endpoint conn.Endpoint buffer *[MaxMessageSize]byte } @@ -33,7 +34,7 @@ type QueueInboundElement struct { packet []byte counter uint64 keypair *Keypair - endpoint Endpoint + endpoint conn.Endpoint } func (elem *QueueInboundElement) Drop() { @@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { +func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { logDebug := device.log.Debug defer func() { @@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { var ( err error size int - endpoint Endpoint + endpoint conn.Endpoint ) for { diff --git a/device/sticky_default.go b/device/sticky_default.go new file mode 100644 index 000000000..1cc52f69b --- /dev/null +++ b/device/sticky_default.go @@ -0,0 +1,12 @@ +// +build !linux + +package device + +import ( + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + return nil, nil +} diff --git a/device/sticky_linux.go b/device/sticky_linux.go new file mode 100644 index 000000000..f9522c23e --- /dev/null +++ b/device/sticky_linux.go @@ -0,0 +1,215 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. This is more or less a straight port + * of the sticky-sockets.c example code: + * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. + */ + +package device + +import ( + "sync" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + netlinkSock, err := createNetlinkRouteSocket() + if err != nil { + return nil, err + } + netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) + if err != nil { + unix.Close(netlinkSock) + return nil, err + } + + go device.routineRouteListener(bind, netlinkSock, netlinkCancel) + + return netlinkCancel, nil +} + +func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { + type peerEndpointPtr struct { + peer *Peer + endpoint *conn.Endpoint + } + var reqPeer map[uint32]peerEndpointPtr + var reqPeerLock sync.Mutex + + defer unix.Close(netlinkSock) + + for msg := make([]byte, 1<<16); ; { + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.RetryAfterError(err) { + break + } + if !netlinkCancel.ReadyRead() { + return + } + } + if err != nil { + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if uint(hdr.Len) > uint(len(remain)) { + break + } + + switch hdr.Type { + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + if hdr.Seq <= MaxPeers && hdr.Seq > 0 { + if uint(len(remain)) < uint(hdr.Len) { + break + } + if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { + attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] + for { + if uint(len(attr)) < uint(unix.SizeofRtAttr) { + break + } + attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) + if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { + break + } + if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { + ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) + reqPeerLock.Lock() + if reqPeer == nil { + reqPeerLock.Unlock() + break + } + pePtr, ok := reqPeer[hdr.Seq] + reqPeerLock.Unlock() + if !ok { + break + } + pePtr.peer.Lock() + if &pePtr.peer.endpoint != pePtr.endpoint { + pePtr.peer.Unlock() + break + } + if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { + pePtr.peer.Unlock() + break + } + pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() + pePtr.peer.Unlock() + } + attr = attr[attrhdr.Len:] + } + } + break + } + reqPeerLock.Lock() + reqPeer = make(map[uint32]peerEndpointPtr) + reqPeerLock.Unlock() + go func() { + device.peers.RLock() + i := uint32(1) + for _, peer := range device.peers.keyMap { + peer.RLock() + if peer.endpoint == nil { + peer.RUnlock() + continue + } + nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) + if nativeEP == nil { + peer.RUnlock() + continue + } + if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { + peer.RUnlock() + break + } + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: i, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + nativeEP.Dst4().Addr, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + nativeEP.Src4().Src, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_MARK, + }, + uint32(bind.LastMark()), + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + reqPeerLock.Lock() + reqPeer[i] = peerEndpointPtr{ + peer: peer, + endpoint: &peer.endpoint, + } + reqPeerLock.Unlock() + peer.RUnlock() + i++ + _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + if err != nil { + break + } + } + device.peers.RUnlock() + }() + } + remain = remain[hdr.Len:] + } + } +} + +func createNetlinkRouteSocket() (int, error) { + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + if err != nil { + return -1, err + } + saddr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + } + err = unix.Bind(sock, saddr) + if err != nil { + unix.Close(sock) + return -1, err + } + return sock, nil +} diff --git a/device/uapi.go b/device/uapi.go index 72611ab5e..6cdccd615 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "time" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" ) @@ -306,7 +307,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { err := func() error { peer.Lock() defer peer.Unlock() - endpoint, err := CreateEndpoint(value) + endpoint, err := conn.CreateEndpoint(value) if err != nil { return err } From 900ae645d16b2cd813e494f4e9593b299ed8b50a Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 28 Feb 2020 08:53:29 -0800 Subject: [PATCH 04/12] tun: remove unused isUp method Signed-off-by: Brad Fitzpatrick --- tun/tun_linux.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 7ab062314..17b8822f9 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -12,7 +12,6 @@ import ( "bytes" "errors" "fmt" - "net" "os" "sync" "syscall" @@ -164,11 +163,6 @@ func (tun *NativeTun) routineNetlinkListener() { } } -func (tun *NativeTun) isUp() (bool, error) { - inter, err := net.InterfaceByName(tun.name) - return inter.Flags&net.FlagUp != 0, err -} - func getIFIndex(name string) (int32, error) { fd, err := unix.Socket( unix.AF_INET, From d5d70756bbf34474645d4863cd49ffce82e35261 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 28 Feb 2020 09:10:16 -0800 Subject: [PATCH 05/12] tun: fix data race on name field Signed-off-by: Brad Fitzpatrick --- tun/tun_linux.go | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 17b8822f9..2f97ebbc5 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -31,7 +31,6 @@ const ( type NativeTun struct { tunFile *os.File index int32 // if index - name string // name of interface errors chan error // async error handling events chan Event // device related events nopi bool // the device was passed IFF_NO_PI @@ -39,6 +38,10 @@ type NativeTun struct { netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + + nameOnce sync.Once // guards calling initNameCache, which sets following fields + nameCache string // name of interface + nameErr error } func (tun *NativeTun) File() *os.File { @@ -192,6 +195,11 @@ func getIFIndex(name string) (int32, error) { } func (tun *NativeTun) setMTU(n int) error { + name, err := tun.Name() + if err != nil { + return err + } + // open datagram socket fd, err := unix.Socket( unix.AF_INET, @@ -206,9 +214,8 @@ func (tun *NativeTun) setMTU(n int) error { defer unix.Close(fd) // do ioctl call - var ifr [ifReqSize]byte - copy(ifr[:], tun.name) + copy(ifr[:], name) *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) _, _, errno := unix.Syscall( unix.SYS_IOCTL, @@ -225,6 +232,11 @@ func (tun *NativeTun) setMTU(n int) error { } func (tun *NativeTun) MTU() (int, error) { + name, err := tun.Name() + if err != nil { + return 0, err + } + // open datagram socket fd, err := unix.Socket( unix.AF_INET, @@ -241,7 +253,7 @@ func (tun *NativeTun) MTU() (int, error) { // do ioctl call var ifr [ifReqSize]byte - copy(ifr[:], tun.name) + copy(ifr[:], name) _, _, errno := unix.Syscall( unix.SYS_IOCTL, uintptr(fd), @@ -256,6 +268,15 @@ func (tun *NativeTun) MTU() (int, error) { } func (tun *NativeTun) Name() (string, error) { + tun.nameOnce.Do(tun.initNameCache) + return tun.nameCache, tun.nameErr +} + +func (tun *NativeTun) initNameCache() { + tun.nameCache, tun.nameErr = tun.nameSlow() +} + +func (tun *NativeTun) nameSlow() (string, error) { sysconn, err := tun.tunFile.SyscallConn() if err != nil { return "", err @@ -276,13 +297,11 @@ func (tun *NativeTun) Name() (string, error) { if errno != 0 { return "", errors.New("failed to get name of TUN device: " + errno.Error()) } - nullStr := ifr[:] - i := bytes.IndexByte(nullStr, 0) - if i != -1 { - nullStr = nullStr[:i] + name := ifr[:] + if i := bytes.IndexByte(name, 0); i != -1 { + name = name[:i] } - tun.name = string(nullStr) - return tun.name, nil + return string(name), nil } func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { @@ -402,16 +421,15 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { statusListenersShutdown: make(chan struct{}), nopi: false, } - var err error - _, err = tun.Name() + name, err := tun.Name() if err != nil { return nil, err } // start event listener - tun.index, err = getIFIndex(tun.name) + tun.index, err = getIFIndex(name) if err != nil { return nil, err } From a38504e3994268ee9aee5cb1d42bab82ef5ede76 Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Wed, 17 Apr 2019 09:41:25 -0400 Subject: [PATCH 06/12] wgcfg: new config package Based on types and config parser from wireguard-windows. Signed-off-by: David Crawshaw --- wgcfg/config.go | 78 +++++++++ wgcfg/ip.go | 128 ++++++++++++++ wgcfg/key.go | 240 ++++++++++++++++++++++++++ wgcfg/key_test.go | 107 ++++++++++++ wgcfg/name.go | 49 ++++++ wgcfg/parser.go | 397 +++++++++++++++++++++++++++++++++++++++++++ wgcfg/parser_test.go | 127 ++++++++++++++ wgcfg/writer.go | 75 ++++++++ 8 files changed, 1201 insertions(+) create mode 100644 wgcfg/config.go create mode 100644 wgcfg/ip.go create mode 100644 wgcfg/key.go create mode 100644 wgcfg/key_test.go create mode 100644 wgcfg/name.go create mode 100644 wgcfg/parser.go create mode 100644 wgcfg/parser_test.go create mode 100644 wgcfg/writer.go diff --git a/wgcfg/config.go b/wgcfg/config.go new file mode 100644 index 000000000..2b5e7148d --- /dev/null +++ b/wgcfg/config.go @@ -0,0 +1,78 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +// Package wgcfg has types and a parser for representing WireGuard config. +package wgcfg + +import ( + "fmt" + "strings" +) + +// Config is a wireguard configuration. +type Config struct { + Name string + PrivateKey PrivateKey + Addresses []CIDR + ListenPort uint16 + MTU uint16 + DNS []IP + Peers []Peer +} + +type Peer struct { + PublicKey Key + PresharedKey SymmetricKey + AllowedIPs []CIDR + Endpoints []Endpoint + PersistentKeepalive uint16 +} + +type Endpoint struct { + Host string + Port uint16 +} + +func (e *Endpoint) String() string { + if strings.IndexByte(e.Host, ':') > 0 { + return fmt.Sprintf("[%s]:%d", e.Host, e.Port) + } + return fmt.Sprintf("%s:%d", e.Host, e.Port) +} + +func (e *Endpoint) IsEmpty() bool { + return len(e.Host) == 0 +} + +// Copy makes a deep copy of Config. +// The result aliases no memory with the original. +func (cfg Config) Copy() Config { + res := cfg + if res.Addresses != nil { + res.Addresses = append([]CIDR{}, res.Addresses...) + } + if res.DNS != nil { + res.DNS = append([]IP{}, res.DNS...) + } + peers := make([]Peer, 0, len(res.Peers)) + for _, peer := range res.Peers { + peers = append(peers, peer.Copy()) + } + res.Peers = peers + return res +} + +// Copy makes a deep copy of Peer. +// The result aliases no memory with the original. +func (peer Peer) Copy() Peer { + res := peer + if res.AllowedIPs != nil { + res.AllowedIPs = append([]CIDR{}, res.AllowedIPs...) + } + if res.Endpoints != nil { + res.Endpoints = append([]Endpoint{}, res.Endpoints...) + } + return res +} diff --git a/wgcfg/ip.go b/wgcfg/ip.go new file mode 100644 index 000000000..ecf5faff7 --- /dev/null +++ b/wgcfg/ip.go @@ -0,0 +1,128 @@ +package wgcfg + +import ( + "fmt" + "net" +) + +// IP is an IPv4 or an IPv6 address. +// +// Internally the address is always represented in its IPv6 form. +// IPv4 addresses use the IPv4-in-IPv6 syntax. +type IP struct { + Addr [16]byte +} + +func (ip IP) String() string { return net.IP(ip.Addr[:]).String() } + +func (ip *IP) IP() net.IP { return net.IP(ip.Addr[:]) } +func (ip *IP) Is6() bool { return !ip.Is4() } +func (ip *IP) Is4() bool { + return ip.Addr[0] == 0 && ip.Addr[1] == 0 && + ip.Addr[2] == 0 && ip.Addr[3] == 0 && + ip.Addr[4] == 0 && ip.Addr[5] == 0 && + ip.Addr[6] == 0 && ip.Addr[7] == 0 && + ip.Addr[8] == 0 && ip.Addr[9] == 0 && + ip.Addr[10] == 0xff && ip.Addr[11] == 0xff +} +func (ip *IP) To4() []byte { + if ip.Is4() { + return ip.Addr[12:16] + } else { + return nil + } +} +func (ip *IP) Equal(x *IP) bool { + if ip == nil || x == nil { + return false + } + // TODO: this isn't hard, write a more efficient implementation. + return ip.IP().Equal(x.IP()) +} + +func (ip IP) MarshalText() ([]byte, error) { + return []byte(ip.String()), nil +} + +func (ip *IP) UnmarshalText(text []byte) error { + parsedIP := ParseIP(string(text)) + if parsedIP == nil { + return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", string(text)) + } + *ip = *parsedIP + return nil +} + +func IPv4(b0, b1, b2, b3 byte) (ip IP) { + ip.Addr[10], ip.Addr[11] = 0xff, 0xff // IPv4-in-IPv6 prefix + ip.Addr[12] = b0 + ip.Addr[13] = b1 + ip.Addr[14] = b2 + ip.Addr[15] = b3 + return ip +} + +// ParseIP parses the string representation of an address into an IP. +// +// It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0". +// If the string is not a valid IP address, ParseIP returns nil. +func ParseIP(s string) *IP { + netIP := net.ParseIP(s) + if netIP == nil { + return nil + } + ip := new(IP) + copy(ip.Addr[:], netIP.To16()) + return ip +} + +// CIDR is a compact IP address and subnet mask. +type CIDR struct { + IP IP + Mask uint8 // 0-32 for IsIPv4, 4-128 for IsIPv6 +} + +// ParseCIDR parses CIDR notation into a CIDR type. +// Typical CIDR strings look like "192.168.1.0/24". +func ParseCIDR(s string) (cidr *CIDR, err error) { + netIP, netAddr, err := net.ParseCIDR(s) + if err != nil { + return nil, err + } + cidr = new(CIDR) + copy(cidr.IP.Addr[:], netIP.To16()) + ones, _ := netAddr.Mask.Size() + cidr.Mask = uint8(ones) + + return cidr, nil +} + +func (r CIDR) String() string { return r.IPNet().String() } + +func (r *CIDR) IPNet() *net.IPNet { + bits := 128 + if r.IP.Is4() { + bits = 32 + } + return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)} +} +func (r *CIDR) Contains(ip *IP) bool { + if r == nil || ip == nil { + return false + } + // TODO: this isn't hard, write a more efficient implementation. + return r.IPNet().Contains(ip.IP()) +} + +func (r CIDR) MarshalText() ([]byte, error) { + return []byte(r.String()), nil +} + +func (r *CIDR) UnmarshalText(text []byte) error { + cidr, err := ParseCIDR(string(text)) + if err != nil { + return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err) + } + *r = *cidr + return nil +} diff --git a/wgcfg/key.go b/wgcfg/key.go new file mode 100644 index 000000000..1597203b1 --- /dev/null +++ b/wgcfg/key.go @@ -0,0 +1,240 @@ +package wgcfg + +import ( + "bytes" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" +) + +const KeySize = 32 + +// Key is curve25519 key. +// It is used by WireGuard to represent public and preshared keys. +type Key [KeySize]byte + +// NewPresharedKey generates a new random key. +func NewPresharedKey() (*Key, error) { + var k [KeySize]byte + _, err := rand.Read(k[:]) + if err != nil { + return nil, err + } + return (*Key)(&k), nil +} + +func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) } + +func ParseHexKey(s string) (Key, error) { + b, err := hex.DecodeString(s) + if err != nil { + return Key{}, &ParseError{"invalid hex key: " + err.Error(), s} + } + if len(b) != KeySize { + return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} + } + + var key Key + copy(key[:], b) + return key, nil +} + +func ParsePrivateHexKey(v string) (PrivateKey, error) { + k, err := ParseHexKey(v) + if err != nil { + return PrivateKey{}, err + } + pk := PrivateKey(k) + if pk.IsZero() { + // Do not clamp a zero key, pass the zero through + // (much like NaN propagation) so that IsZero reports + // a useful result. + return pk, nil + } + pk.clamp() + return pk, nil +} + +func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k Key) String() string { return "pub:" + k.Base64()[:8] } +func (k Key) HexString() string { return hex.EncodeToString(k[:]) } +func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } + +func (k *Key) ShortString() string { + if k.IsZero() { + return "[empty]" + } + long := k.String() + if len(long) < 10 { + return "invalid" + } + return "[" + long[0:4] + "ā€¦" + long[len(long)-5:len(long)-1] + "]" +} + +func (k *Key) IsZero() bool { + if k == nil { + return true + } + var zeros Key + return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 +} + +func (k *Key) MarshalJSON() ([]byte, error) { + if k == nil { + return []byte("null"), nil + } + buf := new(bytes.Buffer) + fmt.Fprintf(buf, `"%x"`, k[:]) + return buf.Bytes(), nil +} + +func (k *Key) UnmarshalJSON(b []byte) error { + if k == nil { + return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer") + } + if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { + return errors.New("wgcfg.Key: UnmarshalJSON not given a string") + } + b = b[1 : len(b)-1] + key, err := ParseHexKey(string(b)) + if err != nil { + return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (a *Key) LessThan(b *Key) bool { + for i := range a { + if a[i] < b[i] { + return true + } else if a[i] > b[i] { + return false + } + } + return false +} + +// PrivateKey is curve25519 key. +// It is used by WireGuard to represent private keys. +type PrivateKey [KeySize]byte + +// NewPrivateKey generates a new curve25519 secret key. +// It conforms to the format described on https://cr.yp.to/ecdh.html. +func NewPrivateKey() (PrivateKey, error) { + k, err := NewPresharedKey() + if err != nil { + return PrivateKey{}, err + } + k[0] &= 248 + k[31] = (k[31] & 127) | 64 + return (PrivateKey)(*k), nil +} + +func ParsePrivateKey(b64 string) (*PrivateKey, error) { + k, err := parseKeyBase64(base64.StdEncoding, b64) + return (*PrivateKey)(k), err +} + +func (k *PrivateKey) String() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) } +func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } + +func (k *PrivateKey) IsZero() bool { + pk := Key(*k) + return pk.IsZero() +} + +func (k *PrivateKey) clamp() { + k[0] &= 248 + k[31] = (k[31] & 127) | 64 +} + +// Public computes the public key matching this curve25519 secret key. +func (k *PrivateKey) Public() Key { + pk := Key(*k) + if pk.IsZero() { + panic("Tried to generate emptyPrivateKey.Public()") + } + var p [KeySize]byte + curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k)) + return (Key)(p) +} + +func (k PrivateKey) MarshalText() ([]byte, error) { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, `privkey:%x`, k[:]) + return buf.Bytes(), nil +} + +func (k *PrivateKey) UnmarshalText(b []byte) error { + s := string(b) + if !strings.HasPrefix(s, `privkey:`) { + return errors.New("wgcfg.PrivateKey: UnmarshalText not given a private-key string") + } + s = strings.TrimPrefix(s, `privkey:`) + key, err := ParseHexKey(s) + if err != nil { + return fmt.Errorf("wgcfg.PrivateKey: UnmarshalText: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) { + apk := (*[KeySize]byte)(&pub) + ask := (*[KeySize]byte)(&k) + curve25519.ScalarMult(&ss, ask, apk) + return ss +} + +func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { + k, err := enc.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeySize { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func ParseSymmetricKey(b64 string) (SymmetricKey, error) { + k, err := parseKeyBase64(base64.StdEncoding, b64) + if err != nil { + return SymmetricKey{}, err + } + return SymmetricKey(*k), nil +} + +func ParseSymmetricHexKey(s string) (SymmetricKey, error) { + b, err := hex.DecodeString(s) + if err != nil { + return SymmetricKey{}, &ParseError{"invalid symmetric hex key: " + err.Error(), s} + } + if len(b) != chacha20poly1305.KeySize { + return SymmetricKey{}, &ParseError{fmt.Sprintf("invalid symmetric hex key length: %d", len(b)), s} + } + var key SymmetricKey + copy(key[:], b) + return key, nil +} + +// SymmetricKey is a chacha20poly1305 key. +// It is used by WireGuard to represent pre-shared symmetric keys. +type SymmetricKey [chacha20poly1305.KeySize]byte + +func (k SymmetricKey) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k SymmetricKey) String() string { return "sym:" + k.Base64()[:8] } +func (k SymmetricKey) HexString() string { return hex.EncodeToString(k[:]) } +func (k SymmetricKey) IsZero() bool { return k.Equal(SymmetricKey{}) } +func (k SymmetricKey) Equal(k2 SymmetricKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } diff --git a/wgcfg/key_test.go b/wgcfg/key_test.go new file mode 100644 index 000000000..0b82d5fcd --- /dev/null +++ b/wgcfg/key_test.go @@ -0,0 +1,107 @@ +package wgcfg + +import ( + "bytes" + "testing" +) + +func TestKeyBasics(t *testing.T) { + k1, err := NewPresharedKey() + if err != nil { + t.Fatal(err) + } + + b, err := k1.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + t.Run("JSON round-trip", func(t *testing.T) { + // should preserve the keys + k2 := new(Key) + if err := k2.UnmarshalJSON(b); err != nil { + t.Fatal(err) + } + if !bytes.Equal(k1[:], k2[:]) { + t.Fatalf("k1 %v != k2 %v", k1[:], k2[:]) + } + if b1, b2 := k1.String(), k2.String(); b1 != b2 { + t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) + } + }) + + t.Run("JSON incompatible with PrivateKey", func(t *testing.T) { + k2 := new(PrivateKey) + if err := k2.UnmarshalText(b); err == nil { + t.Fatalf("successfully decoded key as private key") + } + }) + + t.Run("second key", func(t *testing.T) { + // A second call to NewPresharedKey should make a new key. + k3, err := NewPresharedKey() + if err != nil { + t.Fatal(err) + } + if bytes.Equal(k1[:], k3[:]) { + t.Fatalf("k1 %v == k3 %v", k1[:], k3[:]) + } + // Check for obvious comparables to make sure we are not generating bad strings somewhere. + if b1, b2 := k1.String(), k3.String(); b1 == b2 { + t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) + } + }) +} +func TestPrivateKeyBasics(t *testing.T) { + pri, err := NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + b, err := pri.MarshalText() + if err != nil { + t.Fatal(err) + } + + t.Run("JSON round-trip", func(t *testing.T) { + // should preserve the keys + pri2 := new(PrivateKey) + if err := pri2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + if !bytes.Equal(pri[:], pri2[:]) { + t.Fatalf("pri %v != pri2 %v", pri[:], pri2[:]) + } + if b1, b2 := pri.String(), pri2.String(); b1 != b2 { + t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) + } + if pub1, pub2 := pri.Public().String(), pri2.Public().String(); pub1 != pub2 { + t.Fatalf("base64-encoded public keys do not match: %s, %s", pub1, pub2) + } + }) + + t.Run("JSON incompatible with Key", func(t *testing.T) { + k2 := new(Key) + if err := k2.UnmarshalJSON(b); err == nil { + t.Fatalf("successfully decoded private key as key") + } + }) + + t.Run("second key", func(t *testing.T) { + // A second call to New should make a new key. + pri3, err := NewPrivateKey() + if err != nil { + t.Fatal(err) + } + if bytes.Equal(pri[:], pri3[:]) { + t.Fatalf("pri %v == pri3 %v", pri[:], pri3[:]) + } + // Check for obvious comparables to make sure we are not generating bad strings somewhere. + if b1, b2 := pri.String(), pri3.String(); b1 == b2 { + t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) + } + if pub1, pub2 := pri.Public().String(), pri3.Public().String(); pub1 == pub2 { + t.Fatalf("base64-encoded public keys match: %s, %s", pub1, pub2) + } + }) +} diff --git a/wgcfg/name.go b/wgcfg/name.go new file mode 100644 index 000000000..28bc0f08e --- /dev/null +++ b/wgcfg/name.go @@ -0,0 +1,49 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "regexp" + "strings" +) + +var reservedNames = []string{ + "CON", "PRN", "AUX", "NUL", + "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", + "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", +} + +const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00" + +var allowedNameFormat *regexp.Regexp + +func init() { + allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$") +} + +func isReserved(name string) bool { + if len(name) == 0 { + return false + } + for _, reserved := range reservedNames { + if strings.EqualFold(name, reserved) { + return true + } + } + return false +} + +func hasSpecialChars(name string) bool { + return strings.ContainsAny(name, specialChars) +} + +func TunnelNameIsValid(name string) bool { + // Aside from our own restrictions, let's impose the Windows restrictions first + if isReserved(name) || hasSpecialChars(name) { + return false + } + return allowedNameFormat.MatchString(name) +} diff --git a/wgcfg/parser.go b/wgcfg/parser.go new file mode 100644 index 000000000..45a60577a --- /dev/null +++ b/wgcfg/parser.go @@ -0,0 +1,397 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "encoding/hex" + "fmt" + "net" + "strconv" + "strings" +) + +type ParseError struct { + why string + offender string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("%s: ā€˜%sā€™", e.why, e.offender) +} + +func parseEndpoints(s string) ([]Endpoint, error) { + var eps []Endpoint + vals := strings.Split(s, ",") + for _, val := range vals { + e, err := parseEndpoint(val) + if err != nil { + return nil, err + } + eps = append(eps, *e) + } + return eps, nil +} + +func parseEndpoint(s string) (*Endpoint, error) { + i := strings.LastIndexByte(s, ':') + if i < 0 { + return nil, &ParseError{"Missing port from endpoint", s} + } + host, portStr := s[:i], s[i+1:] + if len(host) < 1 { + return nil, &ParseError{"Invalid endpoint host", host} + } + port, err := parsePort(portStr) + if err != nil { + return nil, err + } + hostColon := strings.IndexByte(host, ':') + if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { + err := &ParseError{"Brackets must contain an IPv6 address", host} + if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { + maybeV6 := net.ParseIP(host[1 : len(host)-1]) + if maybeV6 == nil || len(maybeV6) != net.IPv6len { + return nil, err + } + } else { + return nil, err + } + host = host[1 : len(host)-1] + } + return &Endpoint{host, uint16(port)}, nil +} + +func parseMTU(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 576 || m > 65535 { + return 0, &ParseError{"Invalid MTU", s} + } + return uint16(m), nil +} + +func parsePort(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{"Invalid port", s} + } + return uint16(m), nil +} + +func parsePersistentKeepalive(s string) (uint16, error) { + if s == "off" { + return 0, nil + } + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{"Invalid persistent keepalive", s} + } + return uint16(m), nil +} + +func parseKeyHex(s string) (*Key, error) { + k, err := hex.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeySize { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func parseBytesOrStamp(s string) (uint64, error) { + b, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0, &ParseError{"Number must be a number between 0 and 2^64-1: " + err.Error(), s} + } + return b, nil +} + +func splitList(s string) ([]string, error) { + var out []string + for _, split := range strings.Split(s, ",") { + trim := strings.TrimSpace(split) + if len(trim) == 0 { + return nil, &ParseError{"Two commas in a row", s} + } + out = append(out, trim) + } + return out, nil +} + +type parserState int + +const ( + inInterfaceSection parserState = iota + inPeerSection + notInASection +) + +func (c *Config) maybeAddPeer(p *Peer) { + if p != nil { + c.Peers = append(c.Peers, *p) + } +} + +func FromWgQuick(s string, name string) (*Config, error) { + if !TunnelNameIsValid(name) { + return nil, &ParseError{"Tunnel name is not valid", name} + } + lines := strings.Split(s, "\n") + parserState := notInASection + conf := Config{Name: name} + sawPrivateKey := false + var peer *Peer + for _, line := range lines { + pound := strings.IndexByte(line, '#') + if pound >= 0 { + line = line[:pound] + } + line = strings.TrimSpace(line) + lineLower := strings.ToLower(line) + if len(line) == 0 { + continue + } + if lineLower == "[interface]" { + conf.maybeAddPeer(peer) + parserState = inInterfaceSection + continue + } + if lineLower == "[peer]" { + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + continue + } + if parserState == notInASection { + return nil, &ParseError{"Line must occur in a section", line} + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{"Invalid config key is missing an equals separator", line} + } + key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:]) + if len(val) == 0 { + return nil, &ParseError{"Key must have a value", line} + } + if parserState == inInterfaceSection { + switch key { + case "privatekey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + conf.PrivateKey = PrivateKey(*k) + sawPrivateKey = true + case "listenport": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.ListenPort = p + case "mtu": + m, err := parseMTU(val) + if err != nil { + return nil, err + } + conf.MTU = m + case "address": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := ParseCIDR(address) + if err != nil { + return nil, err + } + conf.Addresses = append(conf.Addresses, *a) + } + case "dns": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a := ParseIP(address) + if a == nil { + return nil, &ParseError{"Invalid IP address", address} + } + conf.DNS = append(conf.DNS, *a) + } + default: + return nil, &ParseError{"Invalid key for [Interface] section", key} + } + } else if parserState == inPeerSection { + switch key { + case "publickey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "presharedkey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + peer.PresharedKey = SymmetricKey(*k) + case "allowedips": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := ParseCIDR(address) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, *a) + } + case "persistentkeepalive": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + eps, err := parseEndpoints(val) + if err != nil { + return nil, err + } + peer.Endpoints = eps + default: + return nil, &ParseError{"Invalid key for [Peer] section", key} + } + } + } + conf.maybeAddPeer(peer) + + if !sawPrivateKey { + return nil, &ParseError{"An interface must have a private key", "[none specified]"} + } + for _, p := range conf.Peers { + if p.PublicKey.IsZero() { + return nil, &ParseError{"All peers must have public keys", "[none specified]"} + } + } + + return &conf, nil +} + +// TODO(apenwarr): This is incompatibe with current Device.IpcSetOperation. +// It duplicates all the parser stuff in there, but is missing some +// keywords. Nothing useful seems to need it anymore. +func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) { + lines := strings.Split(s, "\n") + parserState := inInterfaceSection + conf := Config{ + Name: existingConfig.Name, + Addresses: existingConfig.Addresses, + DNS: existingConfig.DNS, + MTU: existingConfig.MTU, + } + var peer *Peer + for _, line := range lines { + if len(line) == 0 { + continue + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{"Invalid config key is missing an equals separator", line} + } + key, val := line[:equals], line[equals+1:] + if len(val) == 0 { + return nil, &ParseError{"Key must have a value", line} + } + switch key { + case "public_key": + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + case "errno": + if val == "0" { + continue + } else { + return nil, &ParseError{"Error in getting configuration", val} + } + } + if parserState == inInterfaceSection { + switch key { + case "private_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + conf.PrivateKey = PrivateKey(*k) + case "listen_port": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.ListenPort = p + case "fwmark": + // Ignored for now. + + default: + return nil, &ParseError{"Invalid key for interface section", key} + } + } else if parserState == inPeerSection { + switch key { + case "public_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "preshared_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + peer.PresharedKey = SymmetricKey(*k) + case "protocol_version": + if val != "1" { + return nil, &ParseError{"Protocol version must be 1", val} + } + case "allowed_ip": + a, err := ParseCIDR(val) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, *a) + case "persistent_keepalive_interval": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + eps, err := parseEndpoints(val) + if err != nil { + return nil, err + } + peer.Endpoints = eps + default: + return nil, &ParseError{"Invalid key for peer section", key} + } + } + } + conf.maybeAddPeer(peer) + + return &conf, nil +} diff --git a/wgcfg/parser_test.go b/wgcfg/parser_test.go new file mode 100644 index 000000000..d0df53739 --- /dev/null +++ b/wgcfg/parser_test.go @@ -0,0 +1,127 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "reflect" + "runtime" + "testing" +) + +const testInput = ` +[Interface] +Address = 10.192.122.1/24 +Address = 10.10.0.1/16 +PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= +ListenPort = 51820 #comments don't matter + +[Peer] +PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= +Endpoint = 192.95.5.67:1234 +AllowedIPs = 10.192.122.3/32, 10.192.124.1/24 + +[Peer] +PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = [2607:5300:60:6b0::c05f:543]:2468 +AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 +PersistentKeepalive = 100 + +[Peer] +PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= +PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = test.wireguard.com:18981 +AllowedIPs = 10.10.10.230/32` + +func noError(t *testing.T, err error) bool { + if err == nil { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error at %s:%d: %#v", fn, line, err) + return false +} + +func equal(t *testing.T, expected, actual interface{}) bool { + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func lenTest(t *testing.T, actualO interface{}, expected int) bool { + actual := reflect.ValueOf(actualO).Len() + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func contains(t *testing.T, list, element interface{}) bool { + listValue := reflect.ValueOf(list) + for i := 0; i < listValue.Len(); i++ { + if reflect.DeepEqual(listValue.Index(i).Interface(), element) { + return true + } + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element) + return false +} + +func TestFromWgQuick(t *testing.T) { + conf, err := FromWgQuick(testInput, "test") + if noError(t, err) { + + lenTest(t, conf.Addresses, 2) + contains(t, conf.Addresses, CIDR{IPv4(10, 10, 0, 1), 16}) + contains(t, conf.Addresses, CIDR{IPv4(10, 192, 122, 1), 24}) + equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.PrivateKey.String()) + equal(t, uint16(51820), conf.ListenPort) + + lenTest(t, conf.Peers, 3) + lenTest(t, conf.Peers[0].AllowedIPs, 2) + equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoints[0]) + equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.Base64()) + + lenTest(t, conf.Peers[1].AllowedIPs, 2) + equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoints[0]) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.Base64()) + equal(t, uint16(100), conf.Peers[1].PersistentKeepalive) + + lenTest(t, conf.Peers[2].AllowedIPs, 1) + equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoints[0]) + equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.Base64()) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.Base64()) + } +} + +func TestParseEndpoint(t *testing.T) { + _, err := parseEndpoint("[192.168.42.0:]:51880") + if err == nil { + t.Error("Error was expected") + } + e, err := parseEndpoint("192.168.42.0:51880") + if noError(t, err) { + equal(t, "192.168.42.0", e.Host) + equal(t, uint16(51880), e.Port) + } + e, err = parseEndpoint("test.wireguard.com:18981") + if noError(t, err) { + equal(t, "test.wireguard.com", e.Host) + equal(t, uint16(18981), e.Port) + } + e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") + if noError(t, err) { + equal(t, "2607:5300:60:6b0::c05f:543", e.Host) + equal(t, uint16(2468), e.Port) + } + _, err = parseEndpoint("[::::::invalid:18981") + if err == nil { + t.Error("Error was expected") + } +} diff --git a/wgcfg/writer.go b/wgcfg/writer.go new file mode 100644 index 000000000..aafb2a7ad --- /dev/null +++ b/wgcfg/writer.go @@ -0,0 +1,75 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "errors" + "fmt" + "net" + "strings" +) + +func (conf *Config) ToUAPI() (string, error) { + output := new(strings.Builder) + fmt.Fprintf(output, "private_key=%s\n", conf.PrivateKey.HexString()) + + if conf.ListenPort > 0 { + fmt.Fprintf(output, "listen_port=%d\n", conf.ListenPort) + } + + output.WriteString("replace_peers=true\n") + + for _, peer := range conf.Peers { + fmt.Fprintf(output, "public_key=%s\n", peer.PublicKey.HexString()) + fmt.Fprintf(output, "protocol_version=1\n") + fmt.Fprintf(output, "replace_allowed_ips=true\n") + + if !peer.PresharedKey.IsZero() { + fmt.Fprintf(output, "preshared_key = %s\n", peer.PresharedKey.String()) + } + + if len(peer.AllowedIPs) > 0 { + for _, address := range peer.AllowedIPs { + fmt.Fprintf(output, "allowed_ip=%s\n", address.String()) + } + } + + if len(peer.Endpoints) > 0 { + var reps []string + for _, ep := range peer.Endpoints { + ips, err := net.LookupIP(ep.Host) + if err != nil { + return "", err + } + var ip net.IP + for _, iterip := range ips { + iterip = iterip.To4() + if iterip != nil { + ip = iterip + break + } + if ip == nil { + ip = iterip + } + } + if ip == nil { + return "", errors.New("Unable to resolve IP address of endpoint") + } + resolvedEndpoint := Endpoint{ip.String(), ep.Port} + reps = append(reps, resolvedEndpoint.String()) + } + fmt.Fprintf(output, "endpoint=%s\n", strings.Join(reps, ",")) + } else { + fmt.Fprint(output, "endpoint=\n") + } + + // Note: this needs to come *after* endpoint definitions, + // because setting it will trigger a handshake to all + // already-defined endpoints. + fmt.Fprintf(output, "persistent_keepalive_interval=%d\n", peer.PersistentKeepalive) + } + return output.String(), nil +} From c7bb15a70df5cfc949c836429b5e39ce57d047f9 Mon Sep 17 00:00:00 2001 From: Tyler Kropp Date: Mon, 2 Mar 2020 19:41:28 -0500 Subject: [PATCH 07/12] wgcfg: add fast CIDR.Contains implementation Signed-off-by: Tyler Kropp --- wgcfg/ip.go | 26 ++++++++++- wgcfg/ip_test.go | 118 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 wgcfg/ip_test.go diff --git a/wgcfg/ip.go b/wgcfg/ip.go index ecf5faff7..7541d185d 100644 --- a/wgcfg/ip.go +++ b/wgcfg/ip.go @@ -2,6 +2,7 @@ package wgcfg import ( "fmt" + "math" "net" ) @@ -106,12 +107,33 @@ func (r *CIDR) IPNet() *net.IPNet { } return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)} } + func (r *CIDR) Contains(ip *IP) bool { if r == nil || ip == nil { return false } - // TODO: this isn't hard, write a more efficient implementation. - return r.IPNet().Contains(ip.IP()) + c := int8(r.Mask) + i := 0 + if r.IP.Is4() { + i = 12 + if ip.Is6() { + return false + } + } + for ; i < 16 && c > 0; i++ { + var x uint8 + if c < 8 { + x = 8 - uint8(c) + } + m := uint8(math.MaxUint8) >> x << x + a := r.IP.Addr[i] & m + b := ip.Addr[i] & m + if a != b { + return false + } + c -= 8 + } + return true } func (r CIDR) MarshalText() ([]byte, error) { diff --git a/wgcfg/ip_test.go b/wgcfg/ip_test.go new file mode 100644 index 000000000..d3682bbdc --- /dev/null +++ b/wgcfg/ip_test.go @@ -0,0 +1,118 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg_test + +import ( + "testing" + + "golang.zx2c4.com/wireguard/wgcfg" +) + +func TestCIDRContains(t *testing.T) { + t.Run("home router test", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/24") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("192.168.0.1") + if ip == nil { + t.Fatalf("address failed to parse") + } + if !r.Contains(ip) { + t.Fatalf("'%s' should contain '%s'", r, ip) + } + }) + + t.Run("IPv4 outside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/30") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("192.168.0.4") + if ip == nil { + t.Fatalf("address failed to parse") + } + if r.Contains(ip) { + t.Fatalf("'%s' should not contain '%s'", r, ip) + } + }) + + t.Run("IPv4 does not contain IPv6", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/24") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("2001:db8:85a3:0:0:8a2e:370:7334") + if ip == nil { + t.Fatalf("address failed to parse") + } + if r.Contains(ip) { + t.Fatalf("'%s' should not contain '%s'", r, ip) + } + }) + + t.Run("IPv6 inside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("2001:db8:1234::/48") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001") + if ip == nil { + t.Fatalf("ParseIP returned nil pointer") + } + if !r.Contains(ip) { + t.Fatalf("'%s' should not contain '%s'", r, ip) + } + }) + + t.Run("IPv6 outside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("2001:db8:1234:0:190b:0:1982::/126") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("2001:db8:1234:0:190b:0:1982:4") + if ip == nil { + t.Fatalf("ParseIP returned nil pointer") + } + if r.Contains(ip) { + t.Fatalf("'%s' should not contain '%s'", r, ip) + } + }) +} + +func BenchmarkCIDRContainsIPv4(b *testing.B) { + b.Run("IPv4", func(b *testing.B) { + r, err := wgcfg.ParseCIDR("192.168.1.0/24") + if err != nil { + b.Fatal(err) + } + ip := wgcfg.ParseIP("1.2.3.4") + if ip == nil { + b.Fatalf("ParseIP returned nil pointer") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.Contains(ip) + } + }) + + b.Run("IPv6", func(b *testing.B) { + r, err := wgcfg.ParseCIDR("2001:db8:1234::/48") + if err != nil { + b.Fatal(err) + } + ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001") + if ip == nil { + b.Fatalf("ParseIP returned nil pointer") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.Contains(ip) + } + }) +} From 30908fdc5d40f1a7e4023306b743c3074a30a467 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 16 Mar 2020 20:28:29 -0700 Subject: [PATCH 08/12] wgcfg: clean up IP type/method signatures Signed-off-by: Brad Fitzpatrick --- wgcfg/ip.go | 58 +++++++++++++++++++++++++----------------------- wgcfg/ip_test.go | 58 +++++++++++++++++++----------------------------- wgcfg/parser.go | 12 +++++----- 3 files changed, 59 insertions(+), 69 deletions(-) diff --git a/wgcfg/ip.go b/wgcfg/ip.go index 7541d185d..47fa91c27 100644 --- a/wgcfg/ip.go +++ b/wgcfg/ip.go @@ -16,9 +16,14 @@ type IP struct { func (ip IP) String() string { return net.IP(ip.Addr[:]).String() } -func (ip *IP) IP() net.IP { return net.IP(ip.Addr[:]) } -func (ip *IP) Is6() bool { return !ip.Is4() } -func (ip *IP) Is4() bool { +// IP converts ip into a standard library net.IP. +func (ip IP) IP() net.IP { return net.IP(ip.Addr[:]) } + +// Is6 reports whether ip is an IPv6 address. +func (ip IP) Is6() bool { return !ip.Is4() } + +// Is4 reports whether ip is an IPv4 address. +func (ip IP) Is4() bool { return ip.Addr[0] == 0 && ip.Addr[1] == 0 && ip.Addr[2] == 0 && ip.Addr[3] == 0 && ip.Addr[4] == 0 && ip.Addr[5] == 0 && @@ -26,19 +31,20 @@ func (ip *IP) Is4() bool { ip.Addr[8] == 0 && ip.Addr[9] == 0 && ip.Addr[10] == 0xff && ip.Addr[11] == 0xff } -func (ip *IP) To4() []byte { + +// To4 returns either a 4 byte slice for an IPv4 address, or nil if +// it's not IPv4. +func (ip IP) To4() []byte { if ip.Is4() { return ip.Addr[12:16] } else { return nil } } -func (ip *IP) Equal(x *IP) bool { - if ip == nil || x == nil { - return false - } - // TODO: this isn't hard, write a more efficient implementation. - return ip.IP().Equal(x.IP()) + +// Equal reports whether ip == x. +func (ip IP) Equal(x IP) bool { + return ip == x } func (ip IP) MarshalText() ([]byte, error) { @@ -46,11 +52,11 @@ func (ip IP) MarshalText() ([]byte, error) { } func (ip *IP) UnmarshalText(text []byte) error { - parsedIP := ParseIP(string(text)) - if parsedIP == nil { - return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", string(text)) + parsedIP, ok := ParseIP(string(text)) + if !ok { + return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", text) } - *ip = *parsedIP + *ip = parsedIP return nil } @@ -66,15 +72,14 @@ func IPv4(b0, b1, b2, b3 byte) (ip IP) { // ParseIP parses the string representation of an address into an IP. // // It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0". -// If the string is not a valid IP address, ParseIP returns nil. -func ParseIP(s string) *IP { +// The ok result reports whether s was a valid IP and ip is valid. +func ParseIP(s string) (ip IP, ok bool) { netIP := net.ParseIP(s) if netIP == nil { - return nil + return IP{}, false } - ip := new(IP) copy(ip.Addr[:], netIP.To16()) - return ip + return ip, true } // CIDR is a compact IP address and subnet mask. @@ -85,12 +90,12 @@ type CIDR struct { // ParseCIDR parses CIDR notation into a CIDR type. // Typical CIDR strings look like "192.168.1.0/24". -func ParseCIDR(s string) (cidr *CIDR, err error) { +func ParseCIDR(s string) (CIDR, error) { netIP, netAddr, err := net.ParseCIDR(s) if err != nil { - return nil, err + return CIDR{}, err } - cidr = new(CIDR) + var cidr CIDR copy(cidr.IP.Addr[:], netIP.To16()) ones, _ := netAddr.Mask.Size() cidr.Mask = uint8(ones) @@ -100,7 +105,7 @@ func ParseCIDR(s string) (cidr *CIDR, err error) { func (r CIDR) String() string { return r.IPNet().String() } -func (r *CIDR) IPNet() *net.IPNet { +func (r CIDR) IPNet() *net.IPNet { bits := 128 if r.IP.Is4() { bits = 32 @@ -108,10 +113,7 @@ func (r *CIDR) IPNet() *net.IPNet { return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)} } -func (r *CIDR) Contains(ip *IP) bool { - if r == nil || ip == nil { - return false - } +func (r CIDR) Contains(ip IP) bool { c := int8(r.Mask) i := 0 if r.IP.Is4() { @@ -145,6 +147,6 @@ func (r *CIDR) UnmarshalText(text []byte) error { if err != nil { return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err) } - *r = *cidr + *r = cidr return nil } diff --git a/wgcfg/ip_test.go b/wgcfg/ip_test.go index d3682bbdc..6cd41d319 100644 --- a/wgcfg/ip_test.go +++ b/wgcfg/ip_test.go @@ -11,18 +11,24 @@ import ( "golang.zx2c4.com/wireguard/wgcfg" ) +func parseIP(t testing.TB, ipStr string) wgcfg.IP { + t.Helper() + ip, ok := wgcfg.ParseIP(ipStr) + if !ok { + t.Fatalf("failed to parse IP: %q", ipStr) + } + return ip +} + func TestCIDRContains(t *testing.T) { t.Run("home router test", func(t *testing.T) { r, err := wgcfg.ParseCIDR("192.168.0.0/24") if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("192.168.0.1") - if ip == nil { - t.Fatalf("address failed to parse") - } + ip := parseIP(t, "192.168.0.1") if !r.Contains(ip) { - t.Fatalf("'%s' should contain '%s'", r, ip) + t.Fatalf("%q should contain %q", r, ip) } }) @@ -31,12 +37,9 @@ func TestCIDRContains(t *testing.T) { if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("192.168.0.4") - if ip == nil { - t.Fatalf("address failed to parse") - } + ip := parseIP(t, "192.168.0.4") if r.Contains(ip) { - t.Fatalf("'%s' should not contain '%s'", r, ip) + t.Fatalf("%q should not contain %q", r, ip) } }) @@ -45,12 +48,9 @@ func TestCIDRContains(t *testing.T) { if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("2001:db8:85a3:0:0:8a2e:370:7334") - if ip == nil { - t.Fatalf("address failed to parse") - } + ip := parseIP(t, "2001:db8:85a3:0:0:8a2e:370:7334") if r.Contains(ip) { - t.Fatalf("'%s' should not contain '%s'", r, ip) + t.Fatalf("%q should not contain %q", r, ip) } }) @@ -59,12 +59,9 @@ func TestCIDRContains(t *testing.T) { if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001") - if ip == nil { - t.Fatalf("ParseIP returned nil pointer") - } + ip := parseIP(t, "2001:db8:1234:0000:0000:0000:0000:0001") if !r.Contains(ip) { - t.Fatalf("'%s' should not contain '%s'", r, ip) + t.Fatalf("%q should not contain %q", r, ip) } }) @@ -73,12 +70,9 @@ func TestCIDRContains(t *testing.T) { if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("2001:db8:1234:0:190b:0:1982:4") - if ip == nil { - t.Fatalf("ParseIP returned nil pointer") - } + ip := parseIP(t, "2001:db8:1234:0:190b:0:1982:4") if r.Contains(ip) { - t.Fatalf("'%s' should not contain '%s'", r, ip) + t.Fatalf("%q should not contain %q", r, ip) } }) } @@ -89,12 +83,9 @@ func BenchmarkCIDRContainsIPv4(b *testing.B) { if err != nil { b.Fatal(err) } - ip := wgcfg.ParseIP("1.2.3.4") - if ip == nil { - b.Fatalf("ParseIP returned nil pointer") - } - + ip := parseIP(b, "1.2.3.4") b.ResetTimer() + for i := 0; i < b.N; i++ { r.Contains(ip) } @@ -105,12 +96,9 @@ func BenchmarkCIDRContainsIPv4(b *testing.B) { if err != nil { b.Fatal(err) } - ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001") - if ip == nil { - b.Fatalf("ParseIP returned nil pointer") - } - + ip := parseIP(b, "2001:db8:1234:0000:0000:0000:0000:0001") b.ResetTimer() + for i := 0; i < b.N; i++ { r.Contains(ip) } diff --git a/wgcfg/parser.go b/wgcfg/parser.go index 45a60577a..e71d32b1f 100644 --- a/wgcfg/parser.go +++ b/wgcfg/parser.go @@ -219,7 +219,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - conf.Addresses = append(conf.Addresses, *a) + conf.Addresses = append(conf.Addresses, a) } case "dns": addresses, err := splitList(val) @@ -227,11 +227,11 @@ func FromWgQuick(s string, name string) (*Config, error) { return nil, err } for _, address := range addresses { - a := ParseIP(address) - if a == nil { + a, ok := ParseIP(address) + if !ok { return nil, &ParseError{"Invalid IP address", address} } - conf.DNS = append(conf.DNS, *a) + conf.DNS = append(conf.DNS, a) } default: return nil, &ParseError{"Invalid key for [Interface] section", key} @@ -260,7 +260,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - peer.AllowedIPs = append(peer.AllowedIPs, *a) + peer.AllowedIPs = append(peer.AllowedIPs, a) } case "persistentkeepalive": p, err := parsePersistentKeepalive(val) @@ -373,7 +373,7 @@ func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) { if err != nil { return nil, err } - peer.AllowedIPs = append(peer.AllowedIPs, *a) + peer.AllowedIPs = append(peer.AllowedIPs, a) case "persistent_keepalive_interval": p, err := parsePersistentKeepalive(val) if err != nil { From 024d6ea4c2a5f4727e0d483db695ec91c5629af8 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 19 Mar 2020 22:37:57 -0700 Subject: [PATCH 09/12] wgcfg: fix bug preventing IPv6 addresses from working Signed-off-by: Brad Fitzpatrick --- wgcfg/writer.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/wgcfg/writer.go b/wgcfg/writer.go index aafb2a7ad..246a57d0b 100644 --- a/wgcfg/writer.go +++ b/wgcfg/writer.go @@ -6,7 +6,6 @@ package wgcfg import ( - "errors" "fmt" "net" "strings" @@ -46,9 +45,8 @@ func (conf *Config) ToUAPI() (string, error) { } var ip net.IP for _, iterip := range ips { - iterip = iterip.To4() - if iterip != nil { - ip = iterip + if ip4 := iterip.To4(); ip4 != nil { + ip = ip4 break } if ip == nil { @@ -56,7 +54,7 @@ func (conf *Config) ToUAPI() (string, error) { } } if ip == nil { - return "", errors.New("Unable to resolve IP address of endpoint") + return "", fmt.Errorf("unable to resolve IP address of endpoint %q (%v)", ep.Host, ips) } resolvedEndpoint := Endpoint{ip.String(), ep.Port} reps = append(reps, resolvedEndpoint.String()) From ed38ecd90be2a88a8e5d736b1660195c037ae38e Mon Sep 17 00:00:00 2001 From: Avery Pennarun Date: Mon, 14 Oct 2019 22:40:09 -0400 Subject: [PATCH 10/12] device: add a callback when an unexpected IP is used by a peer Signed-off-by: Avery Pennarun --- device/device.go | 22 +++++++++++++++++++++- device/receive.go | 15 +++++++-------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/device/device.go b/device/device.go index a9fedea86..872bd4502 100644 --- a/device/device.go +++ b/device/device.go @@ -17,6 +17,7 @@ import ( "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/wgcfg" ) type Device struct { @@ -61,6 +62,8 @@ type Device struct { indexTable IndexTable cookieChecker CookieChecker + unexpectedip func(key *wgcfg.Key, ip wgcfg.IP) + rate struct { underLoadUntil atomic.Value limiter ratelimiter.Ratelimiter @@ -253,7 +256,16 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { return nil } -func NewDevice(tunDevice tun.Device, logger *Logger) *Device { +type DeviceOptions struct { + // UnexpectedIP is called when a packet is received from a + // validated peer with an unexpected internal IP address. + // The packet is then dropped. + UnexpectedIP func(key *wgcfg.Key, ip wgcfg.IP) +} + +// TODO move logger into DeviceOptions +// TODO make opts non-vararg +func NewDevice(tunDevice tun.Device, logger *Logger, opts ...DeviceOptions) *Device { device := new(Device) device.isUp.Set(false) @@ -261,6 +273,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { device.log = logger + if len(opts) != 0 && opts[0].UnexpectedIP != nil { + device.unexpectedip = opts[0].UnexpectedIP + } else { + device.unexpectedip = func(key *wgcfg.Key, ip wgcfg.IP) { + device.log.Info.Printf("IPv4 packet with disallowed source address %s from %v", ip, key) + } + } + device.tun.device = tunDevice mtu, err := device.tun.device.MTU() if err != nil { diff --git a/device/receive.go b/device/receive.go index 4818d649e..e4212f017 100644 --- a/device/receive.go +++ b/device/receive.go @@ -18,6 +18,7 @@ import ( "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/wgcfg" ) type QueueHandshakeElement struct { @@ -591,10 +592,9 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] if device.allowedips.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer, - ) + ip := wgcfg.IPv4(src[0], src[1], src[2], src[3]) + key := (*wgcfg.Key)(&peer.handshake.remoteStatic) + device.unexpectedip(key, ip) continue } @@ -619,10 +619,9 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.allowedips.LookupIPv6(src) != peer { - logInfo.Println( - "IPv6 packet with disallowed source address from", - peer, - ) + ip := wgcfg.IPv4(src[0], src[1], src[2], src[3]) + key := (*wgcfg.Key)(&peer.handshake.remoteStatic) + device.unexpectedip(key, ip) continue } From 41ad63a8f8e75c4ca5b5e239f5c9eddcf5c95ada Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Wed, 26 Jun 2019 13:59:48 -0400 Subject: [PATCH 11/12] device: allow config of CreateBind/CreateEndpoint Signed-off-by: David Crawshaw --- device/device.go | 59 ++++++++++++++++++++++++++++++++++++++++++------ device/uapi.go | 6 +++-- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/device/device.go b/device/device.go index 872bd4502..70196e649 100644 --- a/device/device.go +++ b/device/device.go @@ -21,9 +21,12 @@ import ( ) type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger + isUp AtomicBool // device is (going) up + isClosed AtomicBool // device is closed? (acting as guard) + log *Logger + skipBindUpdate bool + createBind func(uport uint16) (conn.Bind, uint16, error) + createEndpoint func(key [32]byte, s string) (conn.Endpoint, error) // synchronized resources (locks acquired in order) @@ -261,11 +264,24 @@ type DeviceOptions struct { // validated peer with an unexpected internal IP address. // The packet is then dropped. UnexpectedIP func(key *wgcfg.Key, ip wgcfg.IP) + + // CreateEndpoint creates a conn.Endpoint for a given address. + // If unset, conn.CreateEndpoint is used. + CreateEndpoint func(key [32]byte, addr string) (conn.Endpoint, error) + + // CreateBind creates a conn.Bind bound to uport. + // If unset, conn.CreateBind is used. + CreateBind func(uport uint16) (conn.Bind, uint16, error) + + // SkipBindUpdate instructs Device to only call CreateBind once. + // + // TODO(crawshaw): remove this, it isn't useful externally. + SkipBindUpdate bool } // TODO move logger into DeviceOptions // TODO make opts non-vararg -func NewDevice(tunDevice tun.Device, logger *Logger, opts ...DeviceOptions) *Device { +func NewDevice(tunDevice tun.Device, logger *Logger, varOpts ...DeviceOptions) *Device { device := new(Device) device.isUp.Set(false) @@ -273,13 +289,37 @@ func NewDevice(tunDevice tun.Device, logger *Logger, opts ...DeviceOptions) *Dev device.log = logger - if len(opts) != 0 && opts[0].UnexpectedIP != nil { - device.unexpectedip = opts[0].UnexpectedIP + var opts DeviceOptions + if len(varOpts) != 0 { + if len(varOpts) != 1 { + panic("too many DeviceOptions") + } + opts = varOpts[0] + } + if opts.UnexpectedIP != nil { + device.unexpectedip = opts.UnexpectedIP } else { device.unexpectedip = func(key *wgcfg.Key, ip wgcfg.IP) { device.log.Info.Printf("IPv4 packet with disallowed source address %s from %v", ip, key) } } + if opts.CreateEndpoint != nil { + device.createEndpoint = opts.CreateEndpoint + } else { + device.createEndpoint = func(_ [32]byte, s string) (conn.Endpoint, error) { + return conn.CreateEndpoint(s) + } + } + if opts.CreateBind != nil { + device.createBind = func(uport uint16) (conn.Bind, uint16, error) { + return opts.CreateBind(uport) + } + } else { + device.createBind = func(uport uint16) (conn.Bind, uint16, error) { + return conn.CreateBind(uport) + } + } + device.skipBindUpdate = opts.SkipBindUpdate device.tun.device = tunDevice mtu, err := device.tun.device.MTU() @@ -490,6 +530,11 @@ func (device *Device) BindUpdate() error { device.net.Lock() defer device.net.Unlock() + if device.skipBindUpdate && device.net.bind != nil { + device.log.Debug.Println("UDP bind update skipped") + return nil + } + // close existing sockets if err := unsafeCloseBind(device); err != nil { @@ -504,7 +549,7 @@ func (device *Device) BindUpdate() error { var err error netc := &device.net - netc.bind, netc.port, err = conn.CreateBind(netc.port) + netc.bind, netc.port, err = device.createBind(netc.port) if err != nil { netc.bind = nil netc.port = 0 diff --git a/device/uapi.go b/device/uapi.go index 6cdccd615..b3f10f014 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -15,7 +15,6 @@ import ( "sync/atomic" "time" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" ) @@ -307,7 +306,10 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { err := func() error { peer.Lock() defer peer.Unlock() - endpoint, err := conn.CreateEndpoint(value) + peer.handshake.mutex.Lock() + defer peer.handshake.mutex.Unlock() + key := peer.handshake.remoteStatic + endpoint, err := device.createEndpoint(key, value) if err != nil { return err } From e2104a982d755bdac7e6eb6b3105410645573f87 Mon Sep 17 00:00:00 2001 From: Avery Pennarun Date: Mon, 14 Oct 2019 21:53:10 -0400 Subject: [PATCH 12/12] device: add a HandshakeDone callback Every time a peer handshake completes, we call this function. That lets a GUI immediately notice when the connection has been established. Signed-off-by: Avery Pennarun --- device/device.go | 9 +++++++++ device/send.go | 3 +++ 2 files changed, 12 insertions(+) diff --git a/device/device.go b/device/device.go index 70196e649..8f011270f 100644 --- a/device/device.go +++ b/device/device.go @@ -24,6 +24,7 @@ type Device struct { isUp AtomicBool // device is (going) up isClosed AtomicBool // device is closed? (acting as guard) log *Logger + handshakeDone func() skipBindUpdate bool createBind func(uport uint16) (conn.Bind, uint16, error) createEndpoint func(key [32]byte, s string) (conn.Endpoint, error) @@ -277,6 +278,13 @@ type DeviceOptions struct { // // TODO(crawshaw): remove this, it isn't useful externally. SkipBindUpdate bool + + // HandshakeDone is called every time we complete a peer handshake. + // + // TODO(crawshaw): This isn't quite right. Library users don't care + // about the handshake, per se, they want link status. + // Evolve this in that direction. + HandshakeDone func() } // TODO move logger into DeviceOptions @@ -320,6 +328,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger, varOpts ...DeviceOptions) * } } device.skipBindUpdate = opts.SkipBindUpdate + device.handshakeDone = opts.HandshakeDone device.tun.device = tunDevice mtu, err := device.tun.device.MTU() diff --git a/device/send.go b/device/send.go index 9e29d7778..607ca9b50 100644 --- a/device/send.go +++ b/device/send.go @@ -409,6 +409,9 @@ func (peer *Peer) RoutineNonce() { select { case <-peer.signals.newKeypairArrived: logDebug.Println(peer, "- Obtained awaited keypair") + if device.handshakeDone != nil { + device.handshakeDone() + } case <-peer.signals.flushNonceQueue: device.PutMessageBuffer(elem.buffer)