github.com/polevpn/netstack@v1.10.9/tcpip/stack/linkaddrcache_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"fmt"
    19  	"sync"
    20  	"sync/atomic"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/polevpn/netstack/sleep"
    25  	"github.com/polevpn/netstack/tcpip"
    26  )
    27  
    28  type testaddr struct {
    29  	addr     tcpip.FullAddress
    30  	linkAddr tcpip.LinkAddress
    31  }
    32  
    33  var testAddrs = func() []testaddr {
    34  	var addrs []testaddr
    35  	for i := 0; i < 4*linkAddrCacheSize; i++ {
    36  		addr := fmt.Sprintf("Addr%06d", i)
    37  		addrs = append(addrs, testaddr{
    38  			addr:     tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
    39  			linkAddr: tcpip.LinkAddress("Link" + addr),
    40  		})
    41  	}
    42  	return addrs
    43  }()
    44  
    45  type testLinkAddressResolver struct {
    46  	cache                *linkAddrCache
    47  	delay                time.Duration
    48  	onLinkAddressRequest func()
    49  }
    50  
    51  func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
    52  	time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
    53  	if f := r.onLinkAddressRequest; f != nil {
    54  		f()
    55  	}
    56  	return nil
    57  }
    58  
    59  func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
    60  	for _, ta := range testAddrs {
    61  		if ta.addr.Addr == addr {
    62  			r.cache.add(ta.addr, ta.linkAddr)
    63  			break
    64  		}
    65  	}
    66  }
    67  
    68  func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
    69  	if addr == "broadcast" {
    70  		return "mac_broadcast", true
    71  	}
    72  	return "", false
    73  }
    74  
    75  func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
    76  	return 1
    77  }
    78  
    79  func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
    80  	w := sleep.Waker{}
    81  	s := sleep.Sleeper{}
    82  	s.AddWaker(&w, 123)
    83  	defer s.Done()
    84  
    85  	for {
    86  		if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
    87  			return got, err
    88  		}
    89  		s.Fetch(true)
    90  	}
    91  }
    92  
    93  func TestCacheOverflow(t *testing.T) {
    94  	c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
    95  	for i := len(testAddrs) - 1; i >= 0; i-- {
    96  		e := testAddrs[i]
    97  		c.add(e.addr, e.linkAddr)
    98  		got, _, err := c.get(e.addr, nil, "", nil, nil)
    99  		if err != nil {
   100  			t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
   101  		}
   102  		if got != e.linkAddr {
   103  			t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
   104  		}
   105  	}
   106  	// Expect to find at least half of the most recent entries.
   107  	for i := 0; i < linkAddrCacheSize/2; i++ {
   108  		e := testAddrs[i]
   109  		got, _, err := c.get(e.addr, nil, "", nil, nil)
   110  		if err != nil {
   111  			t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
   112  		}
   113  		if got != e.linkAddr {
   114  			t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
   115  		}
   116  	}
   117  	// The earliest entries should no longer be in the cache.
   118  	for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
   119  		e := testAddrs[i]
   120  		if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
   121  			t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
   122  		}
   123  	}
   124  }
   125  
   126  func TestCacheConcurrent(t *testing.T) {
   127  	c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
   128  
   129  	var wg sync.WaitGroup
   130  	for r := 0; r < 16; r++ {
   131  		wg.Add(1)
   132  		go func() {
   133  			for _, e := range testAddrs {
   134  				c.add(e.addr, e.linkAddr)
   135  				c.get(e.addr, nil, "", nil, nil) // make work for gotsan
   136  			}
   137  			wg.Done()
   138  		}()
   139  	}
   140  	wg.Wait()
   141  
   142  	// All goroutines add in the same order and add more values than
   143  	// can fit in the cache, so our eviction strategy requires that
   144  	// the last entry be present and the first be missing.
   145  	e := testAddrs[len(testAddrs)-1]
   146  	got, _, err := c.get(e.addr, nil, "", nil, nil)
   147  	if err != nil {
   148  		t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
   149  	}
   150  	if got != e.linkAddr {
   151  		t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
   152  	}
   153  
   154  	e = testAddrs[0]
   155  	if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
   156  		t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
   157  	}
   158  }
   159  
   160  func TestCacheAgeLimit(t *testing.T) {
   161  	c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
   162  	e := testAddrs[0]
   163  	c.add(e.addr, e.linkAddr)
   164  	time.Sleep(50 * time.Millisecond)
   165  	if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
   166  		t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
   167  	}
   168  }
   169  
   170  func TestCacheReplace(t *testing.T) {
   171  	c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
   172  	e := testAddrs[0]
   173  	l2 := e.linkAddr + "2"
   174  	c.add(e.addr, e.linkAddr)
   175  	got, _, err := c.get(e.addr, nil, "", nil, nil)
   176  	if err != nil {
   177  		t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
   178  	}
   179  	if got != e.linkAddr {
   180  		t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
   181  	}
   182  
   183  	c.add(e.addr, l2)
   184  	got, _, err = c.get(e.addr, nil, "", nil, nil)
   185  	if err != nil {
   186  		t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
   187  	}
   188  	if got != l2 {
   189  		t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
   190  	}
   191  }
   192  
   193  func TestCacheResolution(t *testing.T) {
   194  	c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
   195  	linkRes := &testLinkAddressResolver{cache: c}
   196  	for i, ta := range testAddrs {
   197  		got, err := getBlocking(c, ta.addr, linkRes)
   198  		if err != nil {
   199  			t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
   200  		}
   201  		if got != ta.linkAddr {
   202  			t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
   203  		}
   204  	}
   205  
   206  	// Check that after resolved, address stays in the cache and never returns WouldBlock.
   207  	for i := 0; i < 10; i++ {
   208  		e := testAddrs[len(testAddrs)-1]
   209  		got, _, err := c.get(e.addr, linkRes, "", nil, nil)
   210  		if err != nil {
   211  			t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
   212  		}
   213  		if got != e.linkAddr {
   214  			t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
   215  		}
   216  	}
   217  }
   218  
   219  func TestCacheResolutionFailed(t *testing.T) {
   220  	c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
   221  	linkRes := &testLinkAddressResolver{cache: c}
   222  
   223  	var requestCount uint32
   224  	linkRes.onLinkAddressRequest = func() {
   225  		atomic.AddUint32(&requestCount, 1)
   226  	}
   227  
   228  	// First, sanity check that resolution is working...
   229  	e := testAddrs[0]
   230  	got, err := getBlocking(c, e.addr, linkRes)
   231  	if err != nil {
   232  		t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
   233  	}
   234  	if got != e.linkAddr {
   235  		t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
   236  	}
   237  
   238  	before := atomic.LoadUint32(&requestCount)
   239  
   240  	e.addr.Addr += "2"
   241  	if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
   242  		t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
   243  	}
   244  
   245  	if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
   246  		t.Errorf("got link address request count = %d, want = %d", got, want)
   247  	}
   248  }
   249  
   250  func TestCacheResolutionTimeout(t *testing.T) {
   251  	resolverDelay := 500 * time.Millisecond
   252  	expiration := resolverDelay / 10
   253  	c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
   254  	linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
   255  
   256  	e := testAddrs[0]
   257  	if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
   258  		t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
   259  	}
   260  }
   261  
   262  // TestStaticResolution checks that static link addresses are resolved immediately and don't
   263  // send resolution requests.
   264  func TestStaticResolution(t *testing.T) {
   265  	c := newLinkAddrCache(1<<63-1, time.Millisecond, 1)
   266  	linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute}
   267  
   268  	addr := tcpip.Address("broadcast")
   269  	want := tcpip.LinkAddress("mac_broadcast")
   270  	got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
   271  	if err != nil {
   272  		t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
   273  	}
   274  	if got != want {
   275  		t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
   276  	}
   277  }