go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/dsmapper/controller.go (about)

     1  // Copyright 2018 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dsmapper
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"math"
    21  	"sync"
    22  	"time"
    23  
    24  	"golang.org/x/sync/errgroup"
    25  	"google.golang.org/protobuf/proto"
    26  
    27  	"go.chromium.org/luci/common/clock"
    28  	"go.chromium.org/luci/common/errors"
    29  	"go.chromium.org/luci/common/logging"
    30  	"go.chromium.org/luci/common/retry/transient"
    31  	"go.chromium.org/luci/common/sync/parallel"
    32  	"go.chromium.org/luci/gae/service/datastore"
    33  
    34  	"go.chromium.org/luci/server/dsmapper/dsmapperpb"
    35  	"go.chromium.org/luci/server/dsmapper/internal/splitter"
    36  	"go.chromium.org/luci/server/dsmapper/internal/tasks"
    37  	"go.chromium.org/luci/server/tq"
    38  
    39  	// Need this to enqueue tasks inside Datastore transactions.
    40  	_ "go.chromium.org/luci/server/tq/txn/datastore"
    41  )
    42  
    43  // ID identifies a mapper registered in the controller.
    44  //
    45  // It will be passed across processes, so all processes that execute mapper jobs
    46  // should register same mappers under same IDs.
    47  //
    48  // The safest approach is to keep mapper IDs in the app unique, e.g. do NOT
    49  // reuse them when adding new mappers or significantly changing existing ones.
    50  type ID string
    51  
    52  // Mapper applies some function to the given slice of entities, given by
    53  // their keys.
    54  //
    55  // May be called multiple times for same key (thus should be idempotent).
    56  //
    57  // Returning a transient error indicates that the processing of this batch of
    58  // keys should be retried (even if some keys were processed successfully).
    59  //
    60  // Returning a fatal error causes the entire shard (and eventually the entire
    61  // job) to be marked as failed. The processing of the failed shard stops right
    62  // away, but other shards are kept running until completion (or their own
    63  // failure).
    64  //
    65  // The function is called outside of any transactions, so it can start its own
    66  // if needed.
    67  type Mapper func(ctx context.Context, keys []*datastore.Key) error
    68  
    69  // Factory knows how to construct instances of Mapper.
    70  //
    71  // Factory is supplied by the users of the library and registered in the
    72  // controller via RegisterFactory call.
    73  //
    74  // It is used to get a mapper to process a set of pages within a shard. It takes
    75  // a Job (including its Config and Params) and a shard index, so it can prepare
    76  // the mapper for processing of this specific shard.
    77  //
    78  // Returning a transient error triggers an eventual retry. Returning a fatal
    79  // error causes the shard (eventually the entire job) to be marked as failed.
    80  type Factory func(ctx context.Context, j *Job, shardIdx int) (Mapper, error)
    81  
    82  // Controller is responsible for starting, progressing and finishing mapping
    83  // jobs.
    84  //
    85  // It should be treated as a global singleton object. Having more than one
    86  // controller in the production application is a bad idea (they'll collide with
    87  // each other since they use global datastore namespace). It's still useful
    88  // to instantiate multiple controllers in unit tests.
    89  type Controller struct {
    90  	// MapperQueue is a name of the Cloud Tasks queue to use for mapping jobs.
    91  	//
    92  	// This queue will perform all "heavy" tasks. It should be configured
    93  	// appropriately to allow desired number of shards to run in parallel.
    94  	//
    95  	// For example, if the largest submitted job is expected to have 128 shards,
    96  	// max_concurrent_requests setting of the mapper queue should be at least 128,
    97  	// otherwise some shards will be stalled waiting for others to finish
    98  	// (defeating the purpose of having large number of shards).
    99  	//
   100  	// If empty, "default" is used.
   101  	MapperQueue string
   102  
   103  	// ControlQueue is a name of the Cloud Tasks queue to use for control signals.
   104  	//
   105  	// This queue is used very lightly when starting and stopping jobs (roughly
   106  	// 2*Shards tasks overall per job). A default queue.yaml settings for such
   107  	// queue should be sufficient (unless you run a lot of different jobs at
   108  	// once).
   109  	//
   110  	// If empty, "default" is used.
   111  	ControlQueue string
   112  
   113  	m       sync.RWMutex
   114  	mappers map[ID]Factory
   115  	disp    *tq.Dispatcher
   116  }
   117  
   118  // Install registers task queue task handlers in the given task queue
   119  // dispatcher.
   120  //
   121  // This must be done before Controller is used.
   122  //
   123  // There can be at most one Controller installed into an instance of TQ
   124  // dispatcher. Installing more will cause panics.
   125  //
   126  // If you need multiple different controllers for some reason, create multiple
   127  // tq.Dispatchers (with different base URLs, so they don't conflict with each
   128  // other) and install them all into the router.
   129  func (ctl *Controller) Install(disp *tq.Dispatcher) {
   130  	ctl.m.Lock()
   131  	defer ctl.m.Unlock()
   132  
   133  	if ctl.disp != nil {
   134  		panic("mapper.Controller is already installed into a tq.Dispatcher")
   135  	}
   136  	ctl.disp = disp
   137  
   138  	controlQueue := ctl.ControlQueue
   139  	if controlQueue == "" {
   140  		controlQueue = "default"
   141  	}
   142  	mapperQueue := ctl.MapperQueue
   143  	if mapperQueue == "" {
   144  		mapperQueue = "default"
   145  	}
   146  
   147  	disp.RegisterTaskClass(tq.TaskClass{
   148  		ID:        "dsmapper-split-and-launch",
   149  		Prototype: &tasks.SplitAndLaunch{},
   150  		Kind:      tq.Transactional,
   151  		Queue:     controlQueue,
   152  		Handler:   ctl.splitAndLaunchHandler,
   153  		Quiet:     true,
   154  	})
   155  	disp.RegisterTaskClass(tq.TaskClass{
   156  		ID:        "dsmapper-fan-out-shards",
   157  		Prototype: &tasks.FanOutShards{},
   158  		Kind:      tq.Transactional,
   159  		Queue:     controlQueue,
   160  		Handler:   ctl.fanOutShardsHandler,
   161  		Quiet:     true,
   162  	})
   163  	disp.RegisterTaskClass(tq.TaskClass{
   164  		ID:        "dsmapper-process-shard",
   165  		Prototype: &tasks.ProcessShard{},
   166  		Kind:      tq.FollowsContext,
   167  		Queue:     mapperQueue,
   168  		Handler:   ctl.processShardHandler,
   169  		Quiet:     true,
   170  	})
   171  	disp.RegisterTaskClass(tq.TaskClass{
   172  		ID:        "dsmapper-request-job-state-update",
   173  		Prototype: &tasks.RequestJobStateUpdate{},
   174  		Kind:      tq.Transactional,
   175  		Queue:     controlQueue,
   176  		Handler:   ctl.requestJobStateUpdateHandler,
   177  		Quiet:     true,
   178  	})
   179  	disp.RegisterTaskClass(tq.TaskClass{
   180  		ID:        "dsmapper-update-job-state",
   181  		Prototype: &tasks.UpdateJobState{},
   182  		Kind:      tq.NonTransactional,
   183  		Queue:     controlQueue,
   184  		Handler:   ctl.updateJobStateHandler,
   185  		Quiet:     true,
   186  	})
   187  }
   188  
   189  // tq returns a dispatcher set in Install or panics if not set yet.
   190  //
   191  // Grabs the reader lock inside.
   192  func (ctl *Controller) tq() *tq.Dispatcher {
   193  	ctl.m.RLock()
   194  	defer ctl.m.RUnlock()
   195  	if ctl.disp == nil {
   196  		panic("mapper.Controller wasn't installed into tq.Dispatcher yet")
   197  	}
   198  	return ctl.disp
   199  }
   200  
   201  // RegisterFactory adds the given mapper factory to the internal registry.
   202  //
   203  // Intended to be used during init() time or early during the process
   204  // initialization. Panics if a factory with such ID has already been registered.
   205  //
   206  // The mapper ID will be used internally to identify which mapper a job should
   207  // be using. If a factory disappears while the job is running (e.g. if the
   208  // service binary is updated and new binary doesn't have the mapper registered
   209  // anymore), the job ends with a failure.
   210  func (ctl *Controller) RegisterFactory(id ID, m Factory) {
   211  	ctl.m.Lock()
   212  	defer ctl.m.Unlock()
   213  
   214  	if _, ok := ctl.mappers[id]; ok {
   215  		panic(fmt.Sprintf("mapper %q is already registered", id))
   216  	}
   217  
   218  	if ctl.mappers == nil {
   219  		ctl.mappers = make(map[ID]Factory, 1)
   220  	}
   221  	ctl.mappers[id] = m
   222  }
   223  
   224  // getFactory returns a registered mapper factory or an error.
   225  //
   226  // Grabs the reader lock inside. Can return only fatal errors.
   227  func (ctl *Controller) getFactory(id ID) (Factory, error) {
   228  	ctl.m.RLock()
   229  	defer ctl.m.RUnlock()
   230  	if m, ok := ctl.mappers[id]; ok {
   231  		return m, nil
   232  	}
   233  	return nil, errors.Reason("no mapper factory with ID %q registered", id).Err()
   234  }
   235  
   236  // initMapper instantiates a Mapper through a registered factory.
   237  //
   238  // May return fatal and transient errors.
   239  func (ctl *Controller) initMapper(ctx context.Context, j *Job, shardIdx int) (Mapper, error) {
   240  	f, err := ctl.getFactory(j.Config.Mapper)
   241  	if err != nil {
   242  		return nil, errors.Annotate(err, "when initializing mapper").Err()
   243  	}
   244  	m, err := f(ctx, j, shardIdx)
   245  	if err != nil {
   246  		return nil, errors.Annotate(err, "error from mapper factory %q", j.Config.Mapper).Err()
   247  	}
   248  	return m, nil
   249  }
   250  
   251  // LaunchJob launches a new mapping job, returning its ID (that can be used to
   252  // control it or query its status).
   253  //
   254  // Launches a datastore transaction inside.
   255  func (ctl *Controller) LaunchJob(ctx context.Context, j *JobConfig) (JobID, error) {
   256  	disp := ctl.tq()
   257  
   258  	if err := j.Validate(); err != nil {
   259  		return 0, errors.Annotate(err, "bad job config").Err()
   260  	}
   261  	if _, err := ctl.getFactory(j.Mapper); err != nil {
   262  		return 0, errors.Annotate(err, "bad job config").Err()
   263  	}
   264  
   265  	// Prepare and store the job entity, generate its key. Launch a tq task that
   266  	// subdivides the key space and launches individual shards. We do it
   267  	// asynchronously since this can be potentially slow (for large number of
   268  	// shards).
   269  	var job Job
   270  	err := runTxn(ctx, func(ctx context.Context) error {
   271  		now := clock.Now(ctx).UTC()
   272  		job = Job{
   273  			Config:  *j,
   274  			State:   dsmapperpb.State_STARTING,
   275  			Created: now,
   276  			Updated: now,
   277  		}
   278  		if err := datastore.Put(ctx, &job); err != nil {
   279  			return errors.Annotate(err, "failed to store Job entity").Tag(transient.Tag).Err()
   280  		}
   281  		return disp.AddTask(ctx, &tq.Task{
   282  			Title: fmt.Sprintf("split:job-%d", job.ID),
   283  			Payload: &tasks.SplitAndLaunch{
   284  				JobId: int64(job.ID),
   285  			},
   286  		})
   287  	})
   288  	if err != nil {
   289  		return 0, err
   290  	}
   291  	return job.ID, nil
   292  }
   293  
   294  // GetJob fetches a previously launched job given its ID.
   295  //
   296  // Returns ErrNoSuchJob if not found. All other possible errors are transient
   297  // and they are marked as such.
   298  func (ctl *Controller) GetJob(ctx context.Context, id JobID) (*Job, error) {
   299  	// Even though we could have made getJob public, we want to force API users
   300  	// to use Controller as a single facade.
   301  	return getJob(ctx, id)
   302  }
   303  
   304  // AbortJob aborts a job and returns its most recent state.
   305  //
   306  // Silently does nothing if the job is finished or already aborted.
   307  //
   308  // Returns ErrNoSuchJob is there's no such job at all. All other possible errors
   309  // are transient and they are marked as such.
   310  func (ctl *Controller) AbortJob(ctx context.Context, id JobID) (job *Job, err error) {
   311  	err = runTxn(ctx, func(ctx context.Context) error {
   312  		var err error
   313  		switch job, err = getJob(ctx, id); {
   314  		case err != nil:
   315  			return err
   316  		case isFinalState(job.State) || job.State == dsmapperpb.State_ABORTING:
   317  			return nil // nothing to abort, already done
   318  		case job.State == dsmapperpb.State_STARTING:
   319  			// Shards haven't been launched yet. Kill the job right away.
   320  			job.State = dsmapperpb.State_ABORTED
   321  		case job.State == dsmapperpb.State_RUNNING:
   322  			// Running shards will discover that the job is aborting and will
   323  			// eventually move into ABORTED state (notifying the job about it). Once
   324  			// all shards report they are done, the job itself will switch into
   325  			// ABORTED state.
   326  			job.State = dsmapperpb.State_ABORTING
   327  		}
   328  		job.Updated = clock.Now(ctx).UTC()
   329  		return errors.Annotate(datastore.Put(ctx, job), "failed to store Job entity").Tag(transient.Tag).Err()
   330  	})
   331  	if err != nil {
   332  		job = nil // don't return bogus data in case txn failed to land
   333  	}
   334  	return
   335  }
   336  
   337  ////////////////////////////////////////////////////////////////////////////////
   338  // Task queue tasks handlers.
   339  
   340  // errJobAborted is used internally as shard failure status when the job is
   341  // being aborted.
   342  //
   343  // It causes the shard to switch into ABORTED state instead of FAIL.
   344  var errJobAborted = errors.New("the job has been aborted")
   345  
   346  // splitAndLaunchHandler splits the job into shards and enqueues tasks that
   347  // process shards.
   348  func (ctl *Controller) splitAndLaunchHandler(ctx context.Context, payload proto.Message) error {
   349  	msg := payload.(*tasks.SplitAndLaunch)
   350  	now := clock.Now(ctx).UTC()
   351  
   352  	// Fetch job details. Make sure it isn't canceled and isn't running already.
   353  	job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_STARTING)
   354  	if err != nil || job == nil {
   355  		return errors.Annotate(err, "in SplitAndLaunch").Err()
   356  	}
   357  
   358  	// Figure out key ranges for shards. There may be fewer shards than requested
   359  	// if there are too few entities.
   360  	dq := job.Config.Query.ToDatastoreQuery()
   361  	ranges, err := splitter.SplitIntoRanges(ctx, dq, splitter.Params{
   362  		Shards:  job.Config.ShardCount,
   363  		Samples: 512, // should be enough for everyone...
   364  	})
   365  	if err != nil {
   366  		return errors.Annotate(err, "failed to split the query into shards").Tag(transient.Tag).Err()
   367  	}
   368  
   369  	// Create entities that hold shards state. Each one is in its own entity
   370  	// group, since the combined write rate to them is O(ShardCount), which can
   371  	// overcome limits of a single entity group.
   372  	shards := make([]*shard, len(ranges))
   373  	for idx, rng := range ranges {
   374  		shards[idx] = &shard{
   375  			JobID:         job.ID,
   376  			Index:         idx,
   377  			State:         dsmapperpb.State_STARTING,
   378  			Range:         rng,
   379  			ExpectedCount: -1,
   380  			Created:       now,
   381  			Updated:       now,
   382  		}
   383  	}
   384  
   385  	// Calculate number of entities in each shard to track shard processing
   386  	// progress. Note that this can be very slow if there are many entities.
   387  	if job.Config.TrackProgress {
   388  		logging.Infof(ctx, "Estimating the size of each shard...")
   389  		if err := fetchShardSizes(ctx, dq, shards); err != nil {
   390  			return errors.Annotate(err, "when estimating shard sizes").Err()
   391  		}
   392  	}
   393  
   394  	// We use auto-generated keys for shards to make sure crashed SplitAndLaunch
   395  	// task retries cleanly, even if the underlying key space we are mapping over
   396  	// changes between the retries (making a naive put using "<job-id>:<index>"
   397  	// key non-idempotent!).
   398  	logging.Infof(ctx, "Instantiating shards...")
   399  	if err := datastore.Put(ctx, shards); err != nil {
   400  		return errors.Annotate(err, "failed to store shards").Tag(transient.Tag).Err()
   401  	}
   402  
   403  	// Prepare shardList which is basically a manual fully consistent index for
   404  	// Job -> [Shard] relation. We can't use a regular index, since shards are all
   405  	// in different entity groups (see O(ShardCount) argument above).
   406  	//
   407  	// Log the resulting shards along the way.
   408  	shardsEnt := shardList{
   409  		Parent: datastore.KeyForObj(ctx, job),
   410  		Shards: make([]int64, len(shards)),
   411  	}
   412  	for idx, s := range shards {
   413  		shardsEnt.Shards[idx] = s.ID
   414  
   415  		l, r := "-inf", "+inf"
   416  		if s.Range.Start != nil {
   417  			l = s.Range.Start.String()
   418  		}
   419  		if s.Range.End != nil {
   420  			r = s.Range.End.String()
   421  		}
   422  		count := ""
   423  		if s.ExpectedCount != 0 {
   424  			count = fmt.Sprintf(" (%d entities)", s.ExpectedCount)
   425  		}
   426  		logging.Infof(ctx, "Shard #%d is %d: %s - %s%s", idx, s.ID, l, r, count)
   427  	}
   428  
   429  	// Transactionally associate shards with the job and launch the TQ task that
   430  	// kicks off the processing of each individual shard. We use an intermediary
   431  	// task for this since transactionally launching O(ShardCount) tasks hits TQ
   432  	// transaction limits.
   433  	//
   434  	// If SplitAndLaunch crashes before this transaction lands, there'll be some
   435  	// orphaned Shard entities, no big deal.
   436  	logging.Infof(ctx, "Updating the job and launching the fan out task...")
   437  	return runTxn(ctx, func(ctx context.Context) error {
   438  		job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_STARTING)
   439  		if err != nil || job == nil {
   440  			return errors.Annotate(err, "in SplitAndLaunch txn").Err()
   441  		}
   442  
   443  		job.State = dsmapperpb.State_RUNNING
   444  		job.Updated = now
   445  		if err := datastore.Put(ctx, job, &shardsEnt); err != nil {
   446  			return errors.Annotate(err,
   447  				"when storing Job %d and ShardList with %d shards", job.ID, len(shards),
   448  			).Tag(transient.Tag).Err()
   449  		}
   450  
   451  		return ctl.tq().AddTask(ctx, &tq.Task{
   452  			Title: fmt.Sprintf("fanout:job-%d", job.ID),
   453  			Payload: &tasks.FanOutShards{
   454  				JobId: int64(job.ID),
   455  			},
   456  		})
   457  	})
   458  }
   459  
   460  // fetchShardSizes makes a bunch of Count() queries to figure out size of each
   461  // shard.
   462  //
   463  // Updates ExpectedCount in-place.
   464  func fetchShardSizes(ctx context.Context, baseQ *datastore.Query, shards []*shard) error {
   465  	ctx, cancel := clock.WithTimeout(ctx, 10*time.Minute)
   466  	defer cancel()
   467  
   468  	err := parallel.WorkPool(32, func(tasks chan<- func() error) {
   469  		for _, sh := range shards {
   470  			sh := sh
   471  			tasks <- func() error {
   472  				n, err := datastore.CountBatch(ctx, 1024, sh.Range.Apply(baseQ))
   473  				if err == nil {
   474  					sh.ExpectedCount = n
   475  				}
   476  				return errors.Annotate(err, "for shard #%d", sh.Index).Err()
   477  			}
   478  		}
   479  	})
   480  
   481  	return transient.Tag.Apply(err)
   482  }
   483  
   484  // fanOutShardsHandler fetches a list of shards from the job and launches
   485  // named ProcessShard tasks, one per shard.
   486  func (ctl *Controller) fanOutShardsHandler(ctx context.Context, payload proto.Message) error {
   487  	msg := payload.(*tasks.FanOutShards)
   488  
   489  	// Make sure the job is still present. If it is aborted, we still need to
   490  	// launch the shards, so they notice they are being aborted. We could try
   491  	// to abort all shards right here and now, but it basically means implementing
   492  	// an alternative shard abort flow. Seems simpler just to let the regular flow
   493  	// to proceed.
   494  	job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_RUNNING, dsmapperpb.State_ABORTING)
   495  	if err != nil || job == nil {
   496  		return errors.Annotate(err, "in FanOutShards").Err()
   497  	}
   498  
   499  	// Grab the list of shards created in SplitAndLaunch. It must exist at this
   500  	// point, since the job is in Running state.
   501  	shardIDs, err := job.fetchShardIDs(ctx)
   502  	if err != nil {
   503  		return errors.Annotate(err, "in FanOutShards").Err()
   504  	}
   505  
   506  	// Enqueue a bunch of named ProcessShard tasks (one per shard) to actually
   507  	// launch shard processing. This is idempotent operation, so if FanOutShards
   508  	// crashes midway and later retried, nothing bad happens.
   509  	eg, ctx := errgroup.WithContext(ctx)
   510  	tq := ctl.tq()
   511  	for _, sid := range shardIDs {
   512  		task := makeProcessShardTask(job.ID, sid, 0, true)
   513  		eg.Go(func() error { return tq.AddTask(ctx, task) })
   514  	}
   515  	return eg.Wait()
   516  }
   517  
   518  // processShardHandler reads a bunch of entities (up to PageSize), and hands
   519  // them to the mapper.
   520  //
   521  // After doing this in a loop for 1 min, it checkpoints the state and reenqueues
   522  // itself to resume mapping in another instance of the task. This makes each
   523  // processing TQ task relatively small, so it doesn't eat a lot of memory, or
   524  // produces gigantic unreadable logs. It also makes TQ's "Pause queue" button
   525  // more handy.
   526  func (ctl *Controller) processShardHandler(ctx context.Context, payload proto.Message) error {
   527  	msg := payload.(*tasks.ProcessShard)
   528  
   529  	// Grab the shard. This returns (nil, nil) if this Task Queue task is stale
   530  	// (based on taskNum) and should be silently skipped.
   531  	sh, err := getActiveShard(ctx, msg.ShardId, msg.TaskNum)
   532  	if err != nil || sh == nil {
   533  		return errors.Annotate(err, "when fetching shard state").Err()
   534  	}
   535  	ctx = logging.SetField(ctx, "shardIdx", sh.Index)
   536  
   537  	logging.Infof(ctx,
   538  		"Resuming processing of the shard (launched %s ago)",
   539  		clock.Now(ctx).Sub(sh.Created))
   540  
   541  	// Grab the job config, make sure the job is still active.
   542  	job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_RUNNING, dsmapperpb.State_ABORTING)
   543  	if err != nil || job == nil {
   544  		return errors.Annotate(err, "in ProcessShard").Err()
   545  	}
   546  
   547  	// If the job is being killed, kill the shard as well. This will eventually
   548  	// notify the job about shard's completion. Once all shards are done, the
   549  	// job will switch into ABORTED state.
   550  	if job.State == dsmapperpb.State_ABORTING {
   551  		return ctl.finishShard(ctx, sh.ID, 0, errJobAborted)
   552  	}
   553  
   554  	// Prepare the mapper by giving the factory job parameters.
   555  	mapper, err := ctl.initMapper(ctx, job, sh.Index)
   556  	switch {
   557  	case transient.Tag.In(err):
   558  		return errors.Annotate(err, "transient error when instantiating a mapper").Err()
   559  	case err != nil:
   560  		// Kill the shard if the factory returns a fatal error.
   561  		return ctl.finishShard(ctx, sh.ID, 0, err)
   562  	}
   563  
   564  	baseQ := job.Config.Query.ToDatastoreQuery()
   565  	lastKey := sh.ResumeFrom
   566  	keys := make([]*datastore.Key, 0, job.Config.PageSize)
   567  
   568  	shardDone := false    // true when finished processing the shard
   569  	pageCount := 0        // how many pages processed successfully
   570  	itemCount := int64(0) // how many entities processed successfully
   571  
   572  	// A soft deadline when to checkpoint the progress and reenqueue the
   573  	// processing task. We never abort processing of a page midway (causes too
   574  	// many complications), so if the mapper is extremely slow, it may end up
   575  	// running longer than this deadline.
   576  	dur := time.Minute
   577  	if job.Config.TaskDuration > 0 {
   578  		dur = job.Config.TaskDuration
   579  	}
   580  	deadline := clock.Now(ctx).Add(dur)
   581  
   582  	// Optionally also put a limit on number of processed pages. Useful if the
   583  	// mapper is somehow leaking resources (not sure it is possible in Go, but
   584  	// it was definitely possible in Python).
   585  	pageCountLimit := math.MaxInt32
   586  	if job.Config.PagesPerTask > 0 {
   587  		pageCountLimit = job.Config.PagesPerTask
   588  	}
   589  
   590  	for clock.Now(ctx).Before(deadline) && pageCount < pageCountLimit {
   591  		rng := sh.Range
   592  		if lastKey != nil {
   593  			rng.Start = lastKey
   594  		}
   595  		if rng.IsEmpty() {
   596  			shardDone = true
   597  			break
   598  		}
   599  
   600  		// Fetch next batch of keys. Return an error to the outer scope where it
   601  		// eventually will bubble up to TQ (so the task is retried with exponential
   602  		// backoff).
   603  		logging.Infof(ctx, "Fetching the next batch...")
   604  		q := rng.Apply(baseQ).Limit(int32(job.Config.PageSize)).KeysOnly(true)
   605  		keys = keys[:0]
   606  		if err = datastore.GetAll(ctx, q, &keys); err != nil {
   607  			err = errors.Annotate(err, "when querying for keys").Tag(transient.Tag).Err()
   608  			break
   609  		}
   610  
   611  		// No results within the range? Processing of the shard is complete!
   612  		if len(keys) == 0 {
   613  			shardDone = true
   614  			break
   615  		}
   616  
   617  		// Let the mapper do its thing. Remember where to resume from.
   618  		logging.Infof(ctx,
   619  			"Processing %d entities: %s - %s",
   620  			len(keys),
   621  			keys[0].String(),
   622  			keys[len(keys)-1].String())
   623  		if err = mapper(ctx, keys); err != nil {
   624  			err = errors.Annotate(err, "while mapping %d keys", len(keys)).Err()
   625  			break
   626  		}
   627  		lastKey = keys[len(keys)-1]
   628  		pageCount++
   629  		itemCount += int64(len(keys))
   630  
   631  		// Note: at this point we might try to checkpoint the progress, but we must
   632  		// be careful not to exceed 1 transaction per second limit. Considering we
   633  		// also MUST checkpoint the progress at the end of the task, it is a bit
   634  		// tricky to guarantee no two checkpoints are closer than 1 sec. We can do
   635  		// silly things like sleep 1 sec before the last checkpoint, but they
   636  		// provide no guarantees.
   637  		//
   638  		// So instead we store the progress after the deadline is up. If the task
   639  		// crashes midway, up to 1 min of work will be retried. No big deal.
   640  	}
   641  
   642  	// We are done with the shard when either processed all its range or failed
   643  	// with a fatal error. finishShard would take care of notifying the parent
   644  	// job about the shard's completion.
   645  	if shardDone || (err != nil && !transient.Tag.In(err)) {
   646  		return ctl.finishShard(ctx, sh.ID, itemCount, err)
   647  	}
   648  
   649  	if lastKey != nil {
   650  		logging.Infof(ctx, "The shard processing will resume from %s", lastKey)
   651  	} else {
   652  		logging.Infof(ctx, "The shard processing will resume from scratch")
   653  	}
   654  
   655  	// If the shard isn't done and we made no progress at all, then we hit
   656  	// a transient error. Ask TQ to retry.
   657  	if pageCount == 0 {
   658  		return err
   659  	}
   660  
   661  	// Otherwise need to checkpoint the progress and either to retry this task
   662  	// (on transient errors, to get an exponential backoff from TQ), or start
   663  	// a new task.
   664  	txnErr := shardTxn(ctx, sh.ID, func(ctx context.Context, sh *shard) (bool, error) {
   665  		switch {
   666  		case sh.ProcessTaskNum != msg.TaskNum:
   667  			logging.Warningf(ctx, "Unexpected shard state: its ProcessTaskNum is %d != %d", sh.ProcessTaskNum, msg.TaskNum)
   668  			return false, nil // some other task is already running
   669  		case sh.ResumeFrom != nil && lastKey.Less(sh.ResumeFrom):
   670  			logging.Warningf(ctx, "Unexpected shard state: its ResumeFrom is %s >= %s", sh.ResumeFrom, lastKey)
   671  			return false, nil // someone already claimed to process further, let them proceed
   672  		}
   673  
   674  		sh.State = dsmapperpb.State_RUNNING
   675  		sh.ResumeFrom = lastKey
   676  		sh.ProcessedCount += itemCount
   677  
   678  		// If the processing failed, just store the progress, but do not start a
   679  		// new TQ task. Retry the current task instead (to get exponential backoff).
   680  		if err != nil {
   681  			return true, nil
   682  		}
   683  
   684  		// Otherwise launch a new task in the chain. This essentially "resets"
   685  		// the exponential backoff counter.
   686  		sh.ProcessTaskNum++
   687  		return true, ctl.tq().AddTask(ctx,
   688  			makeProcessShardTask(sh.JobID, sh.ID, sh.ProcessTaskNum, false))
   689  	})
   690  
   691  	switch {
   692  	case err != nil && txnErr == nil:
   693  		return err
   694  	case err == nil && txnErr != nil:
   695  		return errors.Annotate(txnErr, "when storing shard progress").Err()
   696  	case err != nil && txnErr != nil:
   697  		return errors.Annotate(txnErr, "when storing shard progress after a transient error (%s)", err).Err()
   698  	default: // (nil, nil)
   699  		return nil
   700  	}
   701  }
   702  
   703  // finishShard marks the shard as finished (with status based on shardErr) and
   704  // emits a task to update the parent job's status.
   705  func (ctl *Controller) finishShard(ctx context.Context, shardID, processedCount int64, shardErr error) error {
   706  	err := shardTxn(ctx, shardID, func(ctx context.Context, sh *shard) (save bool, err error) {
   707  		runtime := clock.Now(ctx).Sub(sh.Created)
   708  		switch {
   709  		case shardErr == errJobAborted:
   710  			logging.Warningf(ctx, "The job has been aborted, aborting the shard after it has been running %s", runtime)
   711  			sh.State = dsmapperpb.State_ABORTED
   712  			sh.Error = errJobAborted.Error()
   713  		case shardErr != nil:
   714  			logging.Errorf(ctx, "The shard processing failed in %s with error: %s", runtime, shardErr)
   715  			sh.State = dsmapperpb.State_FAIL
   716  			sh.Error = shardErr.Error()
   717  		default:
   718  			logging.Infof(ctx, "The shard processing finished successfully in %s", runtime)
   719  			sh.State = dsmapperpb.State_SUCCESS
   720  		}
   721  		sh.ProcessedCount += processedCount
   722  		return true, ctl.requestJobStateUpdate(ctx, sh.JobID, sh.ID)
   723  	})
   724  	return errors.Annotate(err, "when marking the shard as finished").Err()
   725  }
   726  
   727  // makeProcessShardTask creates a ProcessShard tq.Task.
   728  //
   729  // If 'named' is true, assigns it a name. Tasks are named based on their shard
   730  // IDs and an index in the chain of ProcessShard tasks (task number), so that
   731  // on retries we don't rekick already finished tasks.
   732  func makeProcessShardTask(job JobID, shardID, taskNum int64, named bool) *tq.Task {
   733  	// Note: strictly speaking including job ID in the task name is redundant,
   734  	// since shardID is already globally unique, but it doesn't hurt. Useful for
   735  	// debugging and when looking at logs and pending TQ tasks.
   736  	t := &tq.Task{
   737  		Title: fmt.Sprintf("map:job-%d-shard-%d-task-%d", job, shardID, taskNum),
   738  		Payload: &tasks.ProcessShard{
   739  			JobId:   int64(job),
   740  			ShardId: shardID,
   741  			TaskNum: taskNum,
   742  		},
   743  	}
   744  	if named {
   745  		t.DeduplicationKey = fmt.Sprintf("v1-%d-%d-%d", job, shardID, taskNum)
   746  	}
   747  	return t
   748  }
   749  
   750  // requestJobStateUpdate submits RequestJobStateUpdate task, which eventually
   751  // causes updateJobStateHandler to execute.
   752  func (ctl *Controller) requestJobStateUpdate(ctx context.Context, jobID JobID, shardID int64) error {
   753  	return ctl.tq().AddTask(ctx, &tq.Task{
   754  		Title: fmt.Sprintf("notify:job-%d-shard-%d", jobID, shardID),
   755  		Payload: &tasks.RequestJobStateUpdate{
   756  			JobId:   int64(jobID),
   757  			ShardId: shardID,
   758  		},
   759  	})
   760  }
   761  
   762  // requestJobStateUpdateHandler is called whenever state of some shard changes.
   763  //
   764  // It forwards this notification to the job (specifically updateJobStateHandler)
   765  // throttling the rate to ~0.5 QPS to avoid overwhelming job's entity group with
   766  // high write rate.
   767  func (ctl *Controller) requestJobStateUpdateHandler(ctx context.Context, payload proto.Message) error {
   768  	msg := payload.(*tasks.RequestJobStateUpdate)
   769  
   770  	// Throttle to once per 2 sec (and make sure it is always in the future). We
   771  	// rely here on a pretty good (< .5s maximum skew) clock sync on servers.
   772  	eta := clock.Now(ctx).Unix()
   773  	eta = (eta/2 + 1) * 2
   774  	dedupKey := fmt.Sprintf("update-job-state-v1:%d:%d", msg.JobId, eta)
   775  
   776  	err := ctl.tq().AddTask(ctx, &tq.Task{
   777  		DeduplicationKey: dedupKey,
   778  		Title:            fmt.Sprintf("update:job-%d", msg.JobId),
   779  		ETA:              time.Unix(eta, 0),
   780  		Payload:          &tasks.UpdateJobState{JobId: msg.JobId},
   781  	})
   782  	return errors.Annotate(err, "when adding UpdateJobState task").Err()
   783  }
   784  
   785  // updateJobStateHandler is called some time later after one or more shards have
   786  // changed state.
   787  //
   788  // It calculates overall job state based on the state of its shards.
   789  func (ctl *Controller) updateJobStateHandler(ctx context.Context, payload proto.Message) error {
   790  	msg := payload.(*tasks.UpdateJobState)
   791  
   792  	// Get the job and all its shards in their most recent state.
   793  	job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_RUNNING, dsmapperpb.State_ABORTING)
   794  	if err != nil || job == nil {
   795  		return errors.Annotate(err, "in UpdateJobState").Err()
   796  	}
   797  	shards, err := job.fetchShards(ctx)
   798  	if err != nil {
   799  		return errors.Annotate(err, "failed to fetch shards").Err()
   800  	}
   801  
   802  	// Switch the job into a final state only when all shards are done running.
   803  	perState := make(map[dsmapperpb.State]int, len(dsmapperpb.State_name))
   804  	finished := 0
   805  	for _, sh := range shards {
   806  		logging.Infof(ctx, "Shard #%d (%d) is in state %s", sh.Index, sh.ID, sh.State)
   807  		perState[sh.State]++
   808  		if isFinalState(sh.State) {
   809  			finished++
   810  		}
   811  	}
   812  	if finished != len(shards) {
   813  		return nil
   814  	}
   815  
   816  	jobState := dsmapperpb.State_SUCCESS
   817  	switch {
   818  	case perState[dsmapperpb.State_ABORTED] != 0:
   819  		jobState = dsmapperpb.State_ABORTED
   820  	case perState[dsmapperpb.State_FAIL] != 0:
   821  		jobState = dsmapperpb.State_FAIL
   822  	}
   823  
   824  	return runTxn(ctx, func(ctx context.Context) error {
   825  		job, err := getJobInState(ctx, JobID(msg.JobId), dsmapperpb.State_RUNNING, dsmapperpb.State_ABORTING)
   826  		if err != nil || job == nil {
   827  			return errors.Annotate(err, "in UpdateJobState txn").Err()
   828  		}
   829  
   830  		// Make sure an aborting job ends up in aborted state, even if all its
   831  		// shards manged to finish. It looks weird when an ABORTING job moves
   832  		// into e.g. SUCCESS state.
   833  		if job.State == dsmapperpb.State_ABORTING {
   834  			job.State = dsmapperpb.State_ABORTED
   835  		} else {
   836  			job.State = jobState
   837  		}
   838  		job.Updated = clock.Now(ctx).UTC()
   839  
   840  		runtime := job.Updated.Sub(job.Created)
   841  		switch job.State {
   842  		case dsmapperpb.State_SUCCESS:
   843  			logging.Infof(ctx, "The job finished successfully in %s", runtime)
   844  		case dsmapperpb.State_FAIL:
   845  			logging.Errorf(ctx, "The job finished with %d shards failing in %s", perState[dsmapperpb.State_FAIL], runtime)
   846  			for _, sh := range shards {
   847  				if sh.State == dsmapperpb.State_FAIL {
   848  					logging.Errorf(ctx, "Shard #%d (%d) error - %s", sh.Index, sh.ID, sh.Error)
   849  				}
   850  			}
   851  		case dsmapperpb.State_ABORTED:
   852  			logging.Warningf(ctx, "The job has been aborted after %s: %d shards succeeded, %d shards failed, %d shards aborted",
   853  				runtime, perState[dsmapperpb.State_SUCCESS], perState[dsmapperpb.State_FAIL], perState[dsmapperpb.State_ABORTED])
   854  		}
   855  
   856  		return transient.Tag.Apply(datastore.Put(ctx, job))
   857  	})
   858  }