github.com/projectdiscovery/nuclei/v2@v2.9.15/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go (about) 1 package hosterrorscache 2 3 import ( 4 "fmt" 5 "sync" 6 "sync/atomic" 7 "testing" 8 9 "github.com/stretchr/testify/require" 10 ) 11 12 func TestCacheCheck(t *testing.T) { 13 cache := New(3, DefaultMaxHostsCount, nil) 14 15 for i := 0; i < 100; i++ { 16 cache.MarkFailed("test", fmt.Errorf("could not resolve host")) 17 got := cache.Check("test") 18 if i < 2 { 19 // till 3 the host is not flagged to skip 20 require.False(t, got) 21 } else { 22 // above 3 it must remain flagged to skip 23 require.True(t, got) 24 } 25 } 26 27 value := cache.Check("test") 28 require.Equal(t, true, value, "could not get checked value") 29 } 30 31 func TestTrackErrors(t *testing.T) { 32 cache := New(3, DefaultMaxHostsCount, []string{"custom error"}) 33 34 for i := 0; i < 100; i++ { 35 cache.MarkFailed("custom", fmt.Errorf("got: nested: custom error")) 36 got := cache.Check("custom") 37 if i < 2 { 38 // till 3 the host is not flagged to skip 39 require.False(t, got) 40 } else { 41 // above 3 it must remain flagged to skip 42 require.True(t, got) 43 } 44 } 45 value := cache.Check("custom") 46 require.Equal(t, true, value, "could not get checked value") 47 } 48 49 func TestCacheItemDo(t *testing.T) { 50 var ( 51 count int 52 item cacheItem 53 ) 54 55 wg := sync.WaitGroup{} 56 for i := 0; i < 100; i++ { 57 wg.Add(1) 58 go func() { 59 defer wg.Done() 60 item.Do(func() { 61 count++ 62 }) 63 }() 64 } 65 wg.Wait() 66 67 // ensures the increment happened only once regardless of the multiple call 68 require.Equal(t, count, 1) 69 } 70 71 func TestCacheMarkFailed(t *testing.T) { 72 cache := New(3, DefaultMaxHostsCount, nil) 73 74 tests := []struct { 75 host string 76 expected int 77 }{ 78 {"http://example.com:80", 1}, 79 {"example.com:80", 2}, 80 {"example.com", 1}, 81 } 82 83 for _, test := range tests { 84 normalizedCacheValue := cache.normalizeCacheValue(test.host) 85 cache.MarkFailed(test.host, fmt.Errorf("no address found for host")) 86 failedTarget, err := cache.failedTargets.Get(normalizedCacheValue) 87 require.Nil(t, err) 88 require.NotNil(t, failedTarget) 89 90 value, ok := failedTarget.(*cacheItem) 91 require.True(t, ok) 92 require.EqualValues(t, test.expected, value.errors.Load()) 93 } 94 } 95 96 func TestCacheMarkFailedConcurrent(t *testing.T) { 97 cache := New(3, DefaultMaxHostsCount, nil) 98 99 tests := []struct { 100 host string 101 expected int32 102 }{ 103 {"http://example.com:80", 200}, 104 {"example.com:80", 200}, 105 {"example.com", 100}, 106 } 107 108 // the cache is not atomic during items creation, so we pre-create them with counter to zero 109 for _, test := range tests { 110 normalizedValue := cache.normalizeCacheValue(test.host) 111 newItem := &cacheItem{errors: atomic.Int32{}} 112 newItem.errors.Store(0) 113 _ = cache.failedTargets.Set(normalizedValue, newItem) 114 } 115 116 wg := sync.WaitGroup{} 117 for _, test := range tests { 118 currentTest := test 119 for i := 0; i < 100; i++ { 120 wg.Add(1) 121 go func() { 122 defer wg.Done() 123 cache.MarkFailed(currentTest.host, fmt.Errorf("could not resolve host")) 124 }() 125 } 126 } 127 wg.Wait() 128 129 for _, test := range tests { 130 require.True(t, cache.Check(test.host)) 131 132 normalizedCacheValue := cache.normalizeCacheValue(test.host) 133 failedTarget, err := cache.failedTargets.Get(normalizedCacheValue) 134 require.Nil(t, err) 135 require.NotNil(t, failedTarget) 136 137 value, ok := failedTarget.(*cacheItem) 138 require.True(t, ok) 139 require.EqualValues(t, test.expected, value.errors.Load()) 140 } 141 }