Skip to content

Commit

Permalink
tun: use add-with-carry in checksumNoFold()
Browse files Browse the repository at this point in the history
Signed-off-by: Tu Dinh Ngoc <[email protected]>
  • Loading branch information
Tu Dinh Ngoc committed Jul 13, 2024
1 parent 12269c2 commit 20d25be
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 82 deletions.
121 changes: 52 additions & 69 deletions tun/checksum.go
Original file line number Diff line number Diff line change
@@ -1,102 +1,85 @@
package tun

import "encoding/binary"
import (
"encoding/binary"
"math/bits"
)

// TODO: Explore SIMD and/or other assembly optimizations.
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
tmp := make([]byte, 8)
binary.NativeEndian.PutUint64(tmp, initial)
ac := binary.BigEndian.Uint64(tmp)
var carry uint64

for len(b) >= 128 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
ac += carry
b = b[128:]
}
if len(b) >= 64 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
ac += carry
b = b[64:]
}
if len(b) >= 32 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac += carry
b = b[32:]
}
if len(b) >= 16 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac += carry
b = b[16:]
}
if len(b) >= 8 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac += carry
b = b[8:]
}
if len(b) >= 4 {
ac += uint64(binary.BigEndian.Uint32(b))
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
ac += carry
b = b[4:]
}
if len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
ac += carry
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
ac, carry = bits.Add64(ac, uint64(b[0]), 0)
ac += carry
}

return ac
binary.NativeEndian.PutUint64(tmp, ac)
return binary.BigEndian.Uint64(tmp)
}

func checksum(b []byte, initial uint64) uint16 {
Expand Down
210 changes: 197 additions & 13 deletions tun/checksum_test.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,196 @@
package tun

import (
"encoding/binary"
"fmt"
"math/rand"
"testing"

"golang.org/x/sys/unix"
)

func BenchmarkChecksum(b *testing.B) {
lengths := []int{
64,
128,
256,
512,
1024,
1500,
2048,
4096,
8192,
9000,
9001,
func checksumRef(b []byte, initial uint64) uint16 {
ac := initial

for len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
}

for (ac >> 16) > 0 {
ac = (ac >> 16) + (ac & 0xffff)
}
return uint16(ac)
}

func checksumOldNoFold(b []byte, initial uint64) uint64 {
ac := initial

for len(b) >= 128 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
b = b[128:]
}
if len(b) >= 64 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
b = b[64:]
}
if len(b) >= 32 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
b = b[32:]
}
if len(b) >= 16 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
b = b[16:]
}
if len(b) >= 8 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
b = b[8:]
}
if len(b) >= 4 {
ac += uint64(binary.BigEndian.Uint32(b))
b = b[4:]
}
if len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
}

return ac
}

func checksumOld(b []byte, initial uint64) uint16 {
ac := checksumOldNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}

func pseudoHeaderChecksumOldNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumOldNoFold(srcAddr, 0)
sum = checksumOldNoFold(dstAddr, sum)
sum = checksumOldNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumOldNoFold(tmp, sum)
}

func TestChecksum(t *testing.T) {
for length := 0; length <= 9001; length++ {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
csum := checksum(buf, 0x1234)
csumRef := checksumRef(buf, 0x1234)
csumOld := checksumOld(buf, 0x1234)
if csum != csumRef {
t.Error("Expected checksum", csumRef, "got", csum)
} else if csum != csumOld {
t.Error("Expected checksumOld", csumOld, "got", csum)
}
}
}

func TestPseudoHeaderChecksum(t *testing.T) {
for _, addrLen := range []int{4, 16} {
for length := 0; length <= 9001; length++ {
srcAddr := make([]byte, addrLen)
dstAddr := make([]byte, addrLen)
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(srcAddr)
rng.Read(dstAddr)
rng.Read(buf)
phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csum := checksum(buf, phSum)
phSumOld := pseudoHeaderChecksumOldNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csumOld := checksumOld(buf, phSumOld)
if csum != csumOld {
t.Error("Expected checksumOld", csumOld, "got", csum)
}
}
}
}

var lengths = [...]int{
64,
128,
256,
512,
1024,
1500,
2048,
4096,
8192,
9000,
9001,
}

func BenchmarkChecksum(b *testing.B) {
for _, length := range lengths {
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
buf := make([]byte, length)
Expand All @@ -33,3 +203,17 @@ func BenchmarkChecksum(b *testing.B) {
})
}
}

func BenchmarkChecksumOld(b *testing.B) {
for _, length := range lengths {
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksumOld(buf, 0)
}
})
}
}

0 comments on commit 20d25be

Please sign in to comment.