github.com/projectdiscovery/nuclei/v2@v2.9.15/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go (about)

     1  package hosterrorscache
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"sync/atomic"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/require"
    10  )
    11  
    12  func TestCacheCheck(t *testing.T) {
    13  	cache := New(3, DefaultMaxHostsCount, nil)
    14  
    15  	for i := 0; i < 100; i++ {
    16  		cache.MarkFailed("test", fmt.Errorf("could not resolve host"))
    17  		got := cache.Check("test")
    18  		if i < 2 {
    19  			// till 3 the host is not flagged to skip
    20  			require.False(t, got)
    21  		} else {
    22  			// above 3 it must remain flagged to skip
    23  			require.True(t, got)
    24  		}
    25  	}
    26  
    27  	value := cache.Check("test")
    28  	require.Equal(t, true, value, "could not get checked value")
    29  }
    30  
    31  func TestTrackErrors(t *testing.T) {
    32  	cache := New(3, DefaultMaxHostsCount, []string{"custom error"})
    33  
    34  	for i := 0; i < 100; i++ {
    35  		cache.MarkFailed("custom", fmt.Errorf("got: nested: custom error"))
    36  		got := cache.Check("custom")
    37  		if i < 2 {
    38  			// till 3 the host is not flagged to skip
    39  			require.False(t, got)
    40  		} else {
    41  			// above 3 it must remain flagged to skip
    42  			require.True(t, got)
    43  		}
    44  	}
    45  	value := cache.Check("custom")
    46  	require.Equal(t, true, value, "could not get checked value")
    47  }
    48  
    49  func TestCacheItemDo(t *testing.T) {
    50  	var (
    51  		count int
    52  		item  cacheItem
    53  	)
    54  
    55  	wg := sync.WaitGroup{}
    56  	for i := 0; i < 100; i++ {
    57  		wg.Add(1)
    58  		go func() {
    59  			defer wg.Done()
    60  			item.Do(func() {
    61  				count++
    62  			})
    63  		}()
    64  	}
    65  	wg.Wait()
    66  
    67  	// ensures the increment happened only once regardless of the multiple call
    68  	require.Equal(t, count, 1)
    69  }
    70  
    71  func TestCacheMarkFailed(t *testing.T) {
    72  	cache := New(3, DefaultMaxHostsCount, nil)
    73  
    74  	tests := []struct {
    75  		host     string
    76  		expected int
    77  	}{
    78  		{"http://example.com:80", 1},
    79  		{"example.com:80", 2},
    80  		{"example.com", 1},
    81  	}
    82  
    83  	for _, test := range tests {
    84  		normalizedCacheValue := cache.normalizeCacheValue(test.host)
    85  		cache.MarkFailed(test.host, fmt.Errorf("no address found for host"))
    86  		failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
    87  		require.Nil(t, err)
    88  		require.NotNil(t, failedTarget)
    89  
    90  		value, ok := failedTarget.(*cacheItem)
    91  		require.True(t, ok)
    92  		require.EqualValues(t, test.expected, value.errors.Load())
    93  	}
    94  }
    95  
    96  func TestCacheMarkFailedConcurrent(t *testing.T) {
    97  	cache := New(3, DefaultMaxHostsCount, nil)
    98  
    99  	tests := []struct {
   100  		host     string
   101  		expected int32
   102  	}{
   103  		{"http://example.com:80", 200},
   104  		{"example.com:80", 200},
   105  		{"example.com", 100},
   106  	}
   107  
   108  	// the cache is not atomic during items creation, so we pre-create them with counter to zero
   109  	for _, test := range tests {
   110  		normalizedValue := cache.normalizeCacheValue(test.host)
   111  		newItem := &cacheItem{errors: atomic.Int32{}}
   112  		newItem.errors.Store(0)
   113  		_ = cache.failedTargets.Set(normalizedValue, newItem)
   114  	}
   115  
   116  	wg := sync.WaitGroup{}
   117  	for _, test := range tests {
   118  		currentTest := test
   119  		for i := 0; i < 100; i++ {
   120  			wg.Add(1)
   121  			go func() {
   122  				defer wg.Done()
   123  				cache.MarkFailed(currentTest.host, fmt.Errorf("could not resolve host"))
   124  			}()
   125  		}
   126  	}
   127  	wg.Wait()
   128  
   129  	for _, test := range tests {
   130  		require.True(t, cache.Check(test.host))
   131  
   132  		normalizedCacheValue := cache.normalizeCacheValue(test.host)
   133  		failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
   134  		require.Nil(t, err)
   135  		require.NotNil(t, failedTarget)
   136  
   137  		value, ok := failedTarget.(*cacheItem)
   138  		require.True(t, ok)
   139  		require.EqualValues(t, test.expected, value.errors.Load())
   140  	}
   141  }