Skip to content

Commit

Permalink
feature: randomized port iterator (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
v-byte-cpu authored Jun 26, 2021
1 parent c3c93c7 commit e75f9ff
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 249 deletions.
15 changes: 0 additions & 15 deletions pkg/ip/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,6 @@ import (

var ErrInvalidAddr = errors.New("invalid IP subnet/host")

func Inc(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}

func DupIP(ip net.IP) net.IP {
dup := make([]byte, 4)
copy(dup, ip.To4())
return dup
}

func ParseIPNet(subnet string) (*net.IPNet, error) {
_, result, err := net.ParseCIDR(subnet)
if err == nil {
Expand Down
50 changes: 0 additions & 50 deletions pkg/ip/ip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,56 +7,6 @@ import (
"github.com/stretchr/testify/assert"
)

func TestInc(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input net.IP
expected net.IP
}{
{
name: "ZeroNet",
input: net.IPv4(0, 0, 0, 0),
expected: net.IPv4(0, 0, 0, 1),
},
{
name: "Inc3rd",
input: net.IPv4(1, 1, 0, 255),
expected: net.IPv4(1, 1, 1, 0),
},
{
name: "Inc2nd",
input: net.IPv4(1, 1, 255, 255),
expected: net.IPv4(1, 2, 0, 0),
},
{
name: "Inc1st",
input: net.IPv4(1, 255, 255, 255),
expected: net.IPv4(2, 0, 0, 0),
},
}

for _, vtt := range tests {
tt := vtt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
Inc(tt.input)
assert.Equal(t, tt.expected, tt.input)
})
}
}

func TestDupIP(t *testing.T) {
t.Parallel()
ipAddr := net.IPv4(192, 168, 0, 1).To4()

dupAddr := DupIP(ipAddr)
assert.Equal(t, ipAddr, dupAddr)

dupAddr[3]++
assert.Equal(t, net.IPv4(192, 168, 0, 1).To4(), ipAddr)
}

func TestParseIPNetWithError(t *testing.T) {
t.Parallel()
_, err := ParseIPNet("")
Expand Down
74 changes: 18 additions & 56 deletions pkg/scan/mock_request_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 46 additions & 11 deletions pkg/scan/request.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:generate mockgen -package scan -destination=mock_request_test.go -source request.go
//go:generate mockgen -package scan -destination=mock_request_test.go . PortGenerator,IPGenerator,RequestGenerator,IPContainer
//go:generate easyjson -output_filename request_easyjson.go request.go

package scan
Expand Down Expand Up @@ -31,37 +31,67 @@ type Request struct {
Err error
}

type PortGetter interface {
GetPort() (uint16, error)
}

type WrapPort uint16

func (p WrapPort) GetPort() (uint16, error) {
return uint16(p), nil
}

type portError struct {
error
}

func (err *portError) GetPort() (uint16, error) {
return 0, err
}

type PortGenerator interface {
Ports(ctx context.Context, r *Range) (<-chan uint16, error)
Ports(ctx context.Context, r *Range) (<-chan PortGetter, error)
}

func NewPortGenerator() PortGenerator {
return &portGenerator{}
}

// TODO randomizedPortGenerator
type portGenerator struct{}

func (*portGenerator) Ports(ctx context.Context, r *Range) (<-chan uint16, error) {
func (*portGenerator) Ports(ctx context.Context, r *Range) (<-chan PortGetter, error) {
if err := validatePorts(r.Ports); err != nil {
return nil, err
}
out := make(chan uint16, 100)
out := make(chan PortGetter, 100)
go func() {
defer close(out)
for _, portRange := range r.Ports {
for port := int(portRange.StartPort); port <= int(portRange.EndPort); port++ {
select {
case <-ctx.Done():
return
case out <- uint16(port):
it, err := newRangeIterator(int64(portRange.EndPort) - int64(portRange.StartPort) + 1)
if err != nil {
writePort(ctx, out, &portError{err})
continue
}
basePort := int64(portRange.StartPort) - 1
for {
writePort(ctx, out, WrapPort(basePort+it.Int().Int64()))
if !it.Next() {
break
}
}
}
}()
return out, nil
}

func writePort(ctx context.Context, out chan<- PortGetter, port PortGetter) {
select {
case <-ctx.Done():
return
case out <- port:
}
}

func validatePorts(ports []*PortRange) error {
if len(ports) == 0 {
return ErrPortRange
Expand Down Expand Up @@ -153,7 +183,12 @@ func (rg *ipPortGenerator) GenerateRequests(ctx context.Context, r *Range) (<-ch
out := make(chan *Request, 100)
go func() {
defer close(out)
for port := range ports {
for p := range ports {
port, err := p.GetPort()
if err != nil {
writeRequest(ctx, out, &Request{Err: err})
continue
}
for ipaddr := range ips {
dstip, err := ipaddr.GetIP()
writeRequest(ctx, out, &Request{
Expand Down
Loading

0 comments on commit e75f9ff

Please sign in to comment.