github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/session_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package exec
     6  
     7  import (
     8  	"context"
     9  	"reflect"
    10  	"sort"
    11  	"sync"
    12  	"sync/atomic"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/grailbio/base/log"
    17  	"github.com/grailbio/bigmachine/testsystem"
    18  	"github.com/grailbio/bigslice"
    19  	"github.com/grailbio/bigslice/frame"
    20  	"github.com/grailbio/bigslice/sliceio"
    21  	"github.com/grailbio/testutil/assert"
    22  	"github.com/grailbio/testutil/h"
    23  )
    24  
    25  func init() {
    26  	log.AddFlags()
    27  }
    28  
    29  func rangeSlice(i, j int) []int {
    30  	s := make([]int, j-i)
    31  	for k := range s {
    32  		s[k] = i + k
    33  	}
    34  	return s
    35  }
    36  
    37  func TestSessionIterative(t *testing.T) {
    38  	const (
    39  		Nelem  = 1000
    40  		Nshard = 5
    41  		Niter  = 5
    42  	)
    43  	var nvalues, nadd int
    44  	values := bigslice.Func(func() bigslice.Slice {
    45  		return bigslice.ReaderFunc(Nshard, func(shard int, n *int, out []int) (int, error) {
    46  			beg, end := shardRange(Nelem, Nshard, shard)
    47  			beg += *n
    48  			t.Logf("shard %d beg %d end %d n %d", shard, beg, end, *n)
    49  			if beg >= end { // empty or done
    50  				nvalues++
    51  				return 0, sliceio.EOF
    52  			}
    53  			m := copy(out, rangeSlice(beg, end))
    54  			*n += m
    55  			return m, nil
    56  		})
    57  	})
    58  	add := bigslice.Func(func(x int, slice bigslice.Slice) bigslice.Slice {
    59  		return bigslice.Map(slice, func(i int) int {
    60  			nadd++
    61  			return i + x
    62  		})
    63  	})
    64  	var (
    65  		ctx  = context.Background()
    66  		nrun int
    67  	)
    68  	testSession(t, func(t *testing.T, sess *Session) {
    69  		nrun++
    70  		res, err := sess.Run(ctx, values)
    71  		if err != nil {
    72  			t.Fatal(err)
    73  		}
    74  		for i := 0; i < Niter; i++ {
    75  			res, err = sess.Run(ctx, add, i, res)
    76  			if err != nil {
    77  				t.Fatal(err)
    78  			}
    79  		}
    80  		var (
    81  			scan = res.Scanner()
    82  			ints []int
    83  			x    int
    84  		)
    85  		defer scan.Close()
    86  		for scan.Scan(ctx, &x) {
    87  			ints = append(ints, x)
    88  		}
    89  		if err := scan.Err(); err != nil {
    90  			t.Fatal(err)
    91  		}
    92  		if got, want := ints, rangeSlice(10, 1010); !reflect.DeepEqual(got, want) {
    93  			t.Errorf("got %v, want %v", got, want)
    94  		}
    95  	})
    96  	if got, want := nvalues, nrun*Nshard; got != want {
    97  		t.Errorf("got %v, want %v", got, want)
    98  	}
    99  	if got, want := nadd, nrun*Niter*1000; got != want {
   100  		t.Errorf("got %v, want %v", got, want)
   101  	}
   102  }
   103  
   104  func TestSessionReuse(t *testing.T) {
   105  	const N = 1000
   106  	input := bigslice.Func(func() bigslice.Slice {
   107  		return bigslice.Const(5, rangeSlice(0, 1000))
   108  	})
   109  	var nmap int64
   110  	mapper := bigslice.Func(func(slice bigslice.Slice) bigslice.Slice {
   111  		return bigslice.Map(slice, func(i int) (int, int, int) {
   112  			atomic.AddInt64(&nmap, 1)
   113  			return i, i, i
   114  		})
   115  	})
   116  	reducer := bigslice.Func(func(slice bigslice.Slice) bigslice.Slice {
   117  		slice = bigslice.Map(slice, func(x, y, z int) (int, int, int) { return 0, y / 2, z })
   118  		slice = bigslice.Prefixed(slice, 2)
   119  		slice = bigslice.Reduce(slice, func(a, e int) int { return a + e })
   120  		slice = bigslice.Map(slice, func(k1, k2, v int) (int, int) { return k2, v })
   121  		return slice
   122  	})
   123  	unmap := bigslice.Func(func(slice bigslice.Slice) bigslice.Slice {
   124  		return bigslice.Map(slice, func(x, y, z int) (int, int) { return x, y + z })
   125  	})
   126  	ctx := context.Background()
   127  	testSession(t, func(t *testing.T, sess *Session) {
   128  		atomic.StoreInt64(&nmap, 0)
   129  		input := sess.Must(ctx, input)
   130  		mapped := sess.Must(ctx, mapper, input)
   131  		var wg sync.WaitGroup
   132  		var reduced *Result
   133  		wg.Add(1)
   134  		go func() {
   135  			reduced = sess.Must(ctx, reducer, mapped)
   136  			wg.Done()
   137  		}()
   138  		unmapped := sess.Must(ctx, unmap, mapped)
   139  		wg.Wait()
   140  		// The map results were reused:
   141  		if got, want := atomic.LoadInt64(&nmap), int64(N); got != want {
   142  			t.Errorf("got %v, want %v", got, want)
   143  		}
   144  		// And we computed the correct results:
   145  		var (
   146  			f = readFrame(t, reduced, N/2)
   147  			k = f.Interface(0).([]int)
   148  			v = f.Interface(1).([]int)
   149  		)
   150  		for i := range k {
   151  			if got, want := v[i], k[i]*4+1; got != want {
   152  				t.Errorf("index %d: got %v, want %v", i, got, want)
   153  			}
   154  		}
   155  
   156  		f = readFrame(t, unmapped, N)
   157  		k = f.Interface(0).([]int)
   158  		v = f.Interface(1).([]int)
   159  		for i := range k {
   160  			if got, want := v[i], k[i]*2; got != want {
   161  				t.Errorf("index %d: got %v, want %v", i, got, want)
   162  			}
   163  		}
   164  	})
   165  }
   166  
   167  // TestSessionFuncPanic verifies that the session survives a Func that panics
   168  // on invocation.
   169  func TestSessionFuncPanic(t *testing.T) {
   170  	panicker := bigslice.Func(func() bigslice.Slice {
   171  		panic("panic")
   172  	})
   173  	nonPanicker := bigslice.Func(func() bigslice.Slice {
   174  		return bigslice.Const(1, []int{})
   175  	})
   176  	ctx := context.Background()
   177  	testSession(t, func(t *testing.T, sess *Session) {
   178  		assert.That(t, func() { _, _ = sess.Run(ctx, panicker) }, h.Panics(h.NotNil()))
   179  		_, err := sess.Run(ctx, nonPanicker)
   180  		if err != nil {
   181  			t.Errorf("session did not survive panic")
   182  		}
   183  	})
   184  }
   185  
   186  // TestScanFaultTolerance verifies that result scanning is tolerant to machine
   187  // failure.
   188  func TestScanFaultTolerance(t *testing.T) {
   189  	if testing.Short() {
   190  		t.Skip("skipping test in short mode.")
   191  	}
   192  	const Nshard = 100
   193  	const N = Nshard * 10 * 1000
   194  	const Kills = 5
   195  	const KillInterval = N / (Kills + 1)
   196  	f := bigslice.Func(func() bigslice.Slice {
   197  		vs := make([]int, N)
   198  		for i := range vs {
   199  			vs[i] = i
   200  		}
   201  		return bigslice.Const(Nshard, vs)
   202  	})
   203  	sys := testsystem.New()
   204  	sys.Machineprocs = 3
   205  	// Use short periods/timeouts so that this test runs in reasonable time.
   206  	sys.KeepalivePeriod = 1 * time.Second
   207  	sys.KeepaliveTimeout = 1 * time.Second
   208  	sys.KeepaliveRpcTimeout = 1 * time.Second
   209  	var (
   210  		sess = Start(Bigmachine(sys), Parallelism(10))
   211  		ctx  = context.Background()
   212  	)
   213  	result, err := sess.Run(ctx, f)
   214  	if err != nil {
   215  		t.Fatalf("run failed")
   216  	}
   217  	scanner := result.Scanner()
   218  	var (
   219  		v  int
   220  		vs []int
   221  		i  int
   222  	)
   223  	for scanner.Scan(ctx, &v) {
   224  		vs = append(vs, v)
   225  		i++
   226  		if i%KillInterval == KillInterval-1 {
   227  			log.Printf("killing random machine")
   228  			sys.Kill(nil)
   229  		}
   230  	}
   231  	if err = scanner.Err(); err != nil {
   232  		t.Fatalf("scanner error:%v", err)
   233  	}
   234  	if got, want := len(vs), N; got != want {
   235  		t.Fatalf("got %v, want %v", got, want)
   236  	}
   237  	sort.Ints(vs)
   238  	for i := range vs {
   239  		if got, want := vs[i], i; got != want {
   240  			t.Fatalf("got %v, want %v", got, want)
   241  		}
   242  	}
   243  	if err = scanner.Err(); err != nil {
   244  		t.Fatalf("scanner error:%v", err)
   245  	}
   246  }
   247  
   248  // TestDiscard verifies that discarding a Result leaves its tasks TaskLost.
   249  func TestDiscard(t *testing.T) {
   250  	const Nshard = 10
   251  	const N = Nshard * 100
   252  	f := bigslice.Func(func() bigslice.Slice {
   253  		vs := make([]int, N)
   254  		for i := range vs {
   255  			vs[i] = i
   256  		}
   257  		// We set up a computation with a Reduce to:
   258  		// - break the pipeline so all tasks materialize some results.
   259  		// - have a non-tree task graph to verify that traversal works
   260  		//   correctly.
   261  		slice := bigslice.Const(Nshard, vs, vs)
   262  		slice = bigslice.Reduce(slice, func(int, int) int { return 0 })
   263  		return slice
   264  	})
   265  	testSession(t, func(t *testing.T, sess *Session) {
   266  		ctx := context.Background()
   267  		result, err := sess.Run(ctx, f)
   268  		if err != nil {
   269  			t.Fatal(err)
   270  		}
   271  		result.Discard(ctx)
   272  		_ = iterTasks(result.tasks, func(task *Task) error {
   273  			if got, want := task.State(), TaskLost; got != want {
   274  				t.Errorf("got %v, want %v", got, want)
   275  			}
   276  			return nil
   277  		})
   278  	})
   279  }
   280  
   281  var executors = map[string]Option{
   282  	"Local":           Local,
   283  	"Bigmachine.Test": Bigmachine(testsystem.New()),
   284  }
   285  
   286  func testSession(t *testing.T, run func(t *testing.T, sess *Session)) {
   287  	t.Helper()
   288  	for name, opt := range executors {
   289  		t.Run(name, func(t *testing.T) {
   290  			sess := Start(opt)
   291  			run(t, sess)
   292  		})
   293  	}
   294  }
   295  
   296  // shardRange gives the range covered by a shard.
   297  func shardRange(nelem, nshard, shard int) (beg, end int) {
   298  	elemsPerShard := (nelem + nshard - 1) / nshard
   299  	beg = elemsPerShard * shard
   300  	if beg >= nelem {
   301  		beg = 0
   302  		return
   303  	}
   304  	end = beg + elemsPerShard
   305  	if end > nelem {
   306  		end = nelem
   307  	}
   308  	return
   309  }
   310  
   311  func readFrame(t *testing.T, res *Result, n int) frame.Frame {
   312  	t.Helper()
   313  	f := frame.Make(res, n+1, n+1)
   314  	ctx := context.Background()
   315  	reader := res.open()
   316  	defer reader.Close()
   317  	m, err := sliceio.ReadFull(ctx, reader, f)
   318  	if err != sliceio.EOF {
   319  		t.Fatal(err)
   320  	}
   321  	if got, want := m, n; got != want {
   322  		t.Fatalf("got %v, want %v", got, want)
   323  	}
   324  	return f.Slice(0, n)
   325  }