Skip to content

Commit 1f945d6

Browse files
consider protocolType in max host error (projectdiscovery#5668)
* consider protocolType in max host error * add mutex when updating internal-event
1 parent e4dae52 commit 1f945d6

File tree

10 files changed

+37
-29
lines changed

10 files changed

+37
-29
lines changed

Diff for: pkg/core/executors.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ
107107
currentInfo.Unlock()
108108

109109
// Skip if the host has had errors
110-
if e.executerOpts.HostErrorsCache != nil && e.executerOpts.HostErrorsCache.Check(contextargs.NewWithMetaInput(ctx, scannedValue)) {
110+
if e.executerOpts.HostErrorsCache != nil && e.executerOpts.HostErrorsCache.Check(e.executerOpts.ProtocolType.String(), contextargs.NewWithMetaInput(ctx, scannedValue)) {
111111
return true
112112
}
113113

Diff for: pkg/core/workflow_execute.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
9898
}
9999
if err != nil {
100100
if w.Options.HostErrorsCache != nil {
101-
w.Options.HostErrorsCache.MarkFailed(ctx.Input, err)
101+
w.Options.HostErrorsCache.MarkFailed(w.Options.ProtocolType.String(), ctx.Input, err)
102102
}
103103
if len(template.Executers) == 1 {
104104
mainErr = err

Diff for: pkg/protocols/common/hosterrorscache/hosterrorscache.go

+11-9
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ import (
2020
// CacheInterface defines the signature of the hosterrorscache so that
2121
// users of Nuclei as embedded lib may implement their own cache
2222
type CacheInterface interface {
23-
SetVerbose(verbose bool) // log verbosely
24-
Close() // close the cache
25-
Check(ctx *contextargs.Context) bool // return true if the host should be skipped
26-
MarkFailed(ctx *contextargs.Context, err error) // record a failure (and cause) for the host
23+
SetVerbose(verbose bool) // log verbosely
24+
Close() // close the cache
25+
Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped
26+
MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host
2727
}
2828

2929
var (
@@ -115,7 +115,7 @@ func (c *Cache) NormalizeCacheValue(value string) string {
115115
// - URL: https?:// type
116116
// - Host:port type
117117
// - host type
118-
func (c *Cache) Check(ctx *contextargs.Context) bool {
118+
func (c *Cache) Check(protoType string, ctx *contextargs.Context) bool {
119119
finalValue := c.GetKeyFromContext(ctx, nil)
120120

121121
existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue)
@@ -138,8 +138,8 @@ func (c *Cache) Check(ctx *contextargs.Context) bool {
138138
}
139139

140140
// MarkFailed marks a host as failed previously
141-
func (c *Cache) MarkFailed(ctx *contextargs.Context, err error) {
142-
if !c.checkError(err) {
141+
func (c *Cache) MarkFailed(protoType string, ctx *contextargs.Context, err error) {
142+
if !c.checkError(protoType, err) {
143143
return
144144
}
145145
finalValue := c.GetKeyFromContext(ctx, err)
@@ -186,11 +186,13 @@ var reCheckError = regexp.MustCompile(`(no address found for host|could not reso
186186
// added to the host skipping table.
187187
// it first parses error and extracts the cause and checks for blacklisted
188188
// or common errors that should be skipped
189-
func (c *Cache) checkError(err error) bool {
189+
func (c *Cache) checkError(protoType string, err error) bool {
190190
if err == nil {
191191
return false
192192
}
193-
193+
if protoType != "http" {
194+
return false
195+
}
194196
kind := errkit.GetErrorKind(err, nucleierr.ErrTemplateLogic)
195197
switch kind {
196198
case nucleierr.ErrTemplateLogic:

Diff for: pkg/protocols/common/hosterrorscache/hosterrorscache_test.go

+13-9
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@ import (
1111
"github.com/stretchr/testify/require"
1212
)
1313

14+
const (
15+
protoType = "http"
16+
)
17+
1418
func TestCacheCheck(t *testing.T) {
1519
cache := New(3, DefaultMaxHostsCount, nil)
1620

1721
for i := 0; i < 100; i++ {
18-
cache.MarkFailed(newCtxArgs("test"), fmt.Errorf("could not resolve host"))
19-
got := cache.Check(newCtxArgs("test"))
22+
cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host"))
23+
got := cache.Check(protoType, newCtxArgs("test"))
2024
if i < 2 {
2125
// till 3 the host is not flagged to skip
2226
require.False(t, got)
@@ -26,16 +30,16 @@ func TestCacheCheck(t *testing.T) {
2630
}
2731
}
2832

29-
value := cache.Check(newCtxArgs("test"))
33+
value := cache.Check(protoType, newCtxArgs("test"))
3034
require.Equal(t, true, value, "could not get checked value")
3135
}
3236

3337
func TestTrackErrors(t *testing.T) {
3438
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})
3539

3640
for i := 0; i < 100; i++ {
37-
cache.MarkFailed(newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
38-
got := cache.Check(newCtxArgs("custom"))
41+
cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
42+
got := cache.Check(protoType, newCtxArgs("custom"))
3943
if i < 2 {
4044
// till 3 the host is not flagged to skip
4145
require.False(t, got)
@@ -44,7 +48,7 @@ func TestTrackErrors(t *testing.T) {
4448
require.True(t, got)
4549
}
4650
}
47-
value := cache.Check(newCtxArgs("custom"))
51+
value := cache.Check(protoType, newCtxArgs("custom"))
4852
require.Equal(t, true, value, "could not get checked value")
4953
}
5054

@@ -86,7 +90,7 @@ func TestCacheMarkFailed(t *testing.T) {
8690

8791
for _, test := range tests {
8892
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
89-
cache.MarkFailed(newCtxArgs(test.host), fmt.Errorf("no address found for host"))
93+
cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host"))
9094
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
9195
require.Nil(t, err)
9296
require.NotNil(t, failedTarget)
@@ -122,14 +126,14 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
122126
wg.Add(1)
123127
go func() {
124128
defer wg.Done()
125-
cache.MarkFailed(newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
129+
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
126130
}()
127131
}
128132
}
129133
wg.Wait()
130134

131135
for _, test := range tests {
132-
require.True(t, cache.Check(newCtxArgs(test.host)))
136+
require.True(t, cache.Check(protoType, newCtxArgs(test.host)))
133137

134138
normalizedCacheValue := cache.NormalizeCacheValue(test.host)
135139
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)

Diff for: pkg/protocols/http/request.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1177,14 +1177,14 @@ func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err
11771177
return
11781178
}
11791179
if request.options.HostErrorsCache != nil {
1180-
request.options.HostErrorsCache.MarkFailed(input, err)
1180+
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err)
11811181
}
11821182
}
11831183

11841184
// isUnresponsiveAddress checks if the error is a unreponsive based on its execution history
11851185
func (request *Request) isUnresponsiveAddress(input *contextargs.Context) bool {
11861186
if request.options.HostErrorsCache != nil {
1187-
return request.options.HostErrorsCache.Check(input)
1187+
return request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input)
11881188
}
11891189
return false
11901190
}

Diff for: pkg/protocols/http/request_fuzz.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ func (request *Request) executeAllFuzzingRules(input *contextargs.Context, value
161161
func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest, input *contextargs.Context, callback protocols.OutputEventCallback) bool {
162162
hasInteractMatchers := interactsh.HasMatchers(request.CompiledOperators)
163163
hasInteractMarkers := len(gr.InteractURLs) > 0
164-
if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(input) {
164+
if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input) {
165165
return false
166166
}
167167
request.options.RateLimitTake()
@@ -215,7 +215,7 @@ func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest,
215215
}
216216
if requestErr != nil {
217217
if request.options.HostErrorsCache != nil {
218-
request.options.HostErrorsCache.MarkFailed(input, requestErr)
218+
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, requestErr)
219219
}
220220
gologger.Verbose().Msgf("[%s] Error occurred in request: %s\n", request.options.TemplateID, requestErr)
221221
}

Diff for: pkg/protocols/network/request.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -504,14 +504,14 @@ func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err
504504
return
505505
}
506506
if request.options.HostErrorsCache != nil {
507-
request.options.HostErrorsCache.MarkFailed(input, err)
507+
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err)
508508
}
509509
}
510510

511511
// isUnresponsiveAddress checks if the error is a unreponsive based on its execution history
512512
func (request *Request) isUnresponsiveAddress(input *contextargs.Context) bool {
513513
if request.options.HostErrorsCache != nil {
514-
return request.options.HostErrorsCache.Check(input)
514+
return request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input)
515515
}
516516
return false
517517
}

Diff for: pkg/templates/cluster.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
274274
}
275275
})
276276
if err != nil && e.options.HostErrorsCache != nil {
277-
e.options.HostErrorsCache.MarkFailed(ctx.Input, err)
277+
e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err)
278278
}
279279
return results, err
280280
}
@@ -310,7 +310,7 @@ func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.R
310310
}
311311

312312
if err != nil && e.options.HostErrorsCache != nil {
313-
e.options.HostErrorsCache.MarkFailed(ctx.Input, err)
313+
e.options.HostErrorsCache.MarkFailed(e.options.ProtocolType.String(), ctx.Input, err)
314314
}
315315
return scanCtx.GenerateResult(), err
316316
}

Diff for: pkg/tmplexec/exec.go

+2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ func (e *TemplateExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
206206
ctx.LogError(errx)
207207

208208
if lastMatcherEvent != nil {
209+
lastMatcherEvent.Lock()
209210
lastMatcherEvent.InternalEvent["error"] = getErrorCause(ctx.GenerateErrorMessage())
211+
lastMatcherEvent.Unlock()
210212
writeFailureCallback(lastMatcherEvent, e.options.Options.MatcherStatus)
211213
}
212214

Diff for: pkg/tmplexec/generic/exec.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error {
8585
if err != nil {
8686
ctx.LogError(err)
8787
if g.options.HostErrorsCache != nil {
88-
g.options.HostErrorsCache.MarkFailed(ctx.Input, err)
88+
g.options.HostErrorsCache.MarkFailed(g.options.ProtocolType.String(), ctx.Input, err)
8989
}
9090
gologger.Warning().Msgf("[%s] Could not execute request for %s: %s\n", g.options.TemplateID, ctx.Input.MetaInput.PrettyPrint(), err)
9191
}

0 commit comments

Comments
 (0)