Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consider protocolType in max host error #5668

Merged
merged 2 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/core/executors.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ
currentInfo.Unlock()

// Skip if the host has had errors
if e.executerOpts.HostErrorsCache != nil && e.executerOpts.HostErrorsCache.Check(contextargs.NewWithMetaInput(ctx, scannedValue)) {
if e.executerOpts.HostErrorsCache != nil && e.executerOpts.HostErrorsCache.Check(e.executerOpts.ProtocolType.String(), contextargs.NewWithMetaInput(ctx, scannedValue)) {
return true
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/core/workflow_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
}
if err != nil {
if w.Options.HostErrorsCache != nil {
w.Options.HostErrorsCache.MarkFailed(ctx.Input, err)
w.Options.HostErrorsCache.MarkFailed(w.Options.ProtocolType.String(), ctx.Input, err)
}
if len(template.Executers) == 1 {
mainErr = err
Expand Down
20 changes: 11 additions & 9 deletions pkg/protocols/common/hosterrorscache/hosterrorscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
// CacheInterface defines the signature of the hosterrorscache so that
// users of Nuclei as embedded lib may implement their own cache
type CacheInterface interface {
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(ctx *contextargs.Context) bool // return true if the host should be skipped
MarkFailed(ctx *contextargs.Context, err error) // record a failure (and cause) for the host
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped
MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host
}

var (
Expand Down Expand Up @@ -115,7 +115,7 @@ func (c *Cache) NormalizeCacheValue(value string) string {
// - URL: https?:// type
// - Host:port type
// - host type
func (c *Cache) Check(ctx *contextargs.Context) bool {
func (c *Cache) Check(protoType string, ctx *contextargs.Context) bool {
finalValue := c.GetKeyFromContext(ctx, nil)

existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue)
Expand All @@ -138,8 +138,8 @@ func (c *Cache) Check(ctx *contextargs.Context) bool {
}

// MarkFailed marks a host as failed previously
func (c *Cache) MarkFailed(ctx *contextargs.Context, err error) {
if !c.checkError(err) {
func (c *Cache) MarkFailed(protoType string, ctx *contextargs.Context, err error) {
if !c.checkError(protoType, err) {
return
}
finalValue := c.GetKeyFromContext(ctx, err)
Expand Down Expand Up @@ -186,11 +186,13 @@ var reCheckError = regexp.MustCompile(`(no address found for host|could not reso
// added to the host skipping table.
// it first parses error and extracts the cause and checks for blacklisted
// or common errors that should be skipped
func (c *Cache) checkError(err error) bool {
func (c *Cache) checkError(protoType string, err error) bool {
if err == nil {
return false
}

if protoType != "http" {
return false
}
kind := errkit.GetErrorKind(err, nucleierr.ErrTemplateLogic)
switch kind {
case nucleierr.ErrTemplateLogic:
Expand Down
22 changes: 13 additions & 9 deletions pkg/protocols/common/hosterrorscache/hosterrorscache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@ import (
"github.com/stretchr/testify/require"
)

const (
protoType = "http"
)

func TestCacheCheck(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)

for i := 0; i < 100; i++ {
cache.MarkFailed(newCtxArgs("test"), fmt.Errorf("could not resolve host"))
got := cache.Check(newCtxArgs("test"))
cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host"))
got := cache.Check(protoType, newCtxArgs("test"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
Expand All @@ -26,16 +30,16 @@ func TestCacheCheck(t *testing.T) {
}
}

value := cache.Check(newCtxArgs("test"))
value := cache.Check(protoType, newCtxArgs("test"))
require.Equal(t, true, value, "could not get checked value")
}

func TestTrackErrors(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})

for i := 0; i < 100; i++ {
cache.MarkFailed(newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
got := cache.Check(newCtxArgs("custom"))
cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
got := cache.Check(protoType, newCtxArgs("custom"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
Expand All @@ -44,7 +48,7 @@ func TestTrackErrors(t *testing.T) {
require.True(t, got)
}
}
value := cache.Check(newCtxArgs("custom"))
value := cache.Check(protoType, newCtxArgs("custom"))
require.Equal(t, true, value, "could not get checked value")
}

Expand Down Expand Up @@ -86,7 +90,7 @@ func TestCacheMarkFailed(t *testing.T) {

for _, test := range tests {
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
cache.MarkFailed(newCtxArgs(test.host), fmt.Errorf("no address found for host"))
cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host"))
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
require.Nil(t, err)
require.NotNil(t, failedTarget)
Expand Down Expand Up @@ -122,14 +126,14 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
cache.MarkFailed(newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
}()
}
}
wg.Wait()

for _, test := range tests {
require.True(t, cache.Check(newCtxArgs(test.host)))
require.True(t, cache.Check(protoType, newCtxArgs(test.host)))

normalizedCacheValue := cache.NormalizeCacheValue(test.host)
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
Expand Down
4 changes: 2 additions & 2 deletions pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -1177,14 +1177,14 @@ func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err
return
}
if request.options.HostErrorsCache != nil {
request.options.HostErrorsCache.MarkFailed(input, err)
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err)
}
}

// isUnresponsiveAddress checks if the error is a unreponsive based on its execution history
func (request *Request) isUnresponsiveAddress(input *contextargs.Context) bool {
if request.options.HostErrorsCache != nil {
return request.options.HostErrorsCache.Check(input)
return request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input)
}
return false
}
4 changes: 2 additions & 2 deletions pkg/protocols/http/request_fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (request *Request) executeAllFuzzingRules(input *contextargs.Context, value
func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest, input *contextargs.Context, callback protocols.OutputEventCallback) bool {
hasInteractMatchers := interactsh.HasMatchers(request.CompiledOperators)
hasInteractMarkers := len(gr.InteractURLs) > 0
if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(input) {
if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input) {
return false
}
request.options.RateLimitTake()
Expand Down Expand Up @@ -215,7 +215,7 @@ func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest,
}
if requestErr != nil {
if request.options.HostErrorsCache != nil {
request.options.HostErrorsCache.MarkFailed(input, requestErr)
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, requestErr)
}
gologger.Verbose().Msgf("[%s] Error occurred in request: %s\n", request.options.TemplateID, requestErr)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/protocols/network/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,14 +504,14 @@ func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err
return
}
if request.options.HostErrorsCache != nil {
request.options.HostErrorsCache.MarkFailed(input, err)
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err)
}
}

// isUnresponsiveAddress checks if the error is a unreponsive based on its execution history
func (request *Request) isUnresponsiveAddress(input *contextargs.Context) bool {
if request.options.HostErrorsCache != nil {
return request.options.HostErrorsCache.Check(input)
return request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input)
}
return false
}
4 changes: 2 additions & 2 deletions pkg/templates/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
}
})
if err != nil && e.options.HostErrorsCache != nil {
e.options.HostErrorsCache.MarkFailed(ctx.Input, err)
e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err)
}
return results, err
}
Expand Down Expand Up @@ -310,7 +310,7 @@ func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.R
}

if err != nil && e.options.HostErrorsCache != nil {
e.options.HostErrorsCache.MarkFailed(ctx.Input, err)
e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err)
}
return scanCtx.GenerateResult(), err
}
2 changes: 2 additions & 0 deletions pkg/tmplexec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ func (e *TemplateExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
ctx.LogError(errx)

if lastMatcherEvent != nil {
lastMatcherEvent.Lock()
lastMatcherEvent.InternalEvent["error"] = getErrorCause(ctx.GenerateErrorMessage())
lastMatcherEvent.Unlock()
writeFailureCallback(lastMatcherEvent, e.options.Options.MatcherStatus)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/tmplexec/generic/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error {
if err != nil {
ctx.LogError(err)
if g.options.HostErrorsCache != nil {
g.options.HostErrorsCache.MarkFailed(ctx.Input, err)
g.options.HostErrorsCache.MarkFailed(g.options.ProtocolType.String(), ctx.Input, err)
}
gologger.Warning().Msgf("[%s] Could not execute request for %s: %s\n", g.options.TemplateID, ctx.Input.MetaInput.PrettyPrint(), err)
}
Expand Down
Loading