Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/overlapping tcp segments #1898

Merged
merged 3 commits into from
Jun 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ https://github.com/elastic/beats/compare/v5.0.0-alpha3...master[Check the HEAD d
*Packetbeat*
- Add missing nil-check to memcached GapInStream handler. {issue}1162[1162]
- Fix NFSv4 Operation returning the first found first-class operation available in compound requests. {pull}1821[1821]
- Fix TCP overlapping segments not being handled correctly. {pull}1898[1898]

*Topbeat*

Expand Down
89 changes: 66 additions & 23 deletions packetbeat/protos/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ type Processor interface {
Process(flow *flows.FlowID, hdr *layers.TCP, pkt *protos.Packet)
}

type seqCompare int

const (
seqLT seqCompare = -1
seqEq seqCompare = 0
seqGT seqCompare = 1
)

var (
debugf = logp.MakeDebug("tcp")
isDebug = false
Expand Down Expand Up @@ -119,55 +127,78 @@ func (tcp *Tcp) Process(id *flows.FlowID, tcphdr *layers.TCP, pkt *protos.Packet
// protocol modules.
defer logp.Recover("Process tcp exception")

debugf("tcp flow id: %p", id)

stream, created := tcp.getStream(pkt)
if stream.conn == nil {
return
}

conn := stream.conn
if id != nil {
id.AddConnectionID(uint64(stream.conn.id))
id.AddConnectionID(uint64(conn.id))
}

if isDebug {
debugf("tcp flow id: %p", id)
}

if len(pkt.Payload) == 0 && !tcphdr.FIN {
// return early if packet is not interesting. Still need to find/create
// stream first in order to update the TCP stream timer
return
}
conn := stream.conn

tcp_start_seq := tcphdr.Seq
tcp_seq := tcp_start_seq + uint32(len(pkt.Payload))
tcpStartSeq := tcphdr.Seq
tcpSeq := tcpStartSeq + uint32(len(pkt.Payload))
lastSeq := conn.lastSeq[stream.dir]
if isDebug {
debugf("pkt.start_seq=%v pkt.last_seq=%v stream.last_seq=%v (len=%d)",
tcp_start_seq, tcp_seq, lastSeq, len(pkt.Payload))
tcpStartSeq, tcpSeq, lastSeq, len(pkt.Payload))
}

if len(pkt.Payload) > 0 && lastSeq != 0 {
if tcpSeqBeforeEq(tcp_seq, lastSeq) {
if tcpSeqBeforeEq(tcpSeq, lastSeq) {
if isDebug {
debugf("Ignoring retransmitted segment. pkt.seq=%v len=%v stream.seq=%v",
tcphdr.Seq, len(pkt.Payload), lastSeq)
}
return
}

if tcpSeqBefore(lastSeq, tcp_start_seq) {
if !created {
gap := int(tcp_start_seq - lastSeq)
logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcp_start_seq, gap)
drop := stream.gapInStream(gap)
if drop {
if isDebug {
debugf("Dropping connection state because of gap")
}

// drop application layer connection state and
// update stream_id for app layer analysers using stream_id for lookups
conn.id = tcp.getId()
conn.data = nil
switch tcpSeqCompare(lastSeq, tcpStartSeq) {
case seqLT: // lastSeq < tcpStartSeq => Gap in tcp stream detected
if created {
break
}

gap := int(tcpStartSeq - lastSeq)
logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcpStartSeq, gap)
drop := stream.gapInStream(gap)
if drop {
if isDebug {
debugf("Dropping connection state because of gap")
}

// drop application layer connection state and
// update stream_id for app layer analysers using stream_id for lookups
conn.id = tcp.getId()
conn.data = nil
}

case seqGT:
// lastSeq > tcpStartSeq => overlapping TCP segment detected. shrink packet
delta := lastSeq - tcpStartSeq

if isDebug {
debugf("Overlapping tcp segment. last_seq %d, seq: %d, delta: %d",
lastSeq, tcpStartSeq, delta)
}

pkt.Payload = pkt.Payload[delta:]
tcphdr.Seq += delta
}
}

conn.lastSeq[stream.dir] = tcp_seq
conn.lastSeq[stream.dir] = tcpSeq
stream.addPacket(pkt, tcphdr)
}

Expand Down Expand Up @@ -209,6 +240,18 @@ func (tcp *Tcp) getStream(pkt *protos.Packet) (stream TcpStream, created bool) {
return TcpStream{conn: conn, dir: TcpDirectionOriginal}, true
}

func tcpSeqCompare(seq1, seq2 uint32) seqCompare {
i := int32(seq1 - seq2)
switch {
case i == 0:
return seqEq
case i < 0:
return seqLT
default:
return seqGT
}
}

func tcpSeqBefore(seq1 uint32, seq2 uint32) bool {
return int32(seq1-seq2) < 0
}
Expand Down
189 changes: 155 additions & 34 deletions packetbeat/protos/tcp/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package tcp

import (
"fmt"
"math/rand"
"net"
"testing"
Expand Down Expand Up @@ -186,45 +185,128 @@ func (p protocols) GetAllTcp() map[protos.Protocol]protos.TcpPlugin { retur
func (p protocols) GetAllUdp() map[protos.Protocol]protos.UdpPlugin { return nil }
func (p protocols) Register(proto protos.Protocol, plugin protos.Plugin) { return }

func TestGapInStreamShouldDropState(t *testing.T) {
gap := 0
var state []byte

data1 := []byte{1, 2, 3, 4}
data2 := []byte{5, 6, 7, 8}

tp := &TestProtocol{Ports: []int{ServerPort}}
tp.gap = func(t *common.TcpTuple, d uint8, n int, p protos.ProtocolData) (protos.ProtocolData, bool) {
fmt.Printf("lost: %v\n", n)
gap += n
return p, true // drop state
}
tp.parse = func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData {
if priv == nil {
state = nil
}
state = append(state, p.Payload...)
return state
func TestTCSeqPayload(t *testing.T) {
type segment struct {
seq uint32
payload []byte
}

p := protocols{}
p.tcp = map[protos.Protocol]protos.TcpPlugin{
httpProtocol: tp,
tests := []struct {
name string
segments []segment
expectedGaps int
expectedState []byte
}{
{"No overlap",
[]segment{
{1, []byte{1, 2, 3, 4, 5}},
{6, []byte{6, 7, 8, 9, 10}},
},
0,
[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
{"Gap drop state",
[]segment{
{1, []byte{1, 2, 3, 4}},
{15, []byte{5, 6, 7, 8}},
},
10,
[]byte{5, 6, 7, 8},
},
{"ACK same sequence number",
[]segment{
{1, []byte{1, 2}},
{3, nil},
{3, []byte{3, 4}},
{5, []byte{5, 6}},
},
0,
[]byte{1, 2, 3, 4, 5, 6},
},
{"ACK same sequence number 2",
[]segment{
{1, nil},
{2, nil},
{2, []byte{1, 2}},
{4, nil},
{4, []byte{3, 4}},
{6, []byte{5, 6}},
{8, []byte{7, 8}},
{10, nil},
},
0,
[]byte{1, 2, 3, 4, 5, 6, 7, 8},
},
{"Overlap, first segment bigger",
[]segment{
{1, []byte{1, 2}},
{3, []byte{3, 4}},
{3, []byte{3}},
{5, []byte{5, 6}},
},
0,
[]byte{1, 2, 3, 4, 5, 6},
},
{"Overlap, second segment bigger",
[]segment{
{1, []byte{1, 2}},
{3, []byte{3}},
{3, []byte{3, 4}},
{5, []byte{5, 6}},
},
0,
[]byte{1, 2, 3, 4, 5, 6},
},
{"Overlap, covered",
[]segment{
{1, []byte{1, 2, 3, 4}},
{2, []byte{2, 3}},
{5, []byte{5, 6}},
},
0,
[]byte{1, 2, 3, 4, 5, 6},
},
}
tcp, _ := NewTcp(p)

addr := common.NewIpPortTuple(4,
net.ParseIP(ServerIp), ServerPort,
net.ParseIP(ClientIp), uint16(rand.Intn(65535)))
for i, test := range tests {
t.Logf("Test (%v): %v", i, test.name)

gap := 0
var state []byte
tcp, err := NewTcp(protocols{
tcp: map[protos.Protocol]protos.TcpPlugin{
httpProtocol: &TestProtocol{
Ports: []int{ServerPort},
gap: makeCountGaps(nil, &gap),
parse: makeCollectPayload(&state, true),
},
},
})
if err != nil {
t.Fatal(err)
}

hdr := &layers.TCP{}
tcp.Process(nil, hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data1})
hdr.Seq += uint32(len(data1) + 10)
tcp.Process(nil, hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data2})
addr := common.NewIpPortTuple(4,
net.ParseIP(ServerIp), ServerPort,
net.ParseIP(ClientIp), uint16(rand.Intn(65535)))

for _, segment := range test.segments {
hdr := &layers.TCP{Seq: segment.seq}
pkt := &protos.Packet{
Ts: time.Now(),
Tuple: addr,
Payload: segment.payload,
}
tcp.Process(nil, hdr, pkt)
}

// validate
assert.Equal(t, 10, gap)
assert.Equal(t, data2, state)
assert.Equal(t, test.expectedGaps, gap)
if len(test.expectedState) != len(state) {
assert.Equal(t, len(test.expectedState), len(state))
continue
}
assert.Equal(t, test.expectedState, state)
}
}

// Benchmark that runs with parallelism to help find concurrency related
Expand All @@ -251,3 +333,42 @@ func BenchmarkParallelProcess(b *testing.B) {
}
})
}

func makeCountGaps(
counter *int,
bytes *int,
) func(*common.TcpTuple, uint8, int, protos.ProtocolData) (protos.ProtocolData, bool) {
return func(
t *common.TcpTuple,
d uint8,
n int,
p protos.ProtocolData,
) (protos.ProtocolData, bool) {
if counter != nil {
(*counter)++
}
if bytes != nil {
*bytes += n
}

return p, true // drop state
}
}

func makeCollectPayload(
state *[]byte,
resetOnNil bool,
) func(*protos.Packet, *common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData {
return func(
p *protos.Packet,
t *common.TcpTuple,
d uint8,
priv protos.ProtocolData,
) protos.ProtocolData {
if resetOnNil && priv == nil {
(*state) = nil
}
*state = append(*state, p.Payload...)
return *state
}
}