From 29ca59d50b1a6b5796b45a8c428b1adbc46cc91b Mon Sep 17 00:00:00 2001 From: v-byte-cpu <65545655+v-byte-cpu@users.noreply.github.com> Date: Sat, 10 Jul 2021 14:53:18 +0300 Subject: [PATCH] feature: vpn support (#94) --- README.md | 8 ++ command/arp.go | 3 + command/config.go | 54 ++++----- command/config_test.go | 4 +- command/icmp.go | 21 ++-- command/root.go | 9 +- command/tcp.go | 37 +++--- command/tcp_fin.go | 12 +- command/tcp_null.go | 12 +- command/tcp_syn.go | 12 +- command/tcp_xmas.go | 12 +- command/udp.go | 22 ++-- pkg/packet/afpacket/readwriter.go | 13 ++- pkg/scan/icmp/icmp.go | 50 ++++++--- pkg/scan/icmp/icmp_test.go | 109 +++++++++++++++++- pkg/scan/tcp/tcp.go | 60 +++++++--- pkg/scan/tcp/tcp_test.go | 179 +++++++++++++++++++++++++++++- pkg/scan/udp/udp.go | 24 ++-- pkg/scan/udp/udp_test.go | 109 +++++++++++++++++- 19 files changed, 606 insertions(+), 144 deletions(-) diff --git a/README.md b/README.md index 21a5f37..0f96012 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,14 @@ cat arp.cache | sx tcp syn -p 22 192.168.0.171 `tcp` subcomand is just a shorthand for `tcp syn` subcommand unless `--flags` option is passed, see below. +### VPN interfaces + +`sx` supports scanning with virtual network interfaces (wireguard, openvpn, etc.) and in this case it is **not** necessary to use the arp cache, since these interfaces require raw IP packets instead of Ethernet frames as input. For instance, scanning an IP address on a vpn network: + +``` +sx tcp 10.1.27.1 -p 80 --json +``` + ### TCP FIN scan Most network scanners try to interpret results of the scan. For instance they say "this port is closed" instead of "I received a RST". Sometimes they are right. Sometimes not. It's easier for beginners, but when you know what you're doing, you keep on trying to deduce what really happened from the program's interpretation, especially for more advanced scan techniques. diff --git a/command/arp.go b/command/arp.go index c9f0897..b5bd6bd 100644 --- a/command/arp.go +++ b/command/arp.go @@ -42,6 +42,9 @@ func newARPCmd() *arpCmd { if r, err = c.opts.getScanRange(dstSubnet); err != nil { return err } + if r.SrcMAC == nil { + return errSrcMAC + } var logger log.Logger if logger, err = c.opts.getLogger(); err != nil { return err diff --git a/command/config.go b/command/config.go index 5481559..8e8d5df 100644 --- a/command/config.go +++ b/command/config.go @@ -125,9 +125,6 @@ func (o *packetScanCmdOpts) getScanRange(dstSubnet *net.IPNet) (*scan.Range, err if o.srcMAC != nil { srcMAC = o.srcMAC } - if srcMAC == nil { - return nil, errSrcMAC - } return &scan.Range{ Interface: iface, @@ -178,6 +175,11 @@ type ipScanCmdOpts struct { ipFile string arpCacheFile string gatewayMAC net.HardwareAddr + vpnMode bool + + logger log.Logger + scanRange *scan.Range + cache *arp.Cache rawGatewayMAC string } @@ -202,52 +204,42 @@ func (o *ipScanCmdOpts) parseRawOptions() (err error) { return } -type scanConfig struct { - logger log.Logger - scanRange *scan.Range - cache *arp.Cache - gatewayMAC net.HardwareAddr -} - -func (o *ipScanCmdOpts) parseScanConfig(scanName string, args []string) (c *scanConfig, err error) { - if err = o.validateStdin(); err != nil { - return - } +func (o *ipScanCmdOpts) parseOptions(scanName string, args []string) (err error) { dstSubnet, err := o.parseDstSubnet(args) if err != nil { return } - var r *scan.Range - if r, err = o.getScanRange(dstSubnet); err != nil { + if o.scanRange, err = o.getScanRange(dstSubnet); err != nil { return } + if o.scanRange.SrcMAC == nil { + o.vpnMode = true + } - var logger log.Logger - if logger, err = o.getLogger(scanName, os.Stdout); err != nil { + if o.logger, err = o.getLogger(scanName, os.Stdout); err != nil { return } - var cache *arp.Cache - if cache, err = o.parseARPCache(); err != nil { + // disable arp cache parsing for vpn mode + if o.vpnMode { + return + } + if err = o.validateARPStdin(); err != nil { return } - var gatewayMAC net.HardwareAddr - if gatewayMAC, err = o.getGatewayMAC(r.Interface, cache); err != nil { + if o.cache, err = o.parseARPCache(); err != nil { return } - c = &scanConfig{ - logger: logger, - scanRange: r, - cache: cache, - gatewayMAC: gatewayMAC, + if o.gatewayMAC, err = o.getGatewayMAC(o.scanRange.Interface, o.cache); err != nil { + return } return } -func (o *ipScanCmdOpts) validateStdin() (err error) { +func (o *ipScanCmdOpts) validateARPStdin() (err error) { if o.isARPCacheFromStdin() && o.ipFile == "-" { return errARPStdin } @@ -333,11 +325,11 @@ func (o *ipPortScanCmdOpts) parseRawOptions() (err error) { return } -func (o *ipPortScanCmdOpts) parseScanConfig(scanName string, args []string) (c *scanConfig, err error) { - if c, err = o.ipScanCmdOpts.parseScanConfig(scanName, args); err != nil { +func (o *ipPortScanCmdOpts) parseOptions(scanName string, args []string) (err error) { + if err = o.ipScanCmdOpts.parseOptions(scanName, args); err != nil { return } - c.scanRange.Ports = o.portRanges + o.scanRange.Ports = o.portRanges return } diff --git a/command/config_test.go b/command/config_test.go index 7d18aaa..2ada237 100644 --- a/command/config_test.go +++ b/command/config_test.go @@ -217,7 +217,7 @@ func TestIPScanCmdOptsIsARPCacheFromStdin(t *testing.T) { } } -func TestIPScanCmdOptsValidateStdin(t *testing.T) { +func TestIPScanCmdOptsValidateARPStdin(t *testing.T) { t.Parallel() tests := []struct { name string @@ -265,7 +265,7 @@ func TestIPScanCmdOptsValidateStdin(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.opts.validateStdin() + err := tt.opts.validateARPStdin() if tt.shouldErr { require.Error(t, err) } else { diff --git a/command/icmp.go b/command/icmp.go index c6a4427..8ef49a0 100644 --- a/command/icmp.go +++ b/command/icmp.go @@ -32,21 +32,21 @@ func newICMPCmd() *icmpCmd { if err = c.opts.parseRawOptions(); err != nil { return } - var conf *scanConfig - if conf, err = c.opts.parseScanConfig(icmp.ScanType, args); err != nil { + if err = c.opts.parseOptions(icmp.ScanType, args); err != nil { return } - m := c.opts.newICMPScanMethod(ctx, conf) + m := c.opts.newICMPScanMethod(ctx) return startPacketScanEngine(ctx, newPacketScanConfig( withPacketScanMethod(m), withPacketBPFFilter(icmp.BPFFilter), withRateCount(c.opts.rateCount), withRateWindow(c.opts.rateWindow), + withPacketVPNmode(c.opts.vpnMode), withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), + withLogger(c.opts.logger), + withScanRange(c.opts.scanRange), withExitDelay(c.opts.exitDelay), )), )) @@ -112,7 +112,7 @@ func (o *icmpCmdOpts) parseRawOptions() (err error) { return } -func (o *icmpCmdOpts) newICMPScanMethod(ctx context.Context, conf *scanConfig) *icmp.ScanMethod { +func (o *icmpCmdOpts) newICMPScanMethod(ctx context.Context) *icmp.ScanMethod { ipgen := scan.NewIPGenerator() if len(o.ipFile) > 0 { ipgen = scan.NewFileIPGenerator(func() (io.ReadCloser, error) { @@ -123,11 +123,13 @@ func (o *icmpCmdOpts) newICMPScanMethod(ctx context.Context, conf *scanConfig) * if o.excludeIPs != nil { reqgen = scan.NewFilterIPRequestGenerator(reqgen, o.excludeIPs) } - reqgen = arp.NewCacheRequestGenerator(reqgen, conf.gatewayMAC, conf.cache) + if o.cache != nil { + reqgen = arp.NewCacheRequestGenerator(reqgen, o.gatewayMAC, o.cache) + } pktgen := scan.NewPacketMultiGenerator(icmp.NewPacketFiller(o.getICMPOptions()...), runtime.NumCPU()) psrc := scan.NewPacketSource(reqgen, pktgen) results := scan.NewResultChan(ctx, 1000) - return icmp.NewScanMethod(psrc, results) + return icmp.NewScanMethod(psrc, results, o.vpnMode) } func (o *icmpCmdOpts) getICMPOptions() (opts []icmp.PacketFillerOption) { @@ -137,7 +139,8 @@ func (o *icmpCmdOpts) getICMPOptions() (opts []icmp.PacketFillerOption) { icmp.WithIPFlags(o.ipFlags), icmp.WithIPTotalLength(o.ipTotalLen), icmp.WithType(o.icmpType), - icmp.WithCode(o.icmpCode)) + icmp.WithCode(o.icmpCode), + icmp.WithVPNmode(o.vpnMode)) if len(o.icmpPayload) > 0 { opts = append(opts, icmp.WithPayload(o.icmpPayload)) diff --git a/command/root.go b/command/root.go index 40cd4e9..6291d17 100644 --- a/command/root.go +++ b/command/root.go @@ -94,6 +94,7 @@ type packetScanConfig struct { bpfFilter bpfFilterFunc rateCount int rateWindow time.Duration + vpnMode bool } type packetScanConfigOption func(c *packetScanConfig) @@ -128,6 +129,12 @@ func withRateWindow(rateWindow time.Duration) packetScanConfigOption { } } +func withPacketVPNmode(vpnMode bool) packetScanConfigOption { + return func(c *packetScanConfig) { + c.vpnMode = vpnMode + } +} + func newPacketScanConfig(opts ...packetScanConfigOption) *packetScanConfig { c := &packetScanConfig{} for _, o := range opts { @@ -140,7 +147,7 @@ func startPacketScanEngine(ctx context.Context, conf *packetScanConfig) error { r := conf.scanRange // setup network interface to read/write packets - ps, err := afpacket.NewPacketSource(r.Interface.Name) + ps, err := afpacket.NewPacketSource(r.Interface.Name, conf.vpnMode) if err != nil { return err } diff --git a/command/tcp.go b/command/tcp.go index f269f50..79b820f 100644 --- a/command/tcp.go +++ b/command/tcp.go @@ -52,8 +52,7 @@ func newTCPFlagsCmd() *tcpFlagsCmd { } scanName := tcp.FlagsScanType - var conf *scanConfig - if conf, err = c.opts.parseScanConfig(scanName, args); err != nil { + if err = c.opts.parseOptions(scanName, args); err != nil { return } @@ -62,9 +61,9 @@ func newTCPFlagsCmd() *tcpFlagsCmd { opts = append(opts, tcpPacketFlagOptions[flag]) } - m := c.opts.newTCPScanMethod(ctx, conf, + m := c.opts.newTCPScanMethod(ctx, withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller(opts...)), + withTCPPacketFillerOptions(opts...), withTCPPacketFilterFunc(tcp.TrueFilter), withTCPPacketFlags(tcp.AllFlags), ) @@ -74,9 +73,10 @@ func newTCPFlagsCmd() *tcpFlagsCmd { withPacketBPFFilter(tcp.BPFFilter), withRateCount(c.opts.rateCount), withRateWindow(c.opts.rateWindow), + withPacketVPNmode(c.opts.vpnMode), withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), + withLogger(c.opts.logger), + withScanRange(c.opts.scanRange), withExitDelay(c.opts.exitDelay), )), )) @@ -146,26 +146,31 @@ type tcpCmdOpts struct { ipPortScanCmdOpts } -func (o *tcpCmdOpts) newTCPScanMethod(ctx context.Context, conf *scanConfig, opts ...tcpScanConfigOption) *tcp.ScanMethod { +func (o *tcpCmdOpts) newTCPScanMethod(ctx context.Context, opts ...tcpScanConfigOption) *tcp.ScanMethod { c := &tcpScanConfig{} for _, opt := range opts { opt(c) } - reqgen := arp.NewCacheRequestGenerator(o.newIPPortGenerator(), conf.gatewayMAC, conf.cache) - pktgen := scan.NewPacketMultiGenerator(c.packetFiller, runtime.NumCPU()) + reqgen := o.newIPPortGenerator() + if o.cache != nil { + reqgen = arp.NewCacheRequestGenerator(reqgen, o.gatewayMAC, o.cache) + } + c.packetFillerOpts = append(c.packetFillerOpts, tcp.WithFillerVPNmode(o.vpnMode)) + pktgen := scan.NewPacketMultiGenerator(tcp.NewPacketFiller(c.packetFillerOpts...), runtime.NumCPU()) psrc := scan.NewPacketSource(reqgen, pktgen) results := scan.NewResultChan(ctx, 1000) return tcp.NewScanMethod( c.scanName, psrc, results, tcp.WithPacketFilterFunc(c.packetFilter), - tcp.WithPacketFlagsFunc(c.packetFlags)) + tcp.WithPacketFlagsFunc(c.packetFlags), + tcp.WithScanVPNmode(o.vpnMode)) } type tcpScanConfig struct { - scanName string - packetFiller scan.PacketFiller - packetFilter tcp.PacketFilterFunc - packetFlags tcp.PacketFlagsFunc + scanName string + packetFillerOpts []tcp.PacketFillerOption + packetFilter tcp.PacketFilterFunc + packetFlags tcp.PacketFlagsFunc } type tcpScanConfigOption func(c *tcpScanConfig) @@ -176,9 +181,9 @@ func withTCPScanName(scanName string) tcpScanConfigOption { } } -func withTCPPacketFiller(filler scan.PacketFiller) tcpScanConfigOption { +func withTCPPacketFillerOptions(opts ...tcp.PacketFillerOption) tcpScanConfigOption { return func(c *tcpScanConfig) { - c.packetFiller = filler + c.packetFillerOpts = opts } } diff --git a/command/tcp_fin.go b/command/tcp_fin.go index c52e7ed..a2facab 100644 --- a/command/tcp_fin.go +++ b/command/tcp_fin.go @@ -26,14 +26,13 @@ func newTCPFINCmd() *tcpFINCmd { } scanName := tcp.FINScanType - var conf *scanConfig - if conf, err = c.opts.parseScanConfig(scanName, args); err != nil { + if err = c.opts.parseOptions(scanName, args); err != nil { return } - m := c.opts.newTCPScanMethod(ctx, conf, + m := c.opts.newTCPScanMethod(ctx, withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller(tcp.WithFIN())), + withTCPPacketFillerOptions(tcp.WithFIN()), withTCPPacketFilterFunc(tcp.TrueFilter), withTCPPacketFlags(tcp.AllFlags), ) @@ -43,9 +42,10 @@ func newTCPFINCmd() *tcpFINCmd { withPacketBPFFilter(tcp.BPFFilter), withRateCount(c.opts.rateCount), withRateWindow(c.opts.rateWindow), + withPacketVPNmode(c.opts.vpnMode), withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), + withLogger(c.opts.logger), + withScanRange(c.opts.scanRange), withExitDelay(c.opts.exitDelay), )), )) diff --git a/command/tcp_null.go b/command/tcp_null.go index 7429ff7..efa34c3 100644 --- a/command/tcp_null.go +++ b/command/tcp_null.go @@ -26,14 +26,13 @@ func newTCPNULLCmd() *tcpNULLCmd { } scanName := tcp.NULLScanType - var conf *scanConfig - if conf, err = c.opts.parseScanConfig(scanName, args); err != nil { + if err = c.opts.parseOptions(scanName, args); err != nil { return } - m := c.opts.newTCPScanMethod(ctx, conf, + m := c.opts.newTCPScanMethod(ctx, withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller()), + withTCPPacketFillerOptions(), withTCPPacketFilterFunc(tcp.TrueFilter), withTCPPacketFlags(tcp.AllFlags), ) @@ -43,9 +42,10 @@ func newTCPNULLCmd() *tcpNULLCmd { withPacketBPFFilter(tcp.BPFFilter), withRateCount(c.opts.rateCount), withRateWindow(c.opts.rateWindow), + withPacketVPNmode(c.opts.vpnMode), withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), + withLogger(c.opts.logger), + withScanRange(c.opts.scanRange), withExitDelay(c.opts.exitDelay), )), )) diff --git a/command/tcp_syn.go b/command/tcp_syn.go index 8c1937c..11112f7 100644 --- a/command/tcp_syn.go +++ b/command/tcp_syn.go @@ -51,14 +51,13 @@ func newTCPSYNCmdOpts(opts tcpCmdOpts) *tcpSYNCmdOpts { func (o *tcpSYNCmdOpts) startScan(ctx context.Context, args []string) (err error) { scanName := tcp.SYNScanType - var conf *scanConfig - if conf, err = o.parseScanConfig(scanName, args); err != nil { + if err = o.parseOptions(scanName, args); err != nil { return } - m := o.newTCPScanMethod(ctx, conf, + m := o.newTCPScanMethod(ctx, withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller(tcp.WithSYN())), + withTCPPacketFillerOptions(tcp.WithSYN()), withTCPPacketFilterFunc(func(pkt *layers.TCP) bool { // port is open return pkt.SYN && pkt.ACK @@ -71,9 +70,10 @@ func (o *tcpSYNCmdOpts) startScan(ctx context.Context, args []string) (err error withPacketBPFFilter(tcp.SYNACKBPFFilter), withRateCount(o.rateCount), withRateWindow(o.rateWindow), + withPacketVPNmode(o.vpnMode), withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), + withLogger(o.logger), + withScanRange(o.scanRange), withExitDelay(o.exitDelay), )), )) diff --git a/command/tcp_xmas.go b/command/tcp_xmas.go index cd28a6b..33a5d24 100644 --- a/command/tcp_xmas.go +++ b/command/tcp_xmas.go @@ -26,14 +26,13 @@ func newTCPXmasCmd() *tcpXmasCmd { } scanName := tcp.XmasScanType - var conf *scanConfig - if conf, err = c.opts.parseScanConfig(scanName, args); err != nil { + if err = c.opts.parseOptions(scanName, args); err != nil { return } - m := c.opts.newTCPScanMethod(ctx, conf, + m := c.opts.newTCPScanMethod(ctx, withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller(tcp.WithFIN(), tcp.WithPSH(), tcp.WithURG())), + withTCPPacketFillerOptions(tcp.WithFIN(), tcp.WithPSH(), tcp.WithURG()), withTCPPacketFilterFunc(tcp.TrueFilter), withTCPPacketFlags(tcp.AllFlags), ) @@ -43,9 +42,10 @@ func newTCPXmasCmd() *tcpXmasCmd { withPacketBPFFilter(tcp.BPFFilter), withRateCount(c.opts.rateCount), withRateWindow(c.opts.rateWindow), + withPacketVPNmode(c.opts.vpnMode), withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), + withLogger(c.opts.logger), + withScanRange(c.opts.scanRange), withExitDelay(c.opts.exitDelay), )), )) diff --git a/command/udp.go b/command/udp.go index 51b3419..bdf26df 100644 --- a/command/udp.go +++ b/command/udp.go @@ -33,21 +33,21 @@ func newUDPCmd() *udpCmd { if err = c.opts.parseRawOptions(); err != nil { return } - var conf *scanConfig - if conf, err = c.opts.parseScanConfig(udp.ScanType, args); err != nil { + if err = c.opts.parseOptions(udp.ScanType, args); err != nil { return } - m := c.opts.newUDPScanMethod(ctx, conf) + m := c.opts.newUDPScanMethod(ctx) return startPacketScanEngine(ctx, newPacketScanConfig( withPacketScanMethod(m), withPacketBPFFilter(icmp.BPFFilter), withRateCount(c.opts.rateCount), withRateWindow(c.opts.rateWindow), + withPacketVPNmode(c.opts.vpnMode), withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), + withLogger(c.opts.logger), + withScanRange(c.opts.scanRange), withExitDelay(c.opts.exitDelay), )), )) @@ -108,12 +108,15 @@ func (o *udpCmdOpts) parseRawOptions() (err error) { return } -func (o *udpCmdOpts) newUDPScanMethod(ctx context.Context, conf *scanConfig) *udp.ScanMethod { - reqgen := arp.NewCacheRequestGenerator(o.newIPPortGenerator(), conf.gatewayMAC, conf.cache) +func (o *udpCmdOpts) newUDPScanMethod(ctx context.Context) *udp.ScanMethod { + reqgen := o.newIPPortGenerator() + if o.cache != nil { + reqgen = arp.NewCacheRequestGenerator(o.newIPPortGenerator(), o.gatewayMAC, o.cache) + } pktgen := scan.NewPacketMultiGenerator(udp.NewPacketFiller(o.getUDPOptions()...), runtime.NumCPU()) psrc := scan.NewPacketSource(reqgen, pktgen) results := scan.NewResultChan(ctx, 1000) - return udp.NewScanMethod(psrc, results) + return udp.NewScanMethod(psrc, results, o.vpnMode) } func (o *udpCmdOpts) getUDPOptions() (opts []udp.PacketFillerOption) { @@ -121,7 +124,8 @@ func (o *udpCmdOpts) getUDPOptions() (opts []udp.PacketFillerOption) { udp.WithTTL(o.ipTTL), udp.WithIPProtocol(o.ipProtocol), udp.WithIPFlags(o.ipFlags), - udp.WithIPTotalLength(o.ipTotalLen)) + udp.WithIPTotalLength(o.ipTotalLen), + udp.WithVPNmode(o.vpnMode)) if len(o.udpPayload) > 0 { opts = append(opts, udp.WithPayload(o.udpPayload)) diff --git a/pkg/packet/afpacket/readwriter.go b/pkg/packet/afpacket/readwriter.go index b61530d..0e736e3 100644 --- a/pkg/packet/afpacket/readwriter.go +++ b/pkg/packet/afpacket/readwriter.go @@ -12,25 +12,30 @@ import ( ) type Source struct { - handle *afp.TPacket + handle *afp.TPacket + linkType layers.LinkType } // Assert that AfPacketSource conforms to the packet.ReadWriter interface var _ packet.ReadWriter = (*Source)(nil) -func NewPacketSource(iface string) (*Source, error) { +func NewPacketSource(iface string, vpnMode bool) (*Source, error) { handle, err := afp.NewTPacket(afp.SocketRaw, afp.OptInterface(iface)) if err != nil { return nil, err } - return &Source{handle}, nil + linkType := layers.LinkTypeEthernet + if vpnMode { + linkType = layers.LinkTypeIPv4 + } + return &Source{handle, linkType}, nil } // maxPacketLength is the maximum size of packets to capture in bytes. // pcap calls it "snaplen" and default value used in tcpdump is 262144 bytes, // that is redundant for most scans, see pcap(3) and tcpdump(1) for more info func (s *Source) SetBPFFilter(bpfFilter string, maxPacketLength int) error { - pcapBPF, err := pcap.CompileBPFFilter(layers.LinkTypeEthernet, maxPacketLength, bpfFilter) + pcapBPF, err := pcap.CompileBPFFilter(s.linkType, maxPacketLength, bpfFilter) if err != nil { return err } diff --git a/pkg/scan/icmp/icmp.go b/pkg/scan/icmp/icmp.go index a4b0d09..b138069 100644 --- a/pkg/scan/icmp/icmp.go +++ b/pkg/scan/icmp/icmp.go @@ -44,8 +44,8 @@ type ScanMethod struct { // Assert that icmp.ScanMethod conforms to the scan.PacketMethod interface var _ scan.PacketMethod = (*ScanMethod)(nil) -func NewScanMethod(psrc scan.PacketSource, results scan.ResultChan) *ScanMethod { - pp := NewPacketProcessor(ScanType, results) +func NewScanMethod(psrc scan.PacketSource, results scan.ResultChan, vpnMode bool) *ScanMethod { + pp := NewPacketProcessor(ScanType, results, vpnMode) return &ScanMethod{ PacketSource: psrc, Processor: pp, @@ -64,9 +64,14 @@ type PacketProcessor struct { rcvICMP layers.ICMPv4 } -func NewPacketProcessor(scanType string, results scan.ResultChan) *PacketProcessor { +func NewPacketProcessor(scanType string, results scan.ResultChan, vpnMode bool) *PacketProcessor { p := &PacketProcessor{scanType: scanType, results: results} - parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &p.rcvEth, &p.rcvIP, &p.rcvICMP) + + layerType := layers.LayerTypeEthernet + if vpnMode { + layerType = layers.LayerTypeIPv4 + } + parser := gopacket.NewDecodingLayerParser(layerType, &p.rcvEth, &p.rcvIP, &p.rcvICMP) parser.IgnoreUnsupported = true p.parser = parser return p @@ -76,12 +81,12 @@ func (p *PacketProcessor) Results() <-chan scan.Result { return p.results.Chan() } -func (p *PacketProcessor) ProcessPacketData(data []byte, _ *gopacket.CaptureInfo) error { - if err := p.parser.DecodeLayers(data, &p.rcvDecoded); err != nil { - return err +func (p *PacketProcessor) ProcessPacketData(data []byte, _ *gopacket.CaptureInfo) (err error) { + if err = p.parser.DecodeLayers(data, &p.rcvDecoded); err != nil { + return } - if len(p.rcvDecoded) != 3 { - return nil + if !validPacket(p.rcvDecoded) { + return } p.results.Put(&ScanResult{ @@ -93,7 +98,11 @@ func (p *PacketProcessor) ProcessPacketData(data []byte, _ *gopacket.CaptureInfo Code: p.rcvICMP.TypeCode.Code(), }, }) - return nil + return +} + +func validPacket(decoded []gopacket.LayerType) bool { + return len(decoded) == 3 || (len(decoded) == 2 && decoded[0] == layers.LayerTypeIPv4) } type PacketFiller struct { @@ -104,6 +113,7 @@ type PacketFiller struct { typ uint8 code uint8 payload []byte + vpnMode bool } // Assert that icmp.PacketFiller conforms to the scan.PacketFiller interface @@ -155,6 +165,12 @@ func WithPayload(payload []byte) PacketFillerOption { } } +func WithVPNmode(vpnMode bool) PacketFillerOption { + return func(f *PacketFiller) { + f.vpnMode = vpnMode + } +} + func NewPacketFiller(opts ...PacketFillerOption) *PacketFiller { payload := make([]byte, 48) rand.Read(payload) @@ -174,11 +190,6 @@ func NewPacketFiller(opts ...PacketFillerOption) *PacketFiller { } func (f *PacketFiller) Fill(packet gopacket.SerializeBuffer, r *scan.Request) (err error) { - eth := &layers.Ethernet{ - SrcMAC: r.SrcMAC, - DstMAC: r.DstMAC, - EthernetType: layers.EthernetTypeIPv4, - } ip := &layers.IPv4{ Version: 4, @@ -206,5 +217,14 @@ func (f *PacketFiller) Fill(packet gopacket.SerializeBuffer, r *scan.Request) (e if ip.Length == 0 { opt.FixLengths = true } + + if f.vpnMode { + return gopacket.SerializeLayers(packet, opt, ip, icmp, gopacket.Payload(f.payload)) + } + eth := &layers.Ethernet{ + SrcMAC: r.SrcMAC, + DstMAC: r.DstMAC, + EthernetType: layers.EthernetTypeIPv4, + } return gopacket.SerializeLayers(packet, opt, eth, ip, icmp, gopacket.Payload(f.payload)) } diff --git a/pkg/scan/icmp/icmp_test.go b/pkg/scan/icmp/icmp_test.go index b17a69b..6d0dd92 100644 --- a/pkg/scan/icmp/icmp_test.go +++ b/pkg/scan/icmp/icmp_test.go @@ -13,7 +13,7 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan" ) -func TestPacketFiller(t *testing.T) { +func TestPacketFillerEthernet(t *testing.T) { t.Parallel() filler := NewPacketFiller( @@ -55,6 +55,45 @@ func TestPacketFiller(t *testing.T) { require.Equal(t, 48, len(icmp.Payload)) } +func TestPacketFillerIPv4(t *testing.T) { + t.Parallel() + + filler := NewPacketFiller( + WithType(layers.ICMPv4TypeTimestampRequest), WithCode(1), WithVPNmode(true)) + packet := gopacket.NewSerializeBuffer() + err := filler.Fill(packet, &scan.Request{ + SrcIP: net.IPv4(192, 168, 0, 3).To4(), + DstIP: net.IPv4(192, 168, 0, 2).To4(), + SrcMAC: net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6}, + DstMAC: net.HardwareAddr{0x10, 0x11, 0x12, 0x13, 0x14, 0x15}, + }) + require.NoError(t, err) + + resultPacket := gopacket.NewPacket(packet.Bytes(), layers.LayerTypeIPv4, gopacket.Default) + + ethLayer := resultPacket.Layer(layers.LayerTypeEthernet) + require.Nil(t, ethLayer, "ethernet layer is not empty") + + ipLayer := resultPacket.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "ip layer is empty") + ip := ipLayer.(*layers.IPv4) + require.Equal(t, net.IPv4(192, 168, 0, 3).To4(), ip.SrcIP.To4()) + require.Equal(t, net.IPv4(192, 168, 0, 2).To4(), ip.DstIP.To4()) + require.Equal(t, uint8(64), ip.TTL) + require.Equal(t, uint8(5), ip.IHL) + // IP header + ICMP header + payload length + require.Equal(t, uint16(20+8+48), ip.Length) + require.Equal(t, layers.IPProtocolICMPv4, ip.Protocol) + require.Equal(t, layers.IPv4DontFragment, ip.Flags) + + icmpLayer := resultPacket.Layer(layers.LayerTypeICMPv4) + require.NotNil(t, icmpLayer, "icmp layer is empty") + icmp := icmpLayer.(*layers.ICMPv4) + require.Equal(t, uint8(layers.ICMPv4TypeTimestampRequest), icmp.TypeCode.Type()) + require.Equal(t, uint8(1), icmp.TypeCode.Code()) + require.Equal(t, 48, len(icmp.Payload)) +} + func TestPacketFillerPayload(t *testing.T) { t.Parallel() @@ -173,7 +212,7 @@ func TestPacketFillerIPFlags(t *testing.T) { require.Equal(t, layers.IPv4DontFragment|layers.IPv4MoreFragments, ip.Flags) } -func TestProcessPacketData(t *testing.T) { +func TestProcessPacketDataEthernet(t *testing.T) { t.Parallel() done := make(chan interface{}) @@ -184,7 +223,7 @@ func TestProcessPacketData(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() results := scan.NewResultChan(ctx, 1000) - p := NewPacketProcessor(ScanType, results) + p := NewPacketProcessor(ScanType, results, false) // generate packet data packet := gopacket.NewSerializeBuffer() @@ -241,3 +280,67 @@ func TestProcessPacketData(t *testing.T) { t.Fatal("test timeout") } } + +func TestProcessPacketDataIPv4(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + + go func() { + defer close(done) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + results := scan.NewResultChan(ctx, 1000) + p := NewPacketProcessor(ScanType, results, true) + + // generate packet data + packet := gopacket.NewSerializeBuffer() + + ip := &layers.IPv4{ + Version: 4, + Id: 12345, + Flags: layers.IPv4DontFragment, + TTL: 64, + Protocol: layers.IPProtocolICMPv4, + SrcIP: net.IPv4(192, 168, 0, 2).To4(), + DstIP: net.IPv4(192, 168, 0, 3).To4(), + } + + icmp := &layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode( + layers.ICMPv4TypeDestinationUnreachable, layers.ICMPv4CodeHost), + } + + opt := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + err := gopacket.SerializeLayers(packet, opt, ip, icmp) + require.NoError(t, err) + + err = p.ProcessPacketData(packet.Bytes(), &gopacket.CaptureInfo{}) + require.NoError(t, err) + + result, ok := <-p.Results() + if !ok { + require.FailNow(t, "results chan is empty") + } + icmpResult := result.(*ScanResult) + assert.Equal(t, ScanType, icmpResult.ScanType) + assert.Equal(t, net.IPv4(192, 168, 0, 2).To4().String(), icmpResult.IP) + assert.Equal(t, uint8(64), icmpResult.TTL) + require.NotNil(t, icmpResult.ICMP) + assert.Equal(t, uint8(layers.ICMPv4TypeDestinationUnreachable), icmpResult.ICMP.Type) + assert.Equal(t, uint8(layers.ICMPv4CodeHost), icmpResult.ICMP.Code) + + cancel() + _, ok = <-p.Results() + require.False(t, ok, "results chan is not closed") + }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("test timeout") + } +} diff --git a/pkg/scan/tcp/tcp.go b/pkg/scan/tcp/tcp.go index 42e2a35..00c1144 100644 --- a/pkg/scan/tcp/tcp.go +++ b/pkg/scan/tcp/tcp.go @@ -88,6 +88,7 @@ type ScanMethod struct { pktFilter PacketFilterFunc pktFlags PacketFlagsFunc results scan.ResultChan + vpnMode bool rcvDecoded []gopacket.LayerType rcvEth layers.Ethernet @@ -112,6 +113,12 @@ func WithPacketFlagsFunc(pktFlags PacketFlagsFunc) ScanMethodOption { } } +func WithScanVPNmode(vpnMode bool) ScanMethodOption { + return func(s *ScanMethod) { + s.vpnMode = vpnMode + } +} + func NewScanMethod(scanType string, psrc scan.PacketSource, results scan.ResultChan, opts ...ScanMethodOption) *ScanMethod { sm := &ScanMethod{ @@ -121,14 +128,18 @@ func NewScanMethod(scanType string, psrc scan.PacketSource, pktFilter: TrueFilter, pktFlags: AllFlags, } - parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &sm.rcvEth, &sm.rcvIP, &sm.rcvTCP) - parser.IgnoreUnsupported = true - sm.parser = parser - // options pattern for _, o := range opts { o(sm) } + + layerType := layers.LayerTypeEthernet + if sm.vpnMode { + layerType = layers.LayerTypeIPv4 + } + parser := gopacket.NewDecodingLayerParser(layerType, &sm.rcvEth, &sm.rcvIP, &sm.rcvTCP) + parser.IgnoreUnsupported = true + sm.parser = parser return sm } @@ -136,12 +147,12 @@ func (s *ScanMethod) Results() <-chan scan.Result { return s.results.Chan() } -func (s *ScanMethod) ProcessPacketData(data []byte, _ *gopacket.CaptureInfo) error { - if err := s.parser.DecodeLayers(data, &s.rcvDecoded); err != nil { - return err +func (s *ScanMethod) ProcessPacketData(data []byte, _ *gopacket.CaptureInfo) (err error) { + if err = s.parser.DecodeLayers(data, &s.rcvDecoded); err != nil { + return } - if len(s.rcvDecoded) != 3 { - return nil + if !validPacket(s.rcvDecoded) { + return } if s.pktFilter(&s.rcvTCP) { @@ -152,7 +163,11 @@ func (s *ScanMethod) ProcessPacketData(data []byte, _ *gopacket.CaptureInfo) err Flags: s.pktFlags(&s.rcvTCP), }) } - return nil + return +} + +func validPacket(decoded []gopacket.LayerType) bool { + return len(decoded) == 3 || (len(decoded) == 2 && decoded[0] == layers.LayerTypeIPv4) } type PacketFiller struct { @@ -165,6 +180,8 @@ type PacketFiller struct { ECE bool CWR bool NS bool + + vpnMode bool } // Assert that tcp.PacketFiller conforms to the scan.PacketFiller interface @@ -226,6 +243,12 @@ func WithNS() PacketFillerOption { } } +func WithFillerVPNmode(vpnMode bool) PacketFillerOption { + return func(f *PacketFiller) { + f.vpnMode = vpnMode + } +} + func NewPacketFiller(opts ...PacketFillerOption) *PacketFiller { f := &PacketFiller{} for _, o := range opts { @@ -235,11 +258,6 @@ func NewPacketFiller(opts ...PacketFillerOption) *PacketFiller { } func (f *PacketFiller) Fill(packet gopacket.SerializeBuffer, r *scan.Request) (err error) { - eth := &layers.Ethernet{ - SrcMAC: r.SrcMAC, - DstMAC: r.DstMAC, - EthernetType: layers.EthernetTypeIPv4, - } ip := &layers.IPv4{ Version: 4, @@ -253,7 +271,6 @@ func (f *PacketFiller) Fill(packet gopacket.SerializeBuffer, r *scan.Request) (e SrcIP: r.SrcIP, DstIP: r.DstIP, } - tcp := &layers.TCP{ // emulate Linux default ephemeral ports range: 32768 60999 // cat /proc/sys/net/ipv4/ip_local_port_range @@ -289,9 +306,16 @@ func (f *PacketFiller) Fill(packet gopacket.SerializeBuffer, r *scan.Request) (e }, } if err = tcp.SetNetworkLayerForChecksum(ip); err != nil { - return err + return } - opt := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + if f.vpnMode { + return gopacket.SerializeLayers(packet, opt, ip, tcp) + } + eth := &layers.Ethernet{ + SrcMAC: r.SrcMAC, + DstMAC: r.DstMAC, + EthernetType: layers.EthernetTypeIPv4, + } return gopacket.SerializeLayers(packet, opt, eth, ip, tcp) } diff --git a/pkg/scan/tcp/tcp_test.go b/pkg/scan/tcp/tcp_test.go index 2f8ed9a..2c5c90e 100644 --- a/pkg/scan/tcp/tcp_test.go +++ b/pkg/scan/tcp/tcp_test.go @@ -15,7 +15,7 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/arp" ) -func TestPacketFiller(t *testing.T) { +func TestPacketFillerEthernet(t *testing.T) { t.Parallel() tests := []struct { @@ -127,7 +127,116 @@ func TestPacketFiller(t *testing.T) { } } -func TestProcessPacketData(t *testing.T) { +func TestPacketFillerIPv4(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + filler *PacketFiller + SYN bool + ACK bool + FIN bool + RST bool + PSH bool + URG bool + ECE bool + CWR bool + NS bool + }{ + { + name: "SYN", + filler: NewPacketFiller(WithSYN(), WithFillerVPNmode(true)), + SYN: true, + }, + { + name: "ACK", + filler: NewPacketFiller(WithACK(), WithFillerVPNmode(true)), + ACK: true, + }, + { + name: "FIN", + filler: NewPacketFiller(WithFIN(), WithFillerVPNmode(true)), + FIN: true, + }, + { + name: "RST", + filler: NewPacketFiller(WithRST(), WithFillerVPNmode(true)), + RST: true, + }, + { + name: "PSH", + filler: NewPacketFiller(WithPSH(), WithFillerVPNmode(true)), + PSH: true, + }, + { + name: "URG", + filler: NewPacketFiller(WithURG(), WithFillerVPNmode(true)), + URG: true, + }, + { + name: "ECE", + filler: NewPacketFiller(WithECE(), WithFillerVPNmode(true)), + ECE: true, + }, + { + name: "CWR", + filler: NewPacketFiller(WithCWR(), WithFillerVPNmode(true)), + CWR: true, + }, + { + name: "NS", + filler: NewPacketFiller(WithNS(), WithFillerVPNmode(true)), + NS: true, + }, + } + + for _, vtt := range tests { + tt := vtt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + packet := gopacket.NewSerializeBuffer() + err := tt.filler.Fill(packet, &scan.Request{ + SrcIP: net.IPv4(192, 168, 0, 3).To4(), + DstIP: net.IPv4(192, 168, 0, 2).To4(), + SrcMAC: net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6}, + DstMAC: net.HardwareAddr{0x10, 0x11, 0x12, 0x13, 0x14, 0x15}, + DstPort: 4567, + }) + require.NoError(t, err) + + resultPacket := gopacket.NewPacket(packet.Bytes(), layers.LayerTypeIPv4, gopacket.Default) + + ethLayer := resultPacket.Layer(layers.LayerTypeEthernet) + require.Nil(t, ethLayer, "ethernet layer is not empty") + + ipLayer := resultPacket.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "ip layer is empty") + ip := ipLayer.(*layers.IPv4) + require.Equal(t, net.IPv4(192, 168, 0, 3).To4(), ip.SrcIP.To4()) + require.Equal(t, net.IPv4(192, 168, 0, 2).To4(), ip.DstIP.To4()) + + tcpLayer := resultPacket.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "tcp layer is empty") + tcp := tcpLayer.(*layers.TCP) + require.GreaterOrEqual(t, tcp.SrcPort, uint16(32768)) + require.LessOrEqual(t, tcp.SrcPort, uint16(60999)) + require.Equal(t, uint16(4567), uint16(tcp.DstPort)) + + require.Equal(t, tt.SYN, tcp.SYN) + require.Equal(t, tt.ACK, tcp.ACK) + require.Equal(t, tt.FIN, tcp.FIN) + require.Equal(t, tt.RST, tcp.RST) + require.Equal(t, tt.PSH, tcp.PSH) + require.Equal(t, tt.URG, tcp.URG) + require.Equal(t, tt.ECE, tcp.ECE) + require.Equal(t, tt.CWR, tcp.CWR) + require.Equal(t, tt.NS, tcp.NS) + }) + } +} + +func TestProcessPacketDataEthernet(t *testing.T) { t.Parallel() done := make(chan interface{}) @@ -198,6 +307,72 @@ func TestProcessPacketData(t *testing.T) { } } +func TestProcessPacketDataIPv4(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + + go func() { + defer close(done) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + results := scan.NewResultChan(ctx, 1000) + sm := NewScanMethod(SYNScanType, nil, results, WithScanVPNmode(true)) + + // generate packet data + packet := gopacket.NewSerializeBuffer() + + ip := &layers.IPv4{ + Version: 4, + Id: 12345, + Flags: layers.IPv4DontFragment, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: net.IPv4(192, 168, 0, 2).To4(), + DstIP: net.IPv4(192, 168, 0, 3).To4(), + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(22), + DstPort: layers.TCPPort(45678), + Seq: 1234567, + SYN: true, + ACK: true, + } + err := tcp.SetNetworkLayerForChecksum(ip) + require.NoError(t, err) + + opt := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + err = gopacket.SerializeLayers(packet, opt, ip, tcp) + require.NoError(t, err) + + err = sm.ProcessPacketData(packet.Bytes(), &gopacket.CaptureInfo{}) + require.NoError(t, err) + + result, ok := <-sm.Results() + if !ok { + require.FailNow(t, "results chan is empty") + } + tcpResult := result.(*ScanResult) + assert.Equal(t, SYNScanType, tcpResult.ScanType) + assert.Equal(t, net.IPv4(192, 168, 0, 2).To4().String(), tcpResult.IP) + assert.Equal(t, uint16(22), tcpResult.Port) + + cancel() + _, ok = <-sm.Results() + require.False(t, ok, "results chan is not closed") + }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("test timeout") + } +} + func TestAllFlags(t *testing.T) { t.Parallel() tests := []struct { diff --git a/pkg/scan/udp/udp.go b/pkg/scan/udp/udp.go index 61b15bc..6a0a1a4 100644 --- a/pkg/scan/udp/udp.go +++ b/pkg/scan/udp/udp.go @@ -25,8 +25,8 @@ type ScanMethod struct { // Assert that udp.ScanMethod conforms to the scan.PacketMethod interface var _ scan.PacketMethod = (*ScanMethod)(nil) -func NewScanMethod(psrc scan.PacketSource, results scan.ResultChan) *ScanMethod { - pp := icmp.NewPacketProcessor(ScanType, results) +func NewScanMethod(psrc scan.PacketSource, results scan.ResultChan, vpnMode bool) *ScanMethod { + pp := icmp.NewPacketProcessor(ScanType, results, vpnMode) return &ScanMethod{ PacketSource: psrc, Processor: pp, @@ -40,6 +40,7 @@ type PacketFiller struct { proto layers.IPProtocol flags layers.IPv4Flag payload []byte + vpnMode bool } // Assert that udp.PacketFiller conforms to the scan.PacketFiller interface @@ -79,6 +80,12 @@ func WithPayload(payload []byte) PacketFillerOption { } } +func WithVPNmode(vpnMode bool) PacketFillerOption { + return func(f *PacketFiller) { + f.vpnMode = vpnMode + } +} + func NewPacketFiller(opts ...PacketFillerOption) *PacketFiller { f := &PacketFiller{ // typical TTL value for Linux @@ -93,11 +100,6 @@ func NewPacketFiller(opts ...PacketFillerOption) *PacketFiller { } func (f *PacketFiller) Fill(packet gopacket.SerializeBuffer, r *scan.Request) (err error) { - eth := &layers.Ethernet{ - SrcMAC: r.SrcMAC, - DstMAC: r.DstMAC, - EthernetType: layers.EthernetTypeIPv4, - } ip := &layers.IPv4{ Version: 4, @@ -130,5 +132,13 @@ func (f *PacketFiller) Fill(packet gopacket.SerializeBuffer, r *scan.Request) (e if ip.Length == 0 { opt.FixLengths = true } + if f.vpnMode { + return gopacket.SerializeLayers(packet, opt, ip, udp, gopacket.Payload(f.payload)) + } + eth := &layers.Ethernet{ + SrcMAC: r.SrcMAC, + DstMAC: r.DstMAC, + EthernetType: layers.EthernetTypeIPv4, + } return gopacket.SerializeLayers(packet, opt, eth, ip, udp, gopacket.Payload(f.payload)) } diff --git a/pkg/scan/udp/udp_test.go b/pkg/scan/udp/udp_test.go index b77b4fe..b4d158e 100644 --- a/pkg/scan/udp/udp_test.go +++ b/pkg/scan/udp/udp_test.go @@ -14,7 +14,7 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/icmp" ) -func TestPacketFiller(t *testing.T) { +func TestPacketFillerEthernet(t *testing.T) { t.Parallel() filler := NewPacketFiller() @@ -57,6 +57,46 @@ func TestPacketFiller(t *testing.T) { require.Equal(t, []byte{}, udp.Payload) } +func TestPacketFillerIPv4(t *testing.T) { + t.Parallel() + + filler := NewPacketFiller(WithVPNmode(true)) + packet := gopacket.NewSerializeBuffer() + err := filler.Fill(packet, &scan.Request{ + SrcIP: net.IPv4(192, 168, 0, 3).To4(), + DstIP: net.IPv4(192, 168, 0, 2).To4(), + SrcMAC: net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6}, + DstMAC: net.HardwareAddr{0x10, 0x11, 0x12, 0x13, 0x14, 0x15}, + DstPort: 4567, + }) + require.NoError(t, err) + + resultPacket := gopacket.NewPacket(packet.Bytes(), layers.LayerTypeIPv4, gopacket.Default) + + ethLayer := resultPacket.Layer(layers.LayerTypeEthernet) + require.Nil(t, ethLayer, "ethernet layer is not empty") + + ipLayer := resultPacket.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "ip layer is empty") + ip := ipLayer.(*layers.IPv4) + require.Equal(t, net.IPv4(192, 168, 0, 3).To4(), ip.SrcIP.To4()) + require.Equal(t, net.IPv4(192, 168, 0, 2).To4(), ip.DstIP.To4()) + require.Equal(t, uint8(64), ip.TTL) + require.Equal(t, uint8(5), ip.IHL) + // IP header + UDP header length + require.Equal(t, uint16(20+8), ip.Length) + require.Equal(t, layers.IPProtocolUDP, ip.Protocol) + require.Equal(t, layers.IPv4DontFragment, ip.Flags) + + udpLayer := resultPacket.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "udp layer is empty") + udp := udpLayer.(*layers.UDP) + require.GreaterOrEqual(t, udp.SrcPort, uint16(32768)) + require.LessOrEqual(t, udp.SrcPort, uint16(60999)) + require.Equal(t, uint16(4567), uint16(udp.DstPort)) + require.Equal(t, []byte{}, udp.Payload) +} + func TestPacketFillerPayload(t *testing.T) { t.Parallel() @@ -199,7 +239,7 @@ func TestPacketFillerIPFlags(t *testing.T) { require.Equal(t, layers.IPv4DontFragment|layers.IPv4MoreFragments, ip.Flags) } -func TestProcessPacketData(t *testing.T) { +func TestProcessPacketDataEthernet(t *testing.T) { t.Parallel() done := make(chan interface{}) @@ -210,7 +250,7 @@ func TestProcessPacketData(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() results := scan.NewResultChan(ctx, 1000) - sm := NewScanMethod(nil, results) + sm := NewScanMethod(nil, results, false) // generate packet data packet := gopacket.NewSerializeBuffer() @@ -266,3 +306,66 @@ func TestProcessPacketData(t *testing.T) { t.Fatal("test timeout") } } + +func TestProcessPacketDataIPv4(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + + go func() { + defer close(done) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + results := scan.NewResultChan(ctx, 1000) + sm := NewScanMethod(nil, results, true) + + // generate packet data + packet := gopacket.NewSerializeBuffer() + + ip := &layers.IPv4{ + Version: 4, + Id: 12345, + Flags: layers.IPv4DontFragment, + TTL: 64, + Protocol: layers.IPProtocolICMPv4, + SrcIP: net.IPv4(192, 168, 0, 2).To4(), + DstIP: net.IPv4(192, 168, 0, 3).To4(), + } + + icmpLayer := &layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode( + layers.ICMPv4TypeDestinationUnreachable, layers.ICMPv4CodePort), + } + + opt := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + err := gopacket.SerializeLayers(packet, opt, ip, icmpLayer) + require.NoError(t, err) + + err = sm.ProcessPacketData(packet.Bytes(), &gopacket.CaptureInfo{}) + require.NoError(t, err) + + result, ok := <-sm.Results() + if !ok { + require.FailNow(t, "results chan is empty") + } + icmpResult := result.(*icmp.ScanResult) + assert.Equal(t, ScanType, icmpResult.ScanType) + assert.Equal(t, net.IPv4(192, 168, 0, 2).To4().String(), icmpResult.IP) + require.NotNil(t, icmpResult.ICMP) + assert.Equal(t, uint8(layers.ICMPv4TypeDestinationUnreachable), icmpResult.ICMP.Type) + assert.Equal(t, uint8(layers.ICMPv4CodePort), icmpResult.ICMP.Code) + + cancel() + _, ok = <-sm.Results() + require.False(t, ok, "results chan is not closed") + }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("test timeout") + } +}