github.com/pachyderm/pachyderm@v1.13.4/src/server/worker/pipeline/transform/registry.go (about)

     1  package transform
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"path"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/gogo/protobuf/types"
    15  	"golang.org/x/sync/errgroup"
    16  
    17  	"github.com/pachyderm/pachyderm/src/client"
    18  	"github.com/pachyderm/pachyderm/src/client/limit"
    19  	"github.com/pachyderm/pachyderm/src/client/pfs"
    20  	"github.com/pachyderm/pachyderm/src/client/pkg/errors"
    21  	"github.com/pachyderm/pachyderm/src/client/pkg/pbutil"
    22  	"github.com/pachyderm/pachyderm/src/client/pps"
    23  	pfsserver "github.com/pachyderm/pachyderm/src/server/pfs"
    24  	"github.com/pachyderm/pachyderm/src/server/pkg/backoff"
    25  	col "github.com/pachyderm/pachyderm/src/server/pkg/collection"
    26  	"github.com/pachyderm/pachyderm/src/server/pkg/errutil"
    27  	"github.com/pachyderm/pachyderm/src/server/pkg/ppsutil"
    28  	"github.com/pachyderm/pachyderm/src/server/pkg/uuid"
    29  	"github.com/pachyderm/pachyderm/src/server/pkg/work"
    30  	ppsserver "github.com/pachyderm/pachyderm/src/server/pps"
    31  	"github.com/pachyderm/pachyderm/src/server/worker/common"
    32  	"github.com/pachyderm/pachyderm/src/server/worker/datum"
    33  	"github.com/pachyderm/pachyderm/src/server/worker/driver"
    34  	"github.com/pachyderm/pachyderm/src/server/worker/logs"
    35  	"github.com/pachyderm/pachyderm/src/server/worker/pipeline/transform/chain"
    36  )
    37  
    38  const (
    39  	taskGranularity = 4
    40  )
    41  
    42  func jobArtifactPrefix(jobID string) string {
    43  	return path.Join("artifacts", fmt.Sprintf("job-%s", jobID))
    44  }
    45  
    46  func jobArtifactChunkDatumList(jobID string, subtaskID string) string {
    47  	return path.Join(jobArtifactPrefix(jobID), fmt.Sprintf("chunk-datum-list-%s", subtaskID))
    48  }
    49  
    50  func jobArtifactHashtrees(jobID string) string {
    51  	return path.Join(jobArtifactPrefix(jobID), "hashtrees")
    52  }
    53  
    54  func jobArtifactRecoveredObject(jobID string) string {
    55  	return path.Join(jobArtifactPrefix(jobID), "recovered")
    56  }
    57  
    58  type pendingJob struct {
    59  	driver          driver.Driver
    60  	commitInfo      *pfs.CommitInfo
    61  	statsCommitInfo *pfs.CommitInfo
    62  	cancel          context.CancelFunc
    63  	logger          logs.TaggedLogger
    64  	ji              *pps.JobInfo
    65  	jdit            chain.JobDatumIterator
    66  	taskMaster      *work.Master
    67  
    68  	// These are filled in when the RUNNING phase completes, but may be re-fetched
    69  	// from object storage.
    70  	chunkHashtrees []*HashtreeInfo
    71  	statsHashtrees []*HashtreeInfo
    72  }
    73  
    74  type registry struct {
    75  	driver      driver.Driver
    76  	logger      logs.TaggedLogger
    77  	taskQueue   *work.TaskQueue
    78  	concurrency int64
    79  	limiter     limit.ConcurrencyLimiter
    80  	jobChain    chain.JobChain
    81  }
    82  
    83  type hasher struct {
    84  	name string
    85  	salt string
    86  }
    87  
    88  func (h *hasher) Hash(inputs []*common.Input) string {
    89  	return common.HashDatum(h.name, h.salt, inputs)
    90  }
    91  
    92  // Returns the registry or lazily instantiates it
    93  func newRegistry(
    94  	logger logs.TaggedLogger,
    95  	driver driver.Driver,
    96  ) (*registry, error) {
    97  	// Determine the maximum number of concurrent tasks we will allow
    98  	concurrency, err := driver.ExpectedNumWorkers()
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	taskQueue, err := driver.NewTaskQueue()
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	return &registry{
   109  		driver:      driver,
   110  		logger:      logger,
   111  		concurrency: concurrency,
   112  		taskQueue:   taskQueue,
   113  		limiter:     limit.New(int(concurrency)),
   114  		jobChain:    nil,
   115  	}, nil
   116  }
   117  
   118  // Helper function for succeedJob and failJob, do not use directly
   119  func finishJob(
   120  	pipelineInfo *pps.PipelineInfo,
   121  	pachClient *client.APIClient,
   122  	jobInfo *pps.JobInfo,
   123  	state pps.JobState,
   124  	reason string,
   125  	datums *pfs.Object,
   126  	trees []*pfs.Object,
   127  	size uint64,
   128  	statsTrees []*pfs.Object,
   129  	statsSize uint64,
   130  ) error {
   131  	// Optimistically update the local state and reason - if any errors occur the
   132  	// local state will be reloaded way up the stack
   133  	jobInfo.State = state
   134  	jobInfo.Reason = reason
   135  
   136  	if _, err := pachClient.RunBatchInTransaction(func(builder *client.TransactionBuilder) error {
   137  		if pipelineInfo.S3Out {
   138  			if err := builder.FinishCommit(jobInfo.OutputCommit.Repo.Name, jobInfo.OutputCommit.ID); err != nil {
   139  				return err
   140  			}
   141  		} else {
   142  			if jobInfo.StatsCommit != nil {
   143  				if _, err := builder.PfsAPIClient.FinishCommit(pachClient.Ctx(), &pfs.FinishCommitRequest{
   144  					Commit:    jobInfo.StatsCommit,
   145  					Empty:     statsTrees == nil,
   146  					Trees:     statsTrees,
   147  					SizeBytes: statsSize,
   148  				}); err != nil {
   149  					return err
   150  				}
   151  			}
   152  
   153  			if _, err := builder.PfsAPIClient.FinishCommit(pachClient.Ctx(), &pfs.FinishCommitRequest{
   154  				Commit:    jobInfo.OutputCommit,
   155  				Empty:     trees == nil,
   156  				Datums:    datums,
   157  				Trees:     trees,
   158  				SizeBytes: size,
   159  			}); err != nil {
   160  				return err
   161  			}
   162  		}
   163  
   164  		return writeJobInfo(&builder.APIClient, jobInfo)
   165  	}); err != nil {
   166  		if pfsserver.IsCommitFinishedErr(err) || pfsserver.IsCommitNotFoundErr(err) || pfsserver.IsCommitDeletedErr(err) || ppsserver.IsJobFinishedErr(err) {
   167  			// For certain types of errors, we want to reattempt these operations
   168  			// outside of a transaction (in case the job or commits were affected by
   169  			// some non-transactional code elsewhere, we can attempt to recover)
   170  			return recoverFinishedJob(pipelineInfo, pachClient, jobInfo, state, reason, datums, trees, size, statsTrees, statsSize)
   171  		}
   172  		// For other types of errors, we want to fail the job supervision and let it
   173  		// reattempt later
   174  		return err
   175  	}
   176  	return nil
   177  }
   178  
   179  // recoverFinishedJob performs job and output commit updates outside of a
   180  // transaction in an attempt to get everything in a consistent state if they
   181  // were modified non-transactionally elsewhere.
   182  func recoverFinishedJob(
   183  	pipelineInfo *pps.PipelineInfo,
   184  	pachClient *client.APIClient,
   185  	jobInfo *pps.JobInfo,
   186  	state pps.JobState,
   187  	reason string,
   188  	datums *pfs.Object,
   189  	trees []*pfs.Object,
   190  	size uint64,
   191  	statsTrees []*pfs.Object,
   192  	statsSize uint64,
   193  ) error {
   194  	if pipelineInfo.S3Out {
   195  		if err := pachClient.FinishCommit(jobInfo.OutputCommit.Repo.Name, jobInfo.OutputCommit.ID); err != nil {
   196  			return err
   197  		}
   198  	} else {
   199  		if jobInfo.StatsCommit != nil {
   200  			if _, err := pachClient.PfsAPIClient.FinishCommit(pachClient.Ctx(), &pfs.FinishCommitRequest{
   201  				Commit:    jobInfo.StatsCommit,
   202  				Empty:     statsTrees == nil,
   203  				Trees:     statsTrees,
   204  				SizeBytes: statsSize,
   205  			}); err != nil {
   206  				if !pfsserver.IsCommitFinishedErr(err) && !pfsserver.IsCommitNotFoundErr(err) && !pfsserver.IsCommitDeletedErr(err) {
   207  					return err
   208  				}
   209  			}
   210  		}
   211  
   212  		if _, err := pachClient.PfsAPIClient.FinishCommit(pachClient.Ctx(), &pfs.FinishCommitRequest{
   213  			Commit:    jobInfo.OutputCommit,
   214  			Empty:     trees == nil,
   215  			Datums:    datums,
   216  			Trees:     trees,
   217  			SizeBytes: size,
   218  		}); err != nil {
   219  			if !pfsserver.IsCommitFinishedErr(err) && !pfsserver.IsCommitNotFoundErr(err) && !pfsserver.IsCommitDeletedErr(err) {
   220  				return err
   221  			}
   222  		}
   223  	}
   224  
   225  	if err := writeJobInfo(pachClient, jobInfo); err != nil {
   226  		if !ppsserver.IsJobFinishedErr(err) {
   227  			return err
   228  		}
   229  	}
   230  	return nil
   231  }
   232  
   233  // succeedJob will move a job to the successful state and propagate to any
   234  // dependent jobs in the jobChain.
   235  func (reg *registry) succeedJob(
   236  	pj *pendingJob,
   237  	trees []*pfs.Object,
   238  	size uint64,
   239  	statsTrees []*pfs.Object,
   240  	statsSize uint64,
   241  ) error {
   242  	datums, err := reg.storeJobDatums(pj)
   243  	if err != nil {
   244  		return err
   245  	}
   246  
   247  	var newState pps.JobState
   248  	if pj.ji.Egress == nil {
   249  		pj.logger.Logf("job successful, closing commits")
   250  		newState = pps.JobState_JOB_SUCCESS
   251  	} else {
   252  		pj.logger.Logf("job successful, advancing to egress")
   253  		newState = pps.JobState_JOB_EGRESSING
   254  	}
   255  
   256  	// Use the registry's driver so that the job's supervision goroutine cannot cancel us
   257  	if err := finishJob(reg.driver.PipelineInfo(), reg.driver.PachClient(), pj.ji, newState, "", datums, trees, size, statsTrees, statsSize); err != nil {
   258  		return err
   259  	}
   260  
   261  	if err := reg.cleanJobArtifacts(pj.ji.Job); err != nil {
   262  		return err
   263  	}
   264  
   265  	return reg.jobChain.Succeed(pj)
   266  }
   267  
   268  func (reg *registry) failJob(
   269  	pj *pendingJob,
   270  	reason string,
   271  	statsTrees []*pfs.Object,
   272  	statsSize uint64,
   273  ) error {
   274  	pj.logger.Logf("failing job with reason: %s", reason)
   275  
   276  	// Use the registry's driver so that the job's supervision goroutine cannot cancel us
   277  	if err := finishJob(reg.driver.PipelineInfo(), reg.driver.PachClient(), pj.ji, pps.JobState_JOB_FAILURE, reason, nil, nil, 0, statsTrees, statsSize); err != nil {
   278  		return err
   279  	}
   280  
   281  	if err := reg.cleanJobArtifacts(pj.ji.Job); err != nil {
   282  		return err
   283  	}
   284  
   285  	// Disregard job chain errors when failing the job - in case of egress, the
   286  	// pending job should already have been removed from the chain.
   287  	reg.jobChain.Fail(pj)
   288  	return nil
   289  }
   290  
   291  func (reg *registry) killJob(
   292  	pj *pendingJob,
   293  	reason string,
   294  ) error {
   295  	pj.logger.Logf("killing job with reason: %s", reason)
   296  
   297  	// Use the registry's driver so that the job's supervision goroutine cannot cancel us
   298  	if err := finishJob(reg.driver.PipelineInfo(), reg.driver.PachClient(), pj.ji, pps.JobState_JOB_KILLED, reason, nil, nil, 0, nil, 0); err != nil {
   299  		return err
   300  	}
   301  
   302  	if err := reg.cleanJobArtifacts(pj.ji.Job); err != nil {
   303  		return err
   304  	}
   305  
   306  	return reg.jobChain.Fail(pj)
   307  }
   308  
   309  func (reg *registry) cleanJobArtifacts(job *pps.Job) error {
   310  	reg.logger.WithJob(job.ID).Logf("Cleaning job artifacts")
   311  	prefix := jobArtifactPrefix(job.ID)
   312  	_, err := reg.driver.PachClient().DeleteObjDirect(
   313  		reg.driver.PachClient().Ctx(),
   314  		&pfs.DeleteObjDirectRequest{Prefix: prefix},
   315  	)
   316  	return err
   317  }
   318  
   319  func writeJobInfo(pachClient *client.APIClient, jobInfo *pps.JobInfo) error {
   320  	_, err := pachClient.PpsAPIClient.UpdateJobState(pachClient.Ctx(), &pps.UpdateJobStateRequest{
   321  		Job:           jobInfo.Job,
   322  		State:         jobInfo.State,
   323  		Reason:        jobInfo.Reason,
   324  		Restart:       jobInfo.Restart,
   325  		DataProcessed: jobInfo.DataProcessed,
   326  		DataSkipped:   jobInfo.DataSkipped,
   327  		DataTotal:     jobInfo.DataTotal,
   328  		DataFailed:    jobInfo.DataFailed,
   329  		DataRecovered: jobInfo.DataRecovered,
   330  		Stats:         jobInfo.Stats,
   331  	})
   332  	return err
   333  }
   334  
   335  func (pj *pendingJob) writeJobInfo() error {
   336  	pj.logger.Logf("updating job info, state: %s", pj.ji.State)
   337  	return writeJobInfo(pj.driver.PachClient(), pj.ji)
   338  }
   339  
   340  func (reg *registry) initializeJobChain(commitInfo *pfs.CommitInfo) error {
   341  	if reg.jobChain == nil {
   342  		// Get the most recent successful commit starting from the given commit
   343  		parentCommitInfo, err := reg.getParentCommitInfo(commitInfo)
   344  		if err != nil {
   345  			return err
   346  		}
   347  
   348  		var baseDatums chain.DatumSet
   349  		if parentCommitInfo != nil {
   350  			baseDatums, err = reg.getDatumSet(parentCommitInfo.Datums)
   351  			if err != nil {
   352  				return err
   353  			}
   354  		} else {
   355  			baseDatums = make(chain.DatumSet)
   356  		}
   357  
   358  		if reg.driver.PipelineInfo().ReprocessSpec == client.ReprocessSpecEveryJob || reg.driver.PipelineInfo().S3Out {
   359  			// When running a pipeline with S3Out (or with skipping disabled), we need
   360  			// to yield every datum for every job, use a no-skip job chain for this.
   361  			reg.jobChain = chain.NewNoSkipJobChain(
   362  				&hasher{
   363  					name: reg.driver.PipelineInfo().Pipeline.Name,
   364  					salt: reg.driver.PipelineInfo().Salt,
   365  				},
   366  			)
   367  		} else {
   368  			reg.jobChain = chain.NewJobChain(
   369  				&hasher{
   370  					name: reg.driver.PipelineInfo().Pipeline.Name,
   371  					salt: reg.driver.PipelineInfo().Salt,
   372  				},
   373  				baseDatums,
   374  			)
   375  		}
   376  	}
   377  
   378  	return nil
   379  }
   380  
   381  // Generate a datum task (and split it up into subtasks) for the added datums
   382  // in the pending job.
   383  func (reg *registry) sendDatumTasks(ctx context.Context, pj *pendingJob, numDatums int64, subtasks chan<- *work.Task) error {
   384  	chunkSpec := pj.ji.ChunkSpec
   385  	if chunkSpec == nil {
   386  		chunkSpec = &pps.ChunkSpec{}
   387  	}
   388  
   389  	maxDatumsPerTask := int64(chunkSpec.Number)
   390  	maxBytesPerTask := int64(chunkSpec.SizeBytes)
   391  	driver := pj.driver.WithContext(ctx)
   392  	var numTasks int64
   393  	if numDatums < reg.concurrency*taskGranularity {
   394  		numTasks = numDatums
   395  	} else if maxDatumsPerTask > 0 && numDatums/maxDatumsPerTask > reg.concurrency*taskGranularity {
   396  		numTasks = numDatums / maxDatumsPerTask
   397  	} else {
   398  		numTasks = reg.concurrency * taskGranularity
   399  	}
   400  	datumsPerTask := int64(math.Ceil(float64(numDatums) / float64(numTasks)))
   401  
   402  	datumsSize := int64(0)
   403  	datums := []*DatumInputs{}
   404  
   405  	// writeDatumsObject is a helper function to synchronously write out the
   406  	// current list of datums to a temporary job artifact object.
   407  	writeDatumsObject := func(objectName string) (retErr error) {
   408  		writer, err := driver.PachClient().DirectObjWriter(objectName)
   409  		if err != nil {
   410  			return errors.EnsureStack(err)
   411  		}
   412  		defer func() {
   413  			if err := writer.Close(); err != nil && retErr == nil {
   414  				retErr = errors.EnsureStack(err)
   415  			}
   416  		}()
   417  
   418  		protoWriter := pbutil.NewWriter(writer)
   419  		_, err = protoWriter.Write(&DatumInputsList{Datums: datums})
   420  		return err
   421  	}
   422  
   423  	// finishTask will finish the currently-writing object and append it to the
   424  	// subtasks, then reset all the relevant variables
   425  	finishTask := func() error {
   426  		subtaskID := uuid.NewWithoutDashes()
   427  		objectName := jobArtifactChunkDatumList(pj.ji.Job.ID, subtaskID)
   428  
   429  		if err := writeDatumsObject(objectName); err != nil {
   430  			return err
   431  		}
   432  
   433  		taskData, err := serializeDatumData(&DatumData{DatumsObject: objectName, OutputCommit: pj.ji.OutputCommit, JobID: pj.ji.Job.ID})
   434  		if err != nil {
   435  			return err
   436  		}
   437  
   438  		select {
   439  		case subtasks <- &work.Task{ID: subtaskID, Data: taskData}:
   440  		case <-ctx.Done():
   441  			return ctx.Err()
   442  		}
   443  
   444  		datumsSize = 0
   445  		datums = []*DatumInputs{}
   446  		return nil
   447  	}
   448  
   449  	// Build up chunks to be put into work tasks from the datum iterator
   450  	for i := int64(0); i < numDatums; i++ {
   451  		inputs, index := pj.jdit.NextDatum()
   452  		if inputs == nil {
   453  			return errors.New("job datum iterator returned nil inputs")
   454  		}
   455  
   456  		datums = append(datums, &DatumInputs{Inputs: inputs, Index: index})
   457  
   458  		// If we have enough input bytes, finish the task
   459  		if maxBytesPerTask != 0 {
   460  			for _, input := range inputs {
   461  				datumsSize += int64(input.FileInfo.SizeBytes)
   462  			}
   463  			if datumsSize >= maxBytesPerTask {
   464  				if err := finishTask(); err != nil {
   465  					return err
   466  				}
   467  			}
   468  		}
   469  
   470  		// If we hit the upper threshold for task size, finish the task
   471  		if int64(len(datums)) >= datumsPerTask {
   472  			if err := finishTask(); err != nil {
   473  				return err
   474  			}
   475  		}
   476  	}
   477  
   478  	if len(datums) > 0 {
   479  		if err := finishTask(); err != nil {
   480  			return err
   481  		}
   482  	}
   483  
   484  	return nil
   485  }
   486  
   487  func serializeDatumData(data *DatumData) (*types.Any, error) {
   488  	serialized, err := types.MarshalAny(data)
   489  	if err != nil {
   490  		return nil, err
   491  	}
   492  	return serialized, nil
   493  }
   494  
   495  func deserializeDatumData(any *types.Any) (*DatumData, error) {
   496  	data := &DatumData{}
   497  	if err := types.UnmarshalAny(any, data); err != nil {
   498  		return nil, err
   499  	}
   500  	return data, nil
   501  }
   502  
   503  func serializeMergeData(data *MergeData) (*types.Any, error) {
   504  	serialized, err := types.MarshalAny(data)
   505  	if err != nil {
   506  		return nil, err
   507  	}
   508  	return serialized, nil
   509  }
   510  
   511  func deserializeMergeData(any *types.Any) (*MergeData, error) {
   512  	data := &MergeData{}
   513  	if err := types.UnmarshalAny(any, data); err != nil {
   514  		return nil, err
   515  	}
   516  	return data, nil
   517  }
   518  
   519  func (reg *registry) getDatumSet(datumsObj *pfs.Object) (_ chain.DatumSet, retErr error) {
   520  	pachClient := reg.driver.PachClient()
   521  	if datumsObj == nil {
   522  		return nil, nil
   523  	}
   524  	r, err := pachClient.GetObjectReader(datumsObj.Hash)
   525  	if err != nil {
   526  		return nil, err
   527  	}
   528  	defer func() {
   529  		if err := r.Close(); err != nil && retErr != nil {
   530  			retErr = err
   531  		}
   532  	}()
   533  	pbr := pbutil.NewReader(r)
   534  	datums := make(chain.DatumSet)
   535  	for {
   536  		k, err := pbr.ReadBytes()
   537  		if err != nil {
   538  			if errors.Is(err, io.EOF) {
   539  				return datums, retErr
   540  			}
   541  			return nil, err
   542  		}
   543  		datums[string(k)]++
   544  	}
   545  }
   546  
   547  // Walk from the given commit back to a successfully completed commit so we can
   548  // get the initial state of datumsBase in the registry.
   549  func (reg *registry) getParentCommitInfo(commitInfo *pfs.CommitInfo) (*pfs.CommitInfo, error) {
   550  	pachClient := reg.driver.PachClient()
   551  	// Walk up the commit chain to find a successfully finished commit
   552  	for commitInfo.ParentCommit != nil {
   553  		parentCommitInfo, err := pachClient.PfsAPIClient.InspectCommit(pachClient.Ctx(),
   554  			&pfs.InspectCommitRequest{
   555  				Commit: commitInfo.ParentCommit,
   556  			})
   557  		if err != nil {
   558  			return nil, err
   559  		}
   560  		// If the parent commit isn't finished, then finish it and continue the traversal.
   561  		// If the parent commit is finished and has output, then return it.
   562  		if parentCommitInfo.Finished == nil {
   563  			if _, err := pachClient.PfsAPIClient.FinishCommit(pachClient.Ctx(), &pfs.FinishCommitRequest{
   564  				Commit: parentCommitInfo.Commit,
   565  				Empty:  true,
   566  			}); err != nil && !pfsserver.IsCommitFinishedErr(err) {
   567  				return nil, err
   568  			}
   569  		} else if parentCommitInfo.Trees != nil {
   570  			return parentCommitInfo, nil
   571  		}
   572  		commitInfo = parentCommitInfo
   573  	}
   574  	return nil, nil
   575  }
   576  
   577  // ensureJob loads an existing job for the given commit in the pipeline, or
   578  // creates it if there is none. If more than one such job exists, an error will
   579  // be generated.
   580  func (reg *registry) ensureJob(
   581  	commitInfo *pfs.CommitInfo,
   582  	statsCommit *pfs.Commit,
   583  ) (*pps.JobInfo, error) {
   584  	pachClient := reg.driver.PachClient()
   585  
   586  	// Check if a job was previously created for this commit. If not, make one
   587  	jobInfos, err := pachClient.ListJob("", nil, commitInfo.Commit, -1, true)
   588  	if err != nil {
   589  		return nil, err
   590  	}
   591  	if len(jobInfos) > 1 {
   592  		return nil, errors.Errorf("multiple jobs found for commit: %s/%s", commitInfo.Commit.Repo.Name, commitInfo.Commit.ID)
   593  	} else if len(jobInfos) < 1 {
   594  		job, err := pachClient.CreateJob(reg.driver.PipelineInfo().Pipeline.Name, commitInfo.Commit, statsCommit)
   595  		if err != nil {
   596  			return nil, err
   597  		}
   598  		reg.logger.Logf("created new job %q for output commit %q", job.ID, commitInfo.Commit.ID)
   599  		// get jobInfo to look up spec commit, pipeline version, etc (if this
   600  		// worker is stale and about to be killed, the new job may have a newer
   601  		// pipeline version than the master. Or if the commit is stale, it may
   602  		// have an older pipeline version than the master)
   603  		return pachClient.InspectJob(job.ID, false)
   604  	}
   605  
   606  	// get latest job state
   607  	reg.logger.Logf("found existing job %q for output commit %q", jobInfos[0].Job.ID, commitInfo.Commit.ID)
   608  	return pachClient.InspectJob(jobInfos[0].Job.ID, false)
   609  }
   610  
   611  func (reg *registry) startJob(commitInfo *pfs.CommitInfo, statsCommit *pfs.Commit) error {
   612  	if err := reg.initializeJobChain(commitInfo); err != nil {
   613  		return err
   614  	}
   615  
   616  	var asyncEg *errgroup.Group
   617  	reg.limiter.Acquire()
   618  
   619  	defer func() {
   620  		if asyncEg == nil {
   621  			// The async errgroup never got started, so give up the limiter lock
   622  			reg.limiter.Release()
   623  		}
   624  	}()
   625  
   626  	jobInfo, err := reg.ensureJob(commitInfo, statsCommit)
   627  	if err != nil {
   628  		return err
   629  	}
   630  
   631  	var statsCommitInfo *pfs.CommitInfo
   632  	if statsCommit != nil {
   633  		statsCommitInfo, err = reg.driver.PachClient().InspectCommit(statsCommit.Repo.Name, statsCommit.ID)
   634  		if err != nil {
   635  			return err
   636  		}
   637  	}
   638  
   639  	jobCtx, cancel := context.WithCancel(reg.driver.PachClient().Ctx())
   640  	driver := reg.driver.WithContext(jobCtx)
   641  
   642  	// Build the pending job to send out to workers - this will block if we have
   643  	// too many already
   644  	pj := &pendingJob{
   645  		driver:          driver,
   646  		commitInfo:      commitInfo,
   647  		statsCommitInfo: statsCommitInfo,
   648  		logger:          reg.logger.WithJob(jobInfo.Job.ID),
   649  		ji:              jobInfo,
   650  		cancel:          cancel,
   651  	}
   652  
   653  	switch {
   654  	case ppsutil.IsTerminal(jobInfo.State):
   655  		// Make sure the output commits are closed
   656  		if err := recoverFinishedJob(pj.driver.PipelineInfo(), pj.driver.PachClient(), pj.ji, jobInfo.State, "", nil, nil, 0, nil, 0); err != nil {
   657  			return err
   658  		}
   659  		// ignore finished jobs (e.g. old pipeline & already killed)
   660  		return nil
   661  	case jobInfo.PipelineVersion < reg.driver.PipelineInfo().Version:
   662  		// kill unfinished jobs from old pipelines (should generally be cleaned
   663  		// up by PPS master, but the PPS master can fail, and if these jobs
   664  		// aren't killed, future jobs will hang indefinitely waiting for their
   665  		// parents to finish)
   666  		pj.ji.State = pps.JobState_JOB_KILLED
   667  		pj.ji.Reason = "pipeline has been updated"
   668  		if err := pj.writeJobInfo(); err != nil {
   669  			if !ppsserver.IsJobFinishedErr(err) {
   670  				return errors.Wrap(err, "failed to kill stale job")
   671  			}
   672  		}
   673  		return nil
   674  	case jobInfo.PipelineVersion > reg.driver.PipelineInfo().Version:
   675  		return errors.Errorf("job %s's version (%d) greater than pipeline's "+
   676  			"version (%d), this should automatically resolve when the worker "+
   677  			"is updated", jobInfo.Job.ID, jobInfo.PipelineVersion, reg.driver.PipelineInfo().Version)
   678  	}
   679  
   680  	// Inputs must be ready before we can construct a datum iterator, so do this
   681  	// synchronously to ensure correct order in the jobChain.
   682  	if pj.ji.State == pps.JobState_JOB_STARTING {
   683  		if err := pj.logger.LogStep("waiting for job inputs", func() error {
   684  			return reg.processJobStarting(pj)
   685  		}); err != nil {
   686  			return err
   687  		}
   688  	}
   689  
   690  	var afterTime time.Duration
   691  	if pj.ji.JobTimeout != nil {
   692  		startTime, err := types.TimestampFromProto(pj.ji.Started)
   693  		if err != nil {
   694  			return err
   695  		}
   696  		timeout, err := types.DurationFromProto(pj.ji.JobTimeout)
   697  		if err != nil {
   698  			return err
   699  		}
   700  		afterTime = time.Until(startTime.Add(timeout))
   701  	}
   702  
   703  	asyncEg, jobCtx = errgroup.WithContext(pj.driver.PachClient().Ctx())
   704  	pj.driver = reg.driver.WithContext(jobCtx)
   705  
   706  	// Use a separate context for egress - it should not be canceled when the
   707  	// output commit is closed.
   708  	egressCtx, egressCancel := context.WithCancel(reg.driver.PachClient().Ctx())
   709  
   710  	// If the job is already in egressing, we need to skip the job chain
   711  	if pj.ji.State != pps.JobState_JOB_EGRESSING {
   712  		pj.jdit, err = reg.jobChain.Start(pj)
   713  		if err != nil {
   714  			return err
   715  		}
   716  		pj.ji.DataTotal = pj.jdit.MaxLen()
   717  		if err := pj.writeJobInfo(); err != nil {
   718  			return err
   719  		}
   720  	}
   721  
   722  	asyncEg.Go(func() error {
   723  		defer pj.cancel()
   724  
   725  		if pj.ji.JobTimeout != nil {
   726  			pj.logger.Logf("cancelling job at: %+v", afterTime)
   727  			timer := time.AfterFunc(afterTime, func() {
   728  				reg.killJob(pj, "job timed out")
   729  
   730  				// We cancel egress after the timeout, but we don't cancel egress if the
   731  				// job's output commit is closed - that is both how jobs complete and
   732  				// how they are killed by user action, but there's no way to distinguish
   733  				// the two at the moment.
   734  				egressCancel()
   735  			})
   736  			defer timer.Stop()
   737  		}
   738  
   739  		// We don't cancel the errgroup if the supervise fails because the job will
   740  		// be canceled anyway, and this makes it easier to filter out spurious
   741  		// errors when waiting on the errgroup.
   742  		backoff.RetryUntilCancel(pj.driver.PachClient().Ctx(), func() error {
   743  			return reg.superviseJob(pj)
   744  		}, backoff.NewInfiniteBackOff(), func(err error, d time.Duration) error {
   745  			pj.logger.Logf("error in superviseJob: %v, retrying in %+v", err, d)
   746  			return nil
   747  		})
   748  
   749  		return nil
   750  	})
   751  
   752  	asyncEg.Go(func() error {
   753  		defer pj.cancel()
   754  		mutex := &sync.Mutex{}
   755  		mutex.Lock()
   756  		defer mutex.Unlock()
   757  
   758  		// This runs the callback asynchronously, but we want to block the errgroup until it completes
   759  		if err := reg.taskQueue.RunTask(pj.driver.PachClient().Ctx(), func(master *work.Master) {
   760  			defer mutex.Unlock()
   761  			pj.taskMaster = master
   762  
   763  			backoff.RetryUntilCancel(pj.driver.PachClient().Ctx(), func() error {
   764  				var err error
   765  				for err == nil {
   766  					err = reg.processJob(pj)
   767  				}
   768  				if errors.Is(err, errutil.ErrBreak) {
   769  					return nil
   770  				}
   771  				return err
   772  			}, backoff.NewInfiniteBackOff(), func(err error, d time.Duration) error {
   773  				pj.logger.Logf("processJob error: %v, retrying in %v", err, d)
   774  				for err != nil {
   775  					if st, ok := err.(errors.StackTracer); ok {
   776  						pj.logger.Logf("error stack: %+v", st.StackTrace())
   777  					}
   778  					err = errors.Unwrap(err)
   779  				}
   780  
   781  				pj.jdit.Reset()
   782  
   783  				// Get job state, increment restarts, write job state
   784  				pj.ji, err = pj.driver.PachClient().InspectJob(pj.ji.Job.ID, false)
   785  				if err != nil {
   786  					return err
   787  				}
   788  
   789  				pj.ji.Restart++
   790  				if err := pj.writeJobInfo(); err != nil {
   791  					pj.logger.Logf("error incrementing restart count for job (%s): %v", pj.ji.Job.ID, err)
   792  				}
   793  
   794  				// Reload the job's commitInfo as it may have changed
   795  				pj.commitInfo, err = reg.driver.PachClient().InspectCommit(pj.commitInfo.Commit.Repo.Name, pj.commitInfo.Commit.ID)
   796  				if err != nil {
   797  					return err
   798  				}
   799  
   800  				if statsCommit != nil {
   801  					pj.statsCommitInfo, err = reg.driver.PachClient().InspectCommit(statsCommit.Repo.Name, statsCommit.ID)
   802  					if err != nil {
   803  						return err
   804  					}
   805  				}
   806  
   807  				return nil
   808  			})
   809  			pj.logger.Logf("master done running processJobs")
   810  		}); err != nil {
   811  			return err
   812  		}
   813  
   814  		// This should block until the callback has completed
   815  		mutex.Lock()
   816  		return nil
   817  	})
   818  
   819  	go func() {
   820  		defer reg.limiter.Release()
   821  
   822  		// Make sure the job has been removed from the job chain, ignore any errors
   823  		defer reg.jobChain.Fail(pj)
   824  
   825  		if err := asyncEg.Wait(); err != nil {
   826  			pj.logger.Logf("fatal job error: %v", err)
   827  		}
   828  
   829  		if pj.ji.State == pps.JobState_JOB_EGRESSING {
   830  			// Set up the driver for the egress context, which has different cancel conditions
   831  			pj.driver = reg.driver.WithContext(egressCtx)
   832  
   833  			// If egress fails, there isn't much we can do - the output commit is
   834  			// already done, so the job is 'complete' regardless.
   835  			pj.logger.LogStep("egressing job data", func() error {
   836  				return reg.processJobEgress(pj)
   837  			})
   838  		}
   839  	}()
   840  
   841  	return nil
   842  }
   843  
   844  // superviseJob watches for the output commit closing and cancels the job, or
   845  // deletes it if the output commit is removed.
   846  func (reg *registry) superviseJob(pj *pendingJob) error {
   847  	commitInfo, err := pj.driver.PachClient().PfsAPIClient.InspectCommit(pj.driver.PachClient().Ctx(),
   848  		&pfs.InspectCommitRequest{
   849  			Commit:     pj.ji.OutputCommit,
   850  			BlockState: pfs.CommitState_FINISHED,
   851  		})
   852  	if err != nil {
   853  		if pfsserver.IsCommitNotFoundErr(err) || pfsserver.IsCommitDeletedErr(err) {
   854  			defer pj.cancel() // whether we return error or nil, job is done
   855  
   856  			// Stop the job and clean up any job state in the registry
   857  			if err := reg.killJob(pj, "output commit missing"); err != nil {
   858  				return err
   859  			}
   860  
   861  			// Output commit was deleted. Delete job as well
   862  			if _, err := pj.driver.NewSTM(func(stm col.STM) error {
   863  				// Delete the job if no other worker has deleted it yet
   864  				jobPtr := &pps.EtcdJobInfo{}
   865  				if err := pj.driver.Jobs().ReadWrite(stm).Get(pj.ji.Job.ID, jobPtr); err != nil {
   866  					return err
   867  				}
   868  				return pj.driver.DeleteJob(stm, jobPtr)
   869  			}); err != nil && !col.IsErrNotFound(err) {
   870  				return err
   871  			}
   872  			return nil
   873  		}
   874  		return err
   875  	}
   876  	// commitInfo.Trees is set by non-S3-output jobs, while commitInfo.Tree is
   877  	// set by S3-output jobs
   878  	// TODO: why don't we cancel in all cases?
   879  	if commitInfo.Trees == nil && commitInfo.Tree == nil {
   880  		defer pj.cancel() // whether job state update succeeds or not, job is done
   881  		return reg.killJob(pj, "output commit closed")
   882  	}
   883  	return nil
   884  }
   885  
   886  func (reg *registry) processJob(pj *pendingJob) error {
   887  	state := pj.ji.State
   888  	switch {
   889  	case ppsutil.IsTerminal(state):
   890  		return errutil.ErrBreak
   891  	case state == pps.JobState_JOB_STARTING:
   892  		return errors.New("job should have been moved out of the STARTING state before processJob")
   893  	case state == pps.JobState_JOB_RUNNING:
   894  		return pj.logger.LogStep("processing job datums", func() error {
   895  			return reg.processJobRunning(pj)
   896  		})
   897  	case state == pps.JobState_JOB_MERGING:
   898  		return pj.logger.LogStep("merging job hashtrees", func() error {
   899  			return reg.processJobMerging(pj)
   900  		})
   901  	case state == pps.JobState_JOB_EGRESSING:
   902  		return errutil.ErrBreak
   903  	}
   904  	return errors.Errorf("unknown job state: %v", state)
   905  }
   906  
   907  func (reg *registry) processJobStarting(pj *pendingJob) error {
   908  	// block until job inputs are ready
   909  	failed, err := failedInputs(pj.driver.PachClient(), pj.ji)
   910  	if err != nil {
   911  		return err
   912  	}
   913  
   914  	if len(failed) > 0 {
   915  		reason := fmt.Sprintf("inputs failed: %s", strings.Join(failed, ", "))
   916  		return reg.failJob(pj, reason, nil, 0)
   917  	}
   918  
   919  	if pj.driver.PipelineInfo().S3Out && pj.commitInfo.ParentCommit != nil {
   920  		// We don't want S3-out pipelines to merge datum output with the parent
   921  		// commit, so we create a PutFile record to delete "/". Doing it before
   922  		// we move the job to the RUNNING state ensures that:
   923  		// 1) workers can't process datums unless DeleteFile("/") has run
   924  		// 2) DeleteFile("/") won't run after work has started
   925  		if err := pj.driver.PachClient().DeleteFile(
   926  			pj.commitInfo.Commit.Repo.Name,
   927  			pj.commitInfo.Commit.ID,
   928  			"/",
   929  		); err != nil {
   930  			return errors.Wrap(err, "couldn't prepare output commit for S3-out job")
   931  		}
   932  	}
   933  
   934  	pj.ji.State = pps.JobState_JOB_RUNNING
   935  	return nil
   936  }
   937  
   938  // Iterator fulfills the chain.JobData interface for pendingJob
   939  func (pj *pendingJob) Iterator() (datum.Iterator, error) {
   940  	var dit datum.Iterator
   941  	err := pj.logger.LogStep("constructing datum iterator", func() (err error) {
   942  		dit, err = datum.NewIterator(pj.driver.PachClient(), pj.ji.Input)
   943  		return
   944  	})
   945  	return dit, err
   946  }
   947  
   948  func (reg *registry) processJobRunning(pj *pendingJob) error {
   949  	pj.logger.Logf("processJobRunning creating task channel")
   950  	subtasks := make(chan *work.Task, 10)
   951  
   952  	eg, ctx := errgroup.WithContext(reg.driver.PachClient().Ctx())
   953  
   954  	// Spawn a goroutine to emit tasks on the datum task channel
   955  	eg.Go(func() error {
   956  		defer close(subtasks)
   957  		return pj.logger.LogStep("collecting datums for tasks", func() error {
   958  			for {
   959  				numDatums, err := pj.jdit.NextBatch(ctx)
   960  				if err != nil {
   961  					return err
   962  				}
   963  				if numDatums == 0 {
   964  					return nil
   965  				}
   966  
   967  				if err := reg.sendDatumTasks(ctx, pj, numDatums, subtasks); err != nil {
   968  					return err
   969  				}
   970  			}
   971  		})
   972  	})
   973  
   974  	mutex := &sync.Mutex{}
   975  	stats := &DatumStats{ProcessStats: &pps.ProcessStats{}}
   976  	chunkHashtrees := []*HashtreeInfo{}
   977  	statsHashtrees := []*HashtreeInfo{}
   978  	recoveredObjects := []string{}
   979  
   980  	// Run subtasks until we are done
   981  	eg.Go(func() error {
   982  		return pj.logger.LogStep("running datum tasks", func() error {
   983  			return pj.taskMaster.RunSubtasksChan(
   984  				subtasks,
   985  				func(ctx context.Context, taskInfo *work.TaskInfo) error {
   986  					if taskInfo.State == work.State_FAILURE {
   987  						return errors.Errorf("datum task failed: %s", taskInfo.Reason)
   988  					}
   989  
   990  					data, err := deserializeDatumData(taskInfo.Task.Data)
   991  					if err != nil {
   992  						return err
   993  					}
   994  
   995  					mutex.Lock()
   996  					defer mutex.Unlock()
   997  
   998  					mergeStats(stats, data.Stats)
   999  
  1000  					if data.ChunkHashtree != nil {
  1001  						chunkHashtrees = append(chunkHashtrees, data.ChunkHashtree)
  1002  					}
  1003  					if data.StatsHashtree != nil {
  1004  						statsHashtrees = append(statsHashtrees, data.StatsHashtree)
  1005  					}
  1006  					if data.RecoveredDatumsObject != "" {
  1007  						recoveredObjects = append(recoveredObjects, data.RecoveredDatumsObject)
  1008  					}
  1009  					// propagate the stats to etcd
  1010  					pj.saveJobStats(stats)
  1011  					return pj.writeJobInfo()
  1012  				},
  1013  			)
  1014  		})
  1015  	})
  1016  
  1017  	err := eg.Wait()
  1018  	if err != nil {
  1019  		// If these was no failed datum, we can reattempt later
  1020  		return errors.Wrap(err, "process datum error")
  1021  	}
  1022  
  1023  	if stats.FailedDatumID != "" {
  1024  		// A datum failed, but we still may need to merge stats - discard chunk hashtrees
  1025  		chunkHashtrees = []*HashtreeInfo{}
  1026  	}
  1027  
  1028  	// S3Out pipelines don't use hashtrees, so skip over the MERGING state - this
  1029  	// will go to EGRESSING, if applicable.
  1030  	if pj.driver.PipelineInfo().S3Out {
  1031  		if stats.FailedDatumID != "" {
  1032  			return reg.failJob(pj, "datum failed", nil, 0)
  1033  		}
  1034  		pj.logger.Logf("processJobRunning succeeding s3out job, total stats: %v", stats)
  1035  		return reg.succeedJob(pj, nil, 0, nil, 0)
  1036  	}
  1037  
  1038  	// Write the hashtrees list and recovered datums list to object storage
  1039  	if err := pj.storeHashtreeInfos(chunkHashtrees, statsHashtrees); err != nil {
  1040  		return err
  1041  	}
  1042  	if err := pj.storeRecoveredDatums(recoveredObjects); err != nil {
  1043  		return err
  1044  	}
  1045  
  1046  	pj.logger.Logf("processJobRunning updating task to merging, total stats: %v", stats)
  1047  	pj.ji.State = pps.JobState_JOB_MERGING
  1048  	pj.finalizeJobStats()
  1049  	return pj.writeJobInfo()
  1050  }
  1051  
  1052  func (pj *pendingJob) saveJobStats(stats *DatumStats) {
  1053  	// Any unaccounted-for datums were skipped in the job datum iterator
  1054  	pj.ji.DataSkipped = stats.DatumsSkipped
  1055  	pj.ji.DataProcessed = stats.DatumsProcessed
  1056  	pj.ji.DataFailed = stats.DatumsFailed
  1057  	pj.ji.DataRecovered = stats.DatumsRecovered
  1058  	pj.ji.DataTotal = int64(pj.jdit.MaxLen())
  1059  	pj.ji.Stats = stats.ProcessStats
  1060  }
  1061  
  1062  func (pj *pendingJob) finalizeJobStats() {
  1063  	pj.ji.DataSkipped = int64(pj.jdit.MaxLen()) - pj.ji.DataProcessed - pj.ji.DataFailed - pj.ji.DataRecovered
  1064  }
  1065  
  1066  func (pj *pendingJob) storeHashtreeInfos(chunks []*HashtreeInfo, stats []*HashtreeInfo) (retErr error) {
  1067  	pj.chunkHashtrees = chunks
  1068  	pj.statsHashtrees = stats
  1069  
  1070  	objects := &HashtreeObjects{ChunkObjects: []string{}, StatsObjects: []string{}}
  1071  	for _, info := range chunks {
  1072  		objects.ChunkObjects = append(objects.ChunkObjects, info.Object)
  1073  	}
  1074  	for _, info := range stats {
  1075  		objects.StatsObjects = append(objects.StatsObjects, info.Object)
  1076  	}
  1077  
  1078  	writer, err := pj.driver.PachClient().DirectObjWriter(jobArtifactHashtrees(pj.ji.Job.ID))
  1079  	if err != nil {
  1080  		return errors.EnsureStack(err)
  1081  	}
  1082  	defer func() {
  1083  		if err := writer.Close(); err != nil && retErr == nil {
  1084  			retErr = errors.EnsureStack(err)
  1085  		}
  1086  	}()
  1087  
  1088  	pbw := pbutil.NewWriter(writer)
  1089  	_, err = pbw.Write(objects)
  1090  	return err
  1091  }
  1092  
  1093  func (pj *pendingJob) storeRecoveredDatums(recoveredObjects []string) (retErr error) {
  1094  	if len(recoveredObjects) == 0 {
  1095  		return nil
  1096  	}
  1097  
  1098  	objects := &RecoveredDatumObjects{Objects: recoveredObjects}
  1099  	writer, err := pj.driver.PachClient().DirectObjWriter(jobArtifactRecoveredObject(pj.ji.Job.ID))
  1100  	if err != nil {
  1101  		return errors.EnsureStack(err)
  1102  	}
  1103  	defer func() {
  1104  		if err := writer.Close(); err != nil && retErr == nil {
  1105  			retErr = errors.EnsureStack(err)
  1106  		}
  1107  	}()
  1108  
  1109  	pbw := pbutil.NewWriter(writer)
  1110  	_, err = pbw.Write(objects)
  1111  	return err
  1112  }
  1113  
  1114  func (pj *pendingJob) initializeHashtrees() (retErr error) {
  1115  	if pj.chunkHashtrees == nil {
  1116  		// We are picking up an old job and don't have the hashtrees generated by
  1117  		// the 'running' state, load them from object storage
  1118  		reader, err := pj.driver.PachClient().DirectObjReader(jobArtifactHashtrees(pj.ji.Job.ID))
  1119  		if err != nil {
  1120  			return errors.EnsureStack(err)
  1121  		}
  1122  		defer func() {
  1123  			if err := reader.Close(); err != nil && retErr == nil {
  1124  				retErr = errors.EnsureStack(err)
  1125  			}
  1126  		}()
  1127  
  1128  		hashtreeObjects := &HashtreeObjects{}
  1129  		protoReader := pbutil.NewReader(reader)
  1130  		if err := protoReader.Read(hashtreeObjects); err != nil {
  1131  			return err
  1132  		}
  1133  
  1134  		pj.chunkHashtrees = []*HashtreeInfo{}
  1135  		for _, object := range hashtreeObjects.ChunkObjects {
  1136  			pj.chunkHashtrees = append(pj.chunkHashtrees, &HashtreeInfo{Object: object})
  1137  		}
  1138  
  1139  		pj.statsHashtrees = []*HashtreeInfo{}
  1140  		for _, object := range hashtreeObjects.StatsObjects {
  1141  			pj.statsHashtrees = append(pj.statsHashtrees, &HashtreeInfo{Object: object})
  1142  		}
  1143  
  1144  		return nil
  1145  	}
  1146  	return nil
  1147  }
  1148  
  1149  func (pj *pendingJob) loadRecoveredDatums() (chain.DatumSet, error) {
  1150  	datumSet := make(chain.DatumSet)
  1151  
  1152  	// We are picking up an old job and don't have the recovered datums generated by
  1153  	// the 'running' state, load them from object storage
  1154  	reader, err := pj.driver.PachClient().DirectObjReader(jobArtifactRecoveredObject(pj.ji.Job.ID))
  1155  	if err != nil {
  1156  		return nil, errors.EnsureStack(err)
  1157  	}
  1158  
  1159  	recoveredObjects := &RecoveredDatumObjects{}
  1160  	protoReader := pbutil.NewReader(reader)
  1161  	if err := protoReader.Read(recoveredObjects); err != nil {
  1162  		reader.Close()
  1163  		if errutil.IsNotFoundError(err) {
  1164  			return datumSet, nil
  1165  		}
  1166  		return nil, err
  1167  	}
  1168  
  1169  	if err := reader.Close(); err != nil {
  1170  		return nil, err
  1171  	}
  1172  
  1173  	recoveredDatums := &RecoveredDatums{}
  1174  	for _, object := range recoveredObjects.Objects {
  1175  		reader, err := pj.driver.PachClient().DirectObjReader(object)
  1176  		if err != nil {
  1177  			return nil, errors.EnsureStack(err)
  1178  		}
  1179  
  1180  		protoReader := pbutil.NewReader(reader)
  1181  		if err := protoReader.Read(recoveredDatums); err != nil {
  1182  			reader.Close()
  1183  			return nil, err
  1184  		}
  1185  
  1186  		if err := reader.Close(); err != nil {
  1187  			return nil, errors.EnsureStack(err)
  1188  		}
  1189  
  1190  		for _, hash := range recoveredDatums.Hashes {
  1191  			datumSet[hash]++
  1192  		}
  1193  	}
  1194  
  1195  	return datumSet, nil
  1196  }
  1197  
  1198  func (reg *registry) makeMergeSubtasks(pj *pendingJob, commitInfo *pfs.CommitInfo, stats bool) ([]*work.Task, error) {
  1199  	hashtrees := pj.chunkHashtrees
  1200  	if stats {
  1201  		hashtrees = pj.statsHashtrees
  1202  	}
  1203  
  1204  	// For jobs that can base their hashtree off of the parent hashtree, fetch the
  1205  	// object information for the parent hashtrees
  1206  	// TODO: this is risky - there's no check that the parent the jobChain is
  1207  	// thinking of is the same one we find.  Add some extra guarantees here if we can.
  1208  	var parentHashtrees []*pfs.Object
  1209  	if pj.jdit.AdditiveOnly() {
  1210  		parentCommitInfo, err := reg.getParentCommitInfo(commitInfo)
  1211  		if err != nil {
  1212  			return nil, err
  1213  		}
  1214  
  1215  		if parentCommitInfo != nil {
  1216  			parentHashtrees = parentCommitInfo.Trees
  1217  			if len(parentHashtrees) != int(reg.driver.NumShards()) {
  1218  				return nil, errors.Errorf(
  1219  					"unexpected number of hashtrees between the parent commit (%d) and the pipeline spec (%d)",
  1220  					len(parentHashtrees), reg.driver.NumShards(),
  1221  				)
  1222  			}
  1223  		}
  1224  	}
  1225  
  1226  	if stats {
  1227  		if parentHashtrees != nil {
  1228  			pj.logger.Logf("merging %d stats hashtrees with parent hashtree across %d shards", len(hashtrees), reg.driver.NumShards())
  1229  		} else {
  1230  			pj.logger.Logf("merging %d stats hashtrees across %d shards", len(hashtrees), reg.driver.NumShards())
  1231  		}
  1232  	} else {
  1233  		if parentHashtrees != nil {
  1234  			pj.logger.Logf("merging %d hashtrees with parent hashtree across %d shards", len(hashtrees), reg.driver.NumShards())
  1235  		} else {
  1236  			pj.logger.Logf("merging %d hashtrees across %d shards", len(hashtrees), reg.driver.NumShards())
  1237  		}
  1238  	}
  1239  
  1240  	mergeSubtasks := []*work.Task{}
  1241  	for i := int64(0); i < reg.driver.NumShards(); i++ {
  1242  		mergeData := &MergeData{Hashtrees: hashtrees, Shard: i, JobID: pj.ji.Job.ID, Stats: stats}
  1243  
  1244  		if parentHashtrees != nil {
  1245  			mergeData.Parent = parentHashtrees[i]
  1246  		}
  1247  
  1248  		data, err := serializeMergeData(mergeData)
  1249  		if err != nil {
  1250  			return nil, errors.Wrap(err, "failed to serialize merge data")
  1251  		}
  1252  
  1253  		mergeSubtasks = append(mergeSubtasks, &work.Task{
  1254  			ID:   uuid.NewWithoutDashes(),
  1255  			Data: data,
  1256  		})
  1257  	}
  1258  
  1259  	return mergeSubtasks, nil
  1260  }
  1261  
  1262  func (reg *registry) processJobMerging(pj *pendingJob) error {
  1263  	if err := pj.initializeHashtrees(); err != nil {
  1264  		return err
  1265  	}
  1266  
  1267  	mutex := &sync.Mutex{}
  1268  	mergeSubtasks := []*work.Task{}
  1269  
  1270  	if pj.ji.DataFailed == 0 {
  1271  		chunkMergeSubtasks, err := reg.makeMergeSubtasks(pj, pj.commitInfo, false)
  1272  		if err != nil {
  1273  			return err
  1274  		}
  1275  		mergeSubtasks = append(mergeSubtasks, chunkMergeSubtasks...)
  1276  	}
  1277  
  1278  	if pj.statsCommitInfo != nil {
  1279  		statsMergeSubtasks, err := reg.makeMergeSubtasks(pj, pj.statsCommitInfo, true)
  1280  		if err != nil {
  1281  			return err
  1282  		}
  1283  		mergeSubtasks = append(mergeSubtasks, statsMergeSubtasks...)
  1284  	}
  1285  
  1286  	trees := make([]*pfs.Object, reg.driver.NumShards())
  1287  	size := uint64(0)
  1288  	statsTrees := make([]*pfs.Object, reg.driver.NumShards())
  1289  	statsSize := uint64(0)
  1290  
  1291  	pj.logger.Logf("sending out %d merge tasks", len(mergeSubtasks))
  1292  
  1293  	// Run merge subtasks and wait for them to complete
  1294  	if err := pj.taskMaster.RunSubtasks(
  1295  		mergeSubtasks,
  1296  		func(ctx context.Context, taskInfo *work.TaskInfo) error {
  1297  			if taskInfo.State == work.State_FAILURE {
  1298  				return errors.Errorf("merge task failed: %s", taskInfo.Reason)
  1299  			}
  1300  
  1301  			data, err := deserializeMergeData(taskInfo.Task.Data)
  1302  			if err != nil {
  1303  				return err
  1304  			}
  1305  
  1306  			if data.Tree == nil {
  1307  				return errors.Errorf("merge task for shard %d failed, no tree returned", data.Shard)
  1308  			}
  1309  
  1310  			mutex.Lock()
  1311  			defer mutex.Unlock()
  1312  
  1313  			if data.Stats {
  1314  				statsTrees[data.Shard] = data.Tree
  1315  				statsSize += data.TreeSize
  1316  			} else {
  1317  				trees[data.Shard] = data.Tree
  1318  				size += data.TreeSize
  1319  			}
  1320  			return nil
  1321  		},
  1322  	); err != nil {
  1323  		// TODO: persist error to job?
  1324  		return errors.Wrap(err, "merge error")
  1325  	}
  1326  
  1327  	pj.logger.Logf("merge results: %v trees (%d bytes), %v stats trees (%d bytes)", trees, size, statsTrees, statsSize)
  1328  
  1329  	if pj.ji.DataFailed == 0 {
  1330  		if err := reg.succeedJob(pj, trees, size, statsTrees, statsSize); err != nil {
  1331  			return err
  1332  		}
  1333  	} else if err := reg.failJob(pj, "datum failed", statsTrees, statsSize); err != nil {
  1334  		return err
  1335  	}
  1336  	return nil
  1337  }
  1338  
  1339  func (reg *registry) storeJobDatums(pj *pendingJob) (*pfs.Object, error) {
  1340  	// Update recovered datums through the job chain, which will update its internal DatumSet
  1341  	recoveredDatums, err := pj.loadRecoveredDatums()
  1342  	if err != nil {
  1343  		return nil, err
  1344  	}
  1345  
  1346  	if err := reg.jobChain.RecoveredDatums(pj, recoveredDatums); err != nil {
  1347  		return nil, err
  1348  	}
  1349  
  1350  	// Write out the datums processed/skipped and merged for this job
  1351  	buf := &bytes.Buffer{}
  1352  	pbw := pbutil.NewWriter(buf)
  1353  	for hash, count := range pj.jdit.DatumSet() {
  1354  		for i := int64(0); i < count; i++ {
  1355  			if _, err := pbw.WriteBytes([]byte(hash)); err != nil {
  1356  				return nil, err
  1357  			}
  1358  		}
  1359  	}
  1360  	datums, _, err := pj.driver.PachClient().PutObject(buf)
  1361  	return datums, err
  1362  }
  1363  
  1364  func (reg *registry) processJobEgress(pj *pendingJob) error {
  1365  	if err := reg.egress(pj); err != nil {
  1366  		return reg.failJob(pj, fmt.Sprintf("egress error: %v", err), nil, 0)
  1367  	}
  1368  
  1369  	pj.ji.State = pps.JobState_JOB_SUCCESS
  1370  	return pj.writeJobInfo()
  1371  }
  1372  
  1373  func failedInputs(pachClient *client.APIClient, jobInfo *pps.JobInfo) ([]string, error) {
  1374  	var failed []string
  1375  	var vistErr error
  1376  	blockCommit := func(name string, commit *pfs.Commit) {
  1377  		ci, err := pachClient.PfsAPIClient.InspectCommit(pachClient.Ctx(),
  1378  			&pfs.InspectCommitRequest{
  1379  				Commit:     commit,
  1380  				BlockState: pfs.CommitState_FINISHED,
  1381  			})
  1382  		if err != nil {
  1383  			if vistErr == nil {
  1384  				vistErr = errors.Wrapf(err, "error blocking on commit %s/%s",
  1385  					commit.Repo.Name, commit.ID)
  1386  			}
  1387  			return
  1388  		}
  1389  		if ci.Tree == nil && ci.Trees == nil {
  1390  			failed = append(failed, name)
  1391  		}
  1392  	}
  1393  	pps.VisitInput(jobInfo.Input, func(input *pps.Input) {
  1394  		if input.Pfs != nil && input.Pfs.Commit != "" {
  1395  			blockCommit(input.Pfs.Name, client.NewCommit(input.Pfs.Repo, input.Pfs.Commit))
  1396  		}
  1397  		if input.Cron != nil && input.Cron.Commit != "" {
  1398  			blockCommit(input.Cron.Name, client.NewCommit(input.Cron.Repo, input.Cron.Commit))
  1399  		}
  1400  		if input.Git != nil && input.Git.Commit != "" {
  1401  			blockCommit(input.Git.Name, client.NewCommit(input.Git.Name, input.Git.Commit))
  1402  		}
  1403  	})
  1404  	return failed, vistErr
  1405  }
  1406  
  1407  func (reg *registry) egress(pj *pendingJob) error {
  1408  	var egressFailureCount int
  1409  	return backoff.RetryNotify(func() (retErr error) {
  1410  		if pj.ji.Egress != nil {
  1411  			return pj.logger.LogStep("egress upload", func() error {
  1412  				return pj.driver.Egress(pj.ji.OutputCommit, pj.ji.Egress.URL)
  1413  			})
  1414  		}
  1415  		return nil
  1416  	}, backoff.NewInfiniteBackOff(), func(err error, d time.Duration) error {
  1417  		egressFailureCount++
  1418  		if egressFailureCount > 3 {
  1419  			return err
  1420  		}
  1421  		pj.logger.Logf("egress failed: %v; retrying in %v", err, d)
  1422  		return nil
  1423  	})
  1424  }