github.com/grailbio/base@v0.0.11/traverse/traverse_test.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package traverse_test 6 7 import ( 8 "errors" 9 "fmt" 10 "math/rand" 11 "reflect" 12 "strings" 13 "sync" 14 "sync/atomic" 15 "testing" 16 "time" 17 18 "github.com/grailbio/base/traverse" 19 ) 20 21 func recovered(f func()) (v interface{}) { 22 defer func() { v = recover() }() 23 f() 24 return v 25 } 26 27 func TestTraverse(t *testing.T) { 28 list := make([]int, 5) 29 err := traverse.Each(5, func(i int) error { 30 list[i] += i 31 return nil 32 }) 33 if err != nil { 34 t.Fatal(err) 35 } 36 if got, want := list, []int{0, 1, 2, 3, 4}; !reflect.DeepEqual(got, want) { 37 t.Errorf("got %v, want %v", got, want) 38 } 39 expectedErr := errors.New("test error") 40 err = traverse.Each(5, func(i int) error { 41 if i == 3 { 42 return expectedErr 43 } 44 return nil 45 }) 46 if got, want := err, expectedErr; got != want { 47 t.Errorf("got %v want %v", got, want) 48 } 49 } 50 51 func TestTraverseLarge(t *testing.T) { 52 tests := []struct { 53 N int 54 Limit int 55 }{ 56 { 57 N: 1, 58 Limit: 1, 59 }, 60 { 61 N: 10, 62 Limit: 2, 63 }, 64 { 65 N: 2999999, 66 Limit: 5, 67 }, 68 { 69 N: 3000001, 70 Limit: 5, 71 }, 72 } 73 for testId, test := range tests { 74 data := make([]int32, test.N) 75 _ = traverse.Limit(test.Limit).Each(test.N, func(i int) error { 76 atomic.AddInt32(&data[i], 1) 77 return nil 78 }) 79 for i, d := range data { 80 if d != 1 { 81 t.Errorf("Test %d - Each. element %d is %d. Expected 1", testId, i, d) 82 break 83 } 84 } 85 86 data = make([]int32, test.N) 87 _ = traverse.Limit(test.Limit).Range(test.N, func(i, j int) error { 88 for k := i; k < j; k++ { 89 atomic.AddInt32(&data[k], 1) 90 } 91 return nil 92 }) 93 for i, d := range data { 94 if d != 1 { 95 t.Errorf("Test %d - Range. element %d is %d. Expected 1", testId, i, d) 96 break 97 } 98 } 99 100 // Emulate a sequential writer. 101 // The test still passes if LimitSequential is replaced with Limit, but it 102 // should take noticeably longer to execute. 103 // (Note that we can't just e.g. guard 'data' with a mutex. Just because 104 // tasks are launched in numerical order does not mean that they will be 105 // completed in numerical order.) 106 data = data[:0] 107 const cachelineSize = 64 108 var nextWriteIndex struct { 109 _ [cachelineSize - 8]byte 110 N int64 111 _ [cachelineSize - 8]byte 112 } 113 _ = traverse.LimitSequential(test.Limit).Each(test.N, func(i int) error { 114 time.Sleep(50 * time.Nanosecond) 115 for { 116 j := atomic.LoadInt64(&nextWriteIndex.N) 117 if int(j) == i { 118 break 119 } 120 } 121 data = append(data, int32(i)) 122 _ = atomic.AddInt64(&nextWriteIndex.N, 1) 123 return nil 124 }) 125 for i, d := range data { 126 if int(d) != i { 127 t.Errorf("Test %d - LimitSequential. element %d is %d. Expected %d", testId, i, d, i) 128 break 129 } 130 } 131 132 } 133 } 134 135 func TestRange(t *testing.T) { 136 const N = 5000 137 var ( 138 counts = make([]int64, N) 139 invocations int64 140 ) 141 var tr traverse.T 142 for i := 0; i < N; i++ { 143 tr.Limit = rand.Intn(N*2) + 1 144 err := tr.Range(N, func(start, end int) error { 145 if start < 0 || end > N || end < start { 146 return fmt.Errorf("invalid range [%d,%d)", start, end) 147 } 148 atomic.AddInt64(&invocations, 1) 149 for i := start; i < end; i++ { 150 atomic.AddInt64(&counts[i], 1) 151 } 152 return nil 153 }) 154 if err != nil { 155 t.Errorf("limit %d: %v", tr.Limit, err) 156 continue 157 } 158 expect := int64(tr.Limit) 159 if expect > N { 160 expect = N 161 } 162 if got, want := invocations, expect; got != want { 163 t.Errorf("got %v, want %v", got, want) 164 } 165 invocations = 0 166 for i := range counts { 167 if got, want := counts[i], int64(1); got != want { 168 t.Errorf("counts[%d,%d]: got %v, want %v", i, tr.Limit, got, want) 169 } 170 counts[i] = 0 171 } 172 } 173 } 174 175 func TestPanic(t *testing.T) { 176 expectedPanic := "panic in the disco!!" 177 f := func() { 178 _ = traverse.Each(5, func(i int) error { 179 if i == 3 { 180 panic(expectedPanic) 181 } 182 return nil 183 }) 184 } 185 v := recovered(f) 186 s, ok := v.(string) 187 if !ok { 188 t.Fatal("expected string") 189 } 190 if got, want := s, fmt.Sprintf("traverse child: %s", expectedPanic); !strings.HasPrefix(got, want) { 191 t.Errorf("got %q, want %q", got, want) 192 } 193 } 194 195 type testStatus struct { 196 queued, running, done int32 197 } 198 199 type testReporter struct { 200 mu sync.Mutex 201 statusHistory []testStatus 202 queued, running, done int32 203 } 204 205 func (r *testReporter) Init(n int) { 206 r.update(int32(n), 0, 0) 207 } 208 209 func (r *testReporter) Complete() {} 210 211 func (r *testReporter) Begin(i int) { 212 r.update(-1, 1, 0) 213 } 214 215 func (r *testReporter) End(i int) { 216 r.update(0, -1, 1) 217 } 218 219 func (r *testReporter) update(queued, running, done int32) { 220 r.mu.Lock() 221 defer r.mu.Unlock() 222 r.queued += queued 223 r.running += running 224 r.done += done 225 r.statusHistory = 226 append(r.statusHistory, testStatus{queued: r.queued, running: r.running, done: r.done}) 227 } 228 229 func TestReportingSingleJob(t *testing.T) { 230 reporter := new(testReporter) 231 232 tr := traverse.T{Reporter: reporter, Limit: 1} 233 _ = tr.Each(5, func(i int) error { return nil }) 234 235 expectedStatuses := []testStatus{ 236 testStatus{queued: 5, running: 0, done: 0}, 237 testStatus{queued: 4, running: 1, done: 0}, 238 testStatus{queued: 4, running: 0, done: 1}, 239 testStatus{queued: 3, running: 1, done: 1}, 240 testStatus{queued: 3, running: 0, done: 2}, 241 testStatus{queued: 2, running: 1, done: 2}, 242 testStatus{queued: 2, running: 0, done: 3}, 243 testStatus{queued: 1, running: 1, done: 3}, 244 testStatus{queued: 1, running: 0, done: 4}, 245 testStatus{queued: 0, running: 1, done: 4}, 246 testStatus{queued: 0, running: 0, done: 5}, 247 } 248 249 for i, status := range reporter.statusHistory { 250 if status != expectedStatuses[i] { 251 t.Errorf("Expected status %v, got status %v, full log %v", 252 expectedStatuses[i], status, reporter.statusHistory) 253 } 254 } 255 } 256 257 func TestReportingManyJobs(t *testing.T) { 258 reporter := new(testReporter) 259 260 numJobs := 50 261 numConcurrent := 5 262 263 tr := traverse.T{Limit: numConcurrent, Reporter: reporter} 264 _ = tr.Each(numJobs, func(i int) error { return nil }) 265 266 // first status should be all jobs queued 267 if (reporter.statusHistory[0] != testStatus{queued: int32(numJobs), running: 0, done: 0}) { 268 t.Errorf("First status should be all jobs queued, instead got %v", reporter.statusHistory[0]) 269 } 270 271 // last status should be all jobs done 272 numStatuses := len(reporter.statusHistory) 273 if (reporter.statusHistory[numStatuses-1] != testStatus{queued: 0, running: 0, done: int32(numJobs)}) { 274 t.Errorf("Last status should be all jobs done, instead got %v", reporter.statusHistory[numJobs-1]) 275 } 276 277 for i, status := range reporter.statusHistory { 278 if (status.queued + status.running + status.done) != int32(numJobs) { 279 t.Errorf("Total number of jobs is not equal to numJobs = %d - status: %v", numJobs, status) 280 } 281 282 if status.queued < 0 || status.running < 0 || status.done < 0 { 283 t.Errorf("Number of jobs can't be <0, status: %v", status) 284 } 285 286 if status.running > int32(numConcurrent) { 287 t.Errorf("Can't have more than %d jobs running, status: %v", numConcurrent, status) 288 } 289 290 if i > 0 { 291 previousStatus := reporter.statusHistory[i-1] 292 293 if status == previousStatus { 294 t.Errorf("Can't have the same status repeat - status: %v, previous status: %v", 295 status, previousStatus) 296 } 297 298 if status.queued > previousStatus.queued { 299 t.Errorf("Can't have queued jobs count increase - status: %v, previous status: %v", 300 status, previousStatus) 301 } 302 303 if status.done < previousStatus.done { 304 t.Errorf("Can't have done jobs count decrease - status: %v, previous status: %v", 305 status, previousStatus) 306 } 307 } 308 } 309 } 310 311 func BenchmarkDo(b *testing.B) { 312 for _, n := range []int{1, 1e6, 1e8} { 313 b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { 314 for k := 0; k < b.N; k++ { 315 err := traverse.Parallel.Each(n, func(i int) error { 316 return nil 317 }) 318 if err != nil { 319 b.Error(err) 320 } 321 } 322 }) 323 } 324 } 325 326 //go:noinline 327 func fn(i int) error { 328 return nil 329 } 330 331 func BenchmarkInvoke(b *testing.B) { 332 for k := 0; k < b.N; k++ { 333 _ = fn(k) 334 } 335 }