Skip to content

Commit

Permalink
refactoring: simplify cli flags parsing (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
v-byte-cpu authored May 25, 2021
1 parent 345e86a commit 30b669f
Show file tree
Hide file tree
Showing 17 changed files with 1,333 additions and 996 deletions.
121 changes: 71 additions & 50 deletions command/arp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,67 +16,88 @@ import (
"github.com/v-byte-cpu/sx/pkg/scan/arp"
)

var (
cliARPLiveTimeoutFlag string
cliARPLiveTimeout time.Duration
)
func newARPCmd() *arpCmd {
c := &arpCmd{}

func init() {
addPacketScanOptions(arpCmd, withoutGatewayMAC())
arpCmd.Flags().StringVar(&cliARPLiveTimeoutFlag, "live", "", "enable live mode")
rootCmd.AddCommand(arpCmd)
}
cmd := &cobra.Command{
Use: "arp [flags] subnet",
Example: strings.Join([]string{"arp 192.168.0.1/24", "arp 10.0.0.1"}, "\n"),
Short: "Perform ARP scan",
RunE: func(cmd *cobra.Command, args []string) (err error) {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()

var arpCmd = &cobra.Command{
Use: "arp [flags] subnet",
Example: strings.Join([]string{"arp 192.168.0.1/24", "arp 10.0.0.1"}, "\n"),
Short: "Perform ARP scan",
PreRunE: func(cmd *cobra.Command, args []string) (err error) {
if len(cliARPLiveTimeoutFlag) > 0 {
if cliARPLiveTimeout, err = time.ParseDuration(cliARPLiveTimeoutFlag); err != nil {
if len(args) != 1 {
return errors.New("requires one ip subnet argument")
}
dstSubnet, err := ip.ParseIPNet(args[0])
if err != nil {
return
}
}
if len(args) != 1 {
return errors.New("requires one ip subnet argument")
}
cliDstSubnet, err = ip.ParseIPNet(args[0])
return
},
RunE: func(cmd *cobra.Command, args []string) (err error) {
var r *scan.Range
if r, err = getScanRange(cliDstSubnet); err != nil {
return err
}

ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
if err = c.opts.parseRawOptions(); err != nil {
return
}
var r *scan.Range
if r, err = c.opts.getScanRange(dstSubnet); err != nil {
return err
}
var logger log.Logger
if logger, err = c.opts.getLogger(); err != nil {
return err
}

var logger log.Logger
if logger, err = getLogger("arp", os.Stdout); err != nil {
return err
}
if cliARPLiveTimeout > 0 {
logger = log.NewUniqueLogger(logger)
}
m := c.opts.newARPScanMethod(ctx)

m := newARPScanMethod(ctx)
return startPacketScanEngine(ctx, newPacketScanConfig(
withPacketScanMethod(m),
withPacketBPFFilter(arp.BPFFilter),
withRateCount(c.opts.rateCount),
withRateWindow(c.opts.rateWindow),
withPacketEngineConfig(newEngineConfig(
withLogger(logger),
withScanRange(r),
withExitDelay(c.opts.exitDelay),
)),
))
},
}

return startPacketScanEngine(ctx, newPacketScanConfig(
withPacketScanMethod(m),
withPacketBPFFilter(arp.BPFFilter),
withPacketEngineConfig(newEngineConfig(
withLogger(logger),
withScanRange(r),
)),
))
},
c.opts.initCliFlags(cmd)

c.cmd = cmd
return c
}

type arpCmd struct {
cmd *cobra.Command
opts arpCmdOpts
}

type arpCmdOpts struct {
packetScanCmdOpts
liveTimeout time.Duration
}

func (o *arpCmdOpts) initCliFlags(cmd *cobra.Command) {
o.packetScanCmdOpts.initCliFlags(cmd)
cmd.Flags().DurationVar(&o.liveTimeout, "live", 0, "enable live mode")
}

func (o *arpCmdOpts) getLogger() (logger log.Logger, err error) {
if logger, err = o.packetScanCmdOpts.getLogger("arp", os.Stdout); err != nil {
return
}
if o.liveTimeout > 0 {
logger = log.NewUniqueLogger(logger)
}
return
}

func newARPScanMethod(ctx context.Context) *arp.ScanMethod {
func (o *arpCmdOpts) newARPScanMethod(ctx context.Context) *arp.ScanMethod {
var reqgen scan.RequestGenerator = scan.NewIPRequestGenerator(scan.NewIPGenerator())
if cliARPLiveTimeout > 0 {
reqgen = scan.NewLiveRequestGenerator(reqgen, cliARPLiveTimeout)
if o.liveTimeout > 0 {
reqgen = scan.NewLiveRequestGenerator(reqgen, o.liveTimeout)
}
pktgen := scan.NewPacketMultiGenerator(arp.NewPacketFiller(), runtime.NumCPU())
psrc := scan.NewPacketSource(reqgen, pktgen)
Expand Down
21 changes: 21 additions & 0 deletions command/arp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package command

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestArpCmdDstSubnetRequiredArg(t *testing.T) {
cmd := newARPCmd().cmd
err := cmd.Execute()
require.Error(t, err)
require.Equal(t, "requires one ip subnet argument", err.Error())
}

func TestArpCmdInvalidDstSubnet(t *testing.T) {
cmd := newARPCmd().cmd
cmd.SetArgs([]string{"invalid_ip_address"})
err := cmd.Execute()
require.Error(t, err)
}
Loading

0 comments on commit 30b669f

Please sign in to comment.