Skip to content

Commit

Permalink
Merge pull request #5547 from projectdiscovery/fix_race_condition
Browse files Browse the repository at this point in the history
fix race condition
  • Loading branch information
Mzack9999 authored Aug 21, 2024
2 parents b1152ef + 5e102b7 commit b53b530
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 10 deletions.
3 changes: 2 additions & 1 deletion lib/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ func (e *NucleiEngine) SignTemplate(tmplSigner *signer.TemplateSigner, data []by
if err != nil {
return data, err
}
buff := bytes.NewBuffer(signer.RemoveSignatureFromData(data))
_, content := signer.ExtractSignatureAndContent(data)
buff := bytes.NewBuffer(content)
buff.WriteString("\n" + signatureData)
return buff.Bytes(), err
}
Expand Down
15 changes: 13 additions & 2 deletions pkg/protocols/common/protocolstate/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/url"
"sync"

"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
Expand All @@ -19,9 +20,17 @@ import (

// Dialer is a shared fastdialer instance for host DNS resolution
var (
Dialer *fastdialer.Dialer
muDialer sync.RWMutex
Dialer *fastdialer.Dialer
)

func GetDialer() *fastdialer.Dialer {
muDialer.RLock()
defer muDialer.RUnlock()

return Dialer
}

func ShouldInit() bool {
return Dialer == nil
}
Expand Down Expand Up @@ -210,10 +219,12 @@ func interfaceAddresses(interfaceName string) ([]net.Addr, error) {

// Close closes the global shared fastdialer
func Close() {
muDialer.Lock()
defer muDialer.Unlock()

if Dialer != nil {
Dialer.Close()
Dialer = nil
}
Dialer = nil
StopActiveMemGuardian()
}
25 changes: 21 additions & 4 deletions pkg/protocols/http/httpclientpool/clientpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

var (
rawHttpClient *rawhttp.Client
rawHttpClientOnce sync.Once
forceMaxRedirects int
normalClient *retryablehttp.Client
clientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client]
Expand Down Expand Up @@ -102,6 +103,22 @@ type Configuration struct {
ResponseHeaderTimeout time.Duration
}

func (c *Configuration) Clone() *Configuration {
clone := *c
if c.Connection != nil {
cloneConnection := &ConnectionConfiguration{
DisableKeepAlive: c.Connection.DisableKeepAlive,
}
if c.Connection.HasCookieJar() {
cookiejar := *c.Connection.GetCookieJar()
cloneConnection.SetCookieJar(&cookiejar)
}
clone.Connection = cloneConnection
}

return &clone
}

// Hash returns the hash of the configuration to allow client pooling
func (c *Configuration) Hash() string {
builder := &strings.Builder{}
Expand Down Expand Up @@ -131,7 +148,7 @@ func (c *Configuration) HasStandardOptions() bool {

// GetRawHTTP returns the rawhttp request client
func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client {
if rawHttpClient == nil {
rawHttpClientOnce.Do(func() {
rawHttpOptions := rawhttp.DefaultOptions
if types.ProxyURL != "" {
rawHttpOptions.Proxy = types.ProxyURL
Expand All @@ -142,7 +159,7 @@ func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client {
}
rawHttpOptions.Timeout = options.Options.GetTimeouts().HttpTimeout
rawHttpClient = rawhttp.NewClient(rawHttpOptions)
}
})
return rawHttpClient
}

Expand Down Expand Up @@ -233,15 +250,15 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl

transport := &http.Transport{
ForceAttemptHTTP2: options.ForceAttemptHTTP2,
DialContext: protocolstate.Dialer.Dial,
DialContext: protocolstate.GetDialer().Dial,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.TlsImpersonate {
return protocolstate.Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
}
if options.HasClientCertificates() || options.ForceAttemptHTTP2 {
return protocolstate.Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
}
return protocolstate.Dialer.DialTLS(ctx, network, addr)
return protocolstate.GetDialer().DialTLS(ctx, network, addr)
},
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
Expand Down
11 changes: 8 additions & 3 deletions pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,15 +770,16 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ

// check for cookie related configuration
if input.CookieJar != nil {
connConfiguration := request.connConfiguration
connConfiguration := request.connConfiguration.Clone()
connConfiguration.Connection.SetCookieJar(input.CookieJar)
modifiedConfig = connConfiguration
}
// check for request updatedTimeout annotation
updatedTimeout, ok := generatedRequest.request.Context().Value(httpclientpool.WithCustomTimeout{}).(httpclientpool.WithCustomTimeout)
if ok {
if modifiedConfig == nil {
modifiedConfig = request.connConfiguration
connConfiguration := request.connConfiguration.Clone()
modifiedConfig = connConfiguration
}
modifiedConfig.ResponseHeaderTimeout = updatedTimeout.Timeout
}
Expand Down Expand Up @@ -941,7 +942,11 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ
if input.MetaInput.CustomIP != "" {
outputEvent["ip"] = input.MetaInput.CustomIP
} else {
outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname)
dialer := protocolstate.GetDialer()
if dialer != nil {
outputEvent["ip"] = dialer.GetDialedIP(hostname)
}

// try getting cname
request.addCNameIfAvailable(hostname, outputEvent)
}
Expand Down

0 comments on commit b53b530

Please sign in to comment.