From 1f945d6d505aed8a6bcced2bd1d76c74c38728a1 Mon Sep 17 00:00:00 2001 From: Tarun Koyalwar <45962551+tarunKoyalwar@users.noreply.github.com> Date: Sat, 28 Sep 2024 17:20:35 +0400 Subject: [PATCH] consider protocolType in max host error (#5668) * consider protocolType in max host error * add mutex when updating internal-event --- pkg/core/executors.go | 2 +- pkg/core/workflow_execute.go | 2 +- .../common/hosterrorscache/hosterrorscache.go | 20 +++++++++-------- .../hosterrorscache/hosterrorscache_test.go | 22 +++++++++++-------- pkg/protocols/http/request.go | 4 ++-- pkg/protocols/http/request_fuzz.go | 4 ++-- pkg/protocols/network/request.go | 4 ++-- pkg/templates/cluster.go | 4 ++-- pkg/tmplexec/exec.go | 2 ++ pkg/tmplexec/generic/exec.go | 2 +- 10 files changed, 37 insertions(+), 29 deletions(-) diff --git a/pkg/core/executors.go b/pkg/core/executors.go index 8447efed13..39ec97aa40 100644 --- a/pkg/core/executors.go +++ b/pkg/core/executors.go @@ -107,7 +107,7 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ currentInfo.Unlock() // Skip if the host has had errors - if e.executerOpts.HostErrorsCache != nil && e.executerOpts.HostErrorsCache.Check(contextargs.NewWithMetaInput(ctx, scannedValue)) { + if e.executerOpts.HostErrorsCache != nil && e.executerOpts.HostErrorsCache.Check(e.executerOpts.ProtocolType.String(), contextargs.NewWithMetaInput(ctx, scannedValue)) { return true } diff --git a/pkg/core/workflow_execute.go b/pkg/core/workflow_execute.go index 0b5d7e8722..22b1b813f7 100644 --- a/pkg/core/workflow_execute.go +++ b/pkg/core/workflow_execute.go @@ -98,7 +98,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan } if err != nil { if w.Options.HostErrorsCache != nil { - w.Options.HostErrorsCache.MarkFailed(ctx.Input, err) + w.Options.HostErrorsCache.MarkFailed(w.Options.ProtocolType.String(), ctx.Input, err) } if len(template.Executers) == 1 { mainErr = err diff --git a/pkg/protocols/common/hosterrorscache/hosterrorscache.go b/pkg/protocols/common/hosterrorscache/hosterrorscache.go index bcfa27dbe4..bca4803e8a 100644 --- a/pkg/protocols/common/hosterrorscache/hosterrorscache.go +++ b/pkg/protocols/common/hosterrorscache/hosterrorscache.go @@ -20,10 +20,10 @@ import ( // CacheInterface defines the signature of the hosterrorscache so that // users of Nuclei as embedded lib may implement their own cache type CacheInterface interface { - SetVerbose(verbose bool) // log verbosely - Close() // close the cache - Check(ctx *contextargs.Context) bool // return true if the host should be skipped - MarkFailed(ctx *contextargs.Context, err error) // record a failure (and cause) for the host + SetVerbose(verbose bool) // log verbosely + Close() // close the cache + Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped + MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host } var ( @@ -115,7 +115,7 @@ func (c *Cache) NormalizeCacheValue(value string) string { // - URL: https?:// type // - Host:port type // - host type -func (c *Cache) Check(ctx *contextargs.Context) bool { +func (c *Cache) Check(protoType string, ctx *contextargs.Context) bool { finalValue := c.GetKeyFromContext(ctx, nil) existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue) @@ -138,8 +138,8 @@ func (c *Cache) Check(ctx *contextargs.Context) bool { } // MarkFailed marks a host as failed previously -func (c *Cache) MarkFailed(ctx *contextargs.Context, err error) { - if !c.checkError(err) { +func (c *Cache) MarkFailed(protoType string, ctx *contextargs.Context, err error) { + if !c.checkError(protoType, err) { return } finalValue := c.GetKeyFromContext(ctx, err) @@ -186,11 +186,13 @@ var reCheckError = regexp.MustCompile(`(no address found for host|could not reso // added to the host skipping table. // it first parses error and extracts the cause and checks for blacklisted // or common errors that should be skipped -func (c *Cache) checkError(err error) bool { +func (c *Cache) checkError(protoType string, err error) bool { if err == nil { return false } - + if protoType != "http" { + return false + } kind := errkit.GetErrorKind(err, nucleierr.ErrTemplateLogic) switch kind { case nucleierr.ErrTemplateLogic: diff --git a/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go b/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go index 3c93177674..9977b968d9 100644 --- a/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go +++ b/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go @@ -11,12 +11,16 @@ import ( "github.com/stretchr/testify/require" ) +const ( + protoType = "http" +) + func TestCacheCheck(t *testing.T) { cache := New(3, DefaultMaxHostsCount, nil) for i := 0; i < 100; i++ { - cache.MarkFailed(newCtxArgs("test"), fmt.Errorf("could not resolve host")) - got := cache.Check(newCtxArgs("test")) + cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host")) + got := cache.Check(protoType, newCtxArgs("test")) if i < 2 { // till 3 the host is not flagged to skip require.False(t, got) @@ -26,7 +30,7 @@ func TestCacheCheck(t *testing.T) { } } - value := cache.Check(newCtxArgs("test")) + value := cache.Check(protoType, newCtxArgs("test")) require.Equal(t, true, value, "could not get checked value") } @@ -34,8 +38,8 @@ func TestTrackErrors(t *testing.T) { cache := New(3, DefaultMaxHostsCount, []string{"custom error"}) for i := 0; i < 100; i++ { - cache.MarkFailed(newCtxArgs("custom"), fmt.Errorf("got: nested: custom error")) - got := cache.Check(newCtxArgs("custom")) + cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error")) + got := cache.Check(protoType, newCtxArgs("custom")) if i < 2 { // till 3 the host is not flagged to skip require.False(t, got) @@ -44,7 +48,7 @@ func TestTrackErrors(t *testing.T) { require.True(t, got) } } - value := cache.Check(newCtxArgs("custom")) + value := cache.Check(protoType, newCtxArgs("custom")) require.Equal(t, true, value, "could not get checked value") } @@ -86,7 +90,7 @@ func TestCacheMarkFailed(t *testing.T) { for _, test := range tests { normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil) - cache.MarkFailed(newCtxArgs(test.host), fmt.Errorf("no address found for host")) + cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host")) failedTarget, err := cache.failedTargets.Get(normalizedCacheValue) require.Nil(t, err) require.NotNil(t, failedTarget) @@ -122,14 +126,14 @@ func TestCacheMarkFailedConcurrent(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - cache.MarkFailed(newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host")) + cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host")) }() } } wg.Wait() for _, test := range tests { - require.True(t, cache.Check(newCtxArgs(test.host))) + require.True(t, cache.Check(protoType, newCtxArgs(test.host))) normalizedCacheValue := cache.NormalizeCacheValue(test.host) failedTarget, err := cache.failedTargets.Get(normalizedCacheValue) diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 180ee512ce..994e065582 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -1177,14 +1177,14 @@ func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err return } if request.options.HostErrorsCache != nil { - request.options.HostErrorsCache.MarkFailed(input, err) + request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err) } } // isUnresponsiveAddress checks if the error is a unreponsive based on its execution history func (request *Request) isUnresponsiveAddress(input *contextargs.Context) bool { if request.options.HostErrorsCache != nil { - return request.options.HostErrorsCache.Check(input) + return request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input) } return false } diff --git a/pkg/protocols/http/request_fuzz.go b/pkg/protocols/http/request_fuzz.go index fdf862eb67..a7c6e80c07 100644 --- a/pkg/protocols/http/request_fuzz.go +++ b/pkg/protocols/http/request_fuzz.go @@ -161,7 +161,7 @@ func (request *Request) executeAllFuzzingRules(input *contextargs.Context, value func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest, input *contextargs.Context, callback protocols.OutputEventCallback) bool { hasInteractMatchers := interactsh.HasMatchers(request.CompiledOperators) hasInteractMarkers := len(gr.InteractURLs) > 0 - if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(input) { + if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input) { return false } request.options.RateLimitTake() @@ -215,7 +215,7 @@ func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest, } if requestErr != nil { if request.options.HostErrorsCache != nil { - request.options.HostErrorsCache.MarkFailed(input, requestErr) + request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, requestErr) } gologger.Verbose().Msgf("[%s] Error occurred in request: %s\n", request.options.TemplateID, requestErr) } diff --git a/pkg/protocols/network/request.go b/pkg/protocols/network/request.go index 5fa8609d51..32d4ae3494 100644 --- a/pkg/protocols/network/request.go +++ b/pkg/protocols/network/request.go @@ -504,14 +504,14 @@ func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err return } if request.options.HostErrorsCache != nil { - request.options.HostErrorsCache.MarkFailed(input, err) + request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err) } } // isUnresponsiveAddress checks if the error is a unreponsive based on its execution history func (request *Request) isUnresponsiveAddress(input *contextargs.Context) bool { if request.options.HostErrorsCache != nil { - return request.options.HostErrorsCache.Check(input) + return request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input) } return false } diff --git a/pkg/templates/cluster.go b/pkg/templates/cluster.go index 8f1af96b74..63b065d346 100644 --- a/pkg/templates/cluster.go +++ b/pkg/templates/cluster.go @@ -274,7 +274,7 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) { } }) if err != nil && e.options.HostErrorsCache != nil { - e.options.HostErrorsCache.MarkFailed(ctx.Input, err) + e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err) } return results, err } @@ -310,7 +310,7 @@ func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.R } if err != nil && e.options.HostErrorsCache != nil { - e.options.HostErrorsCache.MarkFailed(ctx.Input, err) + e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err) } return scanCtx.GenerateResult(), err } diff --git a/pkg/tmplexec/exec.go b/pkg/tmplexec/exec.go index 149deaa4d5..279d03d849 100644 --- a/pkg/tmplexec/exec.go +++ b/pkg/tmplexec/exec.go @@ -206,7 +206,9 @@ func (e *TemplateExecuter) Execute(ctx *scan.ScanContext) (bool, error) { ctx.LogError(errx) if lastMatcherEvent != nil { + lastMatcherEvent.Lock() lastMatcherEvent.InternalEvent["error"] = getErrorCause(ctx.GenerateErrorMessage()) + lastMatcherEvent.Unlock() writeFailureCallback(lastMatcherEvent, e.options.Options.MatcherStatus) } diff --git a/pkg/tmplexec/generic/exec.go b/pkg/tmplexec/generic/exec.go index 3b337f6b20..c8303f70d9 100644 --- a/pkg/tmplexec/generic/exec.go +++ b/pkg/tmplexec/generic/exec.go @@ -85,7 +85,7 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error { if err != nil { ctx.LogError(err) if g.options.HostErrorsCache != nil { - g.options.HostErrorsCache.MarkFailed(ctx.Input, err) + g.options.HostErrorsCache.MarkFailed(g.options.ProtocolType.String(), ctx.Input, err) } gologger.Warning().Msgf("[%s] Could not execute request for %s: %s\n", g.options.TemplateID, ctx.Input.MetaInput.PrettyPrint(), err) }