Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply input transformation to multi-protocol templates #5426

Merged
merged 10 commits into from
Aug 1, 2024
13 changes: 12 additions & 1 deletion pkg/input/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func (h *Helper) Transform(input string, protocol templateTypes.ProtocolType) st
return h.convertInputToType(input, typeHostWithOptionalPort, "")
case templateTypes.WebsocketProtocol:
return h.convertInputToType(input, typeWebsocket, "")
case templateTypes.SSLProtocol:
return h.convertInputToType(input, typeHostWithPort, "443")
}
return input
}
Expand Down Expand Up @@ -94,6 +96,8 @@ func (h *Helper) convertInputToType(input string, inputType inputType, defaultPo
if _, err := filepath.Match(input, ""); err != filepath.ErrBadPattern && !isURL {
return input
}
// if none of these satisfy the condition return empty
return ""
case typeHostOnly:
if hasHost {
return host
Expand All @@ -111,6 +115,10 @@ func (h *Helper) convertInputToType(input string, inputType inputType, defaultPo
return string(probed)
}
}
// try to parse it as absolute url and return
if absUrl, err := urlutil.ParseAbsoluteURL(input, false); err == nil {
return absUrl.String()
}
case typeHostWithPort, typeHostWithOptionalPort:
if hasHost && hasPort {
return net.JoinHostPort(host, port)
Expand All @@ -128,6 +136,9 @@ func (h *Helper) convertInputToType(input string, inputType inputType, defaultPo
if uri != nil && stringsutil.EqualFoldAny(uri.Scheme, "ws", "wss") {
return input
}
// empty if prefix is not given
return ""
}
return ""
// do not return empty
return input
}
4 changes: 2 additions & 2 deletions pkg/input/transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestConvertInputToType(t *testing.T) {
{"https://google.com:443", typeHostOnly, "google.com", ""},

// url
{"test.com", typeURL, "", ""},
{"test.com", typeURL, "test.com", ""},
{"google.com", typeURL, "https://google.com", ""},
{"https://google.com", typeURL, "https://google.com", ""},

Expand All @@ -43,7 +43,7 @@ func TestConvertInputToType(t *testing.T) {
{"input_test.*", typeFilepath, "input_test.*", ""},

// host-port
{"google.com", typeHostWithPort, "", ""},
{"google.com", typeHostWithPort, "google.com", ""},
{"google.com:443", typeHostWithPort, "google.com:443", ""},
{"https://google.com", typeHostWithPort, "google.com:443", ""},
{"https://google.com:443", typeHostWithPort, "google.com:443", ""},
Expand Down
1 change: 1 addition & 0 deletions pkg/protocols/common/contextargs/metainput.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ func (metaInput *MetaInput) Clone() *MetaInput {
input := NewMetaInput()
input.Input = metaInput.Input
input.CustomIP = metaInput.CustomIP
input.hash = metaInput.hash
if metaInput.ReqResp != nil {
input.ReqResp = metaInput.ReqResp.Clone()
}
Expand Down
6 changes: 0 additions & 6 deletions pkg/protocols/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,6 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
recursion := true
request.Recursion = &recursion
}
dnsClientOptions := &dnsclientpool.Configuration{
Retries: request.Retries,
}
if len(request.Resolvers) > 0 {
dnsClientOptions.Resolvers = request.Resolvers
}
// Create a dns client for the class
client, err := request.getDnsClient(options, nil)
if err != nil {
Expand Down
23 changes: 2 additions & 21 deletions pkg/protocols/dns/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package dns
import (
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"

Expand All @@ -23,7 +22,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/vardump"
protocolutils "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"
templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
"github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/projectdiscovery/retryabledns"
iputil "github.com/projectdiscovery/utils/ip"
syncutil "github.com/projectdiscovery/utils/sync"
Expand All @@ -38,16 +36,8 @@ func (request *Request) Type() templateTypes.ProtocolType {

// ExecuteWithResults executes the protocol requests and returns results instead of writing them.
func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, previous output.InternalEvent, callback protocols.OutputEventCallback) error {
// Parse the URL and return domain if URL.
var domain string
if utils.IsURL(input.MetaInput.Input) {
domain = extractDomain(input.MetaInput.Input)
} else {
domain = input.MetaInput.Input
}

var err error
domain, err = request.parseDNSInput(domain)
domain, err := request.parseDNSInput(input.MetaInput.Input)
if err != nil {
return errors.Wrap(err, "could not build request")
}
Expand Down Expand Up @@ -230,7 +220,7 @@ func (request *Request) parseDNSInput(host string) (string, error) {
return host, nil
}

func dumpResponse(event *output.InternalWrappedEvent, request *Request, requestOptions *protocols.ExecutorOptions, response, domain string) {
func dumpResponse(event *output.InternalWrappedEvent, request *Request, _ *protocols.ExecutorOptions, response, domain string) {
cliOptions := request.options.Options
if cliOptions.Debug || cliOptions.DebugResponse || cliOptions.StoreResponse {
hexDump := false
Expand Down Expand Up @@ -261,12 +251,3 @@ func dumpTraceData(event *output.InternalWrappedEvent, requestOptions *protocols
gologger.Debug().Msgf("[%s] Dumped DNS Trace data for %s\n\n%s", requestOptions.TemplateID, domain, highlightedResponse)
}
}

// extractDomain extracts the domain name of a URL
func extractDomain(theURL string) string {
u, err := url.Parse(theURL)
if err != nil {
return ""
}
return u.Hostname()
}
16 changes: 1 addition & 15 deletions pkg/protocols/dns/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,5 @@ func TestDNSExecuteWithResults(t *testing.T) {
require.Equal(t, 1, len(finalEvent.Results[0].ExtractedResults), "could not get correct number of extracted results")
require.Equal(t, "93.184.215.14", finalEvent.Results[0].ExtractedResults[0], "could not get correct extracted results")
finalEvent = nil

t.Run("url-to-domain", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
err := request.ExecuteWithResults(contextargs.NewWithInput(context.Background(), "https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event
})
require.Nil(t, err, "could not execute dns request")
})
require.NotNil(t, finalEvent, "could not get event output from request")
require.Equal(t, 1, len(finalEvent.Results), "could not get correct number of results")
require.Equal(t, "test", finalEvent.Results[0].MatcherName, "could not get correct matcher name of results")
require.Equal(t, 1, len(finalEvent.Results[0].ExtractedResults), "could not get correct number of extracted results")
require.Equal(t, "93.184.215.14", finalEvent.Results[0].ExtractedResults[0], "could not get correct extracted results")
finalEvent = nil
// Note: changing url to domain is responsible at tmplexec package and is implemented there
}
19 changes: 1 addition & 18 deletions pkg/protocols/ssl/ssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"github.com/projectdiscovery/tlsx/pkg/tlsx/openssl"
errorutil "github.com/projectdiscovery/utils/errors"
stringsutil "github.com/projectdiscovery/utils/strings"
urlutil "github.com/projectdiscovery/utils/url"
)

// Request is a request for the SSL protocol
Expand Down Expand Up @@ -199,10 +198,7 @@ func (request *Request) GetID() string {

// ExecuteWithResults executes the protocol requests and returns results instead of writing them.
func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicValues, previous output.InternalEvent, callback protocols.OutputEventCallback) error {
hostPort, err := getAddress(input.MetaInput.Input)
if err != nil {
return err
}
hostPort := input.MetaInput.Input
hostname, port, _ := net.SplitHostPort(hostPort)

requestOptions := request.options
Expand Down Expand Up @@ -358,19 +354,6 @@ var RequestPartDefinitions = map[string]string{
"matched": "Matched is the input which was matched upon",
}

// getAddress returns the address of the host to make request to
func getAddress(toTest string) (string, error) {
urlx, err := urlutil.Parse(toTest)
if err != nil {
// use given input instead of url parsing failure
return toTest, nil
}
if urlx.Port() == "" {
urlx.UpdatePort("443")
}
return urlx.Host, nil
}

// Match performs matching operation for a matcher on model and returns:
// true and a list of matched snippets if the matcher type is supports it
// otherwise false and an empty string slice
Expand Down
5 changes: 0 additions & 5 deletions pkg/protocols/ssl/ssl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,3 @@ func TestSSLProtocol(t *testing.T) {
require.Nil(t, err, "could not run ssl request")
require.NotEmpty(t, gotEvent, "could not get event items")
}

func TestGetAddress(t *testing.T) {
address, _ := getAddress("https://scanme.sh")
require.Equal(t, "scanme.sh:443", address, "could not get correct address")
}
22 changes: 19 additions & 3 deletions pkg/tmplexec/flow/flow_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma
// execution logic for http()/dns() etc
for index := range f.allProtocols[opts.protoName] {
req := f.allProtocols[opts.protoName][index]
err := req.ExecuteWithResults(f.ctx.Input, output.InternalEvent(f.options.GetTemplateCtx(f.ctx.Input.MetaInput).GetAll()), nil, f.protocolResultCallback(req, matcherStatus, opts))
// transform input if required
inputItem := f.ctx.Input.Clone()
if f.options.InputHelper != nil && f.ctx.Input.MetaInput.Input != "" {
if inputItem.MetaInput.Input = f.options.InputHelper.Transform(inputItem.MetaInput.Input, req.Type()); inputItem.MetaInput.Input == "" {
f.ctx.LogError(fmt.Errorf("failed to transform input for protocol %s", req.Type()))
return false
}
}
err := req.ExecuteWithResults(inputItem, output.InternalEvent(f.options.GetTemplateCtx(f.ctx.Input.MetaInput).GetAll()), nil, f.protocolResultCallback(req, matcherStatus, opts))
if err != nil {
// save all errors in a map with id as key
// its less likely that there will be race condition but just in case
Expand Down Expand Up @@ -58,7 +66,15 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma
}
return matcherStatus.Load()
}
err := req.ExecuteWithResults(f.ctx.Input, output.InternalEvent(f.options.GetTemplateCtx(f.ctx.Input.MetaInput).GetAll()), nil, f.protocolResultCallback(req, matcherStatus, opts))
// transform input if required
inputItem := f.ctx.Input.Clone()
if f.options.InputHelper != nil && f.ctx.Input.MetaInput.Input != "" {
if inputItem.MetaInput.Input = f.options.InputHelper.Transform(inputItem.MetaInput.Input, req.Type()); inputItem.MetaInput.Input == "" {
f.ctx.LogError(fmt.Errorf("failed to transform input for protocol %s", req.Type()))
return false
}
}
err := req.ExecuteWithResults(inputItem, output.InternalEvent(f.options.GetTemplateCtx(f.ctx.Input.MetaInput).GetAll()), nil, f.protocolResultCallback(req, matcherStatus, opts))
if err != nil {
index := id
err = f.allErrs.Set(opts.protoName+":"+index, err)
Expand All @@ -72,7 +88,7 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma

// protocolResultCallback returns a callback that is executed
// after execution of each protocol request
func (f *FlowExecutor) protocolResultCallback(req protocols.Request, matcherStatus *atomic.Bool, opts *ProtoOptions) func(result *output.InternalWrappedEvent) {
func (f *FlowExecutor) protocolResultCallback(req protocols.Request, matcherStatus *atomic.Bool, _ *ProtoOptions) func(result *output.InternalWrappedEvent) {
return func(result *output.InternalWrappedEvent) {
if result != nil {
// Note: flow specific implicit behaviours should be handled here
Expand Down
14 changes: 9 additions & 5 deletions pkg/tmplexec/multiproto/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,26 @@ func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error {
return ctx.Context().Err()
default:
}

values := m.options.GetTemplateCtx(ctx.Input.MetaInput).GetAll()
err := req.ExecuteWithResults(ctx.Input, output.InternalEvent(values), nil, multiProtoCallback)
inputItem := ctx.Input.Clone()
if m.options.InputHelper != nil && ctx.Input.MetaInput.Input != "" {
if inputItem.MetaInput.Input = m.options.InputHelper.Transform(inputItem.MetaInput.Input, req.Type()); inputItem.MetaInput.Input == "" {
return nil
}
}
// FIXME: this hack of using hash to get templateCtx has known issues scan context based approach should be adopted ASAP
values := m.options.GetTemplateCtx(inputItem.MetaInput).GetAll()
err := req.ExecuteWithResults(inputItem, output.InternalEvent(values), nil, multiProtoCallback)
// in case of fatal error skip execution of next protocols
if err != nil {
// always log errors
ctx.LogError(err)

// for some classes of protocols (i.e ssl) errors like tls handshake are a legitimate behavior so we don't stop execution
// connection failures are already tracked by the internal host error cache
// we use strings comparison as the error is not formalized into instance within the standard library
// within a flow instead we consider ssl errors as fatal, since a specific logic was requested
if req.Type() == types.SSLProtocol && stringsutil.ContainsAnyI(err.Error(), "protocol version not supported", "could not do tls handshake") {
continue
}

return err
}
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/tmplexec/multiproto/multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package multiproto_test
import (
"context"
"log"
"os"
"testing"
"time"

"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/disk"
"github.com/projectdiscovery/nuclei/v3/pkg/input"
"github.com/projectdiscovery/nuclei/v3/pkg/loader/workflow"
"github.com/projectdiscovery/nuclei/v3/pkg/progress"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols"
Expand Down Expand Up @@ -36,6 +38,7 @@ func setup() {
Catalog: disk.NewCatalog(config.DefaultConfig.TemplatesDirectory),
RateLimiter: ratelimit.New(context.Background(), uint(options.RateLimit), time.Second),
Parser: templates.NewParser(),
InputHelper: input.NewHelper(),
}
workflowLoader, err := workflow.NewLoader(&executerOpts)
if err != nil {
Expand All @@ -45,7 +48,6 @@ func setup() {
}

func TestMultiProtoWithDynamicExtractor(t *testing.T) {
setup()
Template, err := templates.Parse("testcases/multiprotodynamic.yaml", nil, executerOpts)
require.Nil(t, err, "could not parse template")

Expand All @@ -62,7 +64,6 @@ func TestMultiProtoWithDynamicExtractor(t *testing.T) {
}

func TestMultiProtoWithProtoPrefix(t *testing.T) {
setup()
Template, err := templates.Parse("testcases/multiprotowithprefix.yaml", nil, executerOpts)
require.Nil(t, err, "could not parse template")

Expand All @@ -77,3 +78,8 @@ func TestMultiProtoWithProtoPrefix(t *testing.T) {
require.Nil(t, err, "could not execute template")
require.True(t, gotresults)
}

func TestMain(m *testing.M) {
setup()
os.Exit(m.Run())
}
Loading