Skip to content

Commit

Permalink
Fix ExecuteCallbackWithCtx to use the context that was provided (pr…
Browse files Browse the repository at this point in the history
…ojectdiscovery#5236)

* Fix `ExecuteCallbackWithCtx` to use the context that was provided

This updates `ExecuteCallbackWithCtx` to use the context that was
provided.

* remove more hardcoded context

---------

Co-authored-by: Tarun Koyalwar <[email protected]>
  • Loading branch information
doug-threatmate and tarunKoyalwar authored May 30, 2024
1 parent 4ae0b39 commit 8011012
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 21 deletions.
8 changes: 7 additions & 1 deletion lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,17 @@ func WithConcurrency(opts Concurrency) NucleiSDKOptions {
}

// WithGlobalRateLimit sets global rate (i.e all hosts combined) limit options
// Deprecated: will be removed in favour of WithGlobalRateLimitCtx in next release
func WithGlobalRateLimit(maxTokens int, duration time.Duration) NucleiSDKOptions {
return WithGlobalRateLimitCtx(context.Background(), maxTokens, duration)
}

// WithGlobalRateLimitCtx allows setting a global rate limit for the entire engine
func WithGlobalRateLimitCtx(ctx context.Context, maxTokens int, duration time.Duration) NucleiSDKOptions {
return func(e *NucleiEngine) error {
e.opts.RateLimit = maxTokens
e.opts.RateLimitDuration = duration
e.rateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration)
e.rateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration)
return nil
}
}
Expand Down
19 changes: 12 additions & 7 deletions lib/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type unsafeOptions struct {
}

// createEphemeralObjects creates ephemeral nuclei objects/instances/types
func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOptions, error) {
func createEphemeralObjects(ctx context.Context, base *NucleiEngine, opts *types.Options) (*unsafeOptions, error) {
u := &unsafeOptions{}
u.executerOpts = protocols.ExecutorOptions{
Output: base.customWriter,
Expand All @@ -49,9 +49,9 @@ func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOpt
opts.RateLimitDuration = time.Second
}
if opts.RateLimit == 0 && opts.RateLimitDuration == 0 {
u.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background())
u.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx)
} else {
u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimit), opts.RateLimitDuration)
u.executerOpts.RateLimiter = ratelimit.New(ctx, uint(opts.RateLimit), opts.RateLimitDuration)
}
u.engine = core.New(opts)
u.engine.SetExecuterOptions(u.executerOpts)
Expand Down Expand Up @@ -83,7 +83,7 @@ type ThreadSafeNucleiEngine struct {
// NewThreadSafeNucleiEngine creates a new nuclei engine with given options
// whose methods are thread-safe and can be used concurrently
// Note: Non-thread-safe methods start with Global prefix
func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) {
func NewThreadSafeNucleiEngineCtx(ctx context.Context, opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) {
// default options
e := &NucleiEngine{
opts: types.DefaultOptions(),
Expand All @@ -94,12 +94,17 @@ func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngin
return nil, err
}
}
if err := e.init(); err != nil {
if err := e.init(ctx); err != nil {
return nil, err
}
return &ThreadSafeNucleiEngine{eng: e}, nil
}

// Deprecated: use NewThreadSafeNucleiEngineCtx instead
func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) {
return NewThreadSafeNucleiEngineCtx(context.Background(), opts...)
}

// GlobalLoadAllTemplates loads all templates from nuclei-templates repo
// This method will load all templates based on filters given at the time of nuclei engine creation in opts
func (e *ThreadSafeNucleiEngine) GlobalLoadAllTemplates() error {
Expand All @@ -124,7 +129,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t
}
}
// create ephemeral nuclei objects/instances/types using base nuclei engine
unsafeOpts, err := createEphemeralObjects(e.eng, tmpEngine.opts)
unsafeOpts, err := createEphemeralObjects(ctx, e.eng, tmpEngine.opts)
if err != nil {
return err
}
Expand Down Expand Up @@ -156,7 +161,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t
engine := core.New(tmpEngine.opts)
engine.SetExecuterOptions(unsafeOpts.executerOpts)

_ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false)
_ = engine.ExecuteScanWithOpts(ctx, store.Templates(), inputProvider, false)

engine.WorkPool().Wait()
return nil
Expand Down
13 changes: 9 additions & 4 deletions lib/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...f
}
e.resultCallbacks = append(e.resultCallbacks, filtered...)

_ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false)
_ = e.engine.ExecuteScanWithOpts(ctx, e.store.Templates(), e.inputProvider, false)
defer e.engine.WorkPool().Wait()
return nil
}
Expand All @@ -261,8 +261,8 @@ func (e *NucleiEngine) Engine() *core.Engine {
return e.engine
}

// NewNucleiEngine creates a new nuclei engine instance
func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
// NewNucleiEngineCtx creates a new nuclei engine instance with given context
func NewNucleiEngineCtx(ctx context.Context, options ...NucleiSDKOptions) (*NucleiEngine, error) {
// default options
e := &NucleiEngine{
opts: types.DefaultOptions(),
Expand All @@ -273,8 +273,13 @@ func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
return nil, err
}
}
if err := e.init(); err != nil {
if err := e.init(ctx); err != nil {
return nil, err
}
return e, nil
}

// Deprecated: use NewNucleiEngineCtx instead
func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
return NewNucleiEngineCtx(context.Background(), options...)
}
12 changes: 6 additions & 6 deletions lib/sdk_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (
var sharedInit sync.Once = sync.Once{}

// applyRequiredDefaults to options
func (e *NucleiEngine) applyRequiredDefaults() {
func (e *NucleiEngine) applyRequiredDefaults(ctx context.Context) {
mockoutput := testutils.NewMockOutputWriter(e.opts.OmitTemplate)
mockoutput.WriteCallback = func(event *output.ResultEvent) {
if len(e.resultCallbacks) > 0 {
Expand Down Expand Up @@ -81,7 +81,7 @@ func (e *NucleiEngine) applyRequiredDefaults() {
e.interactshOpts = interactsh.DefaultOptions(e.customWriter, e.rc, e.customProgress)
}
if e.rateLimiter == nil {
e.rateLimiter = ratelimit.New(context.Background(), 150, time.Second)
e.rateLimiter = ratelimit.New(ctx, 150, time.Second)
}
if e.opts.ExcludeTags == nil {
e.opts.ExcludeTags = []string{}
Expand All @@ -94,7 +94,7 @@ func (e *NucleiEngine) applyRequiredDefaults() {
}

// init
func (e *NucleiEngine) init() error {
func (e *NucleiEngine) init(ctx context.Context) error {
if e.opts.Verbose {
gologger.DefaultLogger.SetMaxLevel(levels.LevelVerbose)
} else if e.opts.Debug {
Expand All @@ -121,7 +121,7 @@ func (e *NucleiEngine) init() error {
_ = protocolinit.Init(e.opts)
})

e.applyRequiredDefaults()
e.applyRequiredDefaults(ctx)
var err error

// setup progressbar
Expand Down Expand Up @@ -204,9 +204,9 @@ func (e *NucleiEngine) init() error {
e.opts.RateLimitDuration = time.Second
}
if e.opts.RateLimit == 0 && e.opts.RateLimitDuration == 0 {
e.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background())
e.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx)
} else {
e.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration)
e.executerOpts.RateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration)
}
}

Expand Down
9 changes: 6 additions & 3 deletions lib/tests/sdk_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sdk_test

import (
"context"
"os"
"os/exec"
"testing"
Expand Down Expand Up @@ -28,7 +29,8 @@ func TestSimpleNuclei(t *testing.T) {
time.Sleep(2 * time.Second)
goleak.VerifyNone(t, knownLeaks...)
}()
ne, err := nuclei.NewNucleiEngine(
ne, err := nuclei.NewNucleiEngineCtx(
context.TODO(),
nuclei.WithTemplateFilters(nuclei.TemplateFilters{ProtocolTypes: "dns"}), // filter dns templates
nuclei.EnableStatsWithOpts(nuclei.StatsOptions{JSON: true}),
)
Expand Down Expand Up @@ -62,7 +64,8 @@ func TestSimpleNucleiRemote(t *testing.T) {
time.Sleep(2 * time.Second)
goleak.VerifyNone(t, knownLeaks...)
}()
ne, err := nuclei.NewNucleiEngine(
ne, err := nuclei.NewNucleiEngineCtx(
context.TODO(),
nuclei.WithTemplatesOrWorkflows(
nuclei.TemplateSources{
RemoteTemplates: []string{"https://cloud.projectdiscovery.io/public/nameserver-fingerprint.yaml"},
Expand Down Expand Up @@ -100,7 +103,7 @@ func TestThreadSafeNuclei(t *testing.T) {
goleak.VerifyNone(t, knownLeaks...)
}()
// create nuclei engine with options
ne, err := nuclei.NewThreadSafeNucleiEngine()
ne, err := nuclei.NewThreadSafeNucleiEngineCtx(context.TODO())
require.Nil(t, err)

// scan 1 = run dns templates on scanme.sh
Expand Down

0 comments on commit 8011012

Please sign in to comment.