From 8011012c420c8d3a56f3ee5812f472445e00669d Mon Sep 17 00:00:00 2001 From: Douglas Danger Manley <127235272+doug-threatmate@users.noreply.github.com> Date: Thu, 30 May 2024 06:34:15 -0400 Subject: [PATCH] Fix `ExecuteCallbackWithCtx` to use the context that was provided (#5236) * Fix `ExecuteCallbackWithCtx` to use the context that was provided This updates `ExecuteCallbackWithCtx` to use the context that was provided. * remove more hardcoded context --------- Co-authored-by: Tarun Koyalwar --- lib/config.go | 8 +++++++- lib/multi.go | 19 ++++++++++++------- lib/sdk.go | 13 +++++++++---- lib/sdk_private.go | 12 ++++++------ lib/tests/sdk_test.go | 9 ++++++--- 5 files changed, 40 insertions(+), 21 deletions(-) diff --git a/lib/config.go b/lib/config.go index 5610b71515..24bc08691f 100644 --- a/lib/config.go +++ b/lib/config.go @@ -164,11 +164,17 @@ func WithConcurrency(opts Concurrency) NucleiSDKOptions { } // WithGlobalRateLimit sets global rate (i.e all hosts combined) limit options +// Deprecated: will be removed in favour of WithGlobalRateLimitCtx in next release func WithGlobalRateLimit(maxTokens int, duration time.Duration) NucleiSDKOptions { + return WithGlobalRateLimitCtx(context.Background(), maxTokens, duration) +} + +// WithGlobalRateLimitCtx allows setting a global rate limit for the entire engine +func WithGlobalRateLimitCtx(ctx context.Context, maxTokens int, duration time.Duration) NucleiSDKOptions { return func(e *NucleiEngine) error { e.opts.RateLimit = maxTokens e.opts.RateLimitDuration = duration - e.rateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration) + e.rateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration) return nil } } diff --git a/lib/multi.go b/lib/multi.go index 7c713e8aa0..a2149ddcd2 100644 --- a/lib/multi.go +++ b/lib/multi.go @@ -26,7 +26,7 @@ type unsafeOptions struct { } // createEphemeralObjects creates ephemeral nuclei objects/instances/types -func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOptions, error) { +func createEphemeralObjects(ctx context.Context, base *NucleiEngine, opts *types.Options) (*unsafeOptions, error) { u := &unsafeOptions{} u.executerOpts = protocols.ExecutorOptions{ Output: base.customWriter, @@ -49,9 +49,9 @@ func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOpt opts.RateLimitDuration = time.Second } if opts.RateLimit == 0 && opts.RateLimitDuration == 0 { - u.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background()) + u.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx) } else { - u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimit), opts.RateLimitDuration) + u.executerOpts.RateLimiter = ratelimit.New(ctx, uint(opts.RateLimit), opts.RateLimitDuration) } u.engine = core.New(opts) u.engine.SetExecuterOptions(u.executerOpts) @@ -83,7 +83,7 @@ type ThreadSafeNucleiEngine struct { // NewThreadSafeNucleiEngine creates a new nuclei engine with given options // whose methods are thread-safe and can be used concurrently // Note: Non-thread-safe methods start with Global prefix -func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) { +func NewThreadSafeNucleiEngineCtx(ctx context.Context, opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) { // default options e := &NucleiEngine{ opts: types.DefaultOptions(), @@ -94,12 +94,17 @@ func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngin return nil, err } } - if err := e.init(); err != nil { + if err := e.init(ctx); err != nil { return nil, err } return &ThreadSafeNucleiEngine{eng: e}, nil } +// Deprecated: use NewThreadSafeNucleiEngineCtx instead +func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) { + return NewThreadSafeNucleiEngineCtx(context.Background(), opts...) +} + // GlobalLoadAllTemplates loads all templates from nuclei-templates repo // This method will load all templates based on filters given at the time of nuclei engine creation in opts func (e *ThreadSafeNucleiEngine) GlobalLoadAllTemplates() error { @@ -124,7 +129,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t } } // create ephemeral nuclei objects/instances/types using base nuclei engine - unsafeOpts, err := createEphemeralObjects(e.eng, tmpEngine.opts) + unsafeOpts, err := createEphemeralObjects(ctx, e.eng, tmpEngine.opts) if err != nil { return err } @@ -156,7 +161,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t engine := core.New(tmpEngine.opts) engine.SetExecuterOptions(unsafeOpts.executerOpts) - _ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false) + _ = engine.ExecuteScanWithOpts(ctx, store.Templates(), inputProvider, false) engine.WorkPool().Wait() return nil diff --git a/lib/sdk.go b/lib/sdk.go index 63925f47ca..2e23aa49cd 100644 --- a/lib/sdk.go +++ b/lib/sdk.go @@ -240,7 +240,7 @@ func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...f } e.resultCallbacks = append(e.resultCallbacks, filtered...) - _ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false) + _ = e.engine.ExecuteScanWithOpts(ctx, e.store.Templates(), e.inputProvider, false) defer e.engine.WorkPool().Wait() return nil } @@ -261,8 +261,8 @@ func (e *NucleiEngine) Engine() *core.Engine { return e.engine } -// NewNucleiEngine creates a new nuclei engine instance -func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) { +// NewNucleiEngineCtx creates a new nuclei engine instance with given context +func NewNucleiEngineCtx(ctx context.Context, options ...NucleiSDKOptions) (*NucleiEngine, error) { // default options e := &NucleiEngine{ opts: types.DefaultOptions(), @@ -273,8 +273,13 @@ func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) { return nil, err } } - if err := e.init(); err != nil { + if err := e.init(ctx); err != nil { return nil, err } return e, nil } + +// Deprecated: use NewNucleiEngineCtx instead +func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) { + return NewNucleiEngineCtx(context.Background(), options...) +} diff --git a/lib/sdk_private.go b/lib/sdk_private.go index cfe9c88a25..ae61add221 100644 --- a/lib/sdk_private.go +++ b/lib/sdk_private.go @@ -37,7 +37,7 @@ import ( var sharedInit sync.Once = sync.Once{} // applyRequiredDefaults to options -func (e *NucleiEngine) applyRequiredDefaults() { +func (e *NucleiEngine) applyRequiredDefaults(ctx context.Context) { mockoutput := testutils.NewMockOutputWriter(e.opts.OmitTemplate) mockoutput.WriteCallback = func(event *output.ResultEvent) { if len(e.resultCallbacks) > 0 { @@ -81,7 +81,7 @@ func (e *NucleiEngine) applyRequiredDefaults() { e.interactshOpts = interactsh.DefaultOptions(e.customWriter, e.rc, e.customProgress) } if e.rateLimiter == nil { - e.rateLimiter = ratelimit.New(context.Background(), 150, time.Second) + e.rateLimiter = ratelimit.New(ctx, 150, time.Second) } if e.opts.ExcludeTags == nil { e.opts.ExcludeTags = []string{} @@ -94,7 +94,7 @@ func (e *NucleiEngine) applyRequiredDefaults() { } // init -func (e *NucleiEngine) init() error { +func (e *NucleiEngine) init(ctx context.Context) error { if e.opts.Verbose { gologger.DefaultLogger.SetMaxLevel(levels.LevelVerbose) } else if e.opts.Debug { @@ -121,7 +121,7 @@ func (e *NucleiEngine) init() error { _ = protocolinit.Init(e.opts) }) - e.applyRequiredDefaults() + e.applyRequiredDefaults(ctx) var err error // setup progressbar @@ -204,9 +204,9 @@ func (e *NucleiEngine) init() error { e.opts.RateLimitDuration = time.Second } if e.opts.RateLimit == 0 && e.opts.RateLimitDuration == 0 { - e.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background()) + e.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx) } else { - e.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration) + e.executerOpts.RateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration) } } diff --git a/lib/tests/sdk_test.go b/lib/tests/sdk_test.go index b20c163d1a..97ec489abc 100644 --- a/lib/tests/sdk_test.go +++ b/lib/tests/sdk_test.go @@ -1,6 +1,7 @@ package sdk_test import ( + "context" "os" "os/exec" "testing" @@ -28,7 +29,8 @@ func TestSimpleNuclei(t *testing.T) { time.Sleep(2 * time.Second) goleak.VerifyNone(t, knownLeaks...) }() - ne, err := nuclei.NewNucleiEngine( + ne, err := nuclei.NewNucleiEngineCtx( + context.TODO(), nuclei.WithTemplateFilters(nuclei.TemplateFilters{ProtocolTypes: "dns"}), // filter dns templates nuclei.EnableStatsWithOpts(nuclei.StatsOptions{JSON: true}), ) @@ -62,7 +64,8 @@ func TestSimpleNucleiRemote(t *testing.T) { time.Sleep(2 * time.Second) goleak.VerifyNone(t, knownLeaks...) }() - ne, err := nuclei.NewNucleiEngine( + ne, err := nuclei.NewNucleiEngineCtx( + context.TODO(), nuclei.WithTemplatesOrWorkflows( nuclei.TemplateSources{ RemoteTemplates: []string{"https://cloud.projectdiscovery.io/public/nameserver-fingerprint.yaml"}, @@ -100,7 +103,7 @@ func TestThreadSafeNuclei(t *testing.T) { goleak.VerifyNone(t, knownLeaks...) }() // create nuclei engine with options - ne, err := nuclei.NewThreadSafeNucleiEngine() + ne, err := nuclei.NewThreadSafeNucleiEngineCtx(context.TODO()) require.Nil(t, err) // scan 1 = run dns templates on scanme.sh