Skip to content

Commit

Permalink
tuntest: split out testing package
Browse files Browse the repository at this point in the history
This code is useful to other packages writing tests.

Signed-off-by: David Crawshaw <[email protected]>
  • Loading branch information
crawshaw committed Apr 1, 2020
1 parent d127a16 commit e852f4c
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 141 deletions.
146 changes: 5 additions & 141 deletions device/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@ package device
import (
"bufio"
"bytes"
"encoding/binary"
"io"
"net"
"os"
"strings"
"testing"
"time"

"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
)

func TestTwoDevicePing(t *testing.T) {
Expand All @@ -29,7 +26,7 @@ protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
tun1 := NewChannelTUN()
tun1 := tuntest.NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
dev1.Up()
defer dev1.Close()
Expand All @@ -45,7 +42,7 @@ protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
tun2 := NewChannelTUN()
tun2 := tuntest.NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
Expand All @@ -54,7 +51,7 @@ endpoint=127.0.0.1:53511`
}

t.Run("ping 1.0.0.1", func(t *testing.T) {
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
Expand All @@ -67,7 +64,7 @@ endpoint=127.0.0.1:53511`
})

t.Run("ping 1.0.0.2", func(t *testing.T) {
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
Expand All @@ -80,139 +77,6 @@ endpoint=127.0.0.1:53511`
})
}

func ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)

payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)

return genICMPv4(payload, dst, src)
}

// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}

func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)

hdr := make([]byte, ipv4Size+icmpv4Size)

ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]

// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)

// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)

var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}

// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close

closed chan struct{}
events chan tun.Event
tun chTun
}

func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}

func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}

type chTun struct {
c *ChannelTUN
}

func (t *chTun) File() *os.File { return nil }

func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}

// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}

func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
}

func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
Expand Down
150 changes: 150 additions & 0 deletions tun/tuntest/tuntest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/

package tuntest

import (
"encoding/binary"
"io"
"net"
"os"

"golang.zx2c4.com/wireguard/tun"
)

func Ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)

payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)

return genICMPv4(payload, dst, src)
}

// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}

func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)

hdr := make([]byte, ipv4Size+icmpv4Size)

ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]

// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)

// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)

var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}

// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close

closed chan struct{}
events chan tun.Event
tun chTun
}

func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}

func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}

type chTun struct {
c *ChannelTUN
}

func (t *chTun) File() *os.File { return nil }

func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}

// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}

const DefaultMTU = 1420

func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
}

0 comments on commit e852f4c

Please sign in to comment.