github.com/koko1123/flow-go-1@v0.29.6/network/p2p/dns/resolver_test.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"sync"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/mock"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/koko1123/flow-go-1/module/irrecoverable"
    15  	"github.com/koko1123/flow-go-1/module/mempool/herocache"
    16  	"github.com/koko1123/flow-go-1/module/metrics"
    17  	"github.com/koko1123/flow-go-1/network/mocknetwork"
    18  	"github.com/koko1123/flow-go-1/utils/unittest"
    19  	testnetwork "github.com/koko1123/flow-go-1/utils/unittest/network"
    20  )
    21  
    22  const happyPath = true
    23  
    24  // TestResolver_HappyPath evaluates once the request for a domain gets cached, the subsequent requests are going through the cache
    25  // instead of going through the underlying basic resolver, and hence through the network.
    26  func TestResolver_HappyPath(t *testing.T) {
    27  	basicResolver := mocknetwork.BasicResolver{}
    28  	dnsCache := herocache.NewDNSCache(
    29  		DefaultCacheSize,
    30  		unittest.Logger(),
    31  		metrics.NewNoopCollector(),
    32  		metrics.NewNoopCollector(),
    33  	)
    34  
    35  	resolver := NewResolver(
    36  		unittest.Logger(),
    37  		metrics.NewNoopCollector(),
    38  		dnsCache,
    39  		WithBasicResolver(&basicResolver))
    40  
    41  	cancelCtx, cancel := context.WithCancel(context.Background())
    42  	defer cancel()
    43  	ctx, _ := irrecoverable.WithSignaler(cancelCtx)
    44  	resolver.Start(ctx)
    45  	unittest.RequireCloseBefore(t, resolver.Ready(), 100*time.Millisecond, "could not start dns resolver on time")
    46  
    47  	size := 10 // 10 text and 10 ip domains.
    48  	times := 5 // each domain is queried 5 times.
    49  	txtTestCases := testnetwork.TxtLookupFixture(size)
    50  	ipTestCases := testnetwork.IpLookupFixture(size)
    51  
    52  	// each domain is resolved only once through the underlying resolver, and then is cached for subsequent times.
    53  	resolverWG := mockBasicResolverForDomains(t, &basicResolver, ipTestCases, txtTestCases, happyPath, 1)
    54  	queryWG := syncThenAsyncQuery(t, times, resolver, txtTestCases, ipTestCases, happyPath)
    55  
    56  	unittest.RequireReturnsBefore(t, resolverWG.Wait, 1*time.Second, "could not resolve all expected domains")
    57  	unittest.RequireReturnsBefore(t, queryWG.Wait, 1*time.Second, "could not perform all queries on time")
    58  	cancel()
    59  	unittest.RequireCloseBefore(t, resolver.Done(), 100*time.Millisecond, "could not stop dns resolver on time")
    60  }
    61  
    62  // TestResolver_CacheExpiry evaluates that cached dns entries get expired and underlying resolver gets called after their time-to-live is passed.
    63  func TestResolver_CacheExpiry(t *testing.T) {
    64  	basicResolver := mocknetwork.BasicResolver{}
    65  
    66  	dnsCache := herocache.NewDNSCache(
    67  		DefaultCacheSize,
    68  		unittest.Logger(),
    69  		metrics.NewNoopCollector(),
    70  		metrics.NewNoopCollector(),
    71  	)
    72  
    73  	resolver := NewResolver(
    74  		unittest.Logger(),
    75  		metrics.NewNoopCollector(),
    76  		dnsCache,
    77  		WithBasicResolver(&basicResolver),
    78  		WithTTL(1*time.Second)) // cache timeout set to 3 seconds for this test
    79  
    80  	cancelCtx, cancel := context.WithCancel(context.Background())
    81  	defer cancel()
    82  	ctx, _ := irrecoverable.WithSignaler(cancelCtx)
    83  	resolver.Start(ctx)
    84  	unittest.RequireCloseBefore(t, resolver.Ready(), 100*time.Millisecond, "could not start dns resolver on time")
    85  
    86  	size := 5  // we have 5 txt and 5 ip lookup test cases
    87  	times := 3 // each domain is queried for resolution 3 times
    88  	txtTestCases := testnetwork.TxtLookupFixture(size)
    89  	ipTestCase := testnetwork.IpLookupFixture(size)
    90  
    91  	// each domain gets resolved through underlying resolver twice: once initially, and once after expiry.
    92  	resolverWG := mockBasicResolverForDomains(t, &basicResolver, ipTestCase, txtTestCases, happyPath, 2)
    93  
    94  	// queries (5 + 5) cases * 3 = 30 queries.
    95  	queryWG := syncThenAsyncQuery(t, times, resolver, txtTestCases, ipTestCase, happyPath)
    96  	unittest.RequireReturnsBefore(t, queryWG.Wait, 1*time.Second, "could not perform all queries on time")
    97  
    98  	time.Sleep(2 * time.Second) // waits enough for cache to get invalidated
    99  
   100  	queryWG = syncThenAsyncQuery(t, times, resolver, txtTestCases, ipTestCase, happyPath)
   101  
   102  	unittest.RequireReturnsBefore(t, resolverWG.Wait, 3*time.Second, "could not resolve all expected domains")
   103  	unittest.RequireReturnsBefore(t, queryWG.Wait, 1*time.Second, "could not perform all queries on time")
   104  
   105  	cancel()
   106  	unittest.RequireCloseBefore(t, resolver.Done(), 2*time.Second, "could not stop dns resolver on time")
   107  }
   108  
   109  // TestResolver_Error evaluates that when the underlying resolver returns an error, the resolver itself does not cache the result.
   110  func TestResolver_Error(t *testing.T) {
   111  	basicResolver := mocknetwork.BasicResolver{}
   112  
   113  	dnsCache := herocache.NewDNSCache(
   114  		DefaultCacheSize,
   115  		unittest.Logger(),
   116  		metrics.NewNoopCollector(),
   117  		metrics.NewNoopCollector(),
   118  	)
   119  
   120  	resolver := NewResolver(
   121  		unittest.Logger(),
   122  		metrics.NewNoopCollector(),
   123  		dnsCache,
   124  		WithBasicResolver(&basicResolver))
   125  
   126  	cancelCtx, cancel := context.WithCancel(context.Background())
   127  	defer cancel()
   128  	ctx, _ := irrecoverable.WithSignaler(cancelCtx)
   129  	resolver.Start(ctx)
   130  	unittest.RequireCloseBefore(t, resolver.Ready(), 100*time.Millisecond, "could not start dns resolver on time")
   131  
   132  	// one test case for txt and one for ip
   133  	times := 5 // each test case tried 5 times
   134  	txtTestCases := testnetwork.TxtLookupFixture(1)
   135  	ipTestCase := testnetwork.IpLookupFixture(1)
   136  
   137  	// mocks underlying basic resolver invoked 5 times per domain and returns an error each time.
   138  	// this evaluates that upon returning an error, the result is not cached, so the next invocation again goes
   139  	// through the resolver.
   140  	resolverWG := mockBasicResolverForDomains(t, &basicResolver, ipTestCase, txtTestCases, !happyPath, times)
   141  	queryWG := syncThenAsyncQuery(t, times, resolver, txtTestCases, ipTestCase, !happyPath)
   142  
   143  	unittest.RequireReturnsBefore(t, resolverWG.Wait, 1*time.Second, "could not resolve all expected domains")
   144  	unittest.RequireReturnsBefore(t, queryWG.Wait, 1*time.Second, "could not perform all queries on time")
   145  	cancel()
   146  	unittest.RequireCloseBefore(t, resolver.Done(), 100*time.Millisecond, "could not stop dns resolver on time")
   147  
   148  	// since resolving hits an error, cache is invalidated.
   149  	ipSize, txtSize := resolver.c.dCache.Size()
   150  	require.Zero(t, ipSize)
   151  	require.Zero(t, txtSize)
   152  }
   153  
   154  // TestResolver_Expired_Invalidated evaluates that when resolver is queried for an expired entry, it returns the expired one, but queries async on the
   155  // network to refresh the cache. However, when the query hits an error, it invalidates the cache.
   156  func TestResolver_Expired_Invalidated(t *testing.T) {
   157  	basicResolver := mocknetwork.BasicResolver{}
   158  	dnsCache := herocache.NewDNSCache(
   159  		DefaultCacheSize,
   160  		unittest.Logger(),
   161  		metrics.NewNoopCollector(),
   162  		metrics.NewNoopCollector(),
   163  	)
   164  
   165  	resolver := NewResolver(
   166  		unittest.Logger(),
   167  		metrics.NewNoopCollector(),
   168  		dnsCache,
   169  		WithBasicResolver(&basicResolver),
   170  		WithTTL(1*time.Second)) // 1 second TTL for test
   171  
   172  	cancelCtx, cancel := context.WithCancel(context.Background())
   173  	defer cancel()
   174  	ctx, _ := irrecoverable.WithSignaler(cancelCtx)
   175  	resolver.Start(ctx)
   176  	unittest.RequireCloseBefore(t, resolver.Ready(), 100*time.Millisecond, "could not start dns resolver on time")
   177  
   178  	// one test case for txt and one for ip
   179  	txtTestCases := testnetwork.TxtLookupFixture(1)
   180  	ipTestCase := testnetwork.IpLookupFixture(1)
   181  
   182  	// mocks test cases cached in resolver and waits for their expiry
   183  	mockCacheForDomains(resolver, ipTestCase, txtTestCases)
   184  	time.Sleep(1 * time.Second)
   185  
   186  	// queries for an expired entry must return the expired entry but also fire an async update on it.
   187  	// though we mock async update to fail, so the cache should be invalidated literally.
   188  	// mocks underlying basic resolver invoked once per domain and returns an error on each domain
   189  	resolverWG := mockBasicResolverForDomains(t, &basicResolver, ipTestCase, txtTestCases, !happyPath, 1)
   190  	// queries are answered by cache, so resolver returning an error only invalidates the cache asynchronously for the first time.
   191  	queryWG := syncThenAsyncQuery(t, 1, resolver, txtTestCases, ipTestCase, happyPath)
   192  
   193  	unittest.RequireReturnsBefore(t, queryWG.Wait, 1*time.Second, "could not perform all queries on time")
   194  	unittest.RequireReturnsBefore(t, resolverWG.Wait, 1*time.Second, "could not resolve all expected domains")
   195  	cancel()
   196  	unittest.RequireCloseBefore(t, resolver.Done(), 100*time.Millisecond, "could not stop dns resolver on time")
   197  
   198  	// since resolving hits an error, cache is invalidated.
   199  	ipSize, txtSize := resolver.c.dCache.Size()
   200  	require.Zero(t, ipSize)
   201  	require.Zero(t, txtSize)
   202  }
   203  
   204  // syncThenAsyncQuery concurrently requests each test case for the specified number of times. The returned wait group will be released when
   205  // all queries have been resolved.
   206  func syncThenAsyncQuery(t *testing.T,
   207  	times int,
   208  	resolver *Resolver,
   209  	txtTestCases map[string]*testnetwork.TxtLookupTestCase,
   210  	ipTestCases map[string]*testnetwork.IpLookupTestCase,
   211  	happyPath bool) *sync.WaitGroup {
   212  
   213  	ctx := context.Background()
   214  	wg := &sync.WaitGroup{}
   215  	wg.Add(times * (len(txtTestCases) + len(ipTestCases)))
   216  
   217  	for _, txttc := range txtTestCases {
   218  		cacheAndQuery(t, func(domain string) (interface{}, error) {
   219  			return resolver.LookupTXT(ctx, domain)
   220  		}, txttc.Txt, txttc.Records, times, wg, happyPath)
   221  	}
   222  
   223  	for _, iptc := range ipTestCases {
   224  		cacheAndQuery(t, func(domain string) (interface{}, error) {
   225  			return resolver.LookupIPAddr(ctx, domain)
   226  		}, iptc.Domain, iptc.Result, times, wg, happyPath)
   227  	}
   228  
   229  	return wg
   230  }
   231  
   232  // cacheAndQuery makes a dns query for each of domains first so that the result gets cache, and then it performs
   233  // concurrent queries for each test case for the specified number of times. The wait group is released when all
   234  // queries resolved.
   235  func cacheAndQuery(t *testing.T,
   236  	resolver func(domain string) (interface{}, error),
   237  	domain string,
   238  	result interface{},
   239  	times int,
   240  	wg *sync.WaitGroup,
   241  	happyPath bool) {
   242  
   243  	firstCallDone := make(chan interface{})
   244  
   245  	for i := 0; i < times; i++ {
   246  		go func(index int) {
   247  			if index != 0 {
   248  				// other invocations (except first one) of each test
   249  				// wait for the first time to get through and
   250  				// cached and then go concurrently.
   251  				<-firstCallDone
   252  			}
   253  
   254  			addrs, err := resolver(domain)
   255  
   256  			if happyPath {
   257  				require.NoError(t, err)
   258  				require.ElementsMatch(t, addrs, result)
   259  			} else {
   260  				require.Error(t, err, domain)
   261  			}
   262  
   263  			if index == 0 {
   264  				close(firstCallDone) // now lets other invocations go
   265  			}
   266  
   267  			wg.Done()
   268  
   269  		}(i)
   270  	}
   271  }
   272  
   273  // mockBasicResolverForDomains mocks the resolver for the ip and txt lookup test cases, it makes sure that no domain is requested more than
   274  // the number of times specified.
   275  // Returned wait group is released when resolver is queried for `times * (len(ipLookupTestCases) + len(txtLookupTestCases))` times.
   276  func mockBasicResolverForDomains(t *testing.T,
   277  	resolver *mocknetwork.BasicResolver,
   278  	ipLookupTestCases map[string]*testnetwork.IpLookupTestCase,
   279  	txtLookupTestCases map[string]*testnetwork.TxtLookupTestCase,
   280  	happyPath bool,
   281  	times int) *sync.WaitGroup {
   282  
   283  	// keeping track of requested domains
   284  	ipRequested := make(map[string]int)
   285  	txtRequested := make(map[string]int)
   286  
   287  	wg := &sync.WaitGroup{}
   288  	wg.Add(times * (len(ipLookupTestCases) + len(txtLookupTestCases)))
   289  
   290  	mu := sync.Mutex{}
   291  	resolver.On("LookupIPAddr", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
   292  		mu.Lock()
   293  		defer mu.Unlock()
   294  
   295  		// method should be called on expected parameters
   296  		_, ok := args[0].(context.Context)
   297  		require.True(t, ok)
   298  
   299  		domain, ok := args[1].(string)
   300  		require.True(t, ok)
   301  
   302  		// requested domain should be expected.
   303  		_, ok = ipLookupTestCases[domain]
   304  		require.True(t, ok)
   305  
   306  		// requested domain should be only requested once through underlying resolver
   307  		count, ok := ipRequested[domain]
   308  		if !ok {
   309  			count = 0
   310  		}
   311  		count++
   312  
   313  		require.LessOrEqual(t, count, times, domain)
   314  		ipRequested[domain] = count
   315  
   316  		wg.Done()
   317  	}).Return(
   318  		func(ctx context.Context, domain string) []net.IPAddr {
   319  			if !happyPath {
   320  				return nil
   321  			}
   322  			return ipLookupTestCases[domain].Result
   323  		},
   324  		func(ctx context.Context, domain string) error {
   325  			if !happyPath {
   326  				return fmt.Errorf("error")
   327  			}
   328  			return nil
   329  		})
   330  
   331  	resolver.On("LookupTXT", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
   332  		mu.Lock()
   333  		defer mu.Unlock()
   334  
   335  		// method should be called on expected parameters
   336  		_, ok := args[0].(context.Context)
   337  		require.True(t, ok)
   338  
   339  		domain, ok := args[1].(string)
   340  		require.True(t, ok)
   341  
   342  		// requested domain should be expected.
   343  		_, ok = txtLookupTestCases[domain]
   344  		require.True(t, ok)
   345  
   346  		// requested domain should be only requested once through underlying resolver
   347  		count, ok := txtRequested[domain]
   348  		if !ok {
   349  			count = 0
   350  		}
   351  		count++
   352  		require.LessOrEqual(t, count, times, domain)
   353  		txtRequested[domain] = count
   354  
   355  		wg.Done()
   356  
   357  	}).Return(
   358  		func(ctx context.Context, domain string) []string {
   359  			if !happyPath {
   360  				return nil
   361  			}
   362  			return txtLookupTestCases[domain].Records
   363  		},
   364  		func(ctx context.Context, domain string) error {
   365  			if !happyPath {
   366  				return fmt.Errorf("error")
   367  			}
   368  			return nil
   369  		})
   370  
   371  	return wg
   372  }
   373  
   374  // mockCacheForDomains updates cache of resolver with the test cases.
   375  func mockCacheForDomains(resolver *Resolver,
   376  	ipLookupTestCases map[string]*testnetwork.IpLookupTestCase,
   377  	txtLookupTestCases map[string]*testnetwork.TxtLookupTestCase) {
   378  
   379  	for _, iptc := range ipLookupTestCases {
   380  		resolver.c.updateIPCache(iptc.Domain, iptc.Result)
   381  	}
   382  
   383  	for _, txttc := range txtLookupTestCases {
   384  		resolver.c.updateTXTCache(txttc.Txt, txttc.Records)
   385  	}
   386  }