github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sync/atomicptrmap/atomicptrmap_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package atomicptrmap
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"math/rand"
    21  	"reflect"
    22  	"runtime"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/SagerNet/gvisor/pkg/sync"
    27  )
    28  
    29  func TestConsistencyWithGoMap(t *testing.T) {
    30  	const maxKey = 16
    31  	var vals [4]*testValue
    32  	for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
    33  		vals[i] = new(testValue)
    34  	}
    35  	var (
    36  		m   = make(map[int64]*testValue)
    37  		apm testAtomicPtrMap
    38  	)
    39  	for i := 0; i < 100000; i++ {
    40  		// Apply a random operation to both m and apm and expect them to have
    41  		// the same result. Bias toward CompareAndSwap, which has the most
    42  		// cases; bias away from Range and RangeRepeatable, which are
    43  		// relatively expensive.
    44  		switch rand.Intn(10) {
    45  		case 0, 1: // Load
    46  			key := rand.Int63n(maxKey)
    47  			want := m[key]
    48  			got := apm.Load(key)
    49  			t.Logf("Load(%d) = %p", key, got)
    50  			if got != want {
    51  				t.Fatalf("got %p, wanted %p", got, want)
    52  			}
    53  		case 2, 3: // Swap
    54  			key := rand.Int63n(maxKey)
    55  			val := vals[rand.Intn(len(vals))]
    56  			want := m[key]
    57  			if val != nil {
    58  				m[key] = val
    59  			} else {
    60  				delete(m, key)
    61  			}
    62  			got := apm.Swap(key, val)
    63  			t.Logf("Swap(%d, %p) = %p", key, val, got)
    64  			if got != want {
    65  				t.Fatalf("got %p, wanted %p", got, want)
    66  			}
    67  		case 4, 5, 6, 7: // CompareAndSwap
    68  			key := rand.Int63n(maxKey)
    69  			oldVal := vals[rand.Intn(len(vals))]
    70  			newVal := vals[rand.Intn(len(vals))]
    71  			want := m[key]
    72  			if want == oldVal {
    73  				if newVal != nil {
    74  					m[key] = newVal
    75  				} else {
    76  					delete(m, key)
    77  				}
    78  			}
    79  			got := apm.CompareAndSwap(key, oldVal, newVal)
    80  			t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got)
    81  			if got != want {
    82  				t.Fatalf("got %p, wanted %p", got, want)
    83  			}
    84  		case 8: // Range
    85  			got := make(map[int64]*testValue)
    86  			var (
    87  				haveDup = false
    88  				dup     int64
    89  			)
    90  			apm.Range(func(key int64, val *testValue) bool {
    91  				if _, ok := got[key]; ok && !haveDup {
    92  					haveDup = true
    93  					dup = key
    94  				}
    95  				got[key] = val
    96  				return true
    97  			})
    98  			t.Logf("Range() = %v", got)
    99  			if !reflect.DeepEqual(got, m) {
   100  				t.Fatalf("got %v, wanted %v", got, m)
   101  			}
   102  			if haveDup {
   103  				t.Fatalf("got duplicate key %d", dup)
   104  			}
   105  		case 9: // RangeRepeatable
   106  			got := make(map[int64]*testValue)
   107  			apm.RangeRepeatable(func(key int64, val *testValue) bool {
   108  				got[key] = val
   109  				return true
   110  			})
   111  			t.Logf("RangeRepeatable() = %v", got)
   112  			if !reflect.DeepEqual(got, m) {
   113  				t.Fatalf("got %v, wanted %v", got, m)
   114  			}
   115  		}
   116  	}
   117  }
   118  
   119  func TestConcurrentHeterogeneous(t *testing.T) {
   120  	ctx, cancel := context.WithCancel(context.Background())
   121  	var (
   122  		apm testAtomicPtrMap
   123  		wg  sync.WaitGroup
   124  	)
   125  	defer func() {
   126  		cancel()
   127  		wg.Wait()
   128  	}()
   129  
   130  	possibleKeyValuePairs := make(map[int64]map[*testValue]struct{})
   131  	addKeyValuePair := func(key int64, val *testValue) {
   132  		values := possibleKeyValuePairs[key]
   133  		if values == nil {
   134  			values = make(map[*testValue]struct{})
   135  			possibleKeyValuePairs[key] = values
   136  		}
   137  		values[val] = struct{}{}
   138  	}
   139  
   140  	const numValuesPerKey = 4
   141  
   142  	// These goroutines use keys not used by any other goroutine.
   143  	const numPrivateKeys = 3
   144  	for i := 0; i < numPrivateKeys; i++ {
   145  		key := int64(i)
   146  		var vals [numValuesPerKey]*testValue
   147  		for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
   148  			val := new(testValue)
   149  			vals[i] = val
   150  			addKeyValuePair(key, val)
   151  		}
   152  		wg.Add(1)
   153  		go func() {
   154  			defer wg.Done()
   155  			r := rand.New(rand.NewSource(rand.Int63()))
   156  			var stored *testValue
   157  			for ctx.Err() == nil {
   158  				switch r.Intn(4) {
   159  				case 0:
   160  					got := apm.Load(key)
   161  					if got != stored {
   162  						t.Errorf("Load(%d): got %p, wanted %p", key, got, stored)
   163  						return
   164  					}
   165  				case 1:
   166  					val := vals[r.Intn(len(vals))]
   167  					want := stored
   168  					stored = val
   169  					got := apm.Swap(key, val)
   170  					if got != want {
   171  						t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want)
   172  						return
   173  					}
   174  				case 2, 3:
   175  					oldVal := vals[r.Intn(len(vals))]
   176  					newVal := vals[r.Intn(len(vals))]
   177  					want := stored
   178  					if stored == oldVal {
   179  						stored = newVal
   180  					}
   181  					got := apm.CompareAndSwap(key, oldVal, newVal)
   182  					if got != want {
   183  						t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want)
   184  						return
   185  					}
   186  				}
   187  			}
   188  		}()
   189  	}
   190  
   191  	// These goroutines share a small set of keys.
   192  	const numSharedKeys = 2
   193  	var (
   194  		sharedKeys      [numSharedKeys]int64
   195  		sharedValues    = make(map[int64][]*testValue)
   196  		sharedValuesSet = make(map[int64]map[*testValue]struct{})
   197  	)
   198  	for i := range sharedKeys {
   199  		key := int64(numPrivateKeys + i)
   200  		sharedKeys[i] = key
   201  		vals := make([]*testValue, numValuesPerKey)
   202  		valsSet := make(map[*testValue]struct{})
   203  		for j := range vals {
   204  			val := new(testValue)
   205  			vals[j] = val
   206  			valsSet[val] = struct{}{}
   207  			addKeyValuePair(key, val)
   208  		}
   209  		sharedValues[key] = vals
   210  		sharedValuesSet[key] = valsSet
   211  	}
   212  	randSharedValue := func(r *rand.Rand, key int64) *testValue {
   213  		vals := sharedValues[key]
   214  		return vals[r.Intn(len(vals))]
   215  	}
   216  	for i := 0; i < 3; i++ {
   217  		wg.Add(1)
   218  		go func() {
   219  			defer wg.Done()
   220  			r := rand.New(rand.NewSource(rand.Int63()))
   221  			for ctx.Err() == nil {
   222  				keyIndex := r.Intn(len(sharedKeys))
   223  				key := sharedKeys[keyIndex]
   224  				var (
   225  					op  string
   226  					got *testValue
   227  				)
   228  				switch r.Intn(4) {
   229  				case 0:
   230  					op = "Load"
   231  					got = apm.Load(key)
   232  				case 1:
   233  					op = "Swap"
   234  					got = apm.Swap(key, randSharedValue(r, key))
   235  				case 2, 3:
   236  					op = "CompareAndSwap"
   237  					got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key))
   238  				}
   239  				if got != nil {
   240  					valsSet := sharedValuesSet[key]
   241  					if _, ok := valsSet[got]; !ok {
   242  						t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet)
   243  						return
   244  					}
   245  				}
   246  			}
   247  		}()
   248  	}
   249  
   250  	// This goroutine repeatedly searches for unused keys.
   251  	wg.Add(1)
   252  	go func() {
   253  		defer wg.Done()
   254  		r := rand.New(rand.NewSource(rand.Int63()))
   255  		for ctx.Err() == nil {
   256  			key := -1 - r.Int63()
   257  			if got := apm.Load(key); got != nil {
   258  				t.Errorf("Load(%d): got %p, wanted nil", key, got)
   259  			}
   260  		}
   261  	}()
   262  
   263  	// This goroutine repeatedly calls RangeRepeatable() and checks that each
   264  	// key corresponds to an expected value.
   265  	wg.Add(1)
   266  	go func() {
   267  		defer wg.Done()
   268  		abort := false
   269  		for !abort && ctx.Err() == nil {
   270  			apm.RangeRepeatable(func(key int64, val *testValue) bool {
   271  				values, ok := possibleKeyValuePairs[key]
   272  				if !ok {
   273  					t.Errorf("RangeRepeatable: got invalid key %d", key)
   274  					abort = true
   275  					return false
   276  				}
   277  				if _, ok := values[val]; !ok {
   278  					t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values)
   279  					abort = true
   280  					return false
   281  				}
   282  				return true
   283  			})
   284  		}
   285  	}()
   286  
   287  	// Finally, the main goroutine spins for the length of the test calling
   288  	// Range() and checking that each key that it observes is unique and
   289  	// corresponds to an expected value.
   290  	seenKeys := make(map[int64]struct{})
   291  	const testDuration = 5 * time.Second
   292  	end := time.Now().Add(testDuration)
   293  	abort := false
   294  	for time.Now().Before(end) {
   295  		apm.Range(func(key int64, val *testValue) bool {
   296  			values, ok := possibleKeyValuePairs[key]
   297  			if !ok {
   298  				t.Errorf("Range: got invalid key %d", key)
   299  				abort = true
   300  				return false
   301  			}
   302  			if _, ok := values[val]; !ok {
   303  				t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values)
   304  				abort = true
   305  				return false
   306  			}
   307  			if _, ok := seenKeys[key]; ok {
   308  				t.Errorf("Range: got duplicate key %d", key)
   309  				abort = true
   310  				return false
   311  			}
   312  			seenKeys[key] = struct{}{}
   313  			return true
   314  		})
   315  		if abort {
   316  			break
   317  		}
   318  		for k := range seenKeys {
   319  			delete(seenKeys, k)
   320  		}
   321  	}
   322  }
   323  
   324  type benchmarkableMap interface {
   325  	Load(key int64) *testValue
   326  	Store(key int64, val *testValue)
   327  	LoadOrStore(key int64, val *testValue) (*testValue, bool)
   328  	Delete(key int64)
   329  }
   330  
   331  // rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map.
   332  type rwMutexMap struct {
   333  	mu sync.RWMutex
   334  	m  map[int64]*testValue
   335  }
   336  
   337  func (m *rwMutexMap) Load(key int64) *testValue {
   338  	m.mu.RLock()
   339  	defer m.mu.RUnlock()
   340  	return m.m[key]
   341  }
   342  
   343  func (m *rwMutexMap) Store(key int64, val *testValue) {
   344  	m.mu.Lock()
   345  	defer m.mu.Unlock()
   346  	if m.m == nil {
   347  		m.m = make(map[int64]*testValue)
   348  	}
   349  	m.m[key] = val
   350  }
   351  
   352  func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
   353  	m.mu.Lock()
   354  	defer m.mu.Unlock()
   355  	if m.m == nil {
   356  		m.m = make(map[int64]*testValue)
   357  	}
   358  	if oldVal, ok := m.m[key]; ok {
   359  		return oldVal, true
   360  	}
   361  	m.m[key] = val
   362  	return val, false
   363  }
   364  
   365  func (m *rwMutexMap) Delete(key int64) {
   366  	m.mu.Lock()
   367  	defer m.mu.Unlock()
   368  	delete(m.m, key)
   369  }
   370  
   371  // syncMap implements benchmarkableMap for a sync.Map.
   372  type syncMap struct {
   373  	m sync.Map
   374  }
   375  
   376  func (m *syncMap) Load(key int64) *testValue {
   377  	val, ok := m.m.Load(key)
   378  	if !ok {
   379  		return nil
   380  	}
   381  	return val.(*testValue)
   382  }
   383  
   384  func (m *syncMap) Store(key int64, val *testValue) {
   385  	m.m.Store(key, val)
   386  }
   387  
   388  func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
   389  	actual, loaded := m.m.LoadOrStore(key, val)
   390  	return actual.(*testValue), loaded
   391  }
   392  
   393  func (m *syncMap) Delete(key int64) {
   394  	m.m.Delete(key)
   395  }
   396  
   397  // benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap.
   398  type benchmarkableAtomicPtrMap struct {
   399  	m testAtomicPtrMap
   400  }
   401  
   402  func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue {
   403  	return m.m.Load(key)
   404  }
   405  
   406  func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) {
   407  	m.m.Store(key, val)
   408  }
   409  
   410  func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
   411  	if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
   412  		return prev, true
   413  	}
   414  	return val, false
   415  }
   416  
   417  func (m *benchmarkableAtomicPtrMap) Delete(key int64) {
   418  	m.m.Store(key, nil)
   419  }
   420  
   421  // benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded.
   422  type benchmarkableAtomicPtrMapSharded struct {
   423  	m testAtomicPtrMapSharded
   424  }
   425  
   426  func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue {
   427  	return m.m.Load(key)
   428  }
   429  
   430  func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) {
   431  	m.m.Store(key, val)
   432  }
   433  
   434  func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
   435  	if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
   436  		return prev, true
   437  	}
   438  	return val, false
   439  }
   440  
   441  func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) {
   442  	m.m.Store(key, nil)
   443  }
   444  
   445  var mapImpls = [...]struct {
   446  	name string
   447  	ctor func() benchmarkableMap
   448  }{
   449  	{
   450  		name: "RWMutexMap",
   451  		ctor: func() benchmarkableMap {
   452  			return new(rwMutexMap)
   453  		},
   454  	},
   455  	{
   456  		name: "SyncMap",
   457  		ctor: func() benchmarkableMap {
   458  			return new(syncMap)
   459  		},
   460  	},
   461  	{
   462  		name: "AtomicPtrMap",
   463  		ctor: func() benchmarkableMap {
   464  			return new(benchmarkableAtomicPtrMap)
   465  		},
   466  	},
   467  	{
   468  		name: "AtomicPtrMapSharded",
   469  		ctor: func() benchmarkableMap {
   470  			return new(benchmarkableAtomicPtrMapSharded)
   471  		},
   472  	},
   473  }
   474  
   475  func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
   476  	m := mapCtor()
   477  	val := &testValue{}
   478  	for i := 0; i < b.N; i++ {
   479  		m.Store(int64(i), val)
   480  	}
   481  	for i := 0; i < b.N; i++ {
   482  		m.Delete(int64(i))
   483  	}
   484  }
   485  
   486  func BenchmarkStoreDelete(b *testing.B) {
   487  	for _, mapImpl := range mapImpls {
   488  		b.Run(mapImpl.name, func(b *testing.B) {
   489  			benchmarkStoreDelete(b, mapImpl.ctor)
   490  		})
   491  	}
   492  }
   493  
   494  func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
   495  	m := mapCtor()
   496  	val := &testValue{}
   497  	for i := 0; i < b.N; i++ {
   498  		m.LoadOrStore(int64(i), val)
   499  	}
   500  	for i := 0; i < b.N; i++ {
   501  		m.Delete(int64(i))
   502  	}
   503  }
   504  
   505  func BenchmarkLoadOrStoreDelete(b *testing.B) {
   506  	for _, mapImpl := range mapImpls {
   507  		b.Run(mapImpl.name, func(b *testing.B) {
   508  			benchmarkLoadOrStoreDelete(b, mapImpl.ctor)
   509  		})
   510  	}
   511  }
   512  
   513  func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) {
   514  	m := mapCtor()
   515  	val := &testValue{}
   516  	for i := 0; i < b.N; i++ {
   517  		m.Store(int64(i), val)
   518  	}
   519  	b.ResetTimer()
   520  	for i := 0; i < b.N; i++ {
   521  		m.Load(int64(i))
   522  	}
   523  }
   524  
   525  func BenchmarkLookupPositive(b *testing.B) {
   526  	for _, mapImpl := range mapImpls {
   527  		b.Run(mapImpl.name, func(b *testing.B) {
   528  			benchmarkLookupPositive(b, mapImpl.ctor)
   529  		})
   530  	}
   531  }
   532  
   533  func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) {
   534  	m := mapCtor()
   535  	val := &testValue{}
   536  	for i := 0; i < b.N; i++ {
   537  		m.Store(int64(i), val)
   538  	}
   539  	b.ResetTimer()
   540  	for i := 0; i < b.N; i++ {
   541  		m.Load(int64(-1 - i))
   542  	}
   543  }
   544  
   545  func BenchmarkLookupNegative(b *testing.B) {
   546  	for _, mapImpl := range mapImpls {
   547  		b.Run(mapImpl.name, func(b *testing.B) {
   548  			benchmarkLookupNegative(b, mapImpl.ctor)
   549  		})
   550  	}
   551  }
   552  
   553  type benchmarkConcurrentOptions struct {
   554  	// loadsPerMutationPair is the number of map lookups between each
   555  	// insertion/deletion pair.
   556  	loadsPerMutationPair int
   557  
   558  	// If changeKeys is true, the keys used by each goroutine change between
   559  	// iterations of the test.
   560  	changeKeys bool
   561  }
   562  
   563  func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) {
   564  	var (
   565  		started sync.WaitGroup
   566  		workers sync.WaitGroup
   567  	)
   568  	started.Add(1)
   569  
   570  	m := mapCtor()
   571  	val := &testValue{}
   572  	// Insert a large number of unused elements into the map so that used
   573  	// elements are distributed throughout memory.
   574  	for i := 0; i < 10000; i++ {
   575  		m.Store(int64(-1-i), val)
   576  	}
   577  	// n := ceil(b.N / (opts.loadsPerMutationPair + 2))
   578  	n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2)
   579  	for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ {
   580  		workerID := i
   581  		workers.Add(1)
   582  		go func() {
   583  			defer workers.Done()
   584  			started.Wait()
   585  			for i := 0; i < n; i++ {
   586  				var key int64
   587  				if opts.changeKeys {
   588  					key = int64(workerID*n + i)
   589  				} else {
   590  					key = int64(workerID)
   591  				}
   592  				m.LoadOrStore(key, val)
   593  				for j := 0; j < opts.loadsPerMutationPair; j++ {
   594  					m.Load(key)
   595  				}
   596  				m.Delete(key)
   597  			}
   598  		}()
   599  	}
   600  
   601  	b.ResetTimer()
   602  	started.Done()
   603  	workers.Wait()
   604  }
   605  
   606  func BenchmarkConcurrent(b *testing.B) {
   607  	changeKeysChoices := [...]struct {
   608  		name string
   609  		val  bool
   610  	}{
   611  		{"FixedKeys", false},
   612  		{"ChangingKeys", true},
   613  	}
   614  	writePcts := [...]struct {
   615  		name                 string
   616  		loadsPerMutationPair int
   617  	}{
   618  		{"1PercentWrites", 198},
   619  		{"10PercentWrites", 18},
   620  		{"50PercentWrites", 2},
   621  	}
   622  	for _, changeKeys := range changeKeysChoices {
   623  		for _, writePct := range writePcts {
   624  			for _, mapImpl := range mapImpls {
   625  				name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name)
   626  				b.Run(name, func(b *testing.B) {
   627  					benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{
   628  						loadsPerMutationPair: writePct.loadsPerMutationPair,
   629  						changeKeys:           changeKeys.val,
   630  					})
   631  				})
   632  			}
   633  		}
   634  	}
   635  }