github.com/grailbio/base@v0.0.11/traverse/traverse_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package traverse_test
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"math/rand"
    11  	"reflect"
    12  	"strings"
    13  	"sync"
    14  	"sync/atomic"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/grailbio/base/traverse"
    19  )
    20  
    21  func recovered(f func()) (v interface{}) {
    22  	defer func() { v = recover() }()
    23  	f()
    24  	return v
    25  }
    26  
    27  func TestTraverse(t *testing.T) {
    28  	list := make([]int, 5)
    29  	err := traverse.Each(5, func(i int) error {
    30  		list[i] += i
    31  		return nil
    32  	})
    33  	if err != nil {
    34  		t.Fatal(err)
    35  	}
    36  	if got, want := list, []int{0, 1, 2, 3, 4}; !reflect.DeepEqual(got, want) {
    37  		t.Errorf("got %v, want %v", got, want)
    38  	}
    39  	expectedErr := errors.New("test error")
    40  	err = traverse.Each(5, func(i int) error {
    41  		if i == 3 {
    42  			return expectedErr
    43  		}
    44  		return nil
    45  	})
    46  	if got, want := err, expectedErr; got != want {
    47  		t.Errorf("got %v want %v", got, want)
    48  	}
    49  }
    50  
    51  func TestTraverseLarge(t *testing.T) {
    52  	tests := []struct {
    53  		N     int
    54  		Limit int
    55  	}{
    56  		{
    57  			N:     1,
    58  			Limit: 1,
    59  		},
    60  		{
    61  			N:     10,
    62  			Limit: 2,
    63  		},
    64  		{
    65  			N:     2999999,
    66  			Limit: 5,
    67  		},
    68  		{
    69  			N:     3000001,
    70  			Limit: 5,
    71  		},
    72  	}
    73  	for testId, test := range tests {
    74  		data := make([]int32, test.N)
    75  		_ = traverse.Limit(test.Limit).Each(test.N, func(i int) error {
    76  			atomic.AddInt32(&data[i], 1)
    77  			return nil
    78  		})
    79  		for i, d := range data {
    80  			if d != 1 {
    81  				t.Errorf("Test %d - Each. element %d is %d.  Expected 1", testId, i, d)
    82  				break
    83  			}
    84  		}
    85  
    86  		data = make([]int32, test.N)
    87  		_ = traverse.Limit(test.Limit).Range(test.N, func(i, j int) error {
    88  			for k := i; k < j; k++ {
    89  				atomic.AddInt32(&data[k], 1)
    90  			}
    91  			return nil
    92  		})
    93  		for i, d := range data {
    94  			if d != 1 {
    95  				t.Errorf("Test %d - Range. element %d is %d.  Expected 1", testId, i, d)
    96  				break
    97  			}
    98  		}
    99  
   100  		// Emulate a sequential writer.
   101  		// The test still passes if LimitSequential is replaced with Limit, but it
   102  		// should take noticeably longer to execute.
   103  		// (Note that we can't just e.g. guard 'data' with a mutex.  Just because
   104  		// tasks are launched in numerical order does not mean that they will be
   105  		// completed in numerical order.)
   106  		data = data[:0]
   107  		const cachelineSize = 64
   108  		var nextWriteIndex struct {
   109  			_ [cachelineSize - 8]byte
   110  			N int64
   111  			_ [cachelineSize - 8]byte
   112  		}
   113  		_ = traverse.LimitSequential(test.Limit).Each(test.N, func(i int) error {
   114  			time.Sleep(50 * time.Nanosecond)
   115  			for {
   116  				j := atomic.LoadInt64(&nextWriteIndex.N)
   117  				if int(j) == i {
   118  					break
   119  				}
   120  			}
   121  			data = append(data, int32(i))
   122  			_ = atomic.AddInt64(&nextWriteIndex.N, 1)
   123  			return nil
   124  		})
   125  		for i, d := range data {
   126  			if int(d) != i {
   127  				t.Errorf("Test %d - LimitSequential. element %d is %d.  Expected %d", testId, i, d, i)
   128  				break
   129  			}
   130  		}
   131  
   132  	}
   133  }
   134  
   135  func TestRange(t *testing.T) {
   136  	const N = 5000
   137  	var (
   138  		counts      = make([]int64, N)
   139  		invocations int64
   140  	)
   141  	var tr traverse.T
   142  	for i := 0; i < N; i++ {
   143  		tr.Limit = rand.Intn(N*2) + 1
   144  		err := tr.Range(N, func(start, end int) error {
   145  			if start < 0 || end > N || end < start {
   146  				return fmt.Errorf("invalid range [%d,%d)", start, end)
   147  			}
   148  			atomic.AddInt64(&invocations, 1)
   149  			for i := start; i < end; i++ {
   150  				atomic.AddInt64(&counts[i], 1)
   151  			}
   152  			return nil
   153  		})
   154  		if err != nil {
   155  			t.Errorf("limit %d: %v", tr.Limit, err)
   156  			continue
   157  		}
   158  		expect := int64(tr.Limit)
   159  		if expect > N {
   160  			expect = N
   161  		}
   162  		if got, want := invocations, expect; got != want {
   163  			t.Errorf("got %v, want %v", got, want)
   164  		}
   165  		invocations = 0
   166  		for i := range counts {
   167  			if got, want := counts[i], int64(1); got != want {
   168  				t.Errorf("counts[%d,%d]: got %v, want %v", i, tr.Limit, got, want)
   169  			}
   170  			counts[i] = 0
   171  		}
   172  	}
   173  }
   174  
   175  func TestPanic(t *testing.T) {
   176  	expectedPanic := "panic in the disco!!"
   177  	f := func() {
   178  		_ = traverse.Each(5, func(i int) error {
   179  			if i == 3 {
   180  				panic(expectedPanic)
   181  			}
   182  			return nil
   183  		})
   184  	}
   185  	v := recovered(f)
   186  	s, ok := v.(string)
   187  	if !ok {
   188  		t.Fatal("expected string")
   189  	}
   190  	if got, want := s, fmt.Sprintf("traverse child: %s", expectedPanic); !strings.HasPrefix(got, want) {
   191  		t.Errorf("got %q, want %q", got, want)
   192  	}
   193  }
   194  
   195  type testStatus struct {
   196  	queued, running, done int32
   197  }
   198  
   199  type testReporter struct {
   200  	mu                    sync.Mutex
   201  	statusHistory         []testStatus
   202  	queued, running, done int32
   203  }
   204  
   205  func (r *testReporter) Init(n int) {
   206  	r.update(int32(n), 0, 0)
   207  }
   208  
   209  func (r *testReporter) Complete() {}
   210  
   211  func (r *testReporter) Begin(i int) {
   212  	r.update(-1, 1, 0)
   213  }
   214  
   215  func (r *testReporter) End(i int) {
   216  	r.update(0, -1, 1)
   217  }
   218  
   219  func (r *testReporter) update(queued, running, done int32) {
   220  	r.mu.Lock()
   221  	defer r.mu.Unlock()
   222  	r.queued += queued
   223  	r.running += running
   224  	r.done += done
   225  	r.statusHistory =
   226  		append(r.statusHistory, testStatus{queued: r.queued, running: r.running, done: r.done})
   227  }
   228  
   229  func TestReportingSingleJob(t *testing.T) {
   230  	reporter := new(testReporter)
   231  
   232  	tr := traverse.T{Reporter: reporter, Limit: 1}
   233  	_ = tr.Each(5, func(i int) error { return nil })
   234  
   235  	expectedStatuses := []testStatus{
   236  		testStatus{queued: 5, running: 0, done: 0},
   237  		testStatus{queued: 4, running: 1, done: 0},
   238  		testStatus{queued: 4, running: 0, done: 1},
   239  		testStatus{queued: 3, running: 1, done: 1},
   240  		testStatus{queued: 3, running: 0, done: 2},
   241  		testStatus{queued: 2, running: 1, done: 2},
   242  		testStatus{queued: 2, running: 0, done: 3},
   243  		testStatus{queued: 1, running: 1, done: 3},
   244  		testStatus{queued: 1, running: 0, done: 4},
   245  		testStatus{queued: 0, running: 1, done: 4},
   246  		testStatus{queued: 0, running: 0, done: 5},
   247  	}
   248  
   249  	for i, status := range reporter.statusHistory {
   250  		if status != expectedStatuses[i] {
   251  			t.Errorf("Expected status %v, got status %v, full log %v",
   252  				expectedStatuses[i], status, reporter.statusHistory)
   253  		}
   254  	}
   255  }
   256  
   257  func TestReportingManyJobs(t *testing.T) {
   258  	reporter := new(testReporter)
   259  
   260  	numJobs := 50
   261  	numConcurrent := 5
   262  
   263  	tr := traverse.T{Limit: numConcurrent, Reporter: reporter}
   264  	_ = tr.Each(numJobs, func(i int) error { return nil })
   265  
   266  	// first status should be all jobs queued
   267  	if (reporter.statusHistory[0] != testStatus{queued: int32(numJobs), running: 0, done: 0}) {
   268  		t.Errorf("First status should be all jobs queued, instead got %v", reporter.statusHistory[0])
   269  	}
   270  
   271  	// last status should be all jobs done
   272  	numStatuses := len(reporter.statusHistory)
   273  	if (reporter.statusHistory[numStatuses-1] != testStatus{queued: 0, running: 0, done: int32(numJobs)}) {
   274  		t.Errorf("Last status should be all jobs done, instead got %v", reporter.statusHistory[numJobs-1])
   275  	}
   276  
   277  	for i, status := range reporter.statusHistory {
   278  		if (status.queued + status.running + status.done) != int32(numJobs) {
   279  			t.Errorf("Total number of jobs is not equal to numJobs = %d - status: %v", numJobs, status)
   280  		}
   281  
   282  		if status.queued < 0 || status.running < 0 || status.done < 0 {
   283  			t.Errorf("Number of jobs can't be <0, status: %v", status)
   284  		}
   285  
   286  		if status.running > int32(numConcurrent) {
   287  			t.Errorf("Can't have more than %d jobs running, status: %v", numConcurrent, status)
   288  		}
   289  
   290  		if i > 0 {
   291  			previousStatus := reporter.statusHistory[i-1]
   292  
   293  			if status == previousStatus {
   294  				t.Errorf("Can't have the same status repeat - status: %v, previous status: %v",
   295  					status, previousStatus)
   296  			}
   297  
   298  			if status.queued > previousStatus.queued {
   299  				t.Errorf("Can't have queued jobs count increase - status: %v, previous status: %v",
   300  					status, previousStatus)
   301  			}
   302  
   303  			if status.done < previousStatus.done {
   304  				t.Errorf("Can't have done jobs count decrease - status: %v, previous status: %v",
   305  					status, previousStatus)
   306  			}
   307  		}
   308  	}
   309  }
   310  
   311  func BenchmarkDo(b *testing.B) {
   312  	for _, n := range []int{1, 1e6, 1e8} {
   313  		b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
   314  			for k := 0; k < b.N; k++ {
   315  				err := traverse.Parallel.Each(n, func(i int) error {
   316  					return nil
   317  				})
   318  				if err != nil {
   319  					b.Error(err)
   320  				}
   321  			}
   322  		})
   323  	}
   324  }
   325  
   326  //go:noinline
   327  func fn(i int) error {
   328  	return nil
   329  }
   330  
   331  func BenchmarkInvoke(b *testing.B) {
   332  	for k := 0; k < b.N; k++ {
   333  		_ = fn(k)
   334  	}
   335  }