go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/dsmapper/controller.go (about) 1 // Copyright 2018 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dsmapper 16 17 import ( 18 "context" 19 "fmt" 20 "math" 21 "sync" 22 "time" 23 24 "golang.org/x/sync/errgroup" 25 "google.golang.org/protobuf/proto" 26 27 "go.chromium.org/luci/common/clock" 28 "go.chromium.org/luci/common/errors" 29 "go.chromium.org/luci/common/logging" 30 "go.chromium.org/luci/common/retry/transient" 31 "go.chromium.org/luci/common/sync/parallel" 32 "go.chromium.org/luci/gae/service/datastore" 33 34 "go.chromium.org/luci/server/dsmapper/dsmapperpb" 35 "go.chromium.org/luci/server/dsmapper/internal/splitter" 36 "go.chromium.org/luci/server/dsmapper/internal/tasks" 37 "go.chromium.org/luci/server/tq" 38 39 // Need this to enqueue tasks inside Datastore transactions. 40 _ "go.chromium.org/luci/server/tq/txn/datastore" 41 ) 42 43 // ID identifies a mapper registered in the controller. 44 // 45 // It will be passed across processes, so all processes that execute mapper jobs 46 // should register same mappers under same IDs. 47 // 48 // The safest approach is to keep mapper IDs in the app unique, e.g. do NOT 49 // reuse them when adding new mappers or significantly changing existing ones. 50 type ID string 51 52 // Mapper applies some function to the given slice of entities, given by 53 // their keys. 54 // 55 // May be called multiple times for same key (thus should be idempotent). 56 // 57 // Returning a transient error indicates that the processing of this batch of 58 // keys should be retried (even if some keys were processed successfully). 59 // 60 // Returning a fatal error causes the entire shard (and eventually the entire 61 // job) to be marked as failed. The processing of the failed shard stops right 62 // away, but other shards are kept running until completion (or their own 63 // failure). 64 // 65 // The function is called outside of any transactions, so it can start its own 66 // if needed. 67 type Mapper func(ctx context.Context, keys []*datastore.Key) error 68 69 // Factory knows how to construct instances of Mapper. 70 // 71 // Factory is supplied by the users of the library and registered in the 72 // controller via RegisterFactory call. 73 // 74 // It is used to get a mapper to process a set of pages within a shard. It takes 75 // a Job (including its Config and Params) and a shard index, so it can prepare 76 // the mapper for processing of this specific shard. 77 // 78 // Returning a transient error triggers an eventual retry. Returning a fatal 79 // error causes the shard (eventually the entire job) to be marked as failed. 80 type Factory func(ctx context.Context, j *Job, shardIdx int) (Mapper, error) 81 82 // Controller is responsible for starting, progressing and finishing mapping 83 // jobs. 84 // 85 // It should be treated as a global singleton object. Having more than one 86 // controller in the production application is a bad idea (they'll collide with 87 // each other since they use global datastore namespace). It's still useful 88 // to instantiate multiple controllers in unit tests. 89 type Controller struct { 90 // MapperQueue is a name of the Cloud Tasks queue to use for mapping jobs. 91 // 92 // This queue will perform all "heavy" tasks. It should be configured 93 // appropriately to allow desired number of shards to run in parallel. 94 // 95 // For example, if the largest submitted job is expected to have 128 shards, 96 // max_concurrent_requests setting of the mapper queue should be at least 128, 97 // otherwise some shards will be stalled waiting for others to finish 98 // (defeating the purpose of having large number of shards). 99 // 100 // If empty, "default" is used. 101 MapperQueue string 102 103 // ControlQueue is a name of the Cloud Tasks queue to use for control signals. 104 // 105 // This queue is used very lightly when starting and stopping jobs (roughly 106 // 2*Shards tasks overall per job). A default queue.yaml settings for such 107 // queue should be sufficient (unless you run a lot of different jobs at 108 // once). 109 // 110 // If empty, "default" is used. 111 ControlQueue string 112 113 m sync.RWMutex 114 mappers map[ID]Factory 115 disp *tq.Dispatcher 116 } 117 118 // Install registers task queue task handlers in the given task queue 119 // dispatcher. 120 // 121 // This must be done before Controller is used. 122 // 123 // There can be at most one Controller installed into an instance of TQ 124 // dispatcher. Installing more will cause panics. 125 // 126 // If you need multiple different controllers for some reason, create multiple 127 // tq.Dispatchers (with different base URLs, so they don't conflict with each 128 // other) and install them all into the router. 129 func (ctl *Controller) Install(disp *tq.Dispatcher) { 130 ctl.m.Lock() 131 defer ctl.m.Unlock() 132 133 if ctl.disp != nil { 134 panic("mapper.Controller is already installed into a tq.Dispatcher") 135 } 136 ctl.disp = disp 137 138 controlQueue := ctl.ControlQueue 139 if controlQueue == "" { 140 controlQueue = "default" 141 } 142 mapperQueue := ctl.MapperQueue 143 if mapperQueue == "" { 144 mapperQueue = "default" 145 } 146 147 disp.RegisterTaskClass(tq.TaskClass{ 148 ID: "dsmapper-split-and-launch", 149 Prototype: &tasks.SplitAndLaunch{}, 150 Kind: tq.Transactional, 151 Queue: controlQueue, 152 Handler: ctl.splitAndLaunchHandler, 153 Quiet: true, 154 }) 155 disp.RegisterTaskClass(tq.TaskClass{ 156 ID: "dsmapper-fan-out-shards", 157 Prototype: &tasks.FanOutShards{}, 158 Kind: tq.Transactional, 159 Queue: controlQueue, 160 Handler: ctl.fanOutShardsHandler, 161 Quiet: true, 162 }) 163 disp.RegisterTaskClass(tq.TaskClass{ 164 ID: "dsmapper-process-shard", 165 Prototype: &tasks.ProcessShard{}, 166 Kind: tq.FollowsContext, 167 Queue: mapperQueue, 168 Handler: ctl.processShardHandler, 169 Quiet: true, 170 }) 171 disp.RegisterTaskClass(tq.TaskClass{ 172 ID: "dsmapper-request-job-state-update", 173 Prototype: &tasks.RequestJobStateUpdate{}, 174 Kind: tq.Transactional, 175 Queue: controlQueue, 176 Handler: ctl.requestJobStateUpdateHandler, 177 Quiet: true, 178 }) 179 disp.RegisterTaskClass(tq.TaskClass{ 180 ID: "dsmapper-update-job-state", 181 Prototype: &tasks.UpdateJobState{}, 182 Kind: tq.NonTransactional, 183 Queue: controlQueue, 184 Handler: ctl.updateJobStateHandler, 185 Quiet: true, 186 }) 187 } 188 189 // tq returns a dispatcher set in Install or panics if not set yet. 190 // 191 // Grabs the reader lock inside. 192 func (ctl *Controller) tq() *tq.Dispatcher { 193 ctl.m.RLock() 194 defer ctl.m.RUnlock() 195 if ctl.disp == nil { 196 panic("mapper.Controller wasn't installed into tq.Dispatcher yet") 197 } 198 return ctl.disp 199 } 200 201 // RegisterFactory adds the given mapper factory to the internal registry. 202 // 203 // Intended to be used during init() time or early during the process 204 // initialization. Panics if a factory with such ID has already been registered. 205 // 206 // The mapper ID will be used internally to identify which mapper a job should 207 // be using. If a factory disappears while the job is running (e.g. if the 208 // service binary is updated and new binary doesn't have the mapper registered 209 // anymore), the job ends with a failure. 210 func (ctl *Controller) RegisterFactory(id ID, m Factory) { 211 ctl.m.Lock() 212 defer ctl.m.Unlock() 213 214 if _, ok := ctl.mappers[id]; ok { 215 panic(fmt.Sprintf("mapper %q is already registered", id)) 216 } 217 218 if ctl.mappers == nil { 219 ctl.mappers = make(map[ID]Factory, 1) 220 } 221 ctl.mappers[id] = m 222 } 223 224 // getFactory returns a registered mapper factory or an error. 225 // 226 // Grabs the reader lock inside. Can return only fatal errors. 227 func (ctl *Controller) getFactory(id ID) (Factory, error) { 228 ctl.m.RLock() 229 defer ctl.m.RUnlock() 230 if m, ok := ctl.mappers[id]; ok { 231 return m, nil 232 } 233 return nil, errors.Reason("no mapper factory with ID %q registered", id).Err() 234 } 235 236 // initMapper instantiates a Mapper through a registered factory. 237 // 238 // May return fatal and transient errors. 239 func (ctl *Controller) initMapper(ctx context.Context, j *Job, shardIdx int) (Mapper, error) { 240 f, err := ctl.getFactory(j.Config.Mapper) 241 if err != nil { 242 return nil, errors.Annotate(err, "when initializing mapper").Err() 243 } 244 m, err := f(ctx, j, shardIdx) 245 if err != nil { 246 return nil, errors.Annotate(err, "error from mapper factory %q", j.Config.Mapper).Err() 247 } 248 return m, nil 249 } 250 251 // LaunchJob launches a new mapping job, returning its ID (that can be used to 252 // control it or query its status). 253 // 254 // Launches a datastore transaction inside. 255 func (ctl *Controller) LaunchJob(ctx context.Context, j *JobConfig) (JobID, error) { 256 disp := ctl.tq() 257 258 if err := j.Validate(); err != nil { 259 return 0, errors.Annotate(err, "bad job config").Err() 260 } 261 if _, err := ctl.getFactory(j.Mapper); err != nil { 262 return 0, errors.Annotate(err, "bad job config").Err() 263 } 264 265 // Prepare and store the job entity, generate its key. Launch a tq task that 266 // subdivides the key space and launches individual shards. We do it 267 // asynchronously since this can be potentially slow (for large number of 268 // shards). 269 var job Job 270 err := runTxn(ctx, func(ctx context.Context) error { 271 now := clock.Now(ctx).UTC() 272 job = Job{ 273 Config: *j, 274 State: dsmapperpb.State_STARTING, 275 Created: now, 276 Updated: now, 277 } 278 if err := datastore.Put(ctx, &job); err != nil { 279 return errors.Annotate(err, "failed to store Job entity").Tag(transient.Tag).Err() 280 } 281 return disp.AddTask(ctx, &tq.Task{ 282 Title: fmt.Sprintf("split:job-%d", job.ID), 283 Payload: &tasks.SplitAndLaunch{ 284 JobId: int64(job.ID), 285 }, 286 }) 287 }) 288 if err != nil { 289 return 0, err 290 } 291 return job.ID, nil 292 } 293 294 // GetJob fetches a previously launched job given its ID. 295 // 296 // Returns ErrNoSuchJob if not found. All other possible errors are transient 297 // and they are marked as such. 298 func (ctl *Controller) GetJob(ctx context.Context, id JobID) (*Job, error) { 299 // Even though we could have made getJob public, we want to force API users 300 // to use Controller as a single facade. 301 return getJob(ctx, id) 302 } 303 304 // AbortJob aborts a job and returns its most recent state. 305 // 306 // Silently does nothing if the job is finished or already aborted. 307 // 308 // Returns ErrNoSuchJob is there's no such job at all. All other possible errors 309 // are transient and they are marked as such. 310 func (ctl *Controller) AbortJob(ctx context.Context, id JobID) (job *Job, err error) { 311 err = runTxn(ctx, func(ctx context.Context) error { 312 var err error 313 switch job, err = getJob(ctx, id); { 314 case err != nil: 315 return err 316 case isFinalState(job.State) || job.State == dsmapperpb.State_ABORTING: 317 return nil // nothing to abort, already done 318 case job.State == dsmapperpb.State_STARTING: 319 // Shards haven't been launched yet. Kill the job right away. 320 job.State = dsmapperpb.State_ABORTED 321 case job.State == dsmapperpb.State_RUNNING: 322 // Running shards will discover that the job is aborting and will 323 // eventually move into ABORTED state (notifying the job about it). Once 324 // all shards report they are done, the job itself will switch into 325 // ABORTED state. 326 job.State = dsmapperpb.State_ABORTING 327 } 328 job.Updated = clock.Now(ctx).UTC() 329 return errors.Annotate(datastore.Put(ctx, job), "failed to store Job entity").Tag(transient.Tag).Err() 330 }) 331 if err != nil { 332 job = nil // don't return bogus data in case txn failed to land 333 } 334 return 335 } 336 337 //////////////////////////////////////////////////////////////////////////////// 338 // Task queue tasks handlers. 339 340 // errJobAborted is used internally as shard failure status when the job is 341 // being aborted. 342 // 343 // It causes the shard to switch into ABORTED state instead of FAIL. 344 var errJobAborted = errors.New("the job has been aborted") 345 346 // splitAndLaunchHandler splits the job into shards and enqueues tasks that 347 // process shards. 348 func (ctl *Controller) splitAndLaunchHandler(ctx context.Context, payload proto.Message) error { 349 msg := payload.(*tasks.SplitAndLaunch) 350 now := clock.Now(ctx).UTC() 351 352 // Fetch job details. Make sure it isn't canceled and isn't running already. 353 job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_STARTING) 354 if err != nil || job == nil { 355 return errors.Annotate(err, "in SplitAndLaunch").Err() 356 } 357 358 // Figure out key ranges for shards. There may be fewer shards than requested 359 // if there are too few entities. 360 dq := job.Config.Query.ToDatastoreQuery() 361 ranges, err := splitter.SplitIntoRanges(ctx, dq, splitter.Params{ 362 Shards: job.Config.ShardCount, 363 Samples: 512, // should be enough for everyone... 364 }) 365 if err != nil { 366 return errors.Annotate(err, "failed to split the query into shards").Tag(transient.Tag).Err() 367 } 368 369 // Create entities that hold shards state. Each one is in its own entity 370 // group, since the combined write rate to them is O(ShardCount), which can 371 // overcome limits of a single entity group. 372 shards := make([]*shard, len(ranges)) 373 for idx, rng := range ranges { 374 shards[idx] = &shard{ 375 JobID: job.ID, 376 Index: idx, 377 State: dsmapperpb.State_STARTING, 378 Range: rng, 379 ExpectedCount: -1, 380 Created: now, 381 Updated: now, 382 } 383 } 384 385 // Calculate number of entities in each shard to track shard processing 386 // progress. Note that this can be very slow if there are many entities. 387 if job.Config.TrackProgress { 388 logging.Infof(ctx, "Estimating the size of each shard...") 389 if err := fetchShardSizes(ctx, dq, shards); err != nil { 390 return errors.Annotate(err, "when estimating shard sizes").Err() 391 } 392 } 393 394 // We use auto-generated keys for shards to make sure crashed SplitAndLaunch 395 // task retries cleanly, even if the underlying key space we are mapping over 396 // changes between the retries (making a naive put using "<job-id>:<index>" 397 // key non-idempotent!). 398 logging.Infof(ctx, "Instantiating shards...") 399 if err := datastore.Put(ctx, shards); err != nil { 400 return errors.Annotate(err, "failed to store shards").Tag(transient.Tag).Err() 401 } 402 403 // Prepare shardList which is basically a manual fully consistent index for 404 // Job -> [Shard] relation. We can't use a regular index, since shards are all 405 // in different entity groups (see O(ShardCount) argument above). 406 // 407 // Log the resulting shards along the way. 408 shardsEnt := shardList{ 409 Parent: datastore.KeyForObj(ctx, job), 410 Shards: make([]int64, len(shards)), 411 } 412 for idx, s := range shards { 413 shardsEnt.Shards[idx] = s.ID 414 415 l, r := "-inf", "+inf" 416 if s.Range.Start != nil { 417 l = s.Range.Start.String() 418 } 419 if s.Range.End != nil { 420 r = s.Range.End.String() 421 } 422 count := "" 423 if s.ExpectedCount != 0 { 424 count = fmt.Sprintf(" (%d entities)", s.ExpectedCount) 425 } 426 logging.Infof(ctx, "Shard #%d is %d: %s - %s%s", idx, s.ID, l, r, count) 427 } 428 429 // Transactionally associate shards with the job and launch the TQ task that 430 // kicks off the processing of each individual shard. We use an intermediary 431 // task for this since transactionally launching O(ShardCount) tasks hits TQ 432 // transaction limits. 433 // 434 // If SplitAndLaunch crashes before this transaction lands, there'll be some 435 // orphaned Shard entities, no big deal. 436 logging.Infof(ctx, "Updating the job and launching the fan out task...") 437 return runTxn(ctx, func(ctx context.Context) error { 438 job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_STARTING) 439 if err != nil || job == nil { 440 return errors.Annotate(err, "in SplitAndLaunch txn").Err() 441 } 442 443 job.State = dsmapperpb.State_RUNNING 444 job.Updated = now 445 if err := datastore.Put(ctx, job, &shardsEnt); err != nil { 446 return errors.Annotate(err, 447 "when storing Job %d and ShardList with %d shards", job.ID, len(shards), 448 ).Tag(transient.Tag).Err() 449 } 450 451 return ctl.tq().AddTask(ctx, &tq.Task{ 452 Title: fmt.Sprintf("fanout:job-%d", job.ID), 453 Payload: &tasks.FanOutShards{ 454 JobId: int64(job.ID), 455 }, 456 }) 457 }) 458 } 459 460 // fetchShardSizes makes a bunch of Count() queries to figure out size of each 461 // shard. 462 // 463 // Updates ExpectedCount in-place. 464 func fetchShardSizes(ctx context.Context, baseQ *datastore.Query, shards []*shard) error { 465 ctx, cancel := clock.WithTimeout(ctx, 10*time.Minute) 466 defer cancel() 467 468 err := parallel.WorkPool(32, func(tasks chan<- func() error) { 469 for _, sh := range shards { 470 sh := sh 471 tasks <- func() error { 472 n, err := datastore.CountBatch(ctx, 1024, sh.Range.Apply(baseQ)) 473 if err == nil { 474 sh.ExpectedCount = n 475 } 476 return errors.Annotate(err, "for shard #%d", sh.Index).Err() 477 } 478 } 479 }) 480 481 return transient.Tag.Apply(err) 482 } 483 484 // fanOutShardsHandler fetches a list of shards from the job and launches 485 // named ProcessShard tasks, one per shard. 486 func (ctl *Controller) fanOutShardsHandler(ctx context.Context, payload proto.Message) error { 487 msg := payload.(*tasks.FanOutShards) 488 489 // Make sure the job is still present. If it is aborted, we still need to 490 // launch the shards, so they notice they are being aborted. We could try 491 // to abort all shards right here and now, but it basically means implementing 492 // an alternative shard abort flow. Seems simpler just to let the regular flow 493 // to proceed. 494 job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_RUNNING, dsmapperpb.State_ABORTING) 495 if err != nil || job == nil { 496 return errors.Annotate(err, "in FanOutShards").Err() 497 } 498 499 // Grab the list of shards created in SplitAndLaunch. It must exist at this 500 // point, since the job is in Running state. 501 shardIDs, err := job.fetchShardIDs(ctx) 502 if err != nil { 503 return errors.Annotate(err, "in FanOutShards").Err() 504 } 505 506 // Enqueue a bunch of named ProcessShard tasks (one per shard) to actually 507 // launch shard processing. This is idempotent operation, so if FanOutShards 508 // crashes midway and later retried, nothing bad happens. 509 eg, ctx := errgroup.WithContext(ctx) 510 tq := ctl.tq() 511 for _, sid := range shardIDs { 512 task := makeProcessShardTask(job.ID, sid, 0, true) 513 eg.Go(func() error { return tq.AddTask(ctx, task) }) 514 } 515 return eg.Wait() 516 } 517 518 // processShardHandler reads a bunch of entities (up to PageSize), and hands 519 // them to the mapper. 520 // 521 // After doing this in a loop for 1 min, it checkpoints the state and reenqueues 522 // itself to resume mapping in another instance of the task. This makes each 523 // processing TQ task relatively small, so it doesn't eat a lot of memory, or 524 // produces gigantic unreadable logs. It also makes TQ's "Pause queue" button 525 // more handy. 526 func (ctl *Controller) processShardHandler(ctx context.Context, payload proto.Message) error { 527 msg := payload.(*tasks.ProcessShard) 528 529 // Grab the shard. This returns (nil, nil) if this Task Queue task is stale 530 // (based on taskNum) and should be silently skipped. 531 sh, err := getActiveShard(ctx, msg.ShardId, msg.TaskNum) 532 if err != nil || sh == nil { 533 return errors.Annotate(err, "when fetching shard state").Err() 534 } 535 ctx = logging.SetField(ctx, "shardIdx", sh.Index) 536 537 logging.Infof(ctx, 538 "Resuming processing of the shard (launched %s ago)", 539 clock.Now(ctx).Sub(sh.Created)) 540 541 // Grab the job config, make sure the job is still active. 542 job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_RUNNING, dsmapperpb.State_ABORTING) 543 if err != nil || job == nil { 544 return errors.Annotate(err, "in ProcessShard").Err() 545 } 546 547 // If the job is being killed, kill the shard as well. This will eventually 548 // notify the job about shard's completion. Once all shards are done, the 549 // job will switch into ABORTED state. 550 if job.State == dsmapperpb.State_ABORTING { 551 return ctl.finishShard(ctx, sh.ID, 0, errJobAborted) 552 } 553 554 // Prepare the mapper by giving the factory job parameters. 555 mapper, err := ctl.initMapper(ctx, job, sh.Index) 556 switch { 557 case transient.Tag.In(err): 558 return errors.Annotate(err, "transient error when instantiating a mapper").Err() 559 case err != nil: 560 // Kill the shard if the factory returns a fatal error. 561 return ctl.finishShard(ctx, sh.ID, 0, err) 562 } 563 564 baseQ := job.Config.Query.ToDatastoreQuery() 565 lastKey := sh.ResumeFrom 566 keys := make([]*datastore.Key, 0, job.Config.PageSize) 567 568 shardDone := false // true when finished processing the shard 569 pageCount := 0 // how many pages processed successfully 570 itemCount := int64(0) // how many entities processed successfully 571 572 // A soft deadline when to checkpoint the progress and reenqueue the 573 // processing task. We never abort processing of a page midway (causes too 574 // many complications), so if the mapper is extremely slow, it may end up 575 // running longer than this deadline. 576 dur := time.Minute 577 if job.Config.TaskDuration > 0 { 578 dur = job.Config.TaskDuration 579 } 580 deadline := clock.Now(ctx).Add(dur) 581 582 // Optionally also put a limit on number of processed pages. Useful if the 583 // mapper is somehow leaking resources (not sure it is possible in Go, but 584 // it was definitely possible in Python). 585 pageCountLimit := math.MaxInt32 586 if job.Config.PagesPerTask > 0 { 587 pageCountLimit = job.Config.PagesPerTask 588 } 589 590 for clock.Now(ctx).Before(deadline) && pageCount < pageCountLimit { 591 rng := sh.Range 592 if lastKey != nil { 593 rng.Start = lastKey 594 } 595 if rng.IsEmpty() { 596 shardDone = true 597 break 598 } 599 600 // Fetch next batch of keys. Return an error to the outer scope where it 601 // eventually will bubble up to TQ (so the task is retried with exponential 602 // backoff). 603 logging.Infof(ctx, "Fetching the next batch...") 604 q := rng.Apply(baseQ).Limit(int32(job.Config.PageSize)).KeysOnly(true) 605 keys = keys[:0] 606 if err = datastore.GetAll(ctx, q, &keys); err != nil { 607 err = errors.Annotate(err, "when querying for keys").Tag(transient.Tag).Err() 608 break 609 } 610 611 // No results within the range? Processing of the shard is complete! 612 if len(keys) == 0 { 613 shardDone = true 614 break 615 } 616 617 // Let the mapper do its thing. Remember where to resume from. 618 logging.Infof(ctx, 619 "Processing %d entities: %s - %s", 620 len(keys), 621 keys[0].String(), 622 keys[len(keys)-1].String()) 623 if err = mapper(ctx, keys); err != nil { 624 err = errors.Annotate(err, "while mapping %d keys", len(keys)).Err() 625 break 626 } 627 lastKey = keys[len(keys)-1] 628 pageCount++ 629 itemCount += int64(len(keys)) 630 631 // Note: at this point we might try to checkpoint the progress, but we must 632 // be careful not to exceed 1 transaction per second limit. Considering we 633 // also MUST checkpoint the progress at the end of the task, it is a bit 634 // tricky to guarantee no two checkpoints are closer than 1 sec. We can do 635 // silly things like sleep 1 sec before the last checkpoint, but they 636 // provide no guarantees. 637 // 638 // So instead we store the progress after the deadline is up. If the task 639 // crashes midway, up to 1 min of work will be retried. No big deal. 640 } 641 642 // We are done with the shard when either processed all its range or failed 643 // with a fatal error. finishShard would take care of notifying the parent 644 // job about the shard's completion. 645 if shardDone || (err != nil && !transient.Tag.In(err)) { 646 return ctl.finishShard(ctx, sh.ID, itemCount, err) 647 } 648 649 if lastKey != nil { 650 logging.Infof(ctx, "The shard processing will resume from %s", lastKey) 651 } else { 652 logging.Infof(ctx, "The shard processing will resume from scratch") 653 } 654 655 // If the shard isn't done and we made no progress at all, then we hit 656 // a transient error. Ask TQ to retry. 657 if pageCount == 0 { 658 return err 659 } 660 661 // Otherwise need to checkpoint the progress and either to retry this task 662 // (on transient errors, to get an exponential backoff from TQ), or start 663 // a new task. 664 txnErr := shardTxn(ctx, sh.ID, func(ctx context.Context, sh *shard) (bool, error) { 665 switch { 666 case sh.ProcessTaskNum != msg.TaskNum: 667 logging.Warningf(ctx, "Unexpected shard state: its ProcessTaskNum is %d != %d", sh.ProcessTaskNum, msg.TaskNum) 668 return false, nil // some other task is already running 669 case sh.ResumeFrom != nil && lastKey.Less(sh.ResumeFrom): 670 logging.Warningf(ctx, "Unexpected shard state: its ResumeFrom is %s >= %s", sh.ResumeFrom, lastKey) 671 return false, nil // someone already claimed to process further, let them proceed 672 } 673 674 sh.State = dsmapperpb.State_RUNNING 675 sh.ResumeFrom = lastKey 676 sh.ProcessedCount += itemCount 677 678 // If the processing failed, just store the progress, but do not start a 679 // new TQ task. Retry the current task instead (to get exponential backoff). 680 if err != nil { 681 return true, nil 682 } 683 684 // Otherwise launch a new task in the chain. This essentially "resets" 685 // the exponential backoff counter. 686 sh.ProcessTaskNum++ 687 return true, ctl.tq().AddTask(ctx, 688 makeProcessShardTask(sh.JobID, sh.ID, sh.ProcessTaskNum, false)) 689 }) 690 691 switch { 692 case err != nil && txnErr == nil: 693 return err 694 case err == nil && txnErr != nil: 695 return errors.Annotate(txnErr, "when storing shard progress").Err() 696 case err != nil && txnErr != nil: 697 return errors.Annotate(txnErr, "when storing shard progress after a transient error (%s)", err).Err() 698 default: // (nil, nil) 699 return nil 700 } 701 } 702 703 // finishShard marks the shard as finished (with status based on shardErr) and 704 // emits a task to update the parent job's status. 705 func (ctl *Controller) finishShard(ctx context.Context, shardID, processedCount int64, shardErr error) error { 706 err := shardTxn(ctx, shardID, func(ctx context.Context, sh *shard) (save bool, err error) { 707 runtime := clock.Now(ctx).Sub(sh.Created) 708 switch { 709 case shardErr == errJobAborted: 710 logging.Warningf(ctx, "The job has been aborted, aborting the shard after it has been running %s", runtime) 711 sh.State = dsmapperpb.State_ABORTED 712 sh.Error = errJobAborted.Error() 713 case shardErr != nil: 714 logging.Errorf(ctx, "The shard processing failed in %s with error: %s", runtime, shardErr) 715 sh.State = dsmapperpb.State_FAIL 716 sh.Error = shardErr.Error() 717 default: 718 logging.Infof(ctx, "The shard processing finished successfully in %s", runtime) 719 sh.State = dsmapperpb.State_SUCCESS 720 } 721 sh.ProcessedCount += processedCount 722 return true, ctl.requestJobStateUpdate(ctx, sh.JobID, sh.ID) 723 }) 724 return errors.Annotate(err, "when marking the shard as finished").Err() 725 } 726 727 // makeProcessShardTask creates a ProcessShard tq.Task. 728 // 729 // If 'named' is true, assigns it a name. Tasks are named based on their shard 730 // IDs and an index in the chain of ProcessShard tasks (task number), so that 731 // on retries we don't rekick already finished tasks. 732 func makeProcessShardTask(job JobID, shardID, taskNum int64, named bool) *tq.Task { 733 // Note: strictly speaking including job ID in the task name is redundant, 734 // since shardID is already globally unique, but it doesn't hurt. Useful for 735 // debugging and when looking at logs and pending TQ tasks. 736 t := &tq.Task{ 737 Title: fmt.Sprintf("map:job-%d-shard-%d-task-%d", job, shardID, taskNum), 738 Payload: &tasks.ProcessShard{ 739 JobId: int64(job), 740 ShardId: shardID, 741 TaskNum: taskNum, 742 }, 743 } 744 if named { 745 t.DeduplicationKey = fmt.Sprintf("v1-%d-%d-%d", job, shardID, taskNum) 746 } 747 return t 748 } 749 750 // requestJobStateUpdate submits RequestJobStateUpdate task, which eventually 751 // causes updateJobStateHandler to execute. 752 func (ctl *Controller) requestJobStateUpdate(ctx context.Context, jobID JobID, shardID int64) error { 753 return ctl.tq().AddTask(ctx, &tq.Task{ 754 Title: fmt.Sprintf("notify:job-%d-shard-%d", jobID, shardID), 755 Payload: &tasks.RequestJobStateUpdate{ 756 JobId: int64(jobID), 757 ShardId: shardID, 758 }, 759 }) 760 } 761 762 // requestJobStateUpdateHandler is called whenever state of some shard changes. 763 // 764 // It forwards this notification to the job (specifically updateJobStateHandler) 765 // throttling the rate to ~0.5 QPS to avoid overwhelming job's entity group with 766 // high write rate. 767 func (ctl *Controller) requestJobStateUpdateHandler(ctx context.Context, payload proto.Message) error { 768 msg := payload.(*tasks.RequestJobStateUpdate) 769 770 // Throttle to once per 2 sec (and make sure it is always in the future). We 771 // rely here on a pretty good (< .5s maximum skew) clock sync on servers. 772 eta := clock.Now(ctx).Unix() 773 eta = (eta/2 + 1) * 2 774 dedupKey := fmt.Sprintf("update-job-state-v1:%d:%d", msg.JobId, eta) 775 776 err := ctl.tq().AddTask(ctx, &tq.Task{ 777 DeduplicationKey: dedupKey, 778 Title: fmt.Sprintf("update:job-%d", msg.JobId), 779 ETA: time.Unix(eta, 0), 780 Payload: &tasks.UpdateJobState{JobId: msg.JobId}, 781 }) 782 return errors.Annotate(err, "when adding UpdateJobState task").Err() 783 } 784 785 // updateJobStateHandler is called some time later after one or more shards have 786 // changed state. 787 // 788 // It calculates overall job state based on the state of its shards. 789 func (ctl *Controller) updateJobStateHandler(ctx context.Context, payload proto.Message) error { 790 msg := payload.(*tasks.UpdateJobState) 791 792 // Get the job and all its shards in their most recent state. 793 job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_RUNNING, dsmapperpb.State_ABORTING) 794 if err != nil || job == nil { 795 return errors.Annotate(err, "in UpdateJobState").Err() 796 } 797 shards, err := job.fetchShards(ctx) 798 if err != nil { 799 return errors.Annotate(err, "failed to fetch shards").Err() 800 } 801 802 // Switch the job into a final state only when all shards are done running. 803 perState := make(map[dsmapperpb.State]int, len(dsmapperpb.State_name)) 804 finished := 0 805 for _, sh := range shards { 806 logging.Infof(ctx, "Shard #%d (%d) is in state %s", sh.Index, sh.ID, sh.State) 807 perState[sh.State]++ 808 if isFinalState(sh.State) { 809 finished++ 810 } 811 } 812 if finished != len(shards) { 813 return nil 814 } 815 816 jobState := dsmapperpb.State_SUCCESS 817 switch { 818 case perState[dsmapperpb.State_ABORTED] != 0: 819 jobState = dsmapperpb.State_ABORTED 820 case perState[dsmapperpb.State_FAIL] != 0: 821 jobState = dsmapperpb.State_FAIL 822 } 823 824 return runTxn(ctx, func(ctx context.Context) error { 825 job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_RUNNING, dsmapperpb.State_ABORTING) 826 if err != nil || job == nil { 827 return errors.Annotate(err, "in UpdateJobState txn").Err() 828 } 829 830 // Make sure an aborting job ends up in aborted state, even if all its 831 // shards manged to finish. It looks weird when an ABORTING job moves 832 // into e.g. SUCCESS state. 833 if job.State == dsmapperpb.State_ABORTING { 834 job.State = dsmapperpb.State_ABORTED 835 } else { 836 job.State = jobState 837 } 838 job.Updated = clock.Now(ctx).UTC() 839 840 runtime := job.Updated.Sub(job.Created) 841 switch job.State { 842 case dsmapperpb.State_SUCCESS: 843 logging.Infof(ctx, "The job finished successfully in %s", runtime) 844 case dsmapperpb.State_FAIL: 845 logging.Errorf(ctx, "The job finished with %d shards failing in %s", perState[dsmapperpb.State_FAIL], runtime) 846 for _, sh := range shards { 847 if sh.State == dsmapperpb.State_FAIL { 848 logging.Errorf(ctx, "Shard #%d (%d) error - %s", sh.Index, sh.ID, sh.Error) 849 } 850 } 851 case dsmapperpb.State_ABORTED: 852 logging.Warningf(ctx, "The job has been aborted after %s: %d shards succeeded, %d shards failed, %d shards aborted", 853 runtime, perState[dsmapperpb.State_SUCCESS], perState[dsmapperpb.State_FAIL], perState[dsmapperpb.State_ABORTED]) 854 } 855 856 return transient.Tag.Apply(datastore.Put(ctx, job)) 857 }) 858 }