github.com/pachyderm/pachyderm@v1.13.4/src/server/worker/pipeline/transform/worker.go (about)

     1  package transform
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"os"
    10  	"path"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/gogo/protobuf/jsonpb"
    17  	"github.com/gogo/protobuf/types"
    18  	"golang.org/x/sync/errgroup"
    19  
    20  	"github.com/pachyderm/pachyderm/src/client"
    21  	"github.com/pachyderm/pachyderm/src/client/limit"
    22  	"github.com/pachyderm/pachyderm/src/client/pfs"
    23  	"github.com/pachyderm/pachyderm/src/client/pkg/errors"
    24  	"github.com/pachyderm/pachyderm/src/client/pkg/grpcutil"
    25  	"github.com/pachyderm/pachyderm/src/client/pkg/pbutil"
    26  	"github.com/pachyderm/pachyderm/src/client/pps"
    27  	pfsserver "github.com/pachyderm/pachyderm/src/server/pfs/server"
    28  	"github.com/pachyderm/pachyderm/src/server/pkg/backoff"
    29  	"github.com/pachyderm/pachyderm/src/server/pkg/hashtree"
    30  	"github.com/pachyderm/pachyderm/src/server/pkg/ppsutil"
    31  	"github.com/pachyderm/pachyderm/src/server/pkg/uuid"
    32  	"github.com/pachyderm/pachyderm/src/server/pkg/work"
    33  	"github.com/pachyderm/pachyderm/src/server/worker/common"
    34  	"github.com/pachyderm/pachyderm/src/server/worker/driver"
    35  	"github.com/pachyderm/pachyderm/src/server/worker/logs"
    36  	"github.com/pachyderm/pachyderm/src/server/worker/server"
    37  )
    38  
    39  var (
    40  	errDatumRecovered = errors.New("the datum errored, and the error was handled successfully")
    41  	statsTagSuffix    = "_stats"
    42  )
    43  
    44  // TODO: would be nice to have these have a deterministic ID rather than based
    45  // off the subtask ID so we can shortcut processing if we get interrupted and
    46  // restarted
    47  func jobArtifactRecoveredDatums(jobID string, subtaskID string) string {
    48  	return path.Join(jobArtifactPrefix(jobID), fmt.Sprintf("recovered-%s", subtaskID))
    49  }
    50  
    51  func jobArtifactChunkStats(jobID string, subtaskID string) string {
    52  	return path.Join(jobArtifactPrefix(jobID), fmt.Sprintf("chunk-stats-%s", subtaskID))
    53  }
    54  
    55  func jobArtifactChunk(jobID string, subtaskID string) string {
    56  	return path.Join(jobArtifactPrefix(jobID), fmt.Sprintf("chunk-%s", subtaskID))
    57  }
    58  
    59  func hashtreeChunkID(subtaskID string) string {
    60  	return fmt.Sprintf("chunk-%s", subtaskID)
    61  }
    62  
    63  func plusDuration(x *types.Duration, y *types.Duration) (*types.Duration, error) {
    64  	var xd time.Duration
    65  	var yd time.Duration
    66  	var err error
    67  	if x != nil {
    68  		xd, err = types.DurationFromProto(x)
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  	}
    73  	if y != nil {
    74  		yd, err = types.DurationFromProto(y)
    75  		if err != nil {
    76  			return nil, err
    77  		}
    78  	}
    79  	return types.DurationProto(xd + yd), nil
    80  }
    81  
    82  // mergeStats merges y into x
    83  func mergeStats(x, y *DatumStats) error {
    84  	if yps := y.ProcessStats; yps != nil {
    85  		var err error
    86  		xps := x.ProcessStats
    87  		if xps.DownloadTime, err = plusDuration(xps.DownloadTime, yps.DownloadTime); err != nil {
    88  			return err
    89  		}
    90  		if xps.ProcessTime, err = plusDuration(xps.ProcessTime, yps.ProcessTime); err != nil {
    91  			return err
    92  		}
    93  		if xps.UploadTime, err = plusDuration(xps.UploadTime, yps.UploadTime); err != nil {
    94  			return err
    95  		}
    96  		xps.DownloadBytes += yps.DownloadBytes
    97  		xps.UploadBytes += yps.UploadBytes
    98  	}
    99  
   100  	x.DatumsProcessed += y.DatumsProcessed
   101  	x.DatumsSkipped += y.DatumsSkipped
   102  	x.DatumsFailed += y.DatumsFailed
   103  	x.DatumsRecovered += y.DatumsRecovered
   104  	if x.FailedDatumID == "" {
   105  		x.FailedDatumID = y.FailedDatumID
   106  	}
   107  	return nil
   108  }
   109  
   110  // Worker handles a transform pipeline work subtask, then returns.
   111  func Worker(driver driver.Driver, logger logs.TaggedLogger, subtask *work.Task, status *Status) (retErr error) {
   112  	defer func() {
   113  		err := retErr
   114  		for err != nil {
   115  			logger.Logf("error: %v", err)
   116  			if st, ok := err.(errors.StackTracer); ok {
   117  				logger.Logf("error stack: %+v", st.StackTrace())
   118  			}
   119  			err = errors.Unwrap(err)
   120  		}
   121  	}()
   122  
   123  	// Handle 'process datum' tasks
   124  	datumData, err := deserializeDatumData(subtask.Data)
   125  	if err == nil {
   126  		return status.withJob(datumData.JobID, func() error {
   127  			logger = logger.WithJob(datumData.JobID)
   128  			if err := logger.LogStep("datum task", func() error {
   129  				return handleDatumTask(driver, logger, datumData, subtask.ID, status)
   130  			}); err != nil {
   131  				return err
   132  			}
   133  
   134  			subtask.Data, err = serializeDatumData(datumData)
   135  			return err
   136  		})
   137  	}
   138  
   139  	// Handle 'merge hashtrees' tasks
   140  	mergeData, err := deserializeMergeData(subtask.Data)
   141  	if err == nil {
   142  		return status.withJob(mergeData.JobID, func() error {
   143  			logger = logger.WithJob(mergeData.JobID)
   144  			if err := logger.LogStep("merge task", func() error {
   145  				return handleMergeTask(driver, logger, mergeData)
   146  			}); err != nil {
   147  				return err
   148  			}
   149  
   150  			subtask.Data, err = serializeMergeData(mergeData)
   151  			return err
   152  		})
   153  	}
   154  
   155  	return errors.New("worker task format unrecognized")
   156  }
   157  
   158  func forEachDatum(driver driver.Driver, object string, cb func(int64, []*common.Input) error) (retErr error) {
   159  	reader, err := driver.PachClient().DirectObjReader(object)
   160  	if err != nil {
   161  		return errors.EnsureStack(err)
   162  	}
   163  	defer func() {
   164  		if err := reader.Close(); err != nil && retErr == nil {
   165  			retErr = errors.EnsureStack(err)
   166  		}
   167  	}()
   168  
   169  	allDatums := &DatumInputsList{}
   170  	protoReader := pbutil.NewReader(reader)
   171  	if err := protoReader.Read(allDatums); err != nil {
   172  		return err
   173  	}
   174  
   175  	for _, datum := range allDatums.Datums {
   176  		if err := cb(datum.Index, datum.Inputs); err != nil {
   177  			return err
   178  		}
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  func uploadRecoveredDatums(driver driver.Driver, logger logs.TaggedLogger, recoveredDatums []string, object string) (retErr error) {
   185  	return logger.LogStep("uploading recovered datums", func() error {
   186  		message := &RecoveredDatums{Hashes: recoveredDatums}
   187  
   188  		writer, err := driver.PachClient().DirectObjWriter(object)
   189  		if err != nil {
   190  			return errors.EnsureStack(err)
   191  		}
   192  		defer func() {
   193  			if err := writer.Close(); err != nil && retErr == nil {
   194  				retErr = errors.EnsureStack(err)
   195  			}
   196  		}()
   197  
   198  		protoWriter := pbutil.NewWriter(writer)
   199  		_, err = protoWriter.Write(message)
   200  		return err
   201  	})
   202  }
   203  
   204  func uploadChunk(
   205  	driver driver.Driver,
   206  	logger logs.TaggedLogger,
   207  	subtaskCache *hashtree.MergeCache,
   208  	chunkCache *hashtree.MergeCache,
   209  	object string,
   210  	subtaskID string,
   211  ) (retErr error) {
   212  	return logger.LogStep("uploading hashtree chunk", func() error {
   213  		// Merge the datums for this job into a chunk
   214  		buf := &bytes.Buffer{}
   215  		if err := subtaskCache.Merge(hashtree.NewWriter(buf), nil, nil); err != nil {
   216  			return err
   217  		}
   218  
   219  		chunkID := hashtreeChunkID(subtaskID)
   220  		logger.Logf("merged hashtree cache into buffer, len: %d, chunkID: %s, object: %s", buf.Len(), chunkID, object)
   221  		if err := chunkCache.Put(chunkID, bytes.NewBuffer(buf.Bytes())); err != nil {
   222  			return err
   223  		}
   224  
   225  		// Upload the hashtree for this subtask to the given object
   226  		writer, err := driver.PachClient().DirectObjWriter(object)
   227  		if err != nil {
   228  			return errors.EnsureStack(err)
   229  		}
   230  		defer func() {
   231  			if err := writer.Close(); err != nil && retErr == nil {
   232  				retErr = errors.EnsureStack(err)
   233  			}
   234  		}()
   235  
   236  		_, err = writer.Write(buf.Bytes())
   237  		return err
   238  	})
   239  }
   240  
   241  func checkS3Gateway(driver driver.Driver, logger logs.TaggedLogger) error {
   242  	return backoff.RetryNotify(func() error {
   243  		endpoint := fmt.Sprintf("http://%s:%s/",
   244  			ppsutil.SidecarS3GatewayService(logger.JobID()),
   245  			os.Getenv("S3GATEWAY_PORT"),
   246  		)
   247  
   248  		_, err := (&http.Client{Timeout: 5 * time.Second}).Get(endpoint)
   249  		logger.Logf("checking s3 gateway service for job %q: %v", logger.JobID(), err)
   250  		return err
   251  	}, backoff.New60sBackOff(), func(err error, d time.Duration) error {
   252  		logger.Logf("worker could not connect to s3 gateway for %q: %v", logger.JobID(), err)
   253  		return nil
   254  	})
   255  	// TODO: `master` implementation fails the job here, we may need to do the same
   256  	// We would need to load the jobInfo first for this:
   257  	// }); err != nil {
   258  	//   reason := fmt.Sprintf("could not connect to s3 gateway for %q: %v", logger.JobID(), err)
   259  	//   logger.Logf("failing job with reason: %s", reason)
   260  	//   // NOTE: this is the only place a worker will reach over and change the job state, this should not generally be done.
   261  	//   return finishJob(driver.PipelineInfo(), driver.PachClient(), jobInfo, pps.JobState_JOB_FAILURE, reason, nil, nil, 0, nil, 0)
   262  	// }
   263  	// return nil
   264  }
   265  
   266  func handleDatumTask(driver driver.Driver, logger logs.TaggedLogger, data *DatumData, subtaskID string, status *Status) error {
   267  	if ppsutil.ContainsS3Inputs(driver.PipelineInfo().Input) || driver.PipelineInfo().S3Out {
   268  		if err := checkS3Gateway(driver, logger); err != nil {
   269  			return err
   270  		}
   271  	}
   272  
   273  	// TODO: check for existing tagged output files - continue with processing if any are missing
   274  	return driver.WithDatumCache(func(datumCache *hashtree.MergeCache, statsCache *hashtree.MergeCache) error {
   275  		logger.Logf("transform worker datum task: %v", data)
   276  		limiter := limit.New(int(driver.PipelineInfo().MaxQueueSize))
   277  
   278  		// statsMutex controls access to stats so that they can be safely merged
   279  		statsMutex := &sync.Mutex{}
   280  		recoveredDatums := []string{}
   281  		data.Stats = &DatumStats{
   282  			ProcessStats: &pps.ProcessStats{},
   283  		}
   284  
   285  		var queueSize, dataProcessed, dataRecovered int64
   286  		// TODO: the status.GetStatus call may read the process stats without having a lock, it this ~ok?
   287  		if err := logger.LogStep("processing datums", func() error {
   288  			return status.withStats(data.Stats.ProcessStats, &queueSize, &dataProcessed, &dataRecovered, func() error {
   289  				ctx, cancel := context.WithCancel(driver.PachClient().Ctx())
   290  				defer cancel()
   291  
   292  				eg, ctx := errgroup.WithContext(ctx)
   293  				driver := driver.WithContext(ctx)
   294  				if err := forEachDatum(driver, data.DatumsObject, func(index int64, inputs []*common.Input) error {
   295  					limiter.Acquire()
   296  					atomic.AddInt64(&queueSize, 1)
   297  					eg.Go(func() error {
   298  						defer limiter.Release()
   299  						defer atomic.AddInt64(&queueSize, -1)
   300  
   301  						// Construct a new logger here which will capture datum-specific
   302  						// logs for object storage if stats are enabled.
   303  						jobID := logger.JobID()
   304  						logger, err := logs.NewLogger(driver.PipelineInfo(), driver.PachClient())
   305  						if err != nil {
   306  							return err
   307  						}
   308  						logger = logger.WithJob(jobID).WithData(inputs)
   309  
   310  						// subStats is still valid even on an error, merge those in before proceeding
   311  						subStats, subRecovered, err := processDatum(driver, logger, index, inputs, data.OutputCommit, datumCache, statsCache, status)
   312  
   313  						statsMutex.Lock()
   314  						defer statsMutex.Unlock()
   315  						statsErr := mergeStats(data.Stats, subStats)
   316  						if err != nil {
   317  							return err
   318  						}
   319  						recoveredDatums = append(recoveredDatums, subRecovered...)
   320  						if len(subRecovered) == 0 {
   321  							atomic.AddInt64(&dataProcessed, 1)
   322  						}
   323  						atomic.AddInt64(&dataRecovered, int64(len(recoveredDatums)))
   324  						return statsErr
   325  					})
   326  					return nil
   327  				}); err != nil {
   328  					cancel()
   329  					eg.Wait()
   330  					return err
   331  				}
   332  
   333  				return eg.Wait()
   334  			})
   335  		}); err != nil {
   336  			return err
   337  		}
   338  
   339  		if data.Stats.DatumsFailed == 0 && !driver.PipelineInfo().S3Out {
   340  			if len(recoveredDatums) > 0 {
   341  				recoveredDatumsObject := jobArtifactRecoveredDatums(logger.JobID(), subtaskID)
   342  				if err := uploadRecoveredDatums(driver, logger, recoveredDatums, recoveredDatumsObject); err != nil {
   343  					return err
   344  				}
   345  				data.RecoveredDatumsObject = recoveredDatumsObject
   346  			}
   347  
   348  			chunkCache, err := driver.ChunkCaches().GetOrCreateCache(logger.JobID())
   349  			if err != nil {
   350  				return err
   351  			}
   352  
   353  			chunkObject := jobArtifactChunk(logger.JobID(), subtaskID)
   354  			if err := uploadChunk(driver, logger, datumCache, chunkCache, chunkObject, subtaskID); err != nil {
   355  				return err
   356  			}
   357  
   358  			data.ChunkHashtree = &HashtreeInfo{Address: os.Getenv(client.PPSWorkerIPEnv), Object: chunkObject, SubtaskID: subtaskID}
   359  		}
   360  
   361  		if driver.PipelineInfo().EnableStats {
   362  			chunkStatsCache, err := driver.ChunkStatsCaches().GetOrCreateCache(logger.JobID())
   363  			if err != nil {
   364  				return err
   365  			}
   366  
   367  			chunkStatsObject := jobArtifactChunkStats(logger.JobID(), subtaskID)
   368  			if err := uploadChunk(driver, logger, statsCache, chunkStatsCache, chunkStatsObject, subtaskID); err != nil {
   369  				return err
   370  			}
   371  			data.StatsHashtree = &HashtreeInfo{Address: os.Getenv(client.PPSWorkerIPEnv), Object: chunkStatsObject, SubtaskID: subtaskID}
   372  		}
   373  
   374  		return nil
   375  	})
   376  }
   377  
   378  func processDatum(
   379  	driver driver.Driver,
   380  	logger logs.TaggedLogger,
   381  	datumIndex int64,
   382  	inputs []*common.Input,
   383  	outputCommit *pfs.Commit,
   384  	datumCache *hashtree.MergeCache,
   385  	datumStatsCache *hashtree.MergeCache,
   386  	status *Status,
   387  ) (_ *DatumStats, _ []string, retErr error) {
   388  	recoveredDatums := []string{}
   389  	stats := &DatumStats{}
   390  	tag := common.HashDatum(driver.PipelineInfo().Pipeline.Name, driver.PipelineInfo().Salt, inputs)
   391  	datumID := common.DatumID(inputs)
   392  
   393  	if driver.PipelineInfo().ReprocessSpec != client.ReprocessSpecEveryJob {
   394  		if _, err := driver.PachClient().InspectTag(driver.PachClient().Ctx(), client.NewTag(tag)); err == nil {
   395  			buf := &bytes.Buffer{}
   396  			if err := driver.PachClient().GetTag(tag, buf); err != nil {
   397  				return stats, recoveredDatums, err
   398  			}
   399  			if err := datumCache.Put(uuid.NewWithoutDashes(), buf); err != nil {
   400  				return stats, recoveredDatums, err
   401  			}
   402  			if driver.PipelineInfo().EnableStats {
   403  				buf.Reset()
   404  				if err := driver.PachClient().GetTag(tag+statsTagSuffix, buf); err != nil {
   405  					// We are okay with not finding the stats hashtree. This allows users to
   406  					// enable stats on a pipeline with pre-existing jobs.
   407  					return stats, recoveredDatums, nil
   408  				}
   409  				if err := datumStatsCache.Put(uuid.NewWithoutDashes(), buf); err != nil {
   410  					return stats, recoveredDatums, err
   411  				}
   412  			}
   413  			stats.DatumsSkipped++
   414  			return stats, recoveredDatums, nil
   415  		}
   416  	}
   417  
   418  	statsRoot := path.Join("/", datumID)
   419  	var inputTree, outputTree *hashtree.Ordered
   420  	var statsTree *hashtree.Unordered
   421  	if driver.PipelineInfo().EnableStats {
   422  		inputTree = hashtree.NewOrdered(path.Join(statsRoot, "pfs"))
   423  		outputTree = hashtree.NewOrdered(path.Join(statsRoot, "pfs", "out"))
   424  		statsTree = hashtree.NewUnordered(statsRoot)
   425  		// Write job id to stats tree
   426  		statsTree.PutFile(fmt.Sprintf("job:%s", logger.JobID()), nil, 0)
   427  		// Write index in datum factory to stats tree
   428  		object, size, err := driver.PachClient().PutObject(strings.NewReader(fmt.Sprint(int(datumIndex))))
   429  		if err != nil {
   430  			return stats, recoveredDatums, err
   431  		}
   432  		objectInfo, err := driver.PachClient().InspectObject(object.Hash)
   433  		if err != nil {
   434  			return stats, recoveredDatums, err
   435  		}
   436  		h, err := pfs.DecodeHash(object.Hash)
   437  		if err != nil {
   438  			return stats, recoveredDatums, err
   439  		}
   440  		statsTree.PutFile("index", h, size, objectInfo.BlockRef)
   441  		defer func() {
   442  			logger.Logf("writing stats for datum: %s, current err: %v", tag, retErr)
   443  			if err := writeStats(driver, logger, stats.ProcessStats, inputTree, outputTree, statsTree, tag, datumStatsCache); err != nil && retErr == nil {
   444  				retErr = err
   445  			}
   446  		}()
   447  	}
   448  
   449  	var failures int64
   450  	if err := backoff.RetryUntilCancel(driver.PachClient().Ctx(), func() error {
   451  		var err error
   452  
   453  		// WithData will download the inputs for this datum
   454  		stats.ProcessStats, err = driver.WithData(inputs, inputTree, logger, func(dir string, processStats *pps.ProcessStats) error {
   455  
   456  			// WithActiveData acquires a mutex so that we don't run this section concurrently
   457  			if err := driver.WithActiveData(inputs, dir, func() error {
   458  				ctx, cancel := context.WithCancel(driver.PachClient().Ctx())
   459  				defer cancel()
   460  
   461  				driver := driver.WithContext(ctx)
   462  
   463  				return status.withDatum(inputs, cancel, func() error {
   464  					env := driver.UserCodeEnv(logger.JobID(), outputCommit, inputs)
   465  					if err := driver.RunUserCode(logger, env, processStats, driver.PipelineInfo().DatumTimeout); err != nil {
   466  						if driver.PipelineInfo().Transform.ErrCmd != nil && failures == driver.PipelineInfo().DatumTries-1 {
   467  							if err = driver.RunUserErrorHandlingCode(logger, env, processStats, driver.PipelineInfo().DatumTimeout); err != nil {
   468  								return errors.Wrap(err, "RunUserErrorHandlingCode")
   469  							}
   470  							return errDatumRecovered
   471  						}
   472  						return err
   473  					}
   474  					return nil
   475  				})
   476  			}); err != nil {
   477  				return err
   478  			}
   479  
   480  			if driver.PipelineInfo().S3Out {
   481  				return nil // S3Out pipelines do not store data in worker hashtrees
   482  			}
   483  
   484  			hashtreeBytes, err := driver.UploadOutput(dir, tag, logger, inputs, processStats, outputTree)
   485  			if err != nil {
   486  				return err
   487  			}
   488  
   489  			// Cache datum hashtree locally
   490  			return datumCache.Put(uuid.NewWithoutDashes(), bytes.NewReader(hashtreeBytes))
   491  		})
   492  		return err
   493  	}, &backoff.ZeroBackOff{}, func(err error, d time.Duration) error {
   494  		failures++
   495  		if failures >= driver.PipelineInfo().DatumTries {
   496  			logger.Logf("failed to process datum with error: %+v", err)
   497  			if statsTree != nil {
   498  				object, size, err := driver.PachClient().PutObject(strings.NewReader(err.Error()))
   499  				if err != nil {
   500  					logger.Errf("could not put error object: %s\n", err)
   501  				} else {
   502  					objectInfo, err := driver.PachClient().InspectObject(object.Hash)
   503  					if err != nil {
   504  						return err
   505  					}
   506  					h, err := pfs.DecodeHash(object.Hash)
   507  					if err != nil {
   508  						return err
   509  					}
   510  					statsTree.PutFile("failure", h, size, objectInfo.BlockRef)
   511  				}
   512  			}
   513  			return err
   514  		}
   515  		// If stats is enabled, reset input and output tree on retry.
   516  		if statsTree != nil {
   517  			inputTree = hashtree.NewOrdered(path.Join(statsRoot, "pfs"))
   518  			outputTree = hashtree.NewOrdered(path.Join(statsRoot, "pfs", "out"))
   519  		}
   520  		logger.Logf("failed processing datum: %v, retrying in %v", err, d)
   521  		return nil
   522  	}); errors.Is(err, errDatumRecovered) {
   523  		// keep track of the recovered datums
   524  		recoveredDatums = []string{tag}
   525  		stats.DatumsRecovered++
   526  	} else if err != nil {
   527  		stats.FailedDatumID = datumID
   528  		stats.DatumsFailed++
   529  	} else {
   530  		stats.DatumsProcessed++
   531  	}
   532  	return stats, recoveredDatums, nil
   533  }
   534  
   535  func writeStats(
   536  	driver driver.Driver,
   537  	logger logs.TaggedLogger,
   538  	stats *pps.ProcessStats,
   539  	inputTree *hashtree.Ordered,
   540  	outputTree *hashtree.Ordered,
   541  	statsTree *hashtree.Unordered,
   542  	tag string,
   543  	datumStatsCache *hashtree.MergeCache,
   544  ) (retErr error) {
   545  	// Store stats and add stats file
   546  	marshaler := &jsonpb.Marshaler{}
   547  	statsString, err := marshaler.MarshalToString(stats)
   548  	if err != nil {
   549  		logger.Errf("could not serialize stats: %s\n", err)
   550  		return err
   551  	}
   552  	object, size, err := driver.PachClient().PutObject(strings.NewReader(statsString))
   553  	if err != nil {
   554  		logger.Errf("could not put stats object: %s\n", err)
   555  		return err
   556  	}
   557  	objectInfo, err := driver.PachClient().InspectObject(object.Hash)
   558  	if err != nil {
   559  		return err
   560  	}
   561  	h, err := pfs.DecodeHash(object.Hash)
   562  	if err != nil {
   563  		return err
   564  	}
   565  	statsTree.PutFile("stats", h, size, objectInfo.BlockRef)
   566  	// Store logs and add logs file
   567  	object, size, err = logger.Close()
   568  	if err != nil {
   569  		return err
   570  	}
   571  	if object != nil {
   572  		objectInfo, err := driver.PachClient().InspectObject(object.Hash)
   573  		if err != nil {
   574  			return err
   575  		}
   576  		h, err := pfs.DecodeHash(object.Hash)
   577  		if err != nil {
   578  			return err
   579  		}
   580  		statsTree.PutFile("logs", h, size, objectInfo.BlockRef)
   581  	}
   582  	// Merge stats trees (input, output, stats) and write out
   583  	inputBuf := &bytes.Buffer{}
   584  	inputTree.Serialize(inputBuf)
   585  	outputBuf := &bytes.Buffer{}
   586  	outputTree.Serialize(outputBuf)
   587  	statsBuf := &bytes.Buffer{}
   588  	statsTree.Ordered().Serialize(statsBuf)
   589  	// Merge datum stats hashtree
   590  	buf := &bytes.Buffer{}
   591  	if err := hashtree.Merge(hashtree.NewWriter(buf), []*hashtree.Reader{
   592  		hashtree.NewReader(inputBuf, nil),
   593  		hashtree.NewReader(outputBuf, nil),
   594  		hashtree.NewReader(statsBuf, nil),
   595  	}); err != nil {
   596  		return err
   597  	}
   598  	// Write datum stats hashtree to object storage
   599  	objW, err := driver.PachClient().PutObjectAsync([]*pfs.Tag{client.NewTag(tag + statsTagSuffix)})
   600  	if err != nil {
   601  		return err
   602  	}
   603  	defer func() {
   604  		if err := objW.Close(); err != nil && retErr == nil {
   605  			retErr = err
   606  		}
   607  	}()
   608  	if _, err := objW.Write(buf.Bytes()); err != nil {
   609  		return err
   610  	}
   611  	// Cache datum stats hashtree locally
   612  	return datumStatsCache.Put(tag, bytes.NewReader(buf.Bytes()))
   613  }
   614  
   615  func fetchChunkFromWorker(driver driver.Driver, logger logs.TaggedLogger, address string, subtaskID string, shard int64, stats bool) (io.ReadCloser, error) {
   616  	// TODO: cache cross-worker clients at the driver level
   617  	client, err := server.NewClient(address)
   618  	if err != nil {
   619  		return nil, err
   620  	}
   621  
   622  	ctx, cancel := context.WithCancel(driver.PachClient().Ctx())
   623  	getChunkClient, err := client.GetChunk(ctx, &server.GetChunkRequest{JobID: logger.JobID(), ChunkID: hashtreeChunkID(subtaskID), Shard: shard, Stats: stats})
   624  	if err != nil {
   625  		cancel()
   626  		return nil, grpcutil.ScrubGRPC(err)
   627  	}
   628  
   629  	return grpcutil.NewStreamingBytesReader(getChunkClient, cancel), nil
   630  }
   631  
   632  func fetchChunk(driver driver.Driver, logger logs.TaggedLogger, cache *hashtree.MergeCache, chunkID string, info *HashtreeInfo, shard int64, stats bool) (retErr error) {
   633  	if info.Address != "" {
   634  		err := func() (retErr error) {
   635  			reader, err := fetchChunkFromWorker(driver, logger, info.Address, info.SubtaskID, shard, stats)
   636  			if err != nil {
   637  				return err
   638  			}
   639  			defer func() {
   640  				if err := reader.Close(); retErr == nil {
   641  					retErr = err
   642  				}
   643  			}()
   644  			return cache.Put(chunkID, reader)
   645  		}()
   646  		if err == nil {
   647  			return nil
   648  		}
   649  		logger.Logf("error when fetching cached chunk (%s) from worker (%s) - fetching from object store instead: %v", info.Object, info.Address, err)
   650  	}
   651  
   652  	reader, err := driver.PachClient().DirectObjReader(info.Object)
   653  	if err != nil {
   654  		return err
   655  	}
   656  	defer func() {
   657  		if err := reader.Close(); retErr == nil {
   658  			retErr = err
   659  		}
   660  	}()
   661  	return cache.Put(chunkID, reader)
   662  }
   663  
   664  func handleMergeTask(driver driver.Driver, logger logs.TaggedLogger, data *MergeData) (retErr error) {
   665  	var cache *hashtree.MergeCache
   666  	var err error
   667  	if data.Stats {
   668  		cache, err = driver.ChunkStatsCaches().GetOrCreateCache(logger.JobID())
   669  	} else {
   670  		cache, err = driver.ChunkCaches().GetOrCreateCache(logger.JobID())
   671  	}
   672  	if err != nil {
   673  		return err
   674  	}
   675  
   676  	var parentReader io.ReadCloser
   677  	defer func() {
   678  		if parentReader != nil {
   679  			if err := parentReader.Close(); retErr == nil {
   680  				retErr = err
   681  			}
   682  		}
   683  	}()
   684  
   685  	if err := logger.LogStep("downloading hashtree chunks", func() error {
   686  		eg, _ := errgroup.WithContext(driver.PachClient().Ctx())
   687  		limiter := limit.New(20) // TODO: base this off of configuration
   688  
   689  		cachedIDs := cache.Keys()
   690  		usedIDs := make(map[string]struct{})
   691  		var keptChunks, droppedChunks, downloadedChunks int
   692  
   693  		for _, hashtreeInfo := range data.Hashtrees {
   694  			chunkID := hashtreeChunkID(hashtreeInfo.SubtaskID)
   695  			usedIDs[chunkID] = struct{}{}
   696  
   697  			if !cache.Has(chunkID) {
   698  				limiter.Acquire()
   699  				hashtreeInfo := hashtreeInfo
   700  				eg.Go(func() (retErr error) {
   701  					defer limiter.Release()
   702  					return errors.EnsureStack(fetchChunk(driver, logger, cache, chunkID, hashtreeInfo, data.Shard, data.Stats))
   703  				})
   704  				downloadedChunks++
   705  			} else {
   706  				keptChunks++
   707  			}
   708  		}
   709  
   710  		// There may be cached trees from a failed run - drop them
   711  		for _, id := range cachedIDs {
   712  			if _, ok := usedIDs[id]; !ok {
   713  				cache.Delete(id)
   714  				droppedChunks++
   715  			}
   716  		}
   717  
   718  		logger.Logf("all hashtree chunks accounted for: %d kept, %d dropped, %d downloading", keptChunks, droppedChunks, downloadedChunks)
   719  
   720  		if data.Parent != nil {
   721  			eg.Go(func() error {
   722  				var err error
   723  				parentReader, err = driver.PachClient().GetObjectReader(data.Parent.Hash)
   724  				return errors.EnsureStack(err)
   725  			})
   726  		}
   727  
   728  		return errors.EnsureStack(eg.Wait())
   729  	}); err != nil {
   730  		return err
   731  	}
   732  
   733  	return logger.LogStep("merging hashtree chunks", func() error {
   734  		tree, size, err := merge(driver, parentReader, cache, data.Shard)
   735  		if err != nil {
   736  			return err
   737  		}
   738  
   739  		data.Tree = tree
   740  		data.TreeSize = size
   741  		return nil
   742  	})
   743  }
   744  
   745  func merge(driver driver.Driver, parent io.Reader, cache *hashtree.MergeCache, shard int64) (*pfs.Object, uint64, error) {
   746  	var tree *pfs.Object
   747  	var size uint64
   748  	if err := func() (retErr error) {
   749  		objW, err := driver.PachClient().PutObjectAsync(nil)
   750  		if err != nil {
   751  			return errors.EnsureStack(err)
   752  		}
   753  
   754  		w := hashtree.NewWriter(objW)
   755  		filter := hashtree.NewFilter(driver.NumShards(), shard)
   756  		err = cache.Merge(w, parent, filter)
   757  		size = w.Size()
   758  		if err != nil {
   759  			objW.Close()
   760  			return errors.EnsureStack(err)
   761  		}
   762  		// Get object hash for hashtree
   763  		if err := objW.Close(); err != nil {
   764  			return errors.EnsureStack(err)
   765  		}
   766  		tree, err = objW.Object()
   767  		if err != nil {
   768  			return errors.EnsureStack(err)
   769  		}
   770  		// Get index and write it out
   771  		indexData, err := w.Index()
   772  		if err != nil {
   773  			return errors.EnsureStack(err)
   774  		}
   775  		return writeIndex(driver, tree, indexData)
   776  	}(); err != nil {
   777  		return nil, 0, err
   778  	}
   779  	return tree, size, nil
   780  }
   781  
   782  func writeIndex(driver driver.Driver, tree *pfs.Object, indexData []byte) (retErr error) {
   783  	info, err := driver.PachClient().InspectObject(tree.Hash)
   784  	if err != nil {
   785  		return errors.EnsureStack(err)
   786  	}
   787  	path, err := pfsserver.BlockPathFromEnv(info.BlockRef.Block)
   788  	if err != nil {
   789  		return errors.EnsureStack(err)
   790  	}
   791  	indexWriter, err := driver.PachClient().DirectObjWriter(path + hashtree.IndexPath)
   792  	if err != nil {
   793  		return errors.EnsureStack(err)
   794  	}
   795  	defer func() {
   796  		if err := indexWriter.Close(); err != nil && retErr == nil {
   797  			retErr = errors.EnsureStack(err)
   798  		}
   799  	}()
   800  	_, err = indexWriter.Write(indexData)
   801  	return errors.EnsureStack(err)
   802  }