gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/atomicbitops/atomicbitops_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // +checkalignedignore
    16  package atomicbitops
    17  
    18  import (
    19  	"fmt"
    20  	"math"
    21  	"runtime"
    22  	"testing"
    23  
    24  	"gvisor.dev/gvisor/pkg/sync"
    25  )
    26  
    27  const iterations = 100
    28  
    29  func detectRaces32(val, target uint32, fn func(*Uint32, uint32)) bool {
    30  	runtime.GOMAXPROCS(100)
    31  	for n := 0; n < iterations; n++ {
    32  		x := FromUint32(val)
    33  		var wg sync.WaitGroup
    34  		for i := uint32(0); i < 32; i++ {
    35  			wg.Add(1)
    36  			go func(a *Uint32, i uint32) {
    37  				defer wg.Done()
    38  				fn(a, uint32(1<<i))
    39  			}(&x, i)
    40  		}
    41  		wg.Wait()
    42  		if x != FromUint32(target) {
    43  			return true
    44  		}
    45  	}
    46  	return false
    47  }
    48  
    49  func detectRaces64(val, target uint64, fn func(*Uint64, uint64)) bool {
    50  	runtime.GOMAXPROCS(100)
    51  	for n := 0; n < iterations; n++ {
    52  		x := FromUint64(val)
    53  		var wg sync.WaitGroup
    54  		for i := uint64(0); i < 64; i++ {
    55  			wg.Add(1)
    56  			go func(a *Uint64, i uint64) {
    57  				defer wg.Done()
    58  				fn(a, uint64(1<<i))
    59  			}(&x, i)
    60  		}
    61  		wg.Wait()
    62  		if x != FromUint64(target) {
    63  			return true
    64  		}
    65  	}
    66  	return false
    67  }
    68  
    69  func TestOrUint32(t *testing.T) {
    70  	if detectRaces32(0x0, 0xffffffff, OrUint32) {
    71  		t.Error("Data race detected!")
    72  	}
    73  }
    74  
    75  func TestAndUint32(t *testing.T) {
    76  	if detectRaces32(0xf0f0f0f0, 0x00000000, AndUint32) {
    77  		t.Error("Data race detected!")
    78  	}
    79  }
    80  
    81  func TestXorUint32(t *testing.T) {
    82  	if detectRaces32(0xf0f0f0f0, 0x0f0f0f0f, XorUint32) {
    83  		t.Error("Data race detected!")
    84  	}
    85  }
    86  
    87  func TestOrUint64(t *testing.T) {
    88  	if detectRaces64(0x0, 0xffffffffffffffff, OrUint64) {
    89  		t.Error("Data race detected!")
    90  	}
    91  }
    92  
    93  func TestAndUint64(t *testing.T) {
    94  	if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0, AndUint64) {
    95  		t.Error("Data race detected!")
    96  	}
    97  }
    98  
    99  func TestXorUint64(t *testing.T) {
   100  	if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0f0f0f0f0f0f0f0f, XorUint64) {
   101  		t.Error("Data race detected!")
   102  	}
   103  }
   104  
   105  func TestCompareAndSwapUint32(t *testing.T) {
   106  	tests := []struct {
   107  		name string
   108  		prev uint32
   109  		old  uint32
   110  		new  uint32
   111  		next uint32
   112  	}{
   113  		{
   114  			name: "Successful compare-and-swap with prev == new",
   115  			prev: 10,
   116  			old:  10,
   117  			new:  10,
   118  			next: 10,
   119  		},
   120  		{
   121  			name: "Successful compare-and-swap with prev != new",
   122  			prev: 20,
   123  			old:  20,
   124  			new:  22,
   125  			next: 22,
   126  		},
   127  		{
   128  			name: "Failed compare-and-swap with prev == new",
   129  			prev: 31,
   130  			old:  30,
   131  			new:  31,
   132  			next: 31,
   133  		},
   134  		{
   135  			name: "Failed compare-and-swap with prev != new",
   136  			prev: 41,
   137  			old:  40,
   138  			new:  42,
   139  			next: 41,
   140  		},
   141  	}
   142  	for _, test := range tests {
   143  		val := FromUint32(test.prev)
   144  		prev := CompareAndSwapUint32(&val, test.old, test.new)
   145  		if got, want := prev, test.prev; got != want {
   146  			t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want)
   147  		}
   148  		if got, want := val.Load(), test.next; got != want {
   149  			t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want)
   150  		}
   151  	}
   152  }
   153  
   154  func TestCompareAndSwapUint64(t *testing.T) {
   155  	tests := []struct {
   156  		name string
   157  		prev uint64
   158  		old  uint64
   159  		new  uint64
   160  		next uint64
   161  	}{
   162  		{
   163  			name: "Successful compare-and-swap with prev == new",
   164  			prev: 0x100000000,
   165  			old:  0x100000000,
   166  			new:  0x100000000,
   167  			next: 0x100000000,
   168  		},
   169  		{
   170  			name: "Successful compare-and-swap with prev != new",
   171  			prev: 0x200000000,
   172  			old:  0x200000000,
   173  			new:  0x200000002,
   174  			next: 0x200000002,
   175  		},
   176  		{
   177  			name: "Failed compare-and-swap with prev == new",
   178  			prev: 0x300000001,
   179  			old:  0x300000000,
   180  			new:  0x300000001,
   181  			next: 0x300000001,
   182  		},
   183  		{
   184  			name: "Failed compare-and-swap with prev != new",
   185  			prev: 0x400000001,
   186  			old:  0x400000000,
   187  			new:  0x400000002,
   188  			next: 0x400000001,
   189  		},
   190  	}
   191  	for _, test := range tests {
   192  		val := FromUint64(test.prev)
   193  		prev := CompareAndSwapUint64(&val, test.old, test.new)
   194  		if got, want := prev, test.prev; got != want {
   195  			t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want)
   196  		}
   197  		if got, want := val.Load(), test.next; got != want {
   198  			t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want)
   199  		}
   200  	}
   201  }
   202  
   203  var interestingFloats = []float64{
   204  	0.0,
   205  	1.0,
   206  	0.1,
   207  	2.1,
   208  	-1.0,
   209  	-0.1,
   210  	-2.1,
   211  	math.MaxFloat64,
   212  	-math.MaxFloat64,
   213  	math.SmallestNonzeroFloat64,
   214  	-math.SmallestNonzeroFloat64,
   215  	math.Inf(1),
   216  	math.Inf(-1),
   217  	math.NaN(),
   218  }
   219  
   220  // equalOrBothNaN returns true if a == b or if a and b are both NaN.
   221  func equalOrBothNaN(a, b float64) bool {
   222  	return a == b || (math.IsNaN(a) && math.IsNaN(b))
   223  }
   224  
   225  // getInterestingFloatPermutations returns a list of `num`-sized permutations
   226  // of the floating-point values in `interestingFloats`.
   227  func getInterestingFloatPermutations(num int) [][]float64 {
   228  	permutations := make([][]float64, 0, len(interestingFloats))
   229  	for _, f := range interestingFloats {
   230  		permutations = append(permutations, []float64{f})
   231  	}
   232  	for i := 1; i < num; i++ {
   233  		oldPermutations := permutations
   234  		permutations = make([][]float64, 0, len(permutations)*len(interestingFloats))
   235  		for _, oldPermutation := range oldPermutations {
   236  			for _, f := range interestingFloats {
   237  				alreadyInPermutation := false
   238  				for _, f2 := range oldPermutation {
   239  					if equalOrBothNaN(f, f2) {
   240  						alreadyInPermutation = true
   241  						break
   242  					}
   243  				}
   244  				if alreadyInPermutation {
   245  					continue
   246  				}
   247  				permutations = append(permutations, append(oldPermutation, f))
   248  			}
   249  		}
   250  
   251  	}
   252  	return permutations
   253  }
   254  
   255  func TestCompareAndSwapFloat64(t *testing.T) {
   256  	for _, floats := range getInterestingFloatPermutations(3) {
   257  		a, b, c := floats[0], floats[1], floats[2]
   258  		t.Run(fmt.Sprintf("a=%v b=%v c=%v", a, b, c), func(t *testing.T) {
   259  			tests := []struct {
   260  				name string
   261  				prev float64
   262  				old  float64
   263  				new  float64
   264  				next float64
   265  			}{
   266  				{
   267  					name: "Successful compare-and-swap with prev == new",
   268  					prev: a,
   269  					old:  a,
   270  					new:  a,
   271  					next: a,
   272  				},
   273  				{
   274  					name: "Successful compare-and-swap with prev != new",
   275  					prev: a,
   276  					old:  a,
   277  					new:  b,
   278  					next: b,
   279  				},
   280  				{
   281  					name: "Failed compare-and-swap with prev == new",
   282  					prev: a,
   283  					old:  b,
   284  					new:  a,
   285  					next: a,
   286  				},
   287  				{
   288  					name: "Failed compare-and-swap with prev != new",
   289  					prev: a,
   290  					old:  b,
   291  					new:  c,
   292  					next: a,
   293  				},
   294  			}
   295  			for _, test := range tests {
   296  				t.Run(test.name, func(t *testing.T) {
   297  					val := FromFloat64(test.prev)
   298  					success := val.CompareAndSwap(test.old, test.new)
   299  					wantSuccess := equalOrBothNaN(test.prev, test.old) && equalOrBothNaN(test.new, test.next)
   300  					if success != wantSuccess {
   301  						t.Errorf("incorrect success value: got %v, expected %v", success, wantSuccess)
   302  					}
   303  					if got, want := val.Load(), test.next; !equalOrBothNaN(got, want) {
   304  						t.Errorf("incorrect value stored in val: got %v, expected %v", got, want)
   305  					}
   306  				})
   307  			}
   308  		})
   309  	}
   310  }
   311  
   312  func TestAddFloat64(t *testing.T) {
   313  	runtime.GOMAXPROCS(100)
   314  	for _, floats := range getInterestingFloatPermutations(3) {
   315  		a, b, c := floats[0], floats[1], floats[2]
   316  		// This test computes the outcome of adding `b` and `c` to `a`.
   317  		// Because floating point numbers lose precision with each operation,
   318  		// it is not always the case that a + b + c = a + c + b.
   319  		// Therefore, it computes both a + b + c and a + c + b, and verifies that
   320  		// adding Float64s in that order works exactly, while Float64s to which
   321  		// `b` and `c` are added in separate goroutines may end up at either
   322  		// `a + b + c` or `a + c + b`.
   323  		testName := fmt.Sprintf("a=%v b=%v c=%v", a, b, c)
   324  		for i := 0; i < iterations; i++ {
   325  			fCanonical := a
   326  			fCanonicalReverse := a
   327  			fLinear := FromFloat64(a)
   328  			fLinearReverse := FromFloat64(a)
   329  			fParallel1 := FromFloat64(a)
   330  			fParallel2 := FromFloat64(a)
   331  			var wg sync.WaitGroup
   332  			spawn := func(f func()) {
   333  				wg.Add(1)
   334  				go func() {
   335  					defer wg.Done()
   336  					f()
   337  				}()
   338  			}
   339  			spawn(func() {
   340  				fCanonical += b
   341  				fCanonical += c
   342  			})
   343  			spawn(func() {
   344  				fCanonicalReverse += c
   345  				fCanonicalReverse += b
   346  			})
   347  			spawn(func() {
   348  				fLinear.Add(b)
   349  				fLinear.Add(c)
   350  			})
   351  			spawn(func() {
   352  				fLinearReverse.Add(c)
   353  				fLinearReverse.Add(b)
   354  			})
   355  			spawn(func() {
   356  				fParallel1.Add(b)
   357  			})
   358  			spawn(func() {
   359  				fParallel2.Add(c)
   360  			})
   361  			spawn(func() {
   362  				fParallel1.Add(c)
   363  			})
   364  			spawn(func() {
   365  				fParallel2.Add(b)
   366  			})
   367  			wg.Wait()
   368  			for _, f := range []struct {
   369  				name string
   370  				val  float64
   371  				want []float64
   372  			}{
   373  				{"linear", fLinear.Load(), []float64{fCanonical}},
   374  				{"linear reverse", fLinearReverse.Load(), []float64{fCanonicalReverse}},
   375  				{"parallel 1", fParallel1.Load(), []float64{fCanonical, fCanonicalReverse}},
   376  				{"parallel 2", fParallel2.Load(), []float64{fCanonical, fCanonicalReverse}},
   377  			} {
   378  				found := false
   379  				for _, want := range f.want {
   380  					if equalOrBothNaN(f.val, want) {
   381  						found = true
   382  						break
   383  					}
   384  				}
   385  				if !found {
   386  					t.Errorf("%s: %s was not equal to expected result: %v not in %v", testName, f.name, f.val, f.want)
   387  				}
   388  			}
   389  		}
   390  	}
   391  }