diff --git a/device/boundif_windows.go b/conn/boundif_windows.go similarity index 66% rename from device/boundif_windows.go rename to conn/boundif_windows.go index 690841528..fe38d05f5 100644 --- a/device/boundif_windows.go +++ b/conn/boundif_windows.go @@ -3,11 +3,10 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "encoding/binary" - "errors" "unsafe" "golang.org/x/sys/windows" @@ -18,17 +17,13 @@ const ( sockoptIPV6_UNICAST_IF = 31 ) -func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { +func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ bytes := make([]byte, 4) binary.BigEndian.PutUint32(bytes, interfaceIndex) interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn() + sysconn, err := bind.ipv4.SyscallConn() if err != nil { return err } @@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo if err != nil { return err } - device.net.bind.(*nativeBind).blackhole4 = blackhole + bind.blackhole4 = blackhole return nil } -func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() +func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + sysconn, err := bind.ipv6.SyscallConn() if err != nil { return err } @@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo if err != nil { return err } - device.net.bind.(*nativeBind).blackhole6 = blackhole + bind.blackhole6 = blackhole return nil } diff --git a/conn/conn.go b/conn/conn.go new file mode 100644 index 000000000..6b7db12ab --- /dev/null +++ b/conn/conn.go @@ -0,0 +1,101 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +// Package conn implements WireGuard's network connections. +package conn + +import ( + "errors" + "net" + "strings" +) + +// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. +type Bind interface { + // LastMark reports the last mark set for this Bind. + LastMark() uint32 + + // SetMark sets the mark for each packet sent through this Bind. + // This mark is passed to the kernel as the socket option SO_MARK. + SetMark(mark uint32) error + + // ReceiveIPv6 reads an IPv6 UDP packet into b. + // + // It reports the number of bytes read, n, + // the packet source address ep, + // and any error. + ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error) + + // ReceiveIPv4 reads an IPv4 UDP packet into b. + // + // It reports the number of bytes read, n, + // the packet source address ep, + // and any error. + ReceiveIPv4(b []byte) (n int, ep Endpoint, err error) + + // Send writes a packet b to address ep. + Send(b []byte, ep Endpoint) error + + // Close closes the Bind connection. + Close() error +} + +// CreateBind creates a Bind bound to a port. +// +// The value actualPort reports the actual port number the Bind +// object gets bound to. +func CreateBind(port uint16) (b Bind, actualPort uint16, err error) { + return createBind(port) +} + +// BindToInterface is implemented by Bind objects that support being +// tied to a single network interface. +type BindToInterface interface { + BindToInterface4(interfaceIndex uint32, blackhole bool) error + BindToInterface6(interfaceIndex uint32, blackhole bool) error +} + +// An Endpoint maintains the source/destination caching for a peer. +// +// dst : the remote address of a peer ("endpoint" in uapi terminology) +// src : the local address from which datagrams originate going to the peer +type Endpoint interface { + ClearSrc() // clears the source address + SrcToString() string // returns the local source address (ip:port) + DstToString() string // returns the destination address (ip:port) + DstToBytes() []byte // used for mac2 cookie calculations + DstIP() net.IP + SrcIP() net.IP +} + +func parseEndpoint(s string) (*net.UDPAddr, error) { + // ensure that the host is an IP address + + host, _, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { + // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just + // trying to make sure with a small sanity test that this is a real IP address and + // not something that's likely to incur DNS lookups. + host = host[:i] + } + if ip := net.ParseIP(host); ip == nil { + return nil, errors.New("Failed to parse IP address: " + host) + } + + // parse address and port + + addr, err := net.ResolveUDPAddr("udp", s) + if err != nil { + return nil, err + } + ip4 := addr.IP.To4() + if ip4 != nil { + addr.IP = ip4 + } + return addr, err +} diff --git a/device/conn_default.go b/conn/conn_default.go similarity index 94% rename from device/conn_default.go rename to conn/conn_default.go index 661f57d97..bad9d4df8 100644 --- a/device/conn_default.go +++ b/conn/conn_default.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "net" @@ -67,16 +67,13 @@ func (e *NativeEndpoint) SrcToString() string { } func listenNet(network string, port int) (*net.UDPConn, int, error) { - - // listen - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) if err != nil { return nil, 0, err } - // retrieve port - + // Retrieve port. + // TODO(crawshaw): under what circumstances is this necessary? laddr := conn.LocalAddr() uaddr, err := net.ResolveUDPAddr( laddr.Network(), @@ -100,7 +97,7 @@ func extractErrno(err error) error { return syscallErr.Err } -func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { +func createBind(uport uint16) (Bind, uint16, error) { var err error var bind nativeBind @@ -135,6 +132,8 @@ func (bind *nativeBind) Close() error { return err2 } +func (bind *nativeBind) LastMark() uint32 { return 0 } + func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { if bind.ipv4 == nil { return 0, nil, syscall.EAFNOSUPPORT diff --git a/device/conn_linux.go b/conn/conn_linux.go similarity index 63% rename from device/conn_linux.go rename to conn/conn_linux.go index e90b0e35b..523da4a45 100644 --- a/device/conn_linux.go +++ b/conn/conn_linux.go @@ -3,18 +3,9 @@ /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - * - * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. This is more or less a straight port - * of the sticky-sockets.c example code: - * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. */ -package device +package conn import ( "errors" @@ -25,7 +16,6 @@ import ( "unsafe" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" ) const ( @@ -33,8 +23,8 @@ const ( ) type IPv4Source struct { - src [4]byte - ifindex int32 + Src [4]byte + Ifindex int32 } type IPv6Source struct { @@ -49,6 +39,10 @@ type NativeEndpoint struct { isV6 bool } +func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() } +func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } +func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 } + func (endpoint *NativeEndpoint) src4() *IPv4Source { return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) } @@ -66,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { } type nativeBind struct { - sock4 int - sock6 int - netlinkSock int - netlinkCancel *rwcancel.RWCancel - lastMark uint32 + sock4 int + sock6 int + lastMark uint32 } var _ Endpoint = (*NativeEndpoint)(nil) @@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) { return nil, errors.New("Invalid IP address") } -func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - return -1, err - } - saddr := &unix.SockaddrNetlink{ - Family: unix.AF_NETLINK, - Groups: unix.RTMGRP_IPV4_ROUTE, - } - err = unix.Bind(sock, saddr) - if err != nil { - unix.Close(sock) - return -1, err - } - return sock, nil - -} - -func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) { +func createBind(port uint16) (Bind, uint16, error) { var err error var bind nativeBind var newPort uint16 - bind.netlinkSock, err = createNetlinkRouteSocket() - if err != nil { - return nil, 0, err - } - bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) - if err != nil { - unix.Close(bind.netlinkSock) - return nil, 0, err - } - - go bind.routineRouteListener(device) - - // attempt ipv6 bind, update port if successful - + // Attempt ipv6 bind, update port if successful. bind.sock6, newPort, err = create6(port) if err != nil { if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() return nil, 0, err } } else { port = newPort } - // attempt ipv4 bind, update port if successful - + // Attempt ipv4 bind, update port if successful. bind.sock4, newPort, err = create4(port) if err != nil { if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() unix.Close(bind.sock6) return nil, 0, err } @@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) { return &bind, port, nil } +func (bind *nativeBind) LastMark() uint32 { + return bind.lastMark +} + func (bind *nativeBind) SetMark(value uint32) error { if bind.sock6 != -1 { err := unix.SetsockoptInt( @@ -216,22 +178,18 @@ func closeUnblock(fd int) error { } func (bind *nativeBind) Close() error { - var err1, err2, err3 error + var err1, err2 error if bind.sock6 != -1 { err1 = closeUnblock(bind.sock6) } if bind.sock4 != -1 { err2 = closeUnblock(bind.sock4) } - err3 = bind.netlinkCancel.Cancel() if err1 != nil { return err1 } - if err2 != nil { - return err2 - } - return err3 + return err2 } func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { @@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error { func (end *NativeEndpoint) SrcIP() net.IP { if !end.isV6 { return net.IPv4( - end.src4().src[0], - end.src4().src[1], - end.src4().src[2], - end.src4().src[3], + end.src4().Src[0], + end.src4().Src[1], + end.src4().Src[2], + end.src4().Src[3], ) } else { return end.src6().src[:] @@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error { Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, }, unix.Inet4Pktinfo{ - Spec_dst: end.src4().src, - Ifindex: end.src4().ifindex, + Spec_dst: end.src4().Src, + Ifindex: end.src4().Ifindex, }, } @@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().src = cmsg.pktinfo.Spec_dst - end.src4().ifindex = cmsg.pktinfo.Ifindex + end.src4().Src = cmsg.pktinfo.Spec_dst + end.src4().Ifindex = cmsg.pktinfo.Ifindex } return size, nil @@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { return size, nil } - -func (bind *nativeBind) routineRouteListener(device *Device) { - type peerEndpointPtr struct { - peer *Peer - endpoint *Endpoint - } - var reqPeer map[uint32]peerEndpointPtr - var reqPeerLock sync.Mutex - - defer unix.Close(bind.netlinkSock) - - for msg := make([]byte, 1<<16); ; { - var err error - var msgn int - for { - msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) - if err == nil || !rwcancel.RetryAfterError(err) { - break - } - if !bind.netlinkCancel.ReadyRead() { - return - } - } - if err != nil { - return - } - - for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { - - hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) - - if uint(hdr.Len) > uint(len(remain)) { - break - } - - switch hdr.Type { - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - if hdr.Seq <= MaxPeers && hdr.Seq > 0 { - if uint(len(remain)) < uint(hdr.Len) { - break - } - if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { - attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] - for { - if uint(len(attr)) < uint(unix.SizeofRtAttr) { - break - } - attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) - if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { - break - } - if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { - ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) - reqPeerLock.Lock() - if reqPeer == nil { - reqPeerLock.Unlock() - break - } - pePtr, ok := reqPeer[hdr.Seq] - reqPeerLock.Unlock() - if !ok { - break - } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() - break - } - if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { - pePtr.peer.Unlock() - break - } - pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() - } - attr = attr[attrhdr.Len:] - } - } - break - } - reqPeerLock.Lock() - reqPeer = make(map[uint32]peerEndpointPtr) - reqPeerLock.Unlock() - go func() { - device.peers.RLock() - i := uint32(1) - for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { - peer.RUnlock() - continue - } - if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { - peer.RUnlock() - break - } - nlmsg := struct { - hdr unix.NlMsghdr - msg unix.RtMsg - dsthdr unix.RtAttr - dst [4]byte - srchdr unix.RtAttr - src [4]byte - markhdr unix.RtAttr - mark uint32 - }{ - unix.NlMsghdr{ - Type: uint16(unix.RTM_GETROUTE), - Flags: unix.NLM_F_REQUEST, - Seq: i, - }, - unix.RtMsg{ - Family: unix.AF_INET, - Dst_len: 32, - Src_len: 32, - }, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_DST, - }, - peer.endpoint.(*NativeEndpoint).dst4().Addr, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_SRC, - }, - peer.endpoint.(*NativeEndpoint).src4().src, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_MARK, - }, - uint32(bind.lastMark), - } - nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) - reqPeerLock.Lock() - reqPeer[i] = peerEndpointPtr{ - peer: peer, - endpoint: &peer.endpoint, - } - reqPeerLock.Unlock() - peer.RUnlock() - i++ - _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) - if err != nil { - break - } - } - device.peers.RUnlock() - }() - } - remain = remain[hdr.Len:] - } - } -} diff --git a/device/mark_default.go b/conn/mark_default.go similarity index 93% rename from device/mark_default.go rename to conn/mark_default.go index 7de2524c0..fc41ba993 100644 --- a/device/mark_default.go +++ b/conn/mark_default.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn func (bind *nativeBind) SetMark(mark uint32) error { return nil diff --git a/device/mark_unix.go b/conn/mark_unix.go similarity index 98% rename from device/mark_unix.go rename to conn/mark_unix.go index 669b32814..5334582e9 100644 --- a/device/mark_unix.go +++ b/conn/mark_unix.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "runtime" diff --git a/device/bind_test.go b/device/bind_test.go index 0c2e2cfd2..c5f7f68a9 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -5,11 +5,15 @@ package device -import "errors" +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) type DummyDatagram struct { msg []byte - endpoint Endpoint + endpoint conn.Endpoint world bool // better type } @@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error { return nil } -func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in6 if !ok { return 0, nil, errors.New("closed") @@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { return len(datagram.msg), datagram.endpoint, nil } -func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in4 if !ok { return 0, nil, errors.New("closed") @@ -50,6 +54,6 @@ func (b *DummyBind) Close() error { return nil } -func (b *DummyBind) Send(buff []byte, end Endpoint) error { +func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { return nil } diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go new file mode 100644 index 000000000..c4dd4effb --- /dev/null +++ b/device/bindsocketshim.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) + +// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. +func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + if device.net.bind == nil { + return errors.New("Bind is not yet initialized") + } + + if iface, ok := device.net.bind.(conn.BindToInterface); ok { + return iface.BindToInterface4(interfaceIndex, blackhole) + } + return nil +} + +// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. +func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + if device.net.bind == nil { + return errors.New("Bind is not yet initialized") + } + + if iface, ok := device.net.bind.(conn.BindToInterface); ok { + return iface.BindToInterface6(interfaceIndex, blackhole) + } + return nil +} diff --git a/device/conn.go b/device/conn.go deleted file mode 100644 index 7b341f6b2..000000000 --- a/device/conn.go +++ /dev/null @@ -1,187 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - "net" - "strings" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -const ( - ConnRoutineNumber = 2 -) - -/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic - */ -type Bind interface { - SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error - Close() error -} - -/* An Endpoint maintains the source/destination caching for a peer - * - * dst : the remote address of a peer ("endpoint" in uapi terminology) - * src : the local address from which datagrams originate going to the peer - */ -type Endpoint interface { - ClearSrc() // clears the source address - SrcToString() string // returns the local source address (ip:port) - DstToString() string // returns the destination address (ip:port) - DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP -} - -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} - -func unsafeCloseBind(device *Device) error { - var err error - netc := &device.net - if netc.bind != nil { - err = netc.bind.Close() - netc.bind = nil - } - netc.stopping.Wait() - return err -} - -func (device *Device) BindSetMark(mark uint32) error { - - device.net.Lock() - defer device.net.Unlock() - - // check if modified - - if device.net.fwmark == mark { - return nil - } - - // update fwmark on existing bind - - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { - if err := device.net.bind.SetMark(mark); err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - return nil -} - -func (device *Device) BindUpdate() error { - - device.net.Lock() - defer device.net.Unlock() - - // close existing sockets - - if err := unsafeCloseBind(device); err != nil { - return err - } - - // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port, device) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.starting.Add(ConnRoutineNumber) - device.net.stopping.Add(ConnRoutineNumber) - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - device.net.starting.Wait() - - device.log.Debug.Println("UDP bind has been updated") - } - - return nil -} - -func (device *Device) BindClose() error { - device.net.Lock() - err := unsafeCloseBind(device) - device.net.Unlock() - return err -} diff --git a/device/device.go b/device/device.go index 8c08f1c34..a9fedea86 100644 --- a/device/device.go +++ b/device/device.go @@ -11,15 +11,14 @@ import ( "sync/atomic" "time" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" ) -const ( - DeviceRoutineNumberPerCPU = 3 - DeviceRoutineNumberAdditional = 2 -) - type Device struct { isUp AtomicBool // device is (going) up isClosed AtomicBool // device is closed? (acting as guard) @@ -39,9 +38,10 @@ type Device struct { starting sync.WaitGroup stopping sync.WaitGroup sync.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind conn.Bind // bind interface + netlinkCancel *rwcancel.RWCancel + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } staticIdentity struct { @@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { cpus := runtime.NumCPU() device.state.starting.Wait() device.state.stopping.Wait() - device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) for i := 0; i < cpus; i += 1 { + device.state.starting.Add(3) + device.state.stopping.Add(3) go device.RoutineEncryption() go device.RoutineDecryption() go device.RoutineHandshake() } + device.state.starting.Add(2) + device.state.stopping.Add(2) go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() @@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { } device.peers.RUnlock() } + +func unsafeCloseBind(device *Device) error { + var err error + netc := &device.net + if netc.netlinkCancel != nil { + netc.netlinkCancel.Cancel() + } + if netc.bind != nil { + err = netc.bind.Close() + netc.bind = nil + } + netc.stopping.Wait() + return err +} + +func (device *Device) BindSetMark(mark uint32) error { + + device.net.Lock() + defer device.net.Unlock() + + // check if modified + + if device.net.fwmark == mark { + return nil + } + + // update fwmark on existing bind + + device.net.fwmark = mark + if device.isUp.Get() && device.net.bind != nil { + if err := device.net.bind.SetMark(mark); err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + return nil +} + +func (device *Device) BindUpdate() error { + + device.net.Lock() + defer device.net.Unlock() + + // close existing sockets + + if err := unsafeCloseBind(device); err != nil { + return err + } + + // open new sockets + + if device.isUp.Get() { + + // bind to new port + + var err error + netc := &device.net + netc.bind, netc.port, err = conn.CreateBind(netc.port) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.bind = nil + netc.port = 0 + return err + } + + // set fwmark + + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + // start receiving routines + + device.net.starting.Add(2) + device.net.stopping.Add(2) + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.starting.Wait() + + device.log.Debug.Println("UDP bind has been updated") + } + + return nil +} + +func (device *Device) BindClose() error { + device.net.Lock() + err := unsafeCloseBind(device) + device.net.Unlock() + return err +} diff --git a/device/device_test.go b/device/device_test.go index 14cc605ed..87ecfc873 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -8,15 +8,12 @@ package device import ( "bufio" "bytes" - "encoding/binary" - "io" "net" - "os" "strings" "testing" "time" - "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/tuntest" ) func TestTwoDevicePing(t *testing.T) { @@ -29,7 +26,7 @@ protocol_version=1 replace_allowed_ips=true allowed_ip=1.0.0.2/32 endpoint=127.0.0.1:53512` - tun1 := NewChannelTUN() + tun1 := tuntest.NewChannelTUN() dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: ")) dev1.Up() defer dev1.Close() @@ -45,7 +42,7 @@ protocol_version=1 replace_allowed_ips=true allowed_ip=1.0.0.1/32 endpoint=127.0.0.1:53511` - tun2 := NewChannelTUN() + tun2 := tuntest.NewChannelTUN() dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: ")) dev2.Up() defer dev2.Close() @@ -54,7 +51,7 @@ endpoint=127.0.0.1:53511` } t.Run("ping 1.0.0.1", func(t *testing.T) { - msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) + msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) tun2.Outbound <- msg2to1 select { case msgRecv := <-tun1.Inbound: @@ -67,7 +64,7 @@ endpoint=127.0.0.1:53511` }) t.Run("ping 1.0.0.2", func(t *testing.T) { - msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) + msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) tun1.Outbound <- msg1to2 select { case msgRecv := <-tun2.Inbound: @@ -80,139 +77,6 @@ endpoint=127.0.0.1:53511` }) } -func ping(dst, src net.IP) []byte { - localPort := uint16(1337) - seq := uint16(0) - - payload := make([]byte, 4) - binary.BigEndian.PutUint16(payload[0:], localPort) - binary.BigEndian.PutUint16(payload[2:], seq) - - return genICMPv4(payload, dst, src) -} - -// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. -func checksum(buf []byte, initial uint16) uint16 { - v := uint32(initial) - for i := 0; i < len(buf)-1; i += 2 { - v += uint32(binary.BigEndian.Uint16(buf[i:])) - } - if len(buf)%2 == 1 { - v += uint32(buf[len(buf)-1]) << 8 - } - for v > 0xffff { - v = (v >> 16) + (v & 0xffff) - } - return ^uint16(v) -} - -func genICMPv4(payload []byte, dst, src net.IP) []byte { - const ( - icmpv4ProtocolNumber = 1 - icmpv4Echo = 8 - icmpv4ChecksumOffset = 2 - icmpv4Size = 8 - ipv4Size = 20 - ipv4TotalLenOffset = 2 - ipv4ChecksumOffset = 10 - ttl = 65 - ) - - hdr := make([]byte, ipv4Size+icmpv4Size) - - ip := hdr[0:ipv4Size] - icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size] - - // https://tools.ietf.org/html/rfc792 - icmpv4[0] = icmpv4Echo // type - icmpv4[1] = 0 // code - chksum := ^checksum(icmpv4, checksum(payload, 0)) - binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) - - // https://tools.ietf.org/html/rfc760 section 3.1 - length := uint16(len(hdr) + len(payload)) - ip[0] = (4 << 4) | (ipv4Size / 4) - binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) - ip[8] = ttl - ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) - chksum = ^checksum(ip[:], 0) - binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) - - var v []byte - v = append(v, hdr...) - v = append(v, payload...) - return []byte(v) -} - -// TODO(crawshaw): find a reusable home for this. package devicetest? -type ChannelTUN struct { - Inbound chan []byte // incoming packets, closed on TUN close - Outbound chan []byte // outbound packets, blocks forever on TUN close - - closed chan struct{} - events chan tun.Event - tun chTun -} - -func NewChannelTUN() *ChannelTUN { - c := &ChannelTUN{ - Inbound: make(chan []byte), - Outbound: make(chan []byte), - closed: make(chan struct{}), - events: make(chan tun.Event, 1), - } - c.tun.c = c - c.events <- tun.EventUp - return c -} - -func (c *ChannelTUN) TUN() tun.Device { - return &c.tun -} - -type chTun struct { - c *ChannelTUN -} - -func (t *chTun) File() *os.File { return nil } - -func (t *chTun) Read(data []byte, offset int) (int, error) { - select { - case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? - case msg := <-t.c.Outbound: - return copy(data[offset:], msg), nil - } -} - -// Write is called by the wireguard device to deliver a packet for routing. -func (t *chTun) Write(data []byte, offset int) (int, error) { - if offset == -1 { - close(t.c.closed) - close(t.c.events) - return 0, io.EOF - } - msg := make([]byte, len(data)-offset) - copy(msg, data[offset:]) - select { - case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? - case t.c.Inbound <- msg: - return len(data) - offset, nil - } -} - -func (t *chTun) Flush() error { return nil } -func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } -func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } -func (t *chTun) Events() chan tun.Event { return t.c.events } -func (t *chTun) Close() error { - t.Write(nil, -1) - return nil -} - func assertNil(t *testing.T, err error) { if err != nil { t.Fatal(err) diff --git a/device/peer.go b/device/peer.go index 8a8224c66..a96f2612a 100644 --- a/device/peer.go +++ b/device/peer.go @@ -12,6 +12,8 @@ import ( "sync" "sync/atomic" "time" + + "golang.zx2c4.com/wireguard/conn" ) const ( @@ -19,20 +21,27 @@ const ( ) type Peer struct { - isRunning AtomicBool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer - keypairs Keypairs - handshake Handshake - device *Device - endpoint Endpoint - persistentKeepaliveInterval uint16 - - // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly + // These fields are accessed with atomic operations, which must be + // 64-bit aligned even on 32-bit platforms. Go guarantees that an + // allocated struct will be 64-bit aligned. So we place + // atomically-accessed fields up front, so that they can share in + // this alignment before smaller fields throw it off. stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch } + // This field is only 32 bits wide, but is still aligned to 64 + // bits. Don't place other atomic fields after this one. + isRunning AtomicBool + + // Mostly protects endpoint, but is generally taken whenever we modify peer + sync.RWMutex + keypairs Keypairs + handshake Handshake + device *Device + endpoint conn.Endpoint + persistentKeepaliveInterval uint16 timers struct { retransmitHandshake *Timer @@ -286,7 +295,7 @@ func (peer *Peer) Stop() { var RoamingDisabled bool -func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { +func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { if RoamingDisabled { return } diff --git a/device/peer_test.go b/device/peer_test.go new file mode 100644 index 000000000..de87ab638 --- /dev/null +++ b/device/peer_test.go @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "testing" + "unsafe" +) + +func checkAlignment(t *testing.T, name string, offset uintptr) { + t.Helper() + if offset%8 != 0 { + t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) + } +} + +// TestPeerAlignment checks that atomically-accessed fields are +// aligned to 64-bit boundaries, as required by the atomic package. +// +// Unfortunately, violating this rule on 32-bit platforms results in a +// hard segfault at runtime. +func TestPeerAlignment(t *testing.T) { + var p Peer + checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) + checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) +} diff --git a/device/receive.go b/device/receive.go index 7d0693e1f..4818d649e 100644 --- a/device/receive.go +++ b/device/receive.go @@ -17,12 +17,13 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { msgType uint32 packet []byte - endpoint Endpoint + endpoint conn.Endpoint buffer *[MaxMessageSize]byte } @@ -33,7 +34,7 @@ type QueueInboundElement struct { packet []byte counter uint64 keypair *Keypair - endpoint Endpoint + endpoint conn.Endpoint } func (elem *QueueInboundElement) Drop() { @@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { +func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { logDebug := device.log.Debug defer func() { @@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { var ( err error size int - endpoint Endpoint + endpoint conn.Endpoint ) for { diff --git a/device/sticky_default.go b/device/sticky_default.go new file mode 100644 index 000000000..1cc52f69b --- /dev/null +++ b/device/sticky_default.go @@ -0,0 +1,12 @@ +// +build !linux + +package device + +import ( + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + return nil, nil +} diff --git a/device/sticky_linux.go b/device/sticky_linux.go new file mode 100644 index 000000000..f9522c23e --- /dev/null +++ b/device/sticky_linux.go @@ -0,0 +1,215 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. This is more or less a straight port + * of the sticky-sockets.c example code: + * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. + */ + +package device + +import ( + "sync" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + netlinkSock, err := createNetlinkRouteSocket() + if err != nil { + return nil, err + } + netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) + if err != nil { + unix.Close(netlinkSock) + return nil, err + } + + go device.routineRouteListener(bind, netlinkSock, netlinkCancel) + + return netlinkCancel, nil +} + +func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { + type peerEndpointPtr struct { + peer *Peer + endpoint *conn.Endpoint + } + var reqPeer map[uint32]peerEndpointPtr + var reqPeerLock sync.Mutex + + defer unix.Close(netlinkSock) + + for msg := make([]byte, 1<<16); ; { + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.RetryAfterError(err) { + break + } + if !netlinkCancel.ReadyRead() { + return + } + } + if err != nil { + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if uint(hdr.Len) > uint(len(remain)) { + break + } + + switch hdr.Type { + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + if hdr.Seq <= MaxPeers && hdr.Seq > 0 { + if uint(len(remain)) < uint(hdr.Len) { + break + } + if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { + attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] + for { + if uint(len(attr)) < uint(unix.SizeofRtAttr) { + break + } + attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) + if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { + break + } + if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { + ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) + reqPeerLock.Lock() + if reqPeer == nil { + reqPeerLock.Unlock() + break + } + pePtr, ok := reqPeer[hdr.Seq] + reqPeerLock.Unlock() + if !ok { + break + } + pePtr.peer.Lock() + if &pePtr.peer.endpoint != pePtr.endpoint { + pePtr.peer.Unlock() + break + } + if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { + pePtr.peer.Unlock() + break + } + pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() + pePtr.peer.Unlock() + } + attr = attr[attrhdr.Len:] + } + } + break + } + reqPeerLock.Lock() + reqPeer = make(map[uint32]peerEndpointPtr) + reqPeerLock.Unlock() + go func() { + device.peers.RLock() + i := uint32(1) + for _, peer := range device.peers.keyMap { + peer.RLock() + if peer.endpoint == nil { + peer.RUnlock() + continue + } + nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) + if nativeEP == nil { + peer.RUnlock() + continue + } + if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { + peer.RUnlock() + break + } + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: i, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + nativeEP.Dst4().Addr, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + nativeEP.Src4().Src, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_MARK, + }, + uint32(bind.LastMark()), + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + reqPeerLock.Lock() + reqPeer[i] = peerEndpointPtr{ + peer: peer, + endpoint: &peer.endpoint, + } + reqPeerLock.Unlock() + peer.RUnlock() + i++ + _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + if err != nil { + break + } + } + device.peers.RUnlock() + }() + } + remain = remain[hdr.Len:] + } + } +} + +func createNetlinkRouteSocket() (int, error) { + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + if err != nil { + return -1, err + } + saddr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + } + err = unix.Bind(sock, saddr) + if err != nil { + unix.Close(sock) + return -1, err + } + return sock, nil +} diff --git a/device/uapi.go b/device/uapi.go index 72611ab5e..1671faa30 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -7,6 +7,7 @@ package device import ( "bufio" + "errors" "fmt" "io" "net" @@ -15,6 +16,7 @@ import ( "sync/atomic" "time" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" ) @@ -30,7 +32,7 @@ func (s IPCError) ErrorCode() int64 { return s.int64 } -func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError { +func (device *Device) IpcGetOperation(socket *bufio.Writer) error { lines := make([]string, 0, 100) send := func(line string) { lines = append(lines, line) @@ -105,7 +107,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError { return nil } -func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { +func (device *Device) IpcSetOperation(socket *bufio.Reader) error { scanner := bufio.NewScanner(socket) logError := device.log.Error logDebug := device.log.Debug @@ -306,7 +308,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { err := func() error { peer.Lock() defer peer.Unlock() - endpoint, err := CreateEndpoint(value) + endpoint, err := conn.CreateEndpoint(value) if err != nil { return err } @@ -420,10 +422,20 @@ func (device *Device) IpcHandle(socket net.Conn) { switch op { case "set=1\n": - status = device.IpcSetOperation(buffered.Reader) + err = device.IpcSetOperation(buffered.Reader) + if !errors.As(err, &status) { + // should never happen + device.log.Error.Println("Invalid UAPI error:", err) + status = &IPCError{1} + } case "get=1\n": - status = device.IpcGetOperation(buffered.Writer) + err = device.IpcGetOperation(buffered.Writer) + if !errors.As(err, &status) { + // should never happen + device.log.Error.Println("Invalid UAPI error:", err) + status = &IPCError{1} + } default: device.log.Error.Println("Invalid UAPI operation:", op) diff --git a/go.mod b/go.mod index f4e7b97c9..c2183e022 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module golang.zx2c4.com/wireguard -go 1.12 +go 1.13 require ( golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 7ab062314..2f97ebbc5 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -12,7 +12,6 @@ import ( "bytes" "errors" "fmt" - "net" "os" "sync" "syscall" @@ -32,7 +31,6 @@ const ( type NativeTun struct { tunFile *os.File index int32 // if index - name string // name of interface errors chan error // async error handling events chan Event // device related events nopi bool // the device was passed IFF_NO_PI @@ -40,6 +38,10 @@ type NativeTun struct { netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + + nameOnce sync.Once // guards calling initNameCache, which sets following fields + nameCache string // name of interface + nameErr error } func (tun *NativeTun) File() *os.File { @@ -164,11 +166,6 @@ func (tun *NativeTun) routineNetlinkListener() { } } -func (tun *NativeTun) isUp() (bool, error) { - inter, err := net.InterfaceByName(tun.name) - return inter.Flags&net.FlagUp != 0, err -} - func getIFIndex(name string) (int32, error) { fd, err := unix.Socket( unix.AF_INET, @@ -198,6 +195,11 @@ func getIFIndex(name string) (int32, error) { } func (tun *NativeTun) setMTU(n int) error { + name, err := tun.Name() + if err != nil { + return err + } + // open datagram socket fd, err := unix.Socket( unix.AF_INET, @@ -212,9 +214,8 @@ func (tun *NativeTun) setMTU(n int) error { defer unix.Close(fd) // do ioctl call - var ifr [ifReqSize]byte - copy(ifr[:], tun.name) + copy(ifr[:], name) *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) _, _, errno := unix.Syscall( unix.SYS_IOCTL, @@ -231,6 +232,11 @@ func (tun *NativeTun) setMTU(n int) error { } func (tun *NativeTun) MTU() (int, error) { + name, err := tun.Name() + if err != nil { + return 0, err + } + // open datagram socket fd, err := unix.Socket( unix.AF_INET, @@ -247,7 +253,7 @@ func (tun *NativeTun) MTU() (int, error) { // do ioctl call var ifr [ifReqSize]byte - copy(ifr[:], tun.name) + copy(ifr[:], name) _, _, errno := unix.Syscall( unix.SYS_IOCTL, uintptr(fd), @@ -262,6 +268,15 @@ func (tun *NativeTun) MTU() (int, error) { } func (tun *NativeTun) Name() (string, error) { + tun.nameOnce.Do(tun.initNameCache) + return tun.nameCache, tun.nameErr +} + +func (tun *NativeTun) initNameCache() { + tun.nameCache, tun.nameErr = tun.nameSlow() +} + +func (tun *NativeTun) nameSlow() (string, error) { sysconn, err := tun.tunFile.SyscallConn() if err != nil { return "", err @@ -282,13 +297,11 @@ func (tun *NativeTun) Name() (string, error) { if errno != 0 { return "", errors.New("failed to get name of TUN device: " + errno.Error()) } - nullStr := ifr[:] - i := bytes.IndexByte(nullStr, 0) - if i != -1 { - nullStr = nullStr[:i] + name := ifr[:] + if i := bytes.IndexByte(name, 0); i != -1 { + name = name[:i] } - tun.name = string(nullStr) - return tun.name, nil + return string(name), nil } func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { @@ -408,16 +421,15 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { statusListenersShutdown: make(chan struct{}), nopi: false, } - var err error - _, err = tun.Name() + name, err := tun.Name() if err != nil { return nil, err } // start event listener - tun.index, err = getIFIndex(tun.name) + tun.index, err = getIFIndex(name) if err != nil { return nil, err } diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go new file mode 100644 index 000000000..bdd96acd9 --- /dev/null +++ b/tun/tuntest/tuntest.go @@ -0,0 +1,150 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package tuntest + +import ( + "encoding/binary" + "io" + "net" + "os" + + "golang.zx2c4.com/wireguard/tun" +) + +func Ping(dst, src net.IP) []byte { + localPort := uint16(1337) + seq := uint16(0) + + payload := make([]byte, 4) + binary.BigEndian.PutUint16(payload[0:], localPort) + binary.BigEndian.PutUint16(payload[2:], seq) + + return genICMPv4(payload, dst, src) +} + +// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. +func checksum(buf []byte, initial uint16) uint16 { + v := uint32(initial) + for i := 0; i < len(buf)-1; i += 2 { + v += uint32(binary.BigEndian.Uint16(buf[i:])) + } + if len(buf)%2 == 1 { + v += uint32(buf[len(buf)-1]) << 8 + } + for v > 0xffff { + v = (v >> 16) + (v & 0xffff) + } + return ^uint16(v) +} + +func genICMPv4(payload []byte, dst, src net.IP) []byte { + const ( + icmpv4ProtocolNumber = 1 + icmpv4Echo = 8 + icmpv4ChecksumOffset = 2 + icmpv4Size = 8 + ipv4Size = 20 + ipv4TotalLenOffset = 2 + ipv4ChecksumOffset = 10 + ttl = 65 + ) + + hdr := make([]byte, ipv4Size+icmpv4Size) + + ip := hdr[0:ipv4Size] + icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size] + + // https://tools.ietf.org/html/rfc792 + icmpv4[0] = icmpv4Echo // type + icmpv4[1] = 0 // code + chksum := ^checksum(icmpv4, checksum(payload, 0)) + binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) + + // https://tools.ietf.org/html/rfc760 section 3.1 + length := uint16(len(hdr) + len(payload)) + ip[0] = (4 << 4) | (ipv4Size / 4) + binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) + ip[8] = ttl + ip[9] = icmpv4ProtocolNumber + copy(ip[12:], src.To4()) + copy(ip[16:], dst.To4()) + chksum = ^checksum(ip[:], 0) + binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) + + var v []byte + v = append(v, hdr...) + v = append(v, payload...) + return []byte(v) +} + +// TODO(crawshaw): find a reusable home for this. package devicetest? +type ChannelTUN struct { + Inbound chan []byte // incoming packets, closed on TUN close + Outbound chan []byte // outbound packets, blocks forever on TUN close + + closed chan struct{} + events chan tun.Event + tun chTun +} + +func NewChannelTUN() *ChannelTUN { + c := &ChannelTUN{ + Inbound: make(chan []byte), + Outbound: make(chan []byte), + closed: make(chan struct{}), + events: make(chan tun.Event, 1), + } + c.tun.c = c + c.events <- tun.EventUp + return c +} + +func (c *ChannelTUN) TUN() tun.Device { + return &c.tun +} + +type chTun struct { + c *ChannelTUN +} + +func (t *chTun) File() *os.File { return nil } + +func (t *chTun) Read(data []byte, offset int) (int, error) { + select { + case <-t.c.closed: + return 0, io.EOF // TODO(crawshaw): what is the correct error value? + case msg := <-t.c.Outbound: + return copy(data[offset:], msg), nil + } +} + +// Write is called by the wireguard device to deliver a packet for routing. +func (t *chTun) Write(data []byte, offset int) (int, error) { + if offset == -1 { + close(t.c.closed) + close(t.c.events) + return 0, io.EOF + } + msg := make([]byte, len(data)-offset) + copy(msg, data[offset:]) + select { + case <-t.c.closed: + return 0, io.EOF // TODO(crawshaw): what is the correct error value? + case t.c.Inbound <- msg: + return len(data) - offset, nil + } +} + +const DefaultMTU = 1420 + +func (t *chTun) Flush() error { return nil } +func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } +func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } +func (t *chTun) Events() chan tun.Event { return t.c.events } +func (t *chTun) Close() error { + t.Write(nil, -1) + return nil +} diff --git a/tun/wintun/namespace_windows.go b/tun/wintun/namespace_windows.go index f4316fe84..5f8a04138 100644 --- a/tun/wintun/namespace_windows.go +++ b/tun/wintun/namespace_windows.go @@ -59,9 +59,12 @@ func initializeNamespace() error { if err == windows.ERROR_PATH_NOT_FOUND { continue } + if err != nil { + return fmt.Errorf("OpenPrivateNamespace failed: %v", err) + } } if err != nil { - return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err) + return fmt.Errorf("CreatePrivateNamespace failed: %v", err) } break } diff --git a/wgcfg/config.go b/wgcfg/config.go new file mode 100644 index 000000000..2b5e7148d --- /dev/null +++ b/wgcfg/config.go @@ -0,0 +1,78 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +// Package wgcfg has types and a parser for representing WireGuard config. +package wgcfg + +import ( + "fmt" + "strings" +) + +// Config is a wireguard configuration. +type Config struct { + Name string + PrivateKey PrivateKey + Addresses []CIDR + ListenPort uint16 + MTU uint16 + DNS []IP + Peers []Peer +} + +type Peer struct { + PublicKey Key + PresharedKey SymmetricKey + AllowedIPs []CIDR + Endpoints []Endpoint + PersistentKeepalive uint16 +} + +type Endpoint struct { + Host string + Port uint16 +} + +func (e *Endpoint) String() string { + if strings.IndexByte(e.Host, ':') > 0 { + return fmt.Sprintf("[%s]:%d", e.Host, e.Port) + } + return fmt.Sprintf("%s:%d", e.Host, e.Port) +} + +func (e *Endpoint) IsEmpty() bool { + return len(e.Host) == 0 +} + +// Copy makes a deep copy of Config. +// The result aliases no memory with the original. +func (cfg Config) Copy() Config { + res := cfg + if res.Addresses != nil { + res.Addresses = append([]CIDR{}, res.Addresses...) + } + if res.DNS != nil { + res.DNS = append([]IP{}, res.DNS...) + } + peers := make([]Peer, 0, len(res.Peers)) + for _, peer := range res.Peers { + peers = append(peers, peer.Copy()) + } + res.Peers = peers + return res +} + +// Copy makes a deep copy of Peer. +// The result aliases no memory with the original. +func (peer Peer) Copy() Peer { + res := peer + if res.AllowedIPs != nil { + res.AllowedIPs = append([]CIDR{}, res.AllowedIPs...) + } + if res.Endpoints != nil { + res.Endpoints = append([]Endpoint{}, res.Endpoints...) + } + return res +} diff --git a/wgcfg/ip.go b/wgcfg/ip.go new file mode 100644 index 000000000..47fa91c27 --- /dev/null +++ b/wgcfg/ip.go @@ -0,0 +1,152 @@ +package wgcfg + +import ( + "fmt" + "math" + "net" +) + +// IP is an IPv4 or an IPv6 address. +// +// Internally the address is always represented in its IPv6 form. +// IPv4 addresses use the IPv4-in-IPv6 syntax. +type IP struct { + Addr [16]byte +} + +func (ip IP) String() string { return net.IP(ip.Addr[:]).String() } + +// IP converts ip into a standard library net.IP. +func (ip IP) IP() net.IP { return net.IP(ip.Addr[:]) } + +// Is6 reports whether ip is an IPv6 address. +func (ip IP) Is6() bool { return !ip.Is4() } + +// Is4 reports whether ip is an IPv4 address. +func (ip IP) Is4() bool { + return ip.Addr[0] == 0 && ip.Addr[1] == 0 && + ip.Addr[2] == 0 && ip.Addr[3] == 0 && + ip.Addr[4] == 0 && ip.Addr[5] == 0 && + ip.Addr[6] == 0 && ip.Addr[7] == 0 && + ip.Addr[8] == 0 && ip.Addr[9] == 0 && + ip.Addr[10] == 0xff && ip.Addr[11] == 0xff +} + +// To4 returns either a 4 byte slice for an IPv4 address, or nil if +// it's not IPv4. +func (ip IP) To4() []byte { + if ip.Is4() { + return ip.Addr[12:16] + } else { + return nil + } +} + +// Equal reports whether ip == x. +func (ip IP) Equal(x IP) bool { + return ip == x +} + +func (ip IP) MarshalText() ([]byte, error) { + return []byte(ip.String()), nil +} + +func (ip *IP) UnmarshalText(text []byte) error { + parsedIP, ok := ParseIP(string(text)) + if !ok { + return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", text) + } + *ip = parsedIP + return nil +} + +func IPv4(b0, b1, b2, b3 byte) (ip IP) { + ip.Addr[10], ip.Addr[11] = 0xff, 0xff // IPv4-in-IPv6 prefix + ip.Addr[12] = b0 + ip.Addr[13] = b1 + ip.Addr[14] = b2 + ip.Addr[15] = b3 + return ip +} + +// ParseIP parses the string representation of an address into an IP. +// +// It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0". +// The ok result reports whether s was a valid IP and ip is valid. +func ParseIP(s string) (ip IP, ok bool) { + netIP := net.ParseIP(s) + if netIP == nil { + return IP{}, false + } + copy(ip.Addr[:], netIP.To16()) + return ip, true +} + +// CIDR is a compact IP address and subnet mask. +type CIDR struct { + IP IP + Mask uint8 // 0-32 for IsIPv4, 4-128 for IsIPv6 +} + +// ParseCIDR parses CIDR notation into a CIDR type. +// Typical CIDR strings look like "192.168.1.0/24". +func ParseCIDR(s string) (CIDR, error) { + netIP, netAddr, err := net.ParseCIDR(s) + if err != nil { + return CIDR{}, err + } + var cidr CIDR + copy(cidr.IP.Addr[:], netIP.To16()) + ones, _ := netAddr.Mask.Size() + cidr.Mask = uint8(ones) + + return cidr, nil +} + +func (r CIDR) String() string { return r.IPNet().String() } + +func (r CIDR) IPNet() *net.IPNet { + bits := 128 + if r.IP.Is4() { + bits = 32 + } + return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)} +} + +func (r CIDR) Contains(ip IP) bool { + c := int8(r.Mask) + i := 0 + if r.IP.Is4() { + i = 12 + if ip.Is6() { + return false + } + } + for ; i < 16 && c > 0; i++ { + var x uint8 + if c < 8 { + x = 8 - uint8(c) + } + m := uint8(math.MaxUint8) >> x << x + a := r.IP.Addr[i] & m + b := ip.Addr[i] & m + if a != b { + return false + } + c -= 8 + } + return true +} + +func (r CIDR) MarshalText() ([]byte, error) { + return []byte(r.String()), nil +} + +func (r *CIDR) UnmarshalText(text []byte) error { + cidr, err := ParseCIDR(string(text)) + if err != nil { + return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err) + } + *r = cidr + return nil +} diff --git a/wgcfg/ip_test.go b/wgcfg/ip_test.go new file mode 100644 index 000000000..6cd41d319 --- /dev/null +++ b/wgcfg/ip_test.go @@ -0,0 +1,106 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg_test + +import ( + "testing" + + "golang.zx2c4.com/wireguard/wgcfg" +) + +func parseIP(t testing.TB, ipStr string) wgcfg.IP { + t.Helper() + ip, ok := wgcfg.ParseIP(ipStr) + if !ok { + t.Fatalf("failed to parse IP: %q", ipStr) + } + return ip +} + +func TestCIDRContains(t *testing.T) { + t.Run("home router test", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/24") + if err != nil { + t.Fatal(err) + } + ip := parseIP(t, "192.168.0.1") + if !r.Contains(ip) { + t.Fatalf("%q should contain %q", r, ip) + } + }) + + t.Run("IPv4 outside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/30") + if err != nil { + t.Fatal(err) + } + ip := parseIP(t, "192.168.0.4") + if r.Contains(ip) { + t.Fatalf("%q should not contain %q", r, ip) + } + }) + + t.Run("IPv4 does not contain IPv6", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/24") + if err != nil { + t.Fatal(err) + } + ip := parseIP(t, "2001:db8:85a3:0:0:8a2e:370:7334") + if r.Contains(ip) { + t.Fatalf("%q should not contain %q", r, ip) + } + }) + + t.Run("IPv6 inside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("2001:db8:1234::/48") + if err != nil { + t.Fatal(err) + } + ip := parseIP(t, "2001:db8:1234:0000:0000:0000:0000:0001") + if !r.Contains(ip) { + t.Fatalf("%q should not contain %q", r, ip) + } + }) + + t.Run("IPv6 outside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("2001:db8:1234:0:190b:0:1982::/126") + if err != nil { + t.Fatal(err) + } + ip := parseIP(t, "2001:db8:1234:0:190b:0:1982:4") + if r.Contains(ip) { + t.Fatalf("%q should not contain %q", r, ip) + } + }) +} + +func BenchmarkCIDRContainsIPv4(b *testing.B) { + b.Run("IPv4", func(b *testing.B) { + r, err := wgcfg.ParseCIDR("192.168.1.0/24") + if err != nil { + b.Fatal(err) + } + ip := parseIP(b, "1.2.3.4") + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Contains(ip) + } + }) + + b.Run("IPv6", func(b *testing.B) { + r, err := wgcfg.ParseCIDR("2001:db8:1234::/48") + if err != nil { + b.Fatal(err) + } + ip := parseIP(b, "2001:db8:1234:0000:0000:0000:0000:0001") + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Contains(ip) + } + }) +} diff --git a/wgcfg/key.go b/wgcfg/key.go new file mode 100644 index 000000000..cdbbeea70 --- /dev/null +++ b/wgcfg/key.go @@ -0,0 +1,239 @@ +package wgcfg + +import ( + "bytes" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" +) + +const KeySize = 32 + +// Key is curve25519 key. +// It is used by WireGuard to represent public and preshared keys. +type Key [KeySize]byte + +// NewPresharedKey generates a new random key. +func NewPresharedKey() (*Key, error) { + var k [KeySize]byte + _, err := rand.Read(k[:]) + if err != nil { + return nil, err + } + return (*Key)(&k), nil +} + +func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) } + +func ParseHexKey(s string) (Key, error) { + b, err := hex.DecodeString(s) + if err != nil { + return Key{}, &ParseError{"invalid hex key: " + err.Error(), s} + } + if len(b) != KeySize { + return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} + } + + var key Key + copy(key[:], b) + return key, nil +} + +func ParsePrivateHexKey(v string) (PrivateKey, error) { + k, err := ParseHexKey(v) + if err != nil { + return PrivateKey{}, err + } + pk := PrivateKey(k) + if pk.IsZero() { + // Do not clamp a zero key, pass the zero through + // (much like NaN propagation) so that IsZero reports + // a useful result. + return pk, nil + } + pk.clamp() + return pk, nil +} + +func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k Key) String() string { return "pub:" + k.Base64()[:8] } +func (k Key) HexString() string { return hex.EncodeToString(k[:]) } +func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } + +func (k *Key) ShortString() string { + if k.IsZero() { + return "[empty]" + } + long := k.String() + if len(long) < 10 { + return "invalid" + } + return "[" + long[0:4] + "ā€¦" + long[len(long)-5:len(long)-1] + "]" +} + +func (k *Key) IsZero() bool { + if k == nil { + return true + } + var zeros Key + return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 +} + +func (k *Key) MarshalJSON() ([]byte, error) { + if k == nil { + return []byte("null"), nil + } + buf := new(bytes.Buffer) + fmt.Fprintf(buf, `"%x"`, k[:]) + return buf.Bytes(), nil +} + +func (k *Key) UnmarshalJSON(b []byte) error { + if k == nil { + return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer") + } + if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { + return errors.New("wgcfg.Key: UnmarshalJSON not given a string") + } + b = b[1 : len(b)-1] + key, err := ParseHexKey(string(b)) + if err != nil { + return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (a *Key) LessThan(b *Key) bool { + for i := range a { + if a[i] < b[i] { + return true + } else if a[i] > b[i] { + return false + } + } + return false +} + +// PrivateKey is curve25519 key. +// It is used by WireGuard to represent private keys. +type PrivateKey [KeySize]byte + +// NewPrivateKey generates a new curve25519 secret key. +// It conforms to the format described on https://cr.yp.to/ecdh.html. +func NewPrivateKey() (PrivateKey, error) { + k, err := NewPresharedKey() + if err != nil { + return PrivateKey{}, err + } + k[0] &= 248 + k[31] = (k[31] & 127) | 64 + return (PrivateKey)(*k), nil +} + +func ParsePrivateKey(b64 string) (*PrivateKey, error) { + k, err := parseKeyBase64(base64.StdEncoding, b64) + return (*PrivateKey)(k), err +} + +func (k *PrivateKey) String() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) } +func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } + +func (k *PrivateKey) IsZero() bool { + pk := Key(*k) + return pk.IsZero() +} + +func (k *PrivateKey) clamp() { + k[0] &= 248 + k[31] = (k[31] & 127) | 64 +} + +// Public computes the public key matching this curve25519 secret key. +func (k *PrivateKey) Public() Key { + pk := Key(*k) + if pk.IsZero() { + panic("Tried to generate emptyPrivateKey.Public()") + } + var p [KeySize]byte + curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k)) + return (Key)(p) +} + +func (k PrivateKey) MarshalText() ([]byte, error) { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, `privkey:%x`, k[:]) + return buf.Bytes(), nil +} + +func (k *PrivateKey) UnmarshalText(b []byte) error { + s := string(b) + if !strings.HasPrefix(s, `privkey:`) { + return errors.New("wgcfg.PrivateKey: UnmarshalText not given a private-key string") + } + s = strings.TrimPrefix(s, `privkey:`) + key, err := ParseHexKey(s) + if err != nil { + return fmt.Errorf("wgcfg.PrivateKey: UnmarshalText: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) { + apk := (*[KeySize]byte)(&pub) + ask := (*[KeySize]byte)(&k) + curve25519.ScalarMult(&ss, ask, apk) + return ss +} + +func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { + k, err := enc.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeySize { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func ParseSymmetricKey(b64 string) (SymmetricKey, error) { + k, err := parseKeyBase64(base64.StdEncoding, b64) + if err != nil { + return SymmetricKey{}, err + } + return SymmetricKey(*k), nil +} + +func ParseSymmetricHexKey(s string) (SymmetricKey, error) { + b, err := hex.DecodeString(s) + if err != nil { + return SymmetricKey{}, &ParseError{"invalid symmetric hex key: " + err.Error(), s} + } + if len(b) != chacha20poly1305.KeySize { + return SymmetricKey{}, &ParseError{fmt.Sprintf("invalid symmetric hex key length: %d", len(b)), s} + } + var key SymmetricKey + copy(key[:], b) + return key, nil +} + +// SymmetricKey is a 32-byte value used as a pre-shared key. +type SymmetricKey [chacha20poly1305.KeySize]byte + +func (k SymmetricKey) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k SymmetricKey) String() string { return "sym:" + k.Base64()[:8] } +func (k SymmetricKey) HexString() string { return hex.EncodeToString(k[:]) } +func (k SymmetricKey) IsZero() bool { return k.Equal(SymmetricKey{}) } +func (k SymmetricKey) Equal(k2 SymmetricKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } diff --git a/wgcfg/key_test.go b/wgcfg/key_test.go new file mode 100644 index 000000000..0b82d5fcd --- /dev/null +++ b/wgcfg/key_test.go @@ -0,0 +1,107 @@ +package wgcfg + +import ( + "bytes" + "testing" +) + +func TestKeyBasics(t *testing.T) { + k1, err := NewPresharedKey() + if err != nil { + t.Fatal(err) + } + + b, err := k1.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + t.Run("JSON round-trip", func(t *testing.T) { + // should preserve the keys + k2 := new(Key) + if err := k2.UnmarshalJSON(b); err != nil { + t.Fatal(err) + } + if !bytes.Equal(k1[:], k2[:]) { + t.Fatalf("k1 %v != k2 %v", k1[:], k2[:]) + } + if b1, b2 := k1.String(), k2.String(); b1 != b2 { + t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) + } + }) + + t.Run("JSON incompatible with PrivateKey", func(t *testing.T) { + k2 := new(PrivateKey) + if err := k2.UnmarshalText(b); err == nil { + t.Fatalf("successfully decoded key as private key") + } + }) + + t.Run("second key", func(t *testing.T) { + // A second call to NewPresharedKey should make a new key. + k3, err := NewPresharedKey() + if err != nil { + t.Fatal(err) + } + if bytes.Equal(k1[:], k3[:]) { + t.Fatalf("k1 %v == k3 %v", k1[:], k3[:]) + } + // Check for obvious comparables to make sure we are not generating bad strings somewhere. + if b1, b2 := k1.String(), k3.String(); b1 == b2 { + t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) + } + }) +} +func TestPrivateKeyBasics(t *testing.T) { + pri, err := NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + b, err := pri.MarshalText() + if err != nil { + t.Fatal(err) + } + + t.Run("JSON round-trip", func(t *testing.T) { + // should preserve the keys + pri2 := new(PrivateKey) + if err := pri2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + if !bytes.Equal(pri[:], pri2[:]) { + t.Fatalf("pri %v != pri2 %v", pri[:], pri2[:]) + } + if b1, b2 := pri.String(), pri2.String(); b1 != b2 { + t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) + } + if pub1, pub2 := pri.Public().String(), pri2.Public().String(); pub1 != pub2 { + t.Fatalf("base64-encoded public keys do not match: %s, %s", pub1, pub2) + } + }) + + t.Run("JSON incompatible with Key", func(t *testing.T) { + k2 := new(Key) + if err := k2.UnmarshalJSON(b); err == nil { + t.Fatalf("successfully decoded private key as key") + } + }) + + t.Run("second key", func(t *testing.T) { + // A second call to New should make a new key. + pri3, err := NewPrivateKey() + if err != nil { + t.Fatal(err) + } + if bytes.Equal(pri[:], pri3[:]) { + t.Fatalf("pri %v == pri3 %v", pri[:], pri3[:]) + } + // Check for obvious comparables to make sure we are not generating bad strings somewhere. + if b1, b2 := pri.String(), pri3.String(); b1 == b2 { + t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) + } + if pub1, pub2 := pri.Public().String(), pri3.Public().String(); pub1 == pub2 { + t.Fatalf("base64-encoded public keys match: %s, %s", pub1, pub2) + } + }) +} diff --git a/wgcfg/name.go b/wgcfg/name.go new file mode 100644 index 000000000..28bc0f08e --- /dev/null +++ b/wgcfg/name.go @@ -0,0 +1,49 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "regexp" + "strings" +) + +var reservedNames = []string{ + "CON", "PRN", "AUX", "NUL", + "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", + "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", +} + +const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00" + +var allowedNameFormat *regexp.Regexp + +func init() { + allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$") +} + +func isReserved(name string) bool { + if len(name) == 0 { + return false + } + for _, reserved := range reservedNames { + if strings.EqualFold(name, reserved) { + return true + } + } + return false +} + +func hasSpecialChars(name string) bool { + return strings.ContainsAny(name, specialChars) +} + +func TunnelNameIsValid(name string) bool { + // Aside from our own restrictions, let's impose the Windows restrictions first + if isReserved(name) || hasSpecialChars(name) { + return false + } + return allowedNameFormat.MatchString(name) +} diff --git a/wgcfg/parser.go b/wgcfg/parser.go new file mode 100644 index 000000000..e71d32b1f --- /dev/null +++ b/wgcfg/parser.go @@ -0,0 +1,397 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "encoding/hex" + "fmt" + "net" + "strconv" + "strings" +) + +type ParseError struct { + why string + offender string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("%s: ā€˜%sā€™", e.why, e.offender) +} + +func parseEndpoints(s string) ([]Endpoint, error) { + var eps []Endpoint + vals := strings.Split(s, ",") + for _, val := range vals { + e, err := parseEndpoint(val) + if err != nil { + return nil, err + } + eps = append(eps, *e) + } + return eps, nil +} + +func parseEndpoint(s string) (*Endpoint, error) { + i := strings.LastIndexByte(s, ':') + if i < 0 { + return nil, &ParseError{"Missing port from endpoint", s} + } + host, portStr := s[:i], s[i+1:] + if len(host) < 1 { + return nil, &ParseError{"Invalid endpoint host", host} + } + port, err := parsePort(portStr) + if err != nil { + return nil, err + } + hostColon := strings.IndexByte(host, ':') + if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { + err := &ParseError{"Brackets must contain an IPv6 address", host} + if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { + maybeV6 := net.ParseIP(host[1 : len(host)-1]) + if maybeV6 == nil || len(maybeV6) != net.IPv6len { + return nil, err + } + } else { + return nil, err + } + host = host[1 : len(host)-1] + } + return &Endpoint{host, uint16(port)}, nil +} + +func parseMTU(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 576 || m > 65535 { + return 0, &ParseError{"Invalid MTU", s} + } + return uint16(m), nil +} + +func parsePort(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{"Invalid port", s} + } + return uint16(m), nil +} + +func parsePersistentKeepalive(s string) (uint16, error) { + if s == "off" { + return 0, nil + } + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{"Invalid persistent keepalive", s} + } + return uint16(m), nil +} + +func parseKeyHex(s string) (*Key, error) { + k, err := hex.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeySize { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func parseBytesOrStamp(s string) (uint64, error) { + b, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0, &ParseError{"Number must be a number between 0 and 2^64-1: " + err.Error(), s} + } + return b, nil +} + +func splitList(s string) ([]string, error) { + var out []string + for _, split := range strings.Split(s, ",") { + trim := strings.TrimSpace(split) + if len(trim) == 0 { + return nil, &ParseError{"Two commas in a row", s} + } + out = append(out, trim) + } + return out, nil +} + +type parserState int + +const ( + inInterfaceSection parserState = iota + inPeerSection + notInASection +) + +func (c *Config) maybeAddPeer(p *Peer) { + if p != nil { + c.Peers = append(c.Peers, *p) + } +} + +func FromWgQuick(s string, name string) (*Config, error) { + if !TunnelNameIsValid(name) { + return nil, &ParseError{"Tunnel name is not valid", name} + } + lines := strings.Split(s, "\n") + parserState := notInASection + conf := Config{Name: name} + sawPrivateKey := false + var peer *Peer + for _, line := range lines { + pound := strings.IndexByte(line, '#') + if pound >= 0 { + line = line[:pound] + } + line = strings.TrimSpace(line) + lineLower := strings.ToLower(line) + if len(line) == 0 { + continue + } + if lineLower == "[interface]" { + conf.maybeAddPeer(peer) + parserState = inInterfaceSection + continue + } + if lineLower == "[peer]" { + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + continue + } + if parserState == notInASection { + return nil, &ParseError{"Line must occur in a section", line} + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{"Invalid config key is missing an equals separator", line} + } + key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:]) + if len(val) == 0 { + return nil, &ParseError{"Key must have a value", line} + } + if parserState == inInterfaceSection { + switch key { + case "privatekey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + conf.PrivateKey = PrivateKey(*k) + sawPrivateKey = true + case "listenport": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.ListenPort = p + case "mtu": + m, err := parseMTU(val) + if err != nil { + return nil, err + } + conf.MTU = m + case "address": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := ParseCIDR(address) + if err != nil { + return nil, err + } + conf.Addresses = append(conf.Addresses, a) + } + case "dns": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, ok := ParseIP(address) + if !ok { + return nil, &ParseError{"Invalid IP address", address} + } + conf.DNS = append(conf.DNS, a) + } + default: + return nil, &ParseError{"Invalid key for [Interface] section", key} + } + } else if parserState == inPeerSection { + switch key { + case "publickey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "presharedkey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + peer.PresharedKey = SymmetricKey(*k) + case "allowedips": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := ParseCIDR(address) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, a) + } + case "persistentkeepalive": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + eps, err := parseEndpoints(val) + if err != nil { + return nil, err + } + peer.Endpoints = eps + default: + return nil, &ParseError{"Invalid key for [Peer] section", key} + } + } + } + conf.maybeAddPeer(peer) + + if !sawPrivateKey { + return nil, &ParseError{"An interface must have a private key", "[none specified]"} + } + for _, p := range conf.Peers { + if p.PublicKey.IsZero() { + return nil, &ParseError{"All peers must have public keys", "[none specified]"} + } + } + + return &conf, nil +} + +// TODO(apenwarr): This is incompatibe with current Device.IpcSetOperation. +// It duplicates all the parser stuff in there, but is missing some +// keywords. Nothing useful seems to need it anymore. +func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) { + lines := strings.Split(s, "\n") + parserState := inInterfaceSection + conf := Config{ + Name: existingConfig.Name, + Addresses: existingConfig.Addresses, + DNS: existingConfig.DNS, + MTU: existingConfig.MTU, + } + var peer *Peer + for _, line := range lines { + if len(line) == 0 { + continue + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{"Invalid config key is missing an equals separator", line} + } + key, val := line[:equals], line[equals+1:] + if len(val) == 0 { + return nil, &ParseError{"Key must have a value", line} + } + switch key { + case "public_key": + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + case "errno": + if val == "0" { + continue + } else { + return nil, &ParseError{"Error in getting configuration", val} + } + } + if parserState == inInterfaceSection { + switch key { + case "private_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + conf.PrivateKey = PrivateKey(*k) + case "listen_port": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.ListenPort = p + case "fwmark": + // Ignored for now. + + default: + return nil, &ParseError{"Invalid key for interface section", key} + } + } else if parserState == inPeerSection { + switch key { + case "public_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "preshared_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + peer.PresharedKey = SymmetricKey(*k) + case "protocol_version": + if val != "1" { + return nil, &ParseError{"Protocol version must be 1", val} + } + case "allowed_ip": + a, err := ParseCIDR(val) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, a) + case "persistent_keepalive_interval": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + eps, err := parseEndpoints(val) + if err != nil { + return nil, err + } + peer.Endpoints = eps + default: + return nil, &ParseError{"Invalid key for peer section", key} + } + } + } + conf.maybeAddPeer(peer) + + return &conf, nil +} diff --git a/wgcfg/parser_test.go b/wgcfg/parser_test.go new file mode 100644 index 000000000..d0df53739 --- /dev/null +++ b/wgcfg/parser_test.go @@ -0,0 +1,127 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "reflect" + "runtime" + "testing" +) + +const testInput = ` +[Interface] +Address = 10.192.122.1/24 +Address = 10.10.0.1/16 +PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= +ListenPort = 51820 #comments don't matter + +[Peer] +PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= +Endpoint = 192.95.5.67:1234 +AllowedIPs = 10.192.122.3/32, 10.192.124.1/24 + +[Peer] +PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = [2607:5300:60:6b0::c05f:543]:2468 +AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 +PersistentKeepalive = 100 + +[Peer] +PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= +PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = test.wireguard.com:18981 +AllowedIPs = 10.10.10.230/32` + +func noError(t *testing.T, err error) bool { + if err == nil { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error at %s:%d: %#v", fn, line, err) + return false +} + +func equal(t *testing.T, expected, actual interface{}) bool { + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func lenTest(t *testing.T, actualO interface{}, expected int) bool { + actual := reflect.ValueOf(actualO).Len() + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func contains(t *testing.T, list, element interface{}) bool { + listValue := reflect.ValueOf(list) + for i := 0; i < listValue.Len(); i++ { + if reflect.DeepEqual(listValue.Index(i).Interface(), element) { + return true + } + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element) + return false +} + +func TestFromWgQuick(t *testing.T) { + conf, err := FromWgQuick(testInput, "test") + if noError(t, err) { + + lenTest(t, conf.Addresses, 2) + contains(t, conf.Addresses, CIDR{IPv4(10, 10, 0, 1), 16}) + contains(t, conf.Addresses, CIDR{IPv4(10, 192, 122, 1), 24}) + equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.PrivateKey.String()) + equal(t, uint16(51820), conf.ListenPort) + + lenTest(t, conf.Peers, 3) + lenTest(t, conf.Peers[0].AllowedIPs, 2) + equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoints[0]) + equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.Base64()) + + lenTest(t, conf.Peers[1].AllowedIPs, 2) + equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoints[0]) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.Base64()) + equal(t, uint16(100), conf.Peers[1].PersistentKeepalive) + + lenTest(t, conf.Peers[2].AllowedIPs, 1) + equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoints[0]) + equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.Base64()) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.Base64()) + } +} + +func TestParseEndpoint(t *testing.T) { + _, err := parseEndpoint("[192.168.42.0:]:51880") + if err == nil { + t.Error("Error was expected") + } + e, err := parseEndpoint("192.168.42.0:51880") + if noError(t, err) { + equal(t, "192.168.42.0", e.Host) + equal(t, uint16(51880), e.Port) + } + e, err = parseEndpoint("test.wireguard.com:18981") + if noError(t, err) { + equal(t, "test.wireguard.com", e.Host) + equal(t, uint16(18981), e.Port) + } + e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") + if noError(t, err) { + equal(t, "2607:5300:60:6b0::c05f:543", e.Host) + equal(t, uint16(2468), e.Port) + } + _, err = parseEndpoint("[::::::invalid:18981") + if err == nil { + t.Error("Error was expected") + } +} diff --git a/wgcfg/writer.go b/wgcfg/writer.go new file mode 100644 index 000000000..246a57d0b --- /dev/null +++ b/wgcfg/writer.go @@ -0,0 +1,73 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "fmt" + "net" + "strings" +) + +func (conf *Config) ToUAPI() (string, error) { + output := new(strings.Builder) + fmt.Fprintf(output, "private_key=%s\n", conf.PrivateKey.HexString()) + + if conf.ListenPort > 0 { + fmt.Fprintf(output, "listen_port=%d\n", conf.ListenPort) + } + + output.WriteString("replace_peers=true\n") + + for _, peer := range conf.Peers { + fmt.Fprintf(output, "public_key=%s\n", peer.PublicKey.HexString()) + fmt.Fprintf(output, "protocol_version=1\n") + fmt.Fprintf(output, "replace_allowed_ips=true\n") + + if !peer.PresharedKey.IsZero() { + fmt.Fprintf(output, "preshared_key = %s\n", peer.PresharedKey.String()) + } + + if len(peer.AllowedIPs) > 0 { + for _, address := range peer.AllowedIPs { + fmt.Fprintf(output, "allowed_ip=%s\n", address.String()) + } + } + + if len(peer.Endpoints) > 0 { + var reps []string + for _, ep := range peer.Endpoints { + ips, err := net.LookupIP(ep.Host) + if err != nil { + return "", err + } + var ip net.IP + for _, iterip := range ips { + if ip4 := iterip.To4(); ip4 != nil { + ip = ip4 + break + } + if ip == nil { + ip = iterip + } + } + if ip == nil { + return "", fmt.Errorf("unable to resolve IP address of endpoint %q (%v)", ep.Host, ips) + } + resolvedEndpoint := Endpoint{ip.String(), ep.Port} + reps = append(reps, resolvedEndpoint.String()) + } + fmt.Fprintf(output, "endpoint=%s\n", strings.Join(reps, ",")) + } else { + fmt.Fprint(output, "endpoint=\n") + } + + // Note: this needs to come *after* endpoint definitions, + // because setting it will trigger a handshake to all + // already-defined endpoints. + fmt.Fprintf(output, "persistent_keepalive_interval=%d\n", peer.PersistentKeepalive) + } + return output.String(), nil +}