Skip to content

Commit

Permalink
Refactor scan results
Browse files Browse the repository at this point in the history
  • Loading branch information
v-byte-cpu committed Mar 17, 2021
1 parent bff65ea commit d0acd20
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 99 deletions.
8 changes: 4 additions & 4 deletions command/log/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import (
"io"
"time"

"github.com/v-byte-cpu/sx/pkg/scan/arp"
"github.com/v-byte-cpu/sx/pkg/scan"
"go.uber.org/zap"
)

type Logger interface {
Error(err error)
LogResults(results <-chan *arp.ScanResult)
LogResults(results <-chan scan.Result)
}

type FlushWriter interface {
Expand All @@ -20,7 +20,7 @@ type FlushWriter interface {
}

type ResultWriter interface {
Write(w io.Writer, result *arp.ScanResult) error
Write(w io.Writer, result scan.Result) error
}

type logger struct {
Expand Down Expand Up @@ -75,7 +75,7 @@ func (l *logger) Error(err error) {
l.zapl.Error(l.label, zap.Error(err))
}

func (l *logger) LogResults(results <-chan *arp.ScanResult) {
func (l *logger) LogResults(results <-chan scan.Result) {
bw := bufio.NewWriter(l.w)
defer bw.Flush()
var err error
Expand Down
7 changes: 4 additions & 3 deletions command/log/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/v-byte-cpu/sx/pkg/scan"
"github.com/v-byte-cpu/sx/pkg/scan/arp"
)

func scanResultToJSON(t *testing.T, result *arp.ScanResult) string {
func scanResultToJSON(t *testing.T, result scan.Result) string {
t.Helper()
data, err := result.MarshalJSON()
require.NoError(t, err)
Expand Down Expand Up @@ -58,7 +59,7 @@ func TestJSONLoggerResults(t *testing.T) {
logger, err := NewLogger(&buf, "arp", JSON())
require.NoError(t, err)

resultCh := make(chan *arp.ScanResult, len(tt.results))
resultCh := make(chan scan.Result, len(tt.results))
for _, result := range tt.results {
resultCh <- result
}
Expand Down Expand Up @@ -118,7 +119,7 @@ func TestPlainLoggerResults(t *testing.T) {
logger, err := NewLogger(&buf, "arp", Plain())
require.NoError(t, err)

resultCh := make(chan *arp.ScanResult, len(tt.results))
resultCh := make(chan scan.Result, len(tt.results))
for _, result := range tt.results {
resultCh <- result
}
Expand Down
8 changes: 4 additions & 4 deletions command/log/unique_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package log
import (
"context"

"github.com/v-byte-cpu/sx/pkg/scan/arp"
"github.com/v-byte-cpu/sx/pkg/scan"
)

type UniqueLogger struct {
Expand All @@ -19,12 +19,12 @@ func (l *UniqueLogger) Error(err error) {
l.logger.Error(err)
}

func (l *UniqueLogger) LogResults(results <-chan *arp.ScanResult) {
func (l *UniqueLogger) LogResults(results <-chan scan.Result) {
l.logger.LogResults(l.uniqResults(results))
}

func (l *UniqueLogger) uniqResults(in <-chan *arp.ScanResult) <-chan *arp.ScanResult {
results := make(chan *arp.ScanResult, cap(in))
func (l *UniqueLogger) uniqResults(in <-chan scan.Result) <-chan scan.Result {
results := make(chan scan.Result, cap(in))
go func() {
defer close(results)
var member struct{}
Expand Down
14 changes: 7 additions & 7 deletions command/log/unique_logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/v-byte-cpu/sx/pkg/scan/arp"
"github.com/v-byte-cpu/sx/pkg/scan"
)

func TestUniqueLoggerResults(t *testing.T) {
Expand All @@ -18,7 +18,7 @@ func TestUniqueLoggerResults(t *testing.T) {
tests := []struct {
name string
expected []byte
results []*arp.ScanResult
results []scan.Result
}{
{
name: "emptyResults",
Expand All @@ -28,7 +28,7 @@ func TestUniqueLoggerResults(t *testing.T) {
{
name: "oneResult",
expected: []byte(newScanResult(net.IPv4(192, 168, 0, 3).To4()).String() + "\n"),
results: []*arp.ScanResult{
results: []scan.Result{
newScanResult(net.IPv4(192, 168, 0, 3).To4()),
},
},
Expand All @@ -38,15 +38,15 @@ func TestUniqueLoggerResults(t *testing.T) {
newScanResult(net.IPv4(192, 168, 0, 3).To4()).String(),
newScanResult(net.IPv4(192, 168, 0, 5).To4()).String(),
}, "\n") + "\n"),
results: []*arp.ScanResult{
results: []scan.Result{
newScanResult(net.IPv4(192, 168, 0, 3).To4()),
newScanResult(net.IPv4(192, 168, 0, 5).To4()),
},
},
{
name: "twoEqualResults",
expected: []byte(newScanResult(net.IPv4(192, 168, 0, 3).To4()).String() + "\n"),
results: []*arp.ScanResult{
results: []scan.Result{
newScanResult(net.IPv4(192, 168, 0, 3).To4()),
newScanResult(net.IPv4(192, 168, 0, 3).To4()),
},
Expand All @@ -57,7 +57,7 @@ func TestUniqueLoggerResults(t *testing.T) {
newScanResult(net.IPv4(192, 168, 0, 3).To4()).String(),
newScanResult(net.IPv4(192, 168, 0, 5).To4()).String(),
}, "\n") + "\n"),
results: []*arp.ScanResult{
results: []scan.Result{
newScanResult(net.IPv4(192, 168, 0, 3).To4()),
newScanResult(net.IPv4(192, 168, 0, 5).To4()),
newScanResult(net.IPv4(192, 168, 0, 3).To4()),
Expand All @@ -73,7 +73,7 @@ func TestUniqueLoggerResults(t *testing.T) {
require.NoError(t, err)
logger := NewUniqueLogger(context.Background(), plainLogger)

resultCh := make(chan *arp.ScanResult, len(tt.results))
resultCh := make(chan scan.Result, len(tt.results))
for _, result := range tt.results {
resultCh <- result
}
Expand Down
4 changes: 2 additions & 2 deletions command/log/writer_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"fmt"
"io"

"github.com/v-byte-cpu/sx/pkg/scan/arp"
"github.com/v-byte-cpu/sx/pkg/scan"
)

type JSONResultWriter struct{}

func (*JSONResultWriter) Write(w io.Writer, result *arp.ScanResult) error {
func (*JSONResultWriter) Write(w io.Writer, result scan.Result) error {
data, err := result.MarshalJSON()
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions command/log/writer_plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"fmt"
"io"

"github.com/v-byte-cpu/sx/pkg/scan/arp"
"github.com/v-byte-cpu/sx/pkg/scan"
)

type PlainResultWriter struct{}

func (*PlainResultWriter) Write(w io.Writer, result *arp.ScanResult) error {
func (*PlainResultWriter) Write(w io.Writer, result scan.Result) error {
_, err := fmt.Fprintf(w, "%s\n", result.String())
return err
}
51 changes: 13 additions & 38 deletions pkg/scan/arp/arp.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ import (
)

type ScanMethod struct {
reqgen scan.RequestGenerator
pktgen *scan.PacketMultiGenerator
parser *gopacket.DecodingLayerParser
results chan *ScanResult
internalResults chan *ScanResult
ctx context.Context
reqgen scan.RequestGenerator
pktgen *scan.PacketMultiGenerator
parser *gopacket.DecodingLayerParser
results *scan.ResultChan
ctx context.Context

rcvDecoded []gopacket.LayerType
rcvEth layers.Ethernet
Expand Down Expand Up @@ -57,32 +56,11 @@ func LiveMode(rescanTimeout time.Duration) ScanMethodOption {
}

func NewScanMethod(ctx context.Context, opts ...ScanMethodOption) *ScanMethod {
results := make(chan *ScanResult, 1000)
internalResults := make(chan *ScanResult, 1000)

copyChans := func() {
defer close(results)
for {
select {
case <-ctx.Done():
return
case v := <-internalResults:
select {
case <-ctx.Done():
return
case results <- v:
}
}
}
}
go copyChans()

sm := &ScanMethod{
ctx: ctx,
results: results,
internalResults: internalResults,
reqgen: scan.RequestGeneratorFunc(scan.Requests),
pktgen: scan.NewPacketMultiGenerator(newPacketFiller(), runtime.NumCPU()),
ctx: ctx,
results: scan.NewResultChan(ctx, 1000),
reqgen: scan.RequestGeneratorFunc(scan.Requests),
pktgen: scan.NewPacketMultiGenerator(newPacketFiller(), runtime.NumCPU()),
}
parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &sm.rcvEth, &sm.rcvARP)
parser.IgnoreUnsupported = true
Expand All @@ -94,8 +72,8 @@ func NewScanMethod(ctx context.Context, opts ...ScanMethodOption) *ScanMethod {
return sm
}

func (s *ScanMethod) Results() <-chan *ScanResult {
return s.results
func (s *ScanMethod) Results() <-chan scan.Result {
return s.results.Chan()
}

func (s *ScanMethod) Packets(ctx context.Context, r *scan.Range) <-chan *packet.BufferData {
Expand Down Expand Up @@ -126,14 +104,11 @@ func (s *ScanMethod) ProcessPacketData(data []byte, _ *gopacket.CaptureInfo) err
copy(s.rcvMacPrefix[:], s.rcvARP.SourceHwAddress[:3])
hwVendor := macs.ValidMACPrefixMap[s.rcvMacPrefix]

select {
case <-s.ctx.Done():
case s.internalResults <- &ScanResult{
s.results.Put(&ScanResult{
IP: net.IP(s.rcvARP.SourceProtAddress).String(),
MAC: net.HardwareAddr(s.rcvARP.SourceHwAddress).String(),
Vendor: hwVendor,
}:
}
})
return nil
}
}
Expand Down
79 changes: 42 additions & 37 deletions pkg/scan/arp/arp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,56 @@ import (
func TestProcessPacketData(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sm := NewScanMethod(ctx)
done := make(chan interface{})

// generate packet data
packet := gopacket.NewSerializeBuffer()
eth := &layers.Ethernet{
SrcMAC: net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6},
DstMAC: net.HardwareAddr{0x10, 0x11, 0x12, 0x13, 0x14, 0x15},
EthernetType: layers.EthernetTypeARP,
}
go func() {
defer close(done)

a := &layers.ARP{
AddrType: layers.LinkTypeEthernet,
Protocol: layers.EthernetTypeIPv4,
HwAddressSize: uint8(6),
ProtAddressSize: uint8(4),
Operation: layers.ARPRequest,
SourceHwAddress: net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6},
SourceProtAddress: net.IPv4(192, 168, 0, 3).To4(),
DstHwAddress: net.HardwareAddr{0x10, 0x11, 0x12, 0x13, 0x14, 0x15},
DstProtAddress: net.IPv4(192, 168, 0, 2).To4(),
}
var opt gopacket.SerializeOptions
err := gopacket.SerializeLayers(packet, opt, eth, a)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sm := NewScanMethod(ctx)

err = sm.ProcessPacketData(packet.Bytes(), &gopacket.CaptureInfo{})
require.NoError(t, err)
// generate packet data
packet := gopacket.NewSerializeBuffer()
eth := &layers.Ethernet{
SrcMAC: net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6},
DstMAC: net.HardwareAddr{0x10, 0x11, 0x12, 0x13, 0x14, 0x15},
EthernetType: layers.EthernetTypeARP,
}

select {
case result, ok := <-sm.Results():
a := &layers.ARP{
AddrType: layers.LinkTypeEthernet,
Protocol: layers.EthernetTypeIPv4,
HwAddressSize: uint8(6),
ProtAddressSize: uint8(4),
Operation: layers.ARPRequest,
SourceHwAddress: net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6},
SourceProtAddress: net.IPv4(192, 168, 0, 3).To4(),
DstHwAddress: net.HardwareAddr{0x10, 0x11, 0x12, 0x13, 0x14, 0x15},
DstProtAddress: net.IPv4(192, 168, 0, 2).To4(),
}
var opt gopacket.SerializeOptions
err := gopacket.SerializeLayers(packet, opt, eth, a)
require.NoError(t, err)

err = sm.ProcessPacketData(packet.Bytes(), &gopacket.CaptureInfo{})
require.NoError(t, err)

result, ok := <-sm.Results()
if !ok {
require.FailNow(t, "results chan is empty")
}
assert.Equal(t, net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6}.String(), result.MAC)
assert.Equal(t, net.IPv4(192, 168, 0, 3).To4().String(), result.IP)
arpResult := result.(*ScanResult)
assert.Equal(t, net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6}.String(), arpResult.MAC)
assert.Equal(t, net.IPv4(192, 168, 0, 3).To4().String(), arpResult.IP)

cancel()
select {
case _, ok := <-sm.Results():
require.False(t, ok, "results chan is not closed")
case <-time.After(1 * time.Second):
t.Fatal("read timeout")
}
_, ok = <-sm.Results()
require.False(t, ok, "results chan is not closed")
}()
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("read timeout")
t.Fatal("test timeout")
}
}
4 changes: 2 additions & 2 deletions pkg/scan/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func NewEngine(ps PacketSource, s packet.Sender, r packet.Receiver) *Engine {

func (e *Engine) Start(ctx context.Context, r *Range) (<-chan interface{}, <-chan error) {
packets := e.src.Packets(ctx, r)
errc1 := e.rcv.ReceivePackets(ctx)
done, errc2 := e.snd.SendPackets(ctx, packets)
done, errc1 := e.snd.SendPackets(ctx, packets)
errc2 := e.rcv.ReceivePackets(ctx)
return done, mergeErrChan(ctx, errc1, errc2)
}

Expand Down
Loading

0 comments on commit d0acd20

Please sign in to comment.