Skip to content

Commit

Permalink
consider protocolType in max host error (projectdiscovery#5668)
Browse files Browse the repository at this point in the history
* consider protocolType in max host error

* add mutex when updating internal-event
  • Loading branch information
tarunKoyalwar authored Sep 28, 2024
1 parent e4dae52 commit 1f945d6
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 29 deletions.
2 changes: 1 addition & 1 deletion pkg/core/executors.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ
currentInfo.Unlock()

// Skip if the host has had errors
if e.executerOpts.HostErrorsCache != nil && e.executerOpts.HostErrorsCache.Check(contextargs.NewWithMetaInput(ctx, scannedValue)) {
if e.executerOpts.HostErrorsCache != nil && e.executerOpts.HostErrorsCache.Check(e.executerOpts.ProtocolType.String(), contextargs.NewWithMetaInput(ctx, scannedValue)) {
return true
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/core/workflow_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
}
if err != nil {
if w.Options.HostErrorsCache != nil {
w.Options.HostErrorsCache.MarkFailed(ctx.Input, err)
w.Options.HostErrorsCache.MarkFailed(w.Options.ProtocolType.String(), ctx.Input, err)
}
if len(template.Executers) == 1 {
mainErr = err
Expand Down
20 changes: 11 additions & 9 deletions pkg/protocols/common/hosterrorscache/hosterrorscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
// CacheInterface defines the signature of the hosterrorscache so that
// users of Nuclei as embedded lib may implement their own cache
type CacheInterface interface {
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(ctx *contextargs.Context) bool // return true if the host should be skipped
MarkFailed(ctx *contextargs.Context, err error) // record a failure (and cause) for the host
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped
MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host
}

var (
Expand Down Expand Up @@ -115,7 +115,7 @@ func (c *Cache) NormalizeCacheValue(value string) string {
// - URL: https?:// type
// - Host:port type
// - host type
func (c *Cache) Check(ctx *contextargs.Context) bool {
func (c *Cache) Check(protoType string, ctx *contextargs.Context) bool {
finalValue := c.GetKeyFromContext(ctx, nil)

existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue)
Expand All @@ -138,8 +138,8 @@ func (c *Cache) Check(ctx *contextargs.Context) bool {
}

// MarkFailed marks a host as failed previously
func (c *Cache) MarkFailed(ctx *contextargs.Context, err error) {
if !c.checkError(err) {
func (c *Cache) MarkFailed(protoType string, ctx *contextargs.Context, err error) {
if !c.checkError(protoType, err) {
return
}
finalValue := c.GetKeyFromContext(ctx, err)
Expand Down Expand Up @@ -186,11 +186,13 @@ var reCheckError = regexp.MustCompile(`(no address found for host|could not reso
// added to the host skipping table.
// it first parses error and extracts the cause and checks for blacklisted
// or common errors that should be skipped
func (c *Cache) checkError(err error) bool {
func (c *Cache) checkError(protoType string, err error) bool {
if err == nil {
return false
}

if protoType != "http" {
return false
}
kind := errkit.GetErrorKind(err, nucleierr.ErrTemplateLogic)
switch kind {
case nucleierr.ErrTemplateLogic:
Expand Down
22 changes: 13 additions & 9 deletions pkg/protocols/common/hosterrorscache/hosterrorscache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@ import (
"github.com/stretchr/testify/require"
)

const (
protoType = "http"
)

func TestCacheCheck(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)

for i := 0; i < 100; i++ {
cache.MarkFailed(newCtxArgs("test"), fmt.Errorf("could not resolve host"))
got := cache.Check(newCtxArgs("test"))
cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host"))
got := cache.Check(protoType, newCtxArgs("test"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
Expand All @@ -26,16 +30,16 @@ func TestCacheCheck(t *testing.T) {
}
}

value := cache.Check(newCtxArgs("test"))
value := cache.Check(protoType, newCtxArgs("test"))
require.Equal(t, true, value, "could not get checked value")
}

func TestTrackErrors(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})

for i := 0; i < 100; i++ {
cache.MarkFailed(newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
got := cache.Check(newCtxArgs("custom"))
cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
got := cache.Check(protoType, newCtxArgs("custom"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
Expand All @@ -44,7 +48,7 @@ func TestTrackErrors(t *testing.T) {
require.True(t, got)
}
}
value := cache.Check(newCtxArgs("custom"))
value := cache.Check(protoType, newCtxArgs("custom"))
require.Equal(t, true, value, "could not get checked value")
}

Expand Down Expand Up @@ -86,7 +90,7 @@ func TestCacheMarkFailed(t *testing.T) {

for _, test := range tests {
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
cache.MarkFailed(newCtxArgs(test.host), fmt.Errorf("no address found for host"))
cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host"))
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
require.Nil(t, err)
require.NotNil(t, failedTarget)
Expand Down Expand Up @@ -122,14 +126,14 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
cache.MarkFailed(newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
}()
}
}
wg.Wait()

for _, test := range tests {
require.True(t, cache.Check(newCtxArgs(test.host)))
require.True(t, cache.Check(protoType, newCtxArgs(test.host)))

normalizedCacheValue := cache.NormalizeCacheValue(test.host)
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
Expand Down
4 changes: 2 additions & 2 deletions pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -1177,14 +1177,14 @@ func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err
return
}
if request.options.HostErrorsCache != nil {
request.options.HostErrorsCache.MarkFailed(input, err)
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err)
}
}

// isUnresponsiveAddress checks if the error is a unreponsive based on its execution history
func (request *Request) isUnresponsiveAddress(input *contextargs.Context) bool {
if request.options.HostErrorsCache != nil {
return request.options.HostErrorsCache.Check(input)
return request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input)
}
return false
}
4 changes: 2 additions & 2 deletions pkg/protocols/http/request_fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (request *Request) executeAllFuzzingRules(input *contextargs.Context, value
func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest, input *contextargs.Context, callback protocols.OutputEventCallback) bool {
hasInteractMatchers := interactsh.HasMatchers(request.CompiledOperators)
hasInteractMarkers := len(gr.InteractURLs) > 0
if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(input) {
if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input) {
return false
}
request.options.RateLimitTake()
Expand Down Expand Up @@ -215,7 +215,7 @@ func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest,
}
if requestErr != nil {
if request.options.HostErrorsCache != nil {
request.options.HostErrorsCache.MarkFailed(input, requestErr)
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, requestErr)
}
gologger.Verbose().Msgf("[%s] Error occurred in request: %s\n", request.options.TemplateID, requestErr)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/protocols/network/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,14 +504,14 @@ func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err
return
}
if request.options.HostErrorsCache != nil {
request.options.HostErrorsCache.MarkFailed(input, err)
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err)
}
}

// isUnresponsiveAddress checks if the error is a unreponsive based on its execution history
func (request *Request) isUnresponsiveAddress(input *contextargs.Context) bool {
if request.options.HostErrorsCache != nil {
return request.options.HostErrorsCache.Check(input)
return request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input)
}
return false
}
4 changes: 2 additions & 2 deletions pkg/templates/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
}
})
if err != nil && e.options.HostErrorsCache != nil {
e.options.HostErrorsCache.MarkFailed(ctx.Input, err)
e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err)
}
return results, err
}
Expand Down Expand Up @@ -310,7 +310,7 @@ func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.R
}

if err != nil && e.options.HostErrorsCache != nil {
e.options.HostErrorsCache.MarkFailed(ctx.Input, err)
e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err)
}
return scanCtx.GenerateResult(), err
}
2 changes: 2 additions & 0 deletions pkg/tmplexec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ func (e *TemplateExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
ctx.LogError(errx)

if lastMatcherEvent != nil {
lastMatcherEvent.Lock()
lastMatcherEvent.InternalEvent["error"] = getErrorCause(ctx.GenerateErrorMessage())
lastMatcherEvent.Unlock()
writeFailureCallback(lastMatcherEvent, e.options.Options.MatcherStatus)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/tmplexec/generic/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error {
if err != nil {
ctx.LogError(err)
if g.options.HostErrorsCache != nil {
g.options.HostErrorsCache.MarkFailed(ctx.Input, err)
g.options.HostErrorsCache.MarkFailed(g.options.ProtocolType.String(), ctx.Input, err)
}
gologger.Warning().Msgf("[%s] Could not execute request for %s: %s\n", g.options.TemplateID, ctx.Input.MetaInput.PrettyPrint(), err)
}
Expand Down

0 comments on commit 1f945d6

Please sign in to comment.