github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/session_test.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package exec 6 7 import ( 8 "context" 9 "reflect" 10 "sort" 11 "sync" 12 "sync/atomic" 13 "testing" 14 "time" 15 16 "github.com/grailbio/base/log" 17 "github.com/grailbio/bigmachine/testsystem" 18 "github.com/grailbio/bigslice" 19 "github.com/grailbio/bigslice/frame" 20 "github.com/grailbio/bigslice/sliceio" 21 "github.com/grailbio/testutil/assert" 22 "github.com/grailbio/testutil/h" 23 ) 24 25 func init() { 26 log.AddFlags() 27 } 28 29 func rangeSlice(i, j int) []int { 30 s := make([]int, j-i) 31 for k := range s { 32 s[k] = i + k 33 } 34 return s 35 } 36 37 func TestSessionIterative(t *testing.T) { 38 const ( 39 Nelem = 1000 40 Nshard = 5 41 Niter = 5 42 ) 43 var nvalues, nadd int 44 values := bigslice.Func(func() bigslice.Slice { 45 return bigslice.ReaderFunc(Nshard, func(shard int, n *int, out []int) (int, error) { 46 beg, end := shardRange(Nelem, Nshard, shard) 47 beg += *n 48 t.Logf("shard %d beg %d end %d n %d", shard, beg, end, *n) 49 if beg >= end { // empty or done 50 nvalues++ 51 return 0, sliceio.EOF 52 } 53 m := copy(out, rangeSlice(beg, end)) 54 *n += m 55 return m, nil 56 }) 57 }) 58 add := bigslice.Func(func(x int, slice bigslice.Slice) bigslice.Slice { 59 return bigslice.Map(slice, func(i int) int { 60 nadd++ 61 return i + x 62 }) 63 }) 64 var ( 65 ctx = context.Background() 66 nrun int 67 ) 68 testSession(t, func(t *testing.T, sess *Session) { 69 nrun++ 70 res, err := sess.Run(ctx, values) 71 if err != nil { 72 t.Fatal(err) 73 } 74 for i := 0; i < Niter; i++ { 75 res, err = sess.Run(ctx, add, i, res) 76 if err != nil { 77 t.Fatal(err) 78 } 79 } 80 var ( 81 scan = res.Scanner() 82 ints []int 83 x int 84 ) 85 defer scan.Close() 86 for scan.Scan(ctx, &x) { 87 ints = append(ints, x) 88 } 89 if err := scan.Err(); err != nil { 90 t.Fatal(err) 91 } 92 if got, want := ints, rangeSlice(10, 1010); !reflect.DeepEqual(got, want) { 93 t.Errorf("got %v, want %v", got, want) 94 } 95 }) 96 if got, want := nvalues, nrun*Nshard; got != want { 97 t.Errorf("got %v, want %v", got, want) 98 } 99 if got, want := nadd, nrun*Niter*1000; got != want { 100 t.Errorf("got %v, want %v", got, want) 101 } 102 } 103 104 func TestSessionReuse(t *testing.T) { 105 const N = 1000 106 input := bigslice.Func(func() bigslice.Slice { 107 return bigslice.Const(5, rangeSlice(0, 1000)) 108 }) 109 var nmap int64 110 mapper := bigslice.Func(func(slice bigslice.Slice) bigslice.Slice { 111 return bigslice.Map(slice, func(i int) (int, int, int) { 112 atomic.AddInt64(&nmap, 1) 113 return i, i, i 114 }) 115 }) 116 reducer := bigslice.Func(func(slice bigslice.Slice) bigslice.Slice { 117 slice = bigslice.Map(slice, func(x, y, z int) (int, int, int) { return 0, y / 2, z }) 118 slice = bigslice.Prefixed(slice, 2) 119 slice = bigslice.Reduce(slice, func(a, e int) int { return a + e }) 120 slice = bigslice.Map(slice, func(k1, k2, v int) (int, int) { return k2, v }) 121 return slice 122 }) 123 unmap := bigslice.Func(func(slice bigslice.Slice) bigslice.Slice { 124 return bigslice.Map(slice, func(x, y, z int) (int, int) { return x, y + z }) 125 }) 126 ctx := context.Background() 127 testSession(t, func(t *testing.T, sess *Session) { 128 atomic.StoreInt64(&nmap, 0) 129 input := sess.Must(ctx, input) 130 mapped := sess.Must(ctx, mapper, input) 131 var wg sync.WaitGroup 132 var reduced *Result 133 wg.Add(1) 134 go func() { 135 reduced = sess.Must(ctx, reducer, mapped) 136 wg.Done() 137 }() 138 unmapped := sess.Must(ctx, unmap, mapped) 139 wg.Wait() 140 // The map results were reused: 141 if got, want := atomic.LoadInt64(&nmap), int64(N); got != want { 142 t.Errorf("got %v, want %v", got, want) 143 } 144 // And we computed the correct results: 145 var ( 146 f = readFrame(t, reduced, N/2) 147 k = f.Interface(0).([]int) 148 v = f.Interface(1).([]int) 149 ) 150 for i := range k { 151 if got, want := v[i], k[i]*4+1; got != want { 152 t.Errorf("index %d: got %v, want %v", i, got, want) 153 } 154 } 155 156 f = readFrame(t, unmapped, N) 157 k = f.Interface(0).([]int) 158 v = f.Interface(1).([]int) 159 for i := range k { 160 if got, want := v[i], k[i]*2; got != want { 161 t.Errorf("index %d: got %v, want %v", i, got, want) 162 } 163 } 164 }) 165 } 166 167 // TestSessionFuncPanic verifies that the session survives a Func that panics 168 // on invocation. 169 func TestSessionFuncPanic(t *testing.T) { 170 panicker := bigslice.Func(func() bigslice.Slice { 171 panic("panic") 172 }) 173 nonPanicker := bigslice.Func(func() bigslice.Slice { 174 return bigslice.Const(1, []int{}) 175 }) 176 ctx := context.Background() 177 testSession(t, func(t *testing.T, sess *Session) { 178 assert.That(t, func() { _, _ = sess.Run(ctx, panicker) }, h.Panics(h.NotNil())) 179 _, err := sess.Run(ctx, nonPanicker) 180 if err != nil { 181 t.Errorf("session did not survive panic") 182 } 183 }) 184 } 185 186 // TestScanFaultTolerance verifies that result scanning is tolerant to machine 187 // failure. 188 func TestScanFaultTolerance(t *testing.T) { 189 if testing.Short() { 190 t.Skip("skipping test in short mode.") 191 } 192 const Nshard = 100 193 const N = Nshard * 10 * 1000 194 const Kills = 5 195 const KillInterval = N / (Kills + 1) 196 f := bigslice.Func(func() bigslice.Slice { 197 vs := make([]int, N) 198 for i := range vs { 199 vs[i] = i 200 } 201 return bigslice.Const(Nshard, vs) 202 }) 203 sys := testsystem.New() 204 sys.Machineprocs = 3 205 // Use short periods/timeouts so that this test runs in reasonable time. 206 sys.KeepalivePeriod = 1 * time.Second 207 sys.KeepaliveTimeout = 1 * time.Second 208 sys.KeepaliveRpcTimeout = 1 * time.Second 209 var ( 210 sess = Start(Bigmachine(sys), Parallelism(10)) 211 ctx = context.Background() 212 ) 213 result, err := sess.Run(ctx, f) 214 if err != nil { 215 t.Fatalf("run failed") 216 } 217 scanner := result.Scanner() 218 var ( 219 v int 220 vs []int 221 i int 222 ) 223 for scanner.Scan(ctx, &v) { 224 vs = append(vs, v) 225 i++ 226 if i%KillInterval == KillInterval-1 { 227 log.Printf("killing random machine") 228 sys.Kill(nil) 229 } 230 } 231 if err = scanner.Err(); err != nil { 232 t.Fatalf("scanner error:%v", err) 233 } 234 if got, want := len(vs), N; got != want { 235 t.Fatalf("got %v, want %v", got, want) 236 } 237 sort.Ints(vs) 238 for i := range vs { 239 if got, want := vs[i], i; got != want { 240 t.Fatalf("got %v, want %v", got, want) 241 } 242 } 243 if err = scanner.Err(); err != nil { 244 t.Fatalf("scanner error:%v", err) 245 } 246 } 247 248 // TestDiscard verifies that discarding a Result leaves its tasks TaskLost. 249 func TestDiscard(t *testing.T) { 250 const Nshard = 10 251 const N = Nshard * 100 252 f := bigslice.Func(func() bigslice.Slice { 253 vs := make([]int, N) 254 for i := range vs { 255 vs[i] = i 256 } 257 // We set up a computation with a Reduce to: 258 // - break the pipeline so all tasks materialize some results. 259 // - have a non-tree task graph to verify that traversal works 260 // correctly. 261 slice := bigslice.Const(Nshard, vs, vs) 262 slice = bigslice.Reduce(slice, func(int, int) int { return 0 }) 263 return slice 264 }) 265 testSession(t, func(t *testing.T, sess *Session) { 266 ctx := context.Background() 267 result, err := sess.Run(ctx, f) 268 if err != nil { 269 t.Fatal(err) 270 } 271 result.Discard(ctx) 272 _ = iterTasks(result.tasks, func(task *Task) error { 273 if got, want := task.State(), TaskLost; got != want { 274 t.Errorf("got %v, want %v", got, want) 275 } 276 return nil 277 }) 278 }) 279 } 280 281 var executors = map[string]Option{ 282 "Local": Local, 283 "Bigmachine.Test": Bigmachine(testsystem.New()), 284 } 285 286 func testSession(t *testing.T, run func(t *testing.T, sess *Session)) { 287 t.Helper() 288 for name, opt := range executors { 289 t.Run(name, func(t *testing.T) { 290 sess := Start(opt) 291 run(t, sess) 292 }) 293 } 294 } 295 296 // shardRange gives the range covered by a shard. 297 func shardRange(nelem, nshard, shard int) (beg, end int) { 298 elemsPerShard := (nelem + nshard - 1) / nshard 299 beg = elemsPerShard * shard 300 if beg >= nelem { 301 beg = 0 302 return 303 } 304 end = beg + elemsPerShard 305 if end > nelem { 306 end = nelem 307 } 308 return 309 } 310 311 func readFrame(t *testing.T, res *Result, n int) frame.Frame { 312 t.Helper() 313 f := frame.Make(res, n+1, n+1) 314 ctx := context.Background() 315 reader := res.open() 316 defer reader.Close() 317 m, err := sliceio.ReadFull(ctx, reader, f) 318 if err != sliceio.EOF { 319 t.Fatal(err) 320 } 321 if got, want := m, n; got != want { 322 t.Fatalf("got %v, want %v", got, want) 323 } 324 return f.Slice(0, n) 325 }