github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/slicestatus_test.go (about)

     1  // Copyright 2019 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package exec
     6  
     7  import (
     8  	"context"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/grailbio/bigslice"
    13  )
    14  
    15  // sample returns a slice of task sets randomly chosen from tasks, without
    16  // repeat. This is used to build a task DAG for testing.
    17  func sample(r *rand.Rand, tasks [][]*Task, n int) [][]*Task {
    18  	tasks = append([][]*Task{}, tasks...)
    19  	samples := make([][]*Task, n)
    20  	for i := range samples {
    21  		j := r.Intn(len(tasks))
    22  		samples[i] = tasks[j]
    23  		tasks[j] = tasks[len(tasks)-1]
    24  		tasks = tasks[:len(tasks)-1]
    25  	}
    26  	return samples
    27  }
    28  
    29  // setUpDep sets dst to be a dependency of src.
    30  func setUpDep(src, dst []*Task) {
    31  	if len(src) != len(dst) {
    32  		panic("src, dst mismatch")
    33  	}
    34  	for i := range src {
    35  		src[i].Deps = append(src[i].Deps, TaskDep{Head: dst[i]})
    36  	}
    37  }
    38  
    39  // makeFakeGraph builds a fake task graph, returning its roots.
    40  func makeFakeGraph(r *rand.Rand, numSlices, numTasksPerSlice int) []*Task {
    41  	// We build the graph by making some fake slices with fake tasks that
    42  	// ostensibly compute them. We then a build a DAG out of these
    43  	// slices/tasks. This is our rough approximation of what happens during
    44  	// compilation.
    45  	slices := make([]bigslice.Slice, numSlices)
    46  	for i := range slices {
    47  		slices[i] = bigslice.Const(1, []int{})
    48  	}
    49  	tasks := make([][]*Task, numSlices)
    50  	for i := range tasks {
    51  		tasks[i] = make([]*Task, numTasksPerSlice)
    52  		for j := range tasks[i] {
    53  			tasks[i][j] = &Task{Slices: []bigslice.Slice{slices[i]}}
    54  		}
    55  	}
    56  	// We build a DAG by setting up dependencies, only ever depending on tasks
    57  	// with a greater index in the tasks slice.
    58  	for i := range tasks {
    59  		if i != len(tasks)-1 {
    60  			// Each tasks[i] depends on tasks[i+1]. This guarantees that every
    61  			// slice/task is used in the graph.
    62  			setUpDep(tasks[i], tasks[i+1])
    63  		}
    64  		// We limit the number of dependencies to vaguely mimic reality and
    65  		// prevent an explosion in the number of edges in the fake graph.
    66  		maxDeps := len(tasks) - i - 2
    67  		if maxDeps <= 0 {
    68  			continue
    69  		}
    70  		if maxDeps > 3 {
    71  			maxDeps = 3
    72  		}
    73  		numDeps := 1 + r.Intn(maxDeps)
    74  		deps := sample(r, tasks[i+2:], numDeps)
    75  		for _, dep := range deps {
    76  			setUpDep(tasks[i], dep)
    77  		}
    78  	}
    79  	return tasks[0]
    80  }
    81  
    82  // simulateTasks simulates tasks changing states. It will simulate until ctx is
    83  // done.
    84  func simulateTasks(ctx context.Context, r *rand.Rand, tasks []*Task) {
    85  	for {
    86  		select {
    87  		case <-ctx.Done():
    88  			return
    89  		default:
    90  		}
    91  		// Pick a task randomly, and update it to a random state.
    92  		itask := r.Intn(len(tasks))
    93  		currState := tasks[itask].State()
    94  		nextState := TaskState(r.Intn(int(maxState)))
    95  		// Spin until we get a new state. This is the only place that task state
    96  		// is updated, so we don't have to worry about racing.
    97  		for currState == nextState {
    98  			nextState = TaskState(r.Intn(int(maxState)))
    99  		}
   100  		tasks[itask].Set(nextState)
   101  	}
   102  }
   103  
   104  // TestMonitorStateCounts verifies that monitorSliceStatus behaves reasonably
   105  // with simulated tasks.
   106  func TestMonitorSliceStatus(t *testing.T) {
   107  	if testing.Short() {
   108  		t.Skip("skipping test in short mode")
   109  	}
   110  	const numSlices = 100
   111  	const numTasksPerSlice = 10 * 1000
   112  	// The number of status messages to consume for the test.
   113  	const numStatuses = 10 * 1000 * 1000
   114  	r := rand.New(rand.NewSource(123))
   115  	// Set up simulation of slices and tasks.
   116  	tasks := makeFakeGraph(r, numSlices, numTasksPerSlice)
   117  	taskCounts := make(map[bigslice.Name]int32)
   118  	var allTasks []*Task
   119  	var numTasks int
   120  	_ = iterTasks(tasks, func(t *Task) error {
   121  		numTasks++
   122  		allTasks = append(allTasks, t)
   123  		for _, slice := range t.Slices {
   124  			taskCounts[slice.Name()]++
   125  		}
   126  		return nil
   127  	})
   128  	ctx, cancel := context.WithCancel(context.Background())
   129  	statusc := make(chan sliceStatus)
   130  	go monitorSliceStatus(ctx, tasks, statusc)
   131  	go simulateTasks(ctx, r, allTasks)
   132  	// Remember the last sliceStatus we saw for a slice to count how often it
   133  	// changes.
   134  	lastStatuses := make(map[bigslice.Name]sliceStatus)
   135  	// Count the number of status changes we see per slice to verify later.
   136  	sliceChanges := make(map[bigslice.Name]int)
   137  	for i := 0; i < numStatuses; i++ {
   138  		s := <-statusc
   139  		var total int32
   140  		for _, count := range s.counts {
   141  			total += count
   142  		}
   143  		if s != lastStatuses[s.sliceName] {
   144  			sliceChanges[s.sliceName]++
   145  		}
   146  		lastStatuses[s.sliceName] = s
   147  		if got, want := total < 1, false; got != want {
   148  			// If we got a status message, there must exist at least one task:
   149  			// the task that triggered the message.
   150  			t.Errorf("got %v, want %v", got, want)
   151  			continue
   152  		}
   153  		if got, want := numTasksPerSlice < total, false; got != want {
   154  			// The slice status counts more tasks than were associated with the
   155  			// slice.
   156  			t.Errorf("got %v, want %v", got, want)
   157  			continue
   158  		}
   159  		if i > (numSlices * numTasksPerSlice) {
   160  			if got, want := numTasksPerSlice != total, false; got != want {
   161  				// By now, we expect to have seen every task for every slice.
   162  				t.Logf("slice counts total: %d, expected: %d", total, taskCounts[s.sliceName])
   163  				t.Errorf("got %v, want %v", got, want)
   164  				continue
   165  			}
   166  		}
   167  	}
   168  	cancel()
   169  	var totalChanges int
   170  	for _, changes := range sliceChanges {
   171  		totalChanges += changes
   172  		if got, want := changes < numStatuses/numSlices/2, false; got != want {
   173  			// We expect some changes in the task state counts.
   174  			t.Logf("num changes: %d", changes)
   175  			t.Errorf("got %v, want %v", got, want)
   176  		}
   177  	}
   178  }