diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go index e807f0077..4912efd3f 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/tcp_offload_linux.go @@ -397,9 +397,6 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) if totalLen != len(pkt) { return false } - if iphLen < 20 || iphLen > 60 { - return false - } } if len(pkt) < iphLen { return false @@ -474,13 +471,16 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) return false } -func isTCP4(b []byte) bool { +func isTCP4NoIPOptions(b []byte) bool { if len(b) < 40 { return false } if b[0]>>4 != 4 { return false } + if b[0]&0x0F != 5 { + return false + } if b[9] != unix.IPPROTO_TCP { return false } @@ -511,7 +511,7 @@ func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toW } var coalesced bool switch { - case isTCP4(bufs[i][offset:]): + case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce coalesced = tcpGRO(bufs, offset, i, tcp4Table, false) case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce coalesced = tcpGRO(bufs, offset, i, tcp6Table, true) diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go index 11f9e53b5..046e177e0 100644 --- a/tun/tcp_offload_linux_test.go +++ b/tun/tcp_offload_linux_test.go @@ -271,3 +271,53 @@ func Test_handleGRO(t *testing.T) { }) } } + +func Test_isTCP4NoIPOptions(t *testing.T) { + valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + invalidLen := valid[:39] + invalidHeaderLen := make([]byte, len(valid)) + copy(invalidHeaderLen, valid) + invalidHeaderLen[0] = 0x46 + invalidProtocol := make([]byte, len(valid)) + copy(invalidProtocol, valid) + invalidProtocol[9] = unix.IPPROTO_TCP + 1 + + tests := []struct { + name string + b []byte + want bool + }{ + { + "valid", + valid, + true, + }, + { + "invalid length", + invalidLen, + false, + }, + { + "invalid version", + []byte{0x00}, + false, + }, + { + "invalid header len", + invalidHeaderLen, + false, + }, + { + "invalid protocol", + invalidProtocol, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isTCP4NoIPOptions(tt.b); got != tt.want { + t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want) + } + }) + } +}