From 9983d7415c3da36d8fa588fde258516d280c7f57 Mon Sep 17 00:00:00 2001 From: Dwi Siswanto Date: Mon, 23 Sep 2024 17:27:30 +0700 Subject: [PATCH] refactor(runner): adjust `max-host-error` if gt `concurrency` (#5633) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(common): use `ParseRequestURI` instead when `NormalizeCacheValue` also it exports the method Signed-off-by: Dwi Siswanto * refactor(runner): adjust `max-host-error` if gt `concurrency` Signed-off-by: Dwi Siswanto * fix lint * chore(runner): expose adjusted `max-host-error` value Signed-off-by: Dwi Siswanto --------- Signed-off-by: Dwi Siswanto Co-authored-by: Doğan Can Bakır --- internal/runner/runner.go | 11 ++++- .../common/hosterrorscache/hosterrorscache.go | 42 ++++++++++++------- .../hosterrorscache/hosterrorscache_test.go | 4 +- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/internal/runner/runner.go b/internal/runner/runner.go index b36d8ed584..0b6da592d3 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -501,8 +501,17 @@ func (r *Runner) RunEnumeration() error { } if r.options.ShouldUseHostError() { - cache := hosterrorscache.New(r.options.MaxHostError, hosterrorscache.DefaultMaxHostsCount, r.options.TrackError) + maxHostError := r.options.MaxHostError + if r.options.TemplateThreads > maxHostError { + gologger.Print().Msgf("[%v] The concurrency value is higher than max-host-error", r.colorizer.BrightYellow("WRN")) + gologger.Info().Msgf("Adjusting max-host-error to the concurrency value: %d", r.options.TemplateThreads) + + maxHostError = r.options.TemplateThreads + } + + cache := hosterrorscache.New(maxHostError, hosterrorscache.DefaultMaxHostsCount, r.options.TrackError) cache.SetVerbose(r.options.Verbose) + r.hostErrors = cache executorOpts.HostErrorsCache = cache } diff --git a/pkg/protocols/common/hosterrorscache/hosterrorscache.go b/pkg/protocols/common/hosterrorscache/hosterrorscache.go index 1630e97c7d..bcfa27dbe4 100644 --- a/pkg/protocols/common/hosterrorscache/hosterrorscache.go +++ b/pkg/protocols/common/hosterrorscache/hosterrorscache.go @@ -75,24 +75,34 @@ func (c *Cache) Close() { c.failedTargets.Purge() } -func (c *Cache) normalizeCacheValue(value string) string { - finalValue := value - if strings.HasPrefix(value, "http") { - if parsed, err := url.Parse(value); err == nil { - hostname := parsed.Host - finalPort := parsed.Port() - if finalPort == "" { - if parsed.Scheme == "https" { - finalPort = "443" - } else { - finalPort = "80" - } - hostname = net.JoinHostPort(parsed.Host, finalPort) +// NormalizeCacheValue processes the input value and returns a normalized cache +// value. +func (c *Cache) NormalizeCacheValue(value string) string { + var normalizedValue string = value + + u, err := url.ParseRequestURI(value) + if err != nil || u.Host == "" { + u, err2 := url.ParseRequestURI("https://" + value) + if err2 != nil { + return normalizedValue + } + + normalizedValue = u.Host + } else { + port := u.Port() + if port == "" { + switch u.Scheme { + case "https": + normalizedValue = net.JoinHostPort(u.Host, "443") + case "http": + normalizedValue = net.JoinHostPort(u.Host, "80") } - finalValue = hostname + } else { + normalizedValue = u.Host } } - return finalValue + + return normalizedValue } // ErrUnresponsiveHost is returned when a host is unresponsive @@ -166,7 +176,7 @@ func (c *Cache) GetKeyFromContext(ctx *contextargs.Context, err error) string { address = tmp.String() } } - finalValue := c.normalizeCacheValue(address) + finalValue := c.NormalizeCacheValue(address) return finalValue } diff --git a/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go b/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go index 112690d875..3c93177674 100644 --- a/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go +++ b/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go @@ -109,7 +109,7 @@ func TestCacheMarkFailedConcurrent(t *testing.T) { // the cache is not atomic during items creation, so we pre-create them with counter to zero for _, test := range tests { - normalizedValue := cache.normalizeCacheValue(test.host) + normalizedValue := cache.NormalizeCacheValue(test.host) newItem := &cacheItem{errors: atomic.Int32{}} newItem.errors.Store(0) _ = cache.failedTargets.Set(normalizedValue, newItem) @@ -131,7 +131,7 @@ func TestCacheMarkFailedConcurrent(t *testing.T) { for _, test := range tests { require.True(t, cache.Check(newCtxArgs(test.host))) - normalizedCacheValue := cache.normalizeCacheValue(test.host) + normalizedCacheValue := cache.NormalizeCacheValue(test.host) failedTarget, err := cache.failedTargets.Get(normalizedCacheValue) require.Nil(t, err) require.NotNil(t, failedTarget)