diff --git a/tun/checksum.go b/tun/checksum.go index 29a8fc8fc..2c3a701d8 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -1,102 +1,85 @@ package tun -import "encoding/binary" +import ( + "encoding/binary" + "math/bits" +) // TODO: Explore SIMD and/or other assembly optimizations. -// TODO: Test native endian loads. See RFC 1071 section 2 part B. func checksumNoFold(b []byte, initial uint64) uint64 { - ac := initial + tmp := make([]byte, 8) + binary.NativeEndian.PutUint64(tmp, initial) + ac := binary.BigEndian.Uint64(tmp) + var carry uint64 for len(b) >= 128 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) - ac += uint64(binary.BigEndian.Uint32(b[32:36])) - ac += uint64(binary.BigEndian.Uint32(b[36:40])) - ac += uint64(binary.BigEndian.Uint32(b[40:44])) - ac += uint64(binary.BigEndian.Uint32(b[44:48])) - ac += uint64(binary.BigEndian.Uint32(b[48:52])) - ac += uint64(binary.BigEndian.Uint32(b[52:56])) - ac += uint64(binary.BigEndian.Uint32(b[56:60])) - ac += uint64(binary.BigEndian.Uint32(b[60:64])) - ac += uint64(binary.BigEndian.Uint32(b[64:68])) - ac += uint64(binary.BigEndian.Uint32(b[68:72])) - ac += uint64(binary.BigEndian.Uint32(b[72:76])) - ac += uint64(binary.BigEndian.Uint32(b[76:80])) - ac += uint64(binary.BigEndian.Uint32(b[80:84])) - ac += uint64(binary.BigEndian.Uint32(b[84:88])) - ac += uint64(binary.BigEndian.Uint32(b[88:92])) - ac += uint64(binary.BigEndian.Uint32(b[92:96])) - ac += uint64(binary.BigEndian.Uint32(b[96:100])) - ac += uint64(binary.BigEndian.Uint32(b[100:104])) - ac += uint64(binary.BigEndian.Uint32(b[104:108])) - ac += uint64(binary.BigEndian.Uint32(b[108:112])) - ac += uint64(binary.BigEndian.Uint32(b[112:116])) - ac += uint64(binary.BigEndian.Uint32(b[116:120])) - ac += uint64(binary.BigEndian.Uint32(b[120:124])) - ac += uint64(binary.BigEndian.Uint32(b[124:128])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry) + ac += carry b = b[128:] } if len(b) >= 64 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) - ac += uint64(binary.BigEndian.Uint32(b[32:36])) - ac += uint64(binary.BigEndian.Uint32(b[36:40])) - ac += uint64(binary.BigEndian.Uint32(b[40:44])) - ac += uint64(binary.BigEndian.Uint32(b[44:48])) - ac += uint64(binary.BigEndian.Uint32(b[48:52])) - ac += uint64(binary.BigEndian.Uint32(b[52:56])) - ac += uint64(binary.BigEndian.Uint32(b[56:60])) - ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry) + ac += carry b = b[64:] } if len(b) >= 32 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac += carry b = b[32:] } if len(b) >= 16 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac += carry b = b[16:] } if len(b) >= 8 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac += carry b = b[8:] } if len(b) >= 4 { - ac += uint64(binary.BigEndian.Uint32(b)) + ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0) + ac += carry b = b[4:] } if len(b) >= 2 { - ac += uint64(binary.BigEndian.Uint16(b)) + ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0) + ac += carry b = b[2:] } if len(b) == 1 { - ac += uint64(b[0]) << 8 + ac, carry = bits.Add64(ac, uint64(b[0]), 0) + ac += carry } - return ac + binary.NativeEndian.PutUint64(tmp, ac) + return binary.BigEndian.Uint64(tmp) } func checksum(b []byte, initial uint64) uint16 { diff --git a/tun/checksum_test.go b/tun/checksum_test.go index c1ccff531..202c83475 100644 --- a/tun/checksum_test.go +++ b/tun/checksum_test.go @@ -1,26 +1,196 @@ package tun import ( + "encoding/binary" "fmt" "math/rand" "testing" + + "golang.org/x/sys/unix" ) -func BenchmarkChecksum(b *testing.B) { - lengths := []int{ - 64, - 128, - 256, - 512, - 1024, - 1500, - 2048, - 4096, - 8192, - 9000, - 9001, +func checksumRef(b []byte, initial uint64) uint16 { + ac := initial + + for len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + + for (ac >> 16) > 0 { + ac = (ac >> 16) + (ac & 0xffff) + } + return uint16(ac) +} + +func checksumOldNoFold(b []byte, initial uint64) uint64 { + ac := initial + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] + } + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] + } + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] + } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] + } + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] + } + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] + } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + + return ac +} + +func checksumOld(b []byte, initial uint64) uint16 { + ac := checksumOldNoFold(b, initial) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + return uint16(ac) +} + +func pseudoHeaderChecksumOldNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { + sum := checksumOldNoFold(srcAddr, 0) + sum = checksumOldNoFold(dstAddr, sum) + sum = checksumOldNoFold([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumOldNoFold(tmp, sum) +} + +func TestChecksum(t *testing.T) { + for length := 0; length <= 9001; length++ { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + csum := checksum(buf, 0x1234) + csumRef := checksumRef(buf, 0x1234) + csumOld := checksumOld(buf, 0x1234) + if csum != csumRef { + t.Error("Expected checksum", csumRef, "got", csum) + } else if csum != csumOld { + t.Error("Expected checksumOld", csumOld, "got", csum) + } + } +} + +func TestPseudoHeaderChecksum(t *testing.T) { + for _, addrLen := range []int{4, 16} { + for length := 0; length <= 9001; length++ { + srcAddr := make([]byte, addrLen) + dstAddr := make([]byte, addrLen) + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(srcAddr) + rng.Read(dstAddr) + rng.Read(buf) + phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length)) + csum := checksum(buf, phSum) + phSumOld := pseudoHeaderChecksumOldNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length)) + csumOld := checksumOld(buf, phSumOld) + if csum != csumOld { + t.Error("Expected checksumOld", csumOld, "got", csum) + } + } } +} +var lengths = [...]int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, +} + +func BenchmarkChecksum(b *testing.B) { for _, length := range lengths { b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { buf := make([]byte, length) @@ -33,3 +203,17 @@ func BenchmarkChecksum(b *testing.B) { }) } } + +func BenchmarkChecksumOld(b *testing.B) { + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksumOld(buf, 0) + } + }) + } +}