github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/compile_test.go (about) 1 // Copyright 2019 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package exec 6 7 import ( 8 "bytes" 9 "context" 10 "fmt" 11 "io/ioutil" 12 "sort" 13 "strings" 14 "testing" 15 16 "github.com/grailbio/bigslice" 17 "github.com/grailbio/bigslice/frame" 18 "github.com/grailbio/bigslice/internal/slicecache" 19 "github.com/grailbio/bigslice/slicefunc" 20 "github.com/grailbio/bigslice/sliceio" 21 ) 22 23 func TestCompile(t *testing.T) { 24 for _, c := range []struct { 25 name string 26 f func() bigslice.Slice 27 }{ 28 { 29 "trivial", 30 func() (slice bigslice.Slice) { 31 slice = bigslice.Const(3, []int{}) 32 return 33 }, 34 }, 35 { 36 "shuffle", 37 func() (slice bigslice.Slice) { 38 slice = bigslice.Const(3, []int{}, []float64{}) 39 slice = bigslice.Reduce(slice, func(v0, v1 float64) float64 { return v0 + v1 }) 40 return 41 }, 42 }, 43 { 44 // Branch where both branches pipeline with the subsequent maps. 45 "branch", 46 func() (slice bigslice.Slice) { 47 slice = bigslice.Const(3, []int{}) 48 slice = bigslice.Map(slice, func(i int) int { return i }) 49 slice0 := bigslice.Map(slice, func(i int) int { return i }) 50 slice1 := bigslice.Map(slice, func(i int) int { return i }) 51 slice = bigslice.Cogroup(slice0, slice1) 52 return 53 }, 54 }, 55 { 56 // Branch from a materialized slice, so the subsequent maps are not 57 // pipelined through the materialized tasks. 58 "branch-materialize", 59 func() (slice bigslice.Slice) { 60 slice = bigslice.Const(3, []int{}) 61 slice = bigslice.Map(slice, func(i int) int { return i }, bigslice.ExperimentalMaterialize) 62 slice0 := bigslice.Map(slice, func(i int) int { return i }) 63 slice1 := bigslice.Map(slice, func(i int) int { return i }) 64 slice = bigslice.Cogroup(slice0, slice1) 65 return 66 }, 67 }, 68 { 69 // Branch the const slice with a reduce, which introduces its own 70 // shuffle/combiner, so the const slice tasks cannot be reused. 71 "branch-shuffle", 72 func() (slice bigslice.Slice) { 73 slice = bigslice.Const(3, []int{}, []float64{}) 74 slice0 := bigslice.Reduce(slice, func(v0, v1 float64) float64 { return v0 + v1 }) 75 slice = bigslice.Cogroup(slice, slice0) 76 return 77 }, 78 }, 79 { 80 // Branch where each branch demands the same partition number from 81 // the branch point slice. In this case, the branch point tasks can 82 // be reused. 83 "branch-same-partitions", 84 func() (slice bigslice.Slice) { 85 slice = bigslice.Const(3, []int{}) 86 slice = bigslice.Map(slice, func(i int) int { return i }) 87 slice0 := bigslice.Reshard(slice, 2) 88 slice1 := bigslice.Reshard(slice, 2) 89 slice = bigslice.Cogroup(slice0, slice1) 90 return 91 }, 92 }, 93 { 94 // Branch where each branch demands different partition numbers from 95 // the branch point slice. In this case, the branch point tasks 96 // cannot be reused. 97 "branch-different-partitions", 98 func() (slice bigslice.Slice) { 99 slice = bigslice.Const(3, []int{}) 100 slice = bigslice.Map(slice, func(i int) int { return i }) 101 slice0 := bigslice.Reshard(slice, 1) 102 slice1 := bigslice.Reshard(slice, 2) 103 slice = bigslice.Cogroup(slice0, slice1) 104 return 105 }, 106 }, 107 } { 108 t.Run(c.name, func(t *testing.T) { 109 f := bigslice.Func(c.f) 110 inv := makeExecInvocation(f.Invocation("<unknown>")) 111 inv.Index = 1 112 slice := inv.Invoke() 113 tasks, err := compile(inv, slice, false) 114 if err != nil { 115 t.Fatalf("compilation failed") 116 } 117 _ = iterTasks(tasks, func(task *Task) error { 118 if task.Pragma == nil { 119 t.Errorf("%v has nil task.Pragma", task) 120 } 121 return nil 122 }) 123 g := makeGraph(tasks) 124 want, err := ioutil.ReadFile("testdata/" + c.name + ".graph") 125 if err != nil { 126 t.Fatalf("error reading graph: %v", err) 127 } 128 d := lineDiff(g.String(), string(want)) 129 if d != "" { 130 t.Errorf("differs from %s.graph:\n%s", c.name, d) 131 } 132 }) 133 } 134 } 135 136 // TestCompileEnv verifies that the compileEnv is used and behaves properly, 137 // specifically verifying that compilation correctly writes to writable 138 // environments and reads from non-writable environments. 139 func TestCompileEnv(t *testing.T) { 140 const Nshard = 8 141 142 // cachedShards is set up just before we invoke the Func. It represents the 143 // fake cache state from the perspective of that invocation. 144 var cachedShards []int 145 f := bigslice.Func(func() bigslice.Slice { 146 slice := bigslice.Const(Nshard, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) 147 // Break the pipeline, as we use this to detect for which compiled tasks 148 // compilation considered the cache valid. If the cache is valid, the 149 // compiled root task will have no dependencies. 150 slice = bigslice.Reshuffle(slice) 151 slice = fakeCache(slice, cachedShards) 152 return slice 153 }) 154 inv := makeExecInvocation(f.Invocation("<unknown>")) 155 inv.Index = 0 156 157 cachedShardsFrozen := []int{1, 4, 5} 158 cachedShards = cachedShardsFrozen 159 slice0 := inv.Invoke() 160 tasks, err := compile(inv, slice0, false) 161 if err != nil { 162 t.Fatalf("compilation failed") 163 } 164 for _, task := range tasks { 165 var cached bool 166 for _, shard := range cachedShardsFrozen { 167 if shard == task.Name.Shard { 168 cached = true 169 } 170 } 171 // Verify that the resulting tasks reflect the cache state. 172 if got, want := len(task.Deps) == 0, cached; got != want { 173 t.Errorf("got %v, want %v", got, want) 174 } 175 } 176 177 // Freeze the environment, and verify that compilation uses the environment 178 // and not the current cache state. 179 inv.Env.Freeze() 180 cachedShards = []int{2, 4, 7} // different cache state from above. 181 slice1 := inv.Invoke() 182 tasks, err = compile(inv, slice1, false) 183 if err != nil { 184 t.Fatalf("compilation failed") 185 } 186 for _, task := range tasks { 187 var cached bool 188 for _, shard := range cachedShardsFrozen { 189 if shard == task.Name.Shard { 190 cached = true 191 } 192 } 193 // Verify that the tasks are compiled according to the environment that 194 // reflects cachedShardsFrozen, and not the current cache state in 195 // cachedShards. 196 if got, want := len(task.Deps) == 0, cached; got != want { 197 t.Errorf("got %v, want %v", got, want) 198 } 199 } 200 } 201 202 // TestPipelinedCache verifies that cacheable slices that are pipelined for 203 // execution behave as we expect. 204 func TestPipelinedCache(t *testing.T) { 205 const Nshard = 8 206 f := bigslice.Func(func() bigslice.Slice { 207 slice := bigslice.Const(Nshard, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) 208 // Break the pipeline, as we use this to detect for which compiled tasks 209 // compilation considered the cache valid. If the cache is valid, the 210 // compiled root task will have no dependencies. 211 slice = bigslice.Reshuffle(slice) 212 id := func(i int) int { return i } 213 // These slices will be pipelined. We set it up with different shards 214 // cached in different slices, with some shards not cached at all. 215 // When we examine the resulting dependencies of the (pipelined) task, 216 // we should only see dependencies for shards without any cache, as 217 // only those need the upstream results. 218 slice = fakeCache(bigslice.Map(slice, id), []int{0, 2}) 219 slice = bigslice.Map(slice, id) 220 slice = fakeCache(bigslice.Map(slice, id), []int{5, 7}) 221 slice = bigslice.Map(slice, id) 222 return slice 223 }) 224 inv := makeExecInvocation(f.Invocation("<unknown>")) 225 inv.Index = 0 226 slice0 := inv.Invoke() 227 tasks, err := compile(inv, slice0, false) 228 if err != nil { 229 t.Fatalf("compilation failed") 230 } 231 // These are all the shards that we expect to be computable without 232 // dependencies, as some part of the (pipelined) computation is cached. 233 // This is the union of the shards cached in our fakeCache slices. 234 noDeps := []int{0, 2, 5, 7} 235 for _, task := range tasks { 236 var inNoDeps bool 237 for _, shard := range noDeps { 238 if shard == task.Name.Shard { 239 inNoDeps = true 240 } 241 } 242 // Verify that the resulting tasks reflect the cache state. 243 if got, want := len(task.Deps) == 0, inNoDeps; got != want { 244 t.Errorf("got %v, want %v", got, want) 245 } 246 // Invoke Do to verify that we can construct our pipelined computation. 247 // There have been bugs for which this call would panic. Note that 248 // this is somewhat fragile, as we assume that Do does not access the 249 // input readers, instead only composing readers to represent the 250 // pipeline. 251 task.Do([]sliceio.Reader{sliceio.EmptyReader{}}) 252 } 253 } 254 255 // makeGraph returns a graph representation of the task graph roots that is 256 // convenient for printing and comparing. We use this to verify (and debug) 257 // compilation results. 258 func makeGraph(roots []*Task) graph { 259 var ( 260 visited = make(map[*Task]bool) 261 g graph 262 walk func(tasks []*Task) 263 ) 264 walk = func(tasks []*Task) { 265 if len(tasks) == 0 { 266 return 267 } 268 for _, t := range tasks { 269 if visited[t] { 270 continue 271 } 272 visited[t] = true 273 g.nodes = append(g.nodes, t.Name.String()) 274 for _, d := range t.Deps { 275 for i := 0; i < d.NumTask(); i++ { 276 edge := edge{t.Name.String(), d.Task(i).Name.String()} 277 g.edges = append(g.edges, edge) 278 walk([]*Task{d.Task(i)}) 279 } 280 } 281 282 } 283 } 284 walk(roots) 285 g.Sort() 286 return g 287 } 288 289 type edge struct { 290 src string 291 dst string 292 } 293 294 type graph struct { 295 nodes []string 296 edges []edge 297 } 298 299 func (g graph) Sort() { 300 sort.Strings(g.nodes) 301 sort.Slice(g.edges, func(i, j int) bool { 302 if g.edges[i].src != g.edges[j].src { 303 return g.edges[i].src < g.edges[j].src 304 } 305 return g.edges[i].dst < g.edges[j].dst 306 }) 307 } 308 309 func (g graph) String() string { 310 var b bytes.Buffer 311 for _, n := range g.nodes { 312 fmt.Fprintf(&b, "%s\n", n) 313 } 314 for _, e := range g.edges { 315 fmt.Fprintf(&b, "%s -> %s\n", e.src, e.dst) 316 } 317 return b.String() 318 } 319 320 func lineDiff(lhs, rhs string) string { 321 lhsLines := strings.Split(lhs, "\n") 322 rhsLines := strings.Split(rhs, "\n") 323 324 // This is a vanilla Levenshtein distance implementation. 325 const ( 326 editNone = iota 327 editAdd 328 editDel 329 editRep 330 ) 331 type cell struct { 332 edit int 333 cost int 334 } 335 cells := make([][]cell, len(lhsLines)+1) 336 for i := range cells { 337 cells[i] = make([]cell, len(rhsLines)+1) 338 } 339 for i := 1; i < len(lhsLines)+1; i++ { 340 cells[i][0].edit = editDel 341 cells[i][0].cost = i 342 } 343 for j := 1; j < len(rhsLines)+1; j++ { 344 cells[0][j].edit = editAdd 345 cells[0][j].cost = j 346 } 347 for i := 1; i < len(lhsLines)+1; i++ { 348 for j := 1; j < len(rhsLines)+1; j++ { 349 if lhsLines[i-1] == rhsLines[j-1] { 350 cells[i][j].cost = cells[i-1][j-1].cost 351 continue 352 } 353 repCost := cells[i-1][j-1].cost + 1 354 minCost := repCost 355 delCost := cells[i-1][j].cost + 1 356 if delCost < minCost { 357 minCost = delCost 358 } 359 addCost := cells[i][j-1].cost + 1 360 if addCost < minCost { 361 minCost = addCost 362 } 363 cells[i][j].cost = minCost 364 switch minCost { 365 case repCost: 366 cells[i][j].edit = editRep 367 case addCost: 368 cells[i][j].edit = editAdd 369 case delCost: 370 cells[i][j].edit = editDel 371 } 372 } 373 } 374 var ( 375 d []string 376 differ bool 377 ) 378 for i, j := len(lhsLines), len(rhsLines); i > 0 || j > 0; { 379 switch cells[i][j].edit { 380 case editNone: 381 d = append(d, lhsLines[i-1]) 382 i-- 383 j-- 384 case editAdd: 385 d = append(d, "+ "+rhsLines[j-1]) 386 j-- 387 differ = true 388 case editDel: 389 d = append(d, "- "+lhsLines[i-1]) 390 i-- 391 differ = true 392 case editRep: 393 d = append(d, "+ "+rhsLines[j-1]) 394 d = append(d, "- "+lhsLines[i-1]) 395 i-- 396 j-- 397 differ = true 398 } 399 } 400 if !differ { 401 return "" 402 } 403 for i := len(d)/2 - 1; i >= 0; i-- { 404 opp := len(d) - 1 - i 405 d[i], d[opp] = d[opp], d[i] 406 } 407 var b bytes.Buffer 408 for _, dLine := range d { 409 b.WriteString(dLine + "\n") 410 } 411 return b.String() 412 } 413 414 type fakeShardCache struct { 415 cachedSet map[int]bool 416 } 417 418 func (c fakeShardCache) IsCached(shard int) bool { return c.cachedSet[shard] } 419 func (fakeShardCache) WritethroughReader(shard int, reader sliceio.Reader) sliceio.Reader { 420 return reader 421 } 422 func (fakeShardCache) CacheReader(shard int) sliceio.Reader { 423 return emptyReader{} 424 } 425 426 type emptyReader struct{} 427 428 func (emptyReader) Read(ctx context.Context, frame frame.Frame) (int, error) { 429 return 0, sliceio.EOF 430 } 431 432 type fakeCacheSlice struct { 433 name bigslice.Name 434 bigslice.Slice 435 cache slicecache.ShardCache 436 } 437 438 func (c *fakeCacheSlice) Name() bigslice.Name { return c.name } 439 func (c *fakeCacheSlice) NumDep() int { return 1 } 440 func (c *fakeCacheSlice) Dep(i int) bigslice.Dep { 441 return bigslice.Dep{ 442 Slice: c.Slice, 443 Shuffle: false, 444 Partitioner: nil, 445 Expand: false, 446 } 447 } 448 func (*fakeCacheSlice) Combiner() slicefunc.Func { return slicefunc.Nil } 449 func (c *fakeCacheSlice) Reader(shard int, deps []sliceio.Reader) sliceio.Reader { return deps[0] } 450 func (c *fakeCacheSlice) Cache() slicecache.ShardCache { return c.cache } 451 452 func fakeCache(slice bigslice.Slice, cachedShards []int) bigslice.Slice { 453 cachedSet := make(map[int]bool) 454 for _, shard := range cachedShards { 455 cachedSet[shard] = true 456 } 457 return &fakeCacheSlice{bigslice.MakeName("testcache"), slice, fakeShardCache{cachedSet}} 458 }