github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/task.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package exec 6 7 import ( 8 "bytes" 9 "context" 10 "errors" 11 "fmt" 12 "io" 13 "sort" 14 "strings" 15 "sync" 16 "text/tabwriter" 17 18 "github.com/grailbio/base/status" 19 "github.com/grailbio/base/sync/ctxsync" 20 "github.com/grailbio/bigslice" 21 "github.com/grailbio/bigslice/metrics" 22 "github.com/grailbio/bigslice/slicefunc" 23 "github.com/grailbio/bigslice/sliceio" 24 "github.com/grailbio/bigslice/slicetype" 25 ) 26 27 func init() { 28 close(closedc) 29 } 30 31 // closedc is closed in init which can be used any time we just want a closed 32 // channel (i.e. a channel that is always ready and receives a zero value). 33 var closedc = make(chan struct{}) 34 35 // ErrTaskLost indicates that a Task was in TaskLost state. 36 var ErrTaskLost = errors.New("task was lost") 37 38 // TaskState represents the runtime state of a Task. TaskState 39 // values are defined so that their magnitudes correspond with 40 // task progression. 41 type TaskState int 42 43 const ( 44 // TaskInit is the initial state of a task. Tasks in state TaskInit 45 // have usually not yet been seen by an executor. 46 TaskInit TaskState = iota 47 48 // TaskWaiting indicates that a task has been scheduled for 49 // execution (it is runnable) but has not yet been allocated 50 // resources by the executor. 51 TaskWaiting 52 // TaskRunning is the state of a task that's currently being run or 53 // discarded. After a task is in state TaskRunning, it can only enter a 54 // larger-valued state. 55 TaskRunning 56 57 // TaskOk indicates that a task has successfully completed; 58 // the task's results are available to dependent tasks. 59 // 60 // All TaskState values greater than TaskOk indicate task 61 // errors. 62 TaskOk 63 64 // TaskErr indicates that the task experienced a failure while 65 // running. 66 TaskErr 67 // TaskLost indicates that the task was lost, usually because 68 // the machine to which the task was assigned failed. 69 TaskLost 70 71 maxState 72 ) 73 74 var states = [...]string{ 75 TaskInit: "INIT", 76 TaskWaiting: "WAITING", 77 TaskRunning: "RUNNING", 78 TaskOk: "OK", 79 TaskErr: "ERROR", 80 TaskLost: "LOST", 81 } 82 83 // String returns the task's state as an upper-case string. 84 func (s TaskState) String() string { 85 return states[s] 86 } 87 88 // A TaskDep describes a single dependency for a task. A dependency 89 // comprises one or more tasks and the partition number of the task 90 // set that must be read at run time. 91 type TaskDep struct { 92 // Head holds the underlying task that represents this dependency. 93 // For shuffle dependencies, that task is the head task of the 94 // phase, and the evaluator must expand the phase. 95 Head *Task 96 Partition int 97 98 // Expand indicates that the task's dependencies for a given 99 // partition should not be merged, but rather passed individually to 100 // the task implementation. 101 Expand bool 102 103 // CombineKey is an optional label that names the combination key to 104 // be used by this dependency. It is used to name a single combiner 105 // buffer from which is read a number of combined tasks. 106 // 107 // CombineKeys must be provided to tasks that contain combiners. 108 CombineKey string 109 } 110 111 // NumTask returns the number of tasks that are comprised by this dependency. 112 func (d TaskDep) NumTask() int { 113 if d.Head == nil { 114 return 0 115 } 116 if n := len(d.Head.Group); n > 0 { 117 return n 118 } 119 return 1 120 } 121 122 // Task returns the i'th task comprised by this dependency. 123 func (d TaskDep) Task(i int) *Task { 124 if i == 0 { 125 return d.Head 126 } 127 return d.Head.Group[i] 128 } 129 130 // A TaskName uniquely names a task by its constituent components. 131 // Tasks with 0 shards are taken to be combiner tasks: they are 132 // machine-local buffers of combiner outputs for some (non-overlapping) 133 // subset of shards for a task. 134 type TaskName struct { 135 // InvIndex is the index of the invocation for which the task was compiled. 136 InvIndex uint64 137 // Op is a unique string describing the operation that is provided 138 // by the task. 139 Op string 140 // Shard and NumShard describe the shard processed by this task 141 // and the total number of shards to be processed. 142 Shard, NumShard int 143 } 144 145 // String returns a canonical representation of the task name, 146 // formatted as: 147 // 148 // {n.Op}@{n.NumShard}:{n.Shard} 149 // {n.Op}_combiner 150 func (n TaskName) String() string { 151 if n.NumShard == 0 { 152 return n.Op + "_combiner" 153 } 154 return fmt.Sprintf("%s@%d:%d", n.Op, n.NumShard, n.Shard) 155 } 156 157 // IsCombiner returns whether the named task is a combiner task. 158 func (n TaskName) IsCombiner() bool { 159 return n.NumShard == 0 160 } 161 162 // TaskSubscriber is subscribed to a Task using Subscribe. It is then notified 163 // whenever the Task state changes. This is useful for efficiently observing the 164 // state changes of many tasks. 165 type TaskSubscriber struct { 166 sync.Mutex 167 cond *ctxsync.Cond 168 169 // tasks holds the set of tasks that has changed since the last call to 170 // Tasks. 171 tasks map[*Task]struct{} 172 } 173 174 // NewTaskSubscriber returns a new TaskSubscriber. It needs to be subscribed to 175 // a Task with Subscribe for it to be notified of task state changes. 176 func NewTaskSubscriber() *TaskSubscriber { 177 s := &TaskSubscriber{tasks: make(map[*Task]struct{})} 178 s.cond = ctxsync.NewCond(s) 179 return s 180 } 181 182 // Notify notifies s of a task whose state has changed. 183 func (s *TaskSubscriber) Notify(task *Task) { 184 s.Lock() 185 defer s.Unlock() 186 s.tasks[task] = struct{}{} 187 s.cond.Broadcast() 188 } 189 190 // Ready returns a channel that is closed if a subsequent call to Tasks will 191 // return a non-nil slice. 192 func (s *TaskSubscriber) Ready() <-chan struct{} { 193 s.Lock() 194 if len(s.tasks) > 0 { 195 s.Unlock() 196 return closedc 197 } 198 return s.cond.Done() 199 } 200 201 // Tasks returns the tasks whose state has changed since the last call to Tasks. 202 func (s *TaskSubscriber) Tasks() []*Task { 203 s.Lock() 204 defer s.Unlock() 205 tasks := make([]*Task, 0, len(s.tasks)) 206 for task := range s.tasks { 207 tasks = append(tasks, task) 208 } 209 s.tasks = make(map[*Task]struct{}) 210 return tasks 211 } 212 213 // A Task represents a concrete computational task. Tasks form graphs 214 // through dependencies; task graphs are compiled from slices. 215 // 216 // Tasks also maintain executor state, and are used to coordinate 217 // execution between concurrent evaluators and a single executor 218 // (which may be evaluating many tasks concurrently). Tasks thus 219 // embed a mutex for coordination and provide a context-aware 220 // conditional variable to coordinate runtime state changes. 221 type Task struct { 222 slicetype.Type 223 // Invocation is the task's invocation, i.e. the Func invocation 224 // from which this task was compiled. 225 Invocation execInvocation 226 // Name is the name of the task. Tasks are named uniquely inside each 227 // Bigslice session. 228 Name TaskName 229 // Do starts computation for this task, returning a reader that 230 // computes batches of values on demand. Do is invoked with readers 231 // for the task's dependencies. 232 Do func([]sliceio.Reader) sliceio.Reader 233 // Deps are the task's dependencies. See TaskDep for details. 234 Deps []TaskDep 235 236 // Partitioner is used to partition the task's output. It will only 237 // be called when NumPartition > 1. 238 Partitioner bigslice.Partitioner 239 // NumPartition is the number of partitions that are output by this task. 240 // If NumPartition > 1, then the task must also define a partitioner. 241 NumPartition int 242 243 // Combiner specifies an (optional) combiner to use for this task's output. 244 // If a Combiner is not Nil, CombineKey names the combine buffer used: 245 // each combine buffer contains combiner outputs from multiple tasks. 246 // If CombineKey is not set, then per-task buffers are used instead. 247 Combiner slicefunc.Func 248 CombineKey string 249 250 // Pragma comprises the pragmas of all slice operations that 251 // are pipelined into this task. 252 bigslice.Pragma 253 254 // Slices is the set of slices to which this task directly contributes. 255 Slices []bigslice.Slice 256 257 // Group stores an ordered list of peer tasks. If Group is nonempty, 258 // it is guaranteed that these sets of tasks constitute a shuffle 259 // dependency, and share a set of shuffle dependencies. This allows 260 // the evaluator to perform optimizations while tracking such 261 // dependencies. 262 Group []*Task 263 264 // Scopes is the metrics scope for this task. It is populated with the 265 // metrics produced during execution of this task. 266 Scope metrics.Scope 267 268 // subs is the set of subscribers to which this task will be sent whenever 269 // its state changes. 270 subs []*TaskSubscriber 271 272 // The following are used to coordinate runtime execution. 273 274 sync.Mutex 275 waitc chan struct{} 276 277 // State is the task's state. It is protected by the task's lock 278 // and state changes are also broadcast on the task's condition 279 // variable. 280 state TaskState 281 // Err is defines when state == TaskErr. 282 err error 283 284 // consecutiveLost is the number of times this task has been run and lost 285 // consecutively. See maxConsecutiveLost. 286 consecutiveLost int 287 288 // Status is a status object to which task status is reported. 289 Status *status.Task 290 } 291 292 // Phase returns the phase to which this task belongs. 293 func (t *Task) Phase() []*Task { 294 if len(t.Group) == 0 { 295 return []*Task{t} 296 } 297 return t.Group 298 } 299 300 // Head returns the head task of this task's phase. If the task does 301 // not belong to a phase, Head returns the task t. 302 func (t *Task) Head() *Task { 303 if len(t.Group) == 0 { 304 return t 305 } 306 return t.Group[0] 307 } 308 309 // String returns a short, human-readable string describing the 310 // task's state. 311 func (t *Task) String() string { 312 // We play fast-and-loose with concurrency here (we read state and 313 // err without holding the task's mutex) so that it is safe to call 314 // String even when the lock is held. 315 var b bytes.Buffer 316 fmt.Fprintf(&b, "task %s [%d] %s", t.Name, t.Invocation.Index, t.state) 317 if t.err != nil { 318 fmt.Fprintf(&b, ": %v", t.err) 319 } 320 return b.String() 321 } 322 323 // Set sets the task's state to the provided state and notifies 324 // any waiters. 325 func (t *Task) Set(state TaskState) { 326 t.Lock() 327 t.state = state 328 t.Broadcast() 329 t.Unlock() 330 } 331 332 // Error sets the task's state to TaskErr and its error to the 333 // provided error. Waiters are notified. 334 func (t *Task) Error(err error) { 335 t.Lock() 336 t.state = TaskErr 337 t.err = err 338 t.Status.Printf(err.Error()) 339 t.Broadcast() 340 t.Unlock() 341 } 342 343 // Errorf formats an error message using fmt.Errorf, sets the task's 344 // state to TaskErr and its err to the resulting error message. 345 func (t *Task) Errorf(format string, v ...interface{}) { 346 t.Error(fmt.Errorf(format, v...)) 347 } 348 349 // Err returns an error if the task's state is >= TaskErr. When the 350 // state is > TaskErr, Err returns an error describing the task's 351 // failed state, otherwise, t.err is returned. 352 func (t *Task) Err() error { 353 t.Lock() 354 defer t.Unlock() 355 switch t.state { 356 case TaskErr: 357 if t.err == nil { 358 panic("TaskErr without an err") 359 } 360 return t.err 361 case TaskLost: 362 return ErrTaskLost 363 } 364 if t.state >= TaskErr { 365 panic("unhandled state") 366 } 367 return nil 368 } 369 370 // State returns the task's current state. 371 func (t *Task) State() TaskState { 372 t.Lock() 373 state := t.state 374 t.Unlock() 375 return state 376 } 377 378 // Broadcast notifies waiters of a state change. Broadcast must only 379 // be called while the task's lock is held. 380 func (t *Task) Broadcast() { 381 if t.waitc != nil { 382 close(t.waitc) 383 t.waitc = nil 384 } 385 for _, sub := range t.subs { 386 sub.Notify(t) 387 } 388 } 389 390 // Wait returns after the next call to Broadcast, or if the context 391 // is complete. The task's lock must be held when calling Wait. 392 func (t *Task) Wait(ctx context.Context) error { 393 if t.waitc == nil { 394 t.waitc = make(chan struct{}) 395 } 396 waitc := t.waitc 397 t.Unlock() 398 var err error 399 select { 400 case <-waitc: 401 case <-ctx.Done(): 402 err = ctx.Err() 403 } 404 t.Lock() 405 return err 406 } 407 408 // WaitState returns when the task's state is at least the provided state, 409 // or else when the context is done. 410 func (t *Task) WaitState(ctx context.Context, state TaskState) (TaskState, error) { 411 t.Lock() 412 defer t.Unlock() 413 var err error 414 for t.state < state && err == nil { 415 err = t.Wait(ctx) 416 } 417 return t.state, err 418 } 419 420 // Subscribe subscribes s to be notified of any changes to t's state. If s has 421 // already been subscribed, no-op. 422 func (t *Task) Subscribe(s *TaskSubscriber) { 423 t.Lock() 424 defer t.Unlock() 425 for _, sub := range t.subs { 426 if s == sub { 427 // It is already registered. 428 return 429 } 430 } 431 t.subs = append(t.subs, s) 432 } 433 434 // Unsubscribe unsubscribes previously subscribe s. s will on longer receive 435 // task state change notifications. No-op if s was never subscribed. 436 func (t *Task) Unsubscribe(s *TaskSubscriber) { 437 t.Lock() 438 defer t.Unlock() 439 subs := t.subs[:0] 440 for _, sub := range t.subs { 441 if s == sub { 442 continue 443 } 444 subs = append(subs, sub) 445 } 446 t.subs = subs 447 } 448 449 // GraphString returns a schematic string of the task graph rooted at t. 450 func (t *Task) GraphString() string { 451 var b bytes.Buffer 452 t.WriteGraph(&b) 453 return b.String() 454 } 455 456 // WriteGraph writes a schematic string of the task graph rooted at t into w. 457 func (t *Task) WriteGraph(w io.Writer) { 458 var tw tabwriter.Writer 459 tw.Init(w, 4, 4, 1, ' ', 0) 460 fmt.Fprintln(&tw, "tasks:") 461 for _, task := range t.All() { 462 out := make([]string, task.NumOut()) 463 for i := range out { 464 out[i] = fmt.Sprint(task.Out(i)) 465 } 466 outstr := strings.Join(out, ",") 467 fmt.Fprintf(&tw, "\t%s\t%s\t%d [%s]\n", task.Name, outstr, task.NumPartition, task.State()) 468 } 469 tw.Flush() 470 fmt.Fprintln(&tw, "dependencies:") 471 t.writeDeps(&tw) 472 tw.Flush() 473 } 474 475 func (t *Task) writeDeps(w io.Writer) { 476 for _, dep := range t.Deps { 477 for i := 0; i < dep.NumTask(); i++ { 478 task := dep.Task(i) 479 fmt.Fprintf(w, "\t%s:\t%s[%d]\n", t.Name, task.Name, dep.Partition) 480 task.writeDeps(w) 481 } 482 } 483 } 484 485 // All returns all tasks reachable from t. The returned 486 // set of tasks is unique. 487 func (t *Task) All() []*Task { 488 all := make(map[*Task]bool) 489 t.all(all) 490 var tasks []*Task 491 for task := range all { 492 tasks = append(tasks, task) 493 } 494 sort.Slice(tasks, func(i, j int) bool { 495 return tasks[i].Name.String() < tasks[j].Name.String() 496 }) 497 return tasks 498 } 499 500 func (t *Task) all(tasks map[*Task]bool) { 501 if tasks[t] { 502 return 503 } 504 tasks[t] = true 505 for _, dep := range t.Deps { 506 for i := 0; i < dep.NumTask(); i++ { 507 dep.Task(i).all(tasks) 508 } 509 } 510 }