github.com/pachyderm/pachyderm@v1.13.4/src/server/worker/pipeline/transform/worker.go (about) 1 package transform 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "net/http" 9 "os" 10 "path" 11 "strings" 12 "sync" 13 "sync/atomic" 14 "time" 15 16 "github.com/gogo/protobuf/jsonpb" 17 "github.com/gogo/protobuf/types" 18 "golang.org/x/sync/errgroup" 19 20 "github.com/pachyderm/pachyderm/src/client" 21 "github.com/pachyderm/pachyderm/src/client/limit" 22 "github.com/pachyderm/pachyderm/src/client/pfs" 23 "github.com/pachyderm/pachyderm/src/client/pkg/errors" 24 "github.com/pachyderm/pachyderm/src/client/pkg/grpcutil" 25 "github.com/pachyderm/pachyderm/src/client/pkg/pbutil" 26 "github.com/pachyderm/pachyderm/src/client/pps" 27 pfsserver "github.com/pachyderm/pachyderm/src/server/pfs/server" 28 "github.com/pachyderm/pachyderm/src/server/pkg/backoff" 29 "github.com/pachyderm/pachyderm/src/server/pkg/hashtree" 30 "github.com/pachyderm/pachyderm/src/server/pkg/ppsutil" 31 "github.com/pachyderm/pachyderm/src/server/pkg/uuid" 32 "github.com/pachyderm/pachyderm/src/server/pkg/work" 33 "github.com/pachyderm/pachyderm/src/server/worker/common" 34 "github.com/pachyderm/pachyderm/src/server/worker/driver" 35 "github.com/pachyderm/pachyderm/src/server/worker/logs" 36 "github.com/pachyderm/pachyderm/src/server/worker/server" 37 ) 38 39 var ( 40 errDatumRecovered = errors.New("the datum errored, and the error was handled successfully") 41 statsTagSuffix = "_stats" 42 ) 43 44 // TODO: would be nice to have these have a deterministic ID rather than based 45 // off the subtask ID so we can shortcut processing if we get interrupted and 46 // restarted 47 func jobArtifactRecoveredDatums(jobID string, subtaskID string) string { 48 return path.Join(jobArtifactPrefix(jobID), fmt.Sprintf("recovered-%s", subtaskID)) 49 } 50 51 func jobArtifactChunkStats(jobID string, subtaskID string) string { 52 return path.Join(jobArtifactPrefix(jobID), fmt.Sprintf("chunk-stats-%s", subtaskID)) 53 } 54 55 func jobArtifactChunk(jobID string, subtaskID string) string { 56 return path.Join(jobArtifactPrefix(jobID), fmt.Sprintf("chunk-%s", subtaskID)) 57 } 58 59 func hashtreeChunkID(subtaskID string) string { 60 return fmt.Sprintf("chunk-%s", subtaskID) 61 } 62 63 func plusDuration(x *types.Duration, y *types.Duration) (*types.Duration, error) { 64 var xd time.Duration 65 var yd time.Duration 66 var err error 67 if x != nil { 68 xd, err = types.DurationFromProto(x) 69 if err != nil { 70 return nil, err 71 } 72 } 73 if y != nil { 74 yd, err = types.DurationFromProto(y) 75 if err != nil { 76 return nil, err 77 } 78 } 79 return types.DurationProto(xd + yd), nil 80 } 81 82 // mergeStats merges y into x 83 func mergeStats(x, y *DatumStats) error { 84 if yps := y.ProcessStats; yps != nil { 85 var err error 86 xps := x.ProcessStats 87 if xps.DownloadTime, err = plusDuration(xps.DownloadTime, yps.DownloadTime); err != nil { 88 return err 89 } 90 if xps.ProcessTime, err = plusDuration(xps.ProcessTime, yps.ProcessTime); err != nil { 91 return err 92 } 93 if xps.UploadTime, err = plusDuration(xps.UploadTime, yps.UploadTime); err != nil { 94 return err 95 } 96 xps.DownloadBytes += yps.DownloadBytes 97 xps.UploadBytes += yps.UploadBytes 98 } 99 100 x.DatumsProcessed += y.DatumsProcessed 101 x.DatumsSkipped += y.DatumsSkipped 102 x.DatumsFailed += y.DatumsFailed 103 x.DatumsRecovered += y.DatumsRecovered 104 if x.FailedDatumID == "" { 105 x.FailedDatumID = y.FailedDatumID 106 } 107 return nil 108 } 109 110 // Worker handles a transform pipeline work subtask, then returns. 111 func Worker(driver driver.Driver, logger logs.TaggedLogger, subtask *work.Task, status *Status) (retErr error) { 112 defer func() { 113 err := retErr 114 for err != nil { 115 logger.Logf("error: %v", err) 116 if st, ok := err.(errors.StackTracer); ok { 117 logger.Logf("error stack: %+v", st.StackTrace()) 118 } 119 err = errors.Unwrap(err) 120 } 121 }() 122 123 // Handle 'process datum' tasks 124 datumData, err := deserializeDatumData(subtask.Data) 125 if err == nil { 126 return status.withJob(datumData.JobID, func() error { 127 logger = logger.WithJob(datumData.JobID) 128 if err := logger.LogStep("datum task", func() error { 129 return handleDatumTask(driver, logger, datumData, subtask.ID, status) 130 }); err != nil { 131 return err 132 } 133 134 subtask.Data, err = serializeDatumData(datumData) 135 return err 136 }) 137 } 138 139 // Handle 'merge hashtrees' tasks 140 mergeData, err := deserializeMergeData(subtask.Data) 141 if err == nil { 142 return status.withJob(mergeData.JobID, func() error { 143 logger = logger.WithJob(mergeData.JobID) 144 if err := logger.LogStep("merge task", func() error { 145 return handleMergeTask(driver, logger, mergeData) 146 }); err != nil { 147 return err 148 } 149 150 subtask.Data, err = serializeMergeData(mergeData) 151 return err 152 }) 153 } 154 155 return errors.New("worker task format unrecognized") 156 } 157 158 func forEachDatum(driver driver.Driver, object string, cb func(int64, []*common.Input) error) (retErr error) { 159 reader, err := driver.PachClient().DirectObjReader(object) 160 if err != nil { 161 return errors.EnsureStack(err) 162 } 163 defer func() { 164 if err := reader.Close(); err != nil && retErr == nil { 165 retErr = errors.EnsureStack(err) 166 } 167 }() 168 169 allDatums := &DatumInputsList{} 170 protoReader := pbutil.NewReader(reader) 171 if err := protoReader.Read(allDatums); err != nil { 172 return err 173 } 174 175 for _, datum := range allDatums.Datums { 176 if err := cb(datum.Index, datum.Inputs); err != nil { 177 return err 178 } 179 } 180 181 return nil 182 } 183 184 func uploadRecoveredDatums(driver driver.Driver, logger logs.TaggedLogger, recoveredDatums []string, object string) (retErr error) { 185 return logger.LogStep("uploading recovered datums", func() error { 186 message := &RecoveredDatums{Hashes: recoveredDatums} 187 188 writer, err := driver.PachClient().DirectObjWriter(object) 189 if err != nil { 190 return errors.EnsureStack(err) 191 } 192 defer func() { 193 if err := writer.Close(); err != nil && retErr == nil { 194 retErr = errors.EnsureStack(err) 195 } 196 }() 197 198 protoWriter := pbutil.NewWriter(writer) 199 _, err = protoWriter.Write(message) 200 return err 201 }) 202 } 203 204 func uploadChunk( 205 driver driver.Driver, 206 logger logs.TaggedLogger, 207 subtaskCache *hashtree.MergeCache, 208 chunkCache *hashtree.MergeCache, 209 object string, 210 subtaskID string, 211 ) (retErr error) { 212 return logger.LogStep("uploading hashtree chunk", func() error { 213 // Merge the datums for this job into a chunk 214 buf := &bytes.Buffer{} 215 if err := subtaskCache.Merge(hashtree.NewWriter(buf), nil, nil); err != nil { 216 return err 217 } 218 219 chunkID := hashtreeChunkID(subtaskID) 220 logger.Logf("merged hashtree cache into buffer, len: %d, chunkID: %s, object: %s", buf.Len(), chunkID, object) 221 if err := chunkCache.Put(chunkID, bytes.NewBuffer(buf.Bytes())); err != nil { 222 return err 223 } 224 225 // Upload the hashtree for this subtask to the given object 226 writer, err := driver.PachClient().DirectObjWriter(object) 227 if err != nil { 228 return errors.EnsureStack(err) 229 } 230 defer func() { 231 if err := writer.Close(); err != nil && retErr == nil { 232 retErr = errors.EnsureStack(err) 233 } 234 }() 235 236 _, err = writer.Write(buf.Bytes()) 237 return err 238 }) 239 } 240 241 func checkS3Gateway(driver driver.Driver, logger logs.TaggedLogger) error { 242 return backoff.RetryNotify(func() error { 243 endpoint := fmt.Sprintf("http://%s:%s/", 244 ppsutil.SidecarS3GatewayService(logger.JobID()), 245 os.Getenv("S3GATEWAY_PORT"), 246 ) 247 248 _, err := (&http.Client{Timeout: 5 * time.Second}).Get(endpoint) 249 logger.Logf("checking s3 gateway service for job %q: %v", logger.JobID(), err) 250 return err 251 }, backoff.New60sBackOff(), func(err error, d time.Duration) error { 252 logger.Logf("worker could not connect to s3 gateway for %q: %v", logger.JobID(), err) 253 return nil 254 }) 255 // TODO: `master` implementation fails the job here, we may need to do the same 256 // We would need to load the jobInfo first for this: 257 // }); err != nil { 258 // reason := fmt.Sprintf("could not connect to s3 gateway for %q: %v", logger.JobID(), err) 259 // logger.Logf("failing job with reason: %s", reason) 260 // // NOTE: this is the only place a worker will reach over and change the job state, this should not generally be done. 261 // return finishJob(driver.PipelineInfo(), driver.PachClient(), jobInfo, pps.JobState_JOB_FAILURE, reason, nil, nil, 0, nil, 0) 262 // } 263 // return nil 264 } 265 266 func handleDatumTask(driver driver.Driver, logger logs.TaggedLogger, data *DatumData, subtaskID string, status *Status) error { 267 if ppsutil.ContainsS3Inputs(driver.PipelineInfo().Input) || driver.PipelineInfo().S3Out { 268 if err := checkS3Gateway(driver, logger); err != nil { 269 return err 270 } 271 } 272 273 // TODO: check for existing tagged output files - continue with processing if any are missing 274 return driver.WithDatumCache(func(datumCache *hashtree.MergeCache, statsCache *hashtree.MergeCache) error { 275 logger.Logf("transform worker datum task: %v", data) 276 limiter := limit.New(int(driver.PipelineInfo().MaxQueueSize)) 277 278 // statsMutex controls access to stats so that they can be safely merged 279 statsMutex := &sync.Mutex{} 280 recoveredDatums := []string{} 281 data.Stats = &DatumStats{ 282 ProcessStats: &pps.ProcessStats{}, 283 } 284 285 var queueSize, dataProcessed, dataRecovered int64 286 // TODO: the status.GetStatus call may read the process stats without having a lock, it this ~ok? 287 if err := logger.LogStep("processing datums", func() error { 288 return status.withStats(data.Stats.ProcessStats, &queueSize, &dataProcessed, &dataRecovered, func() error { 289 ctx, cancel := context.WithCancel(driver.PachClient().Ctx()) 290 defer cancel() 291 292 eg, ctx := errgroup.WithContext(ctx) 293 driver := driver.WithContext(ctx) 294 if err := forEachDatum(driver, data.DatumsObject, func(index int64, inputs []*common.Input) error { 295 limiter.Acquire() 296 atomic.AddInt64(&queueSize, 1) 297 eg.Go(func() error { 298 defer limiter.Release() 299 defer atomic.AddInt64(&queueSize, -1) 300 301 // Construct a new logger here which will capture datum-specific 302 // logs for object storage if stats are enabled. 303 jobID := logger.JobID() 304 logger, err := logs.NewLogger(driver.PipelineInfo(), driver.PachClient()) 305 if err != nil { 306 return err 307 } 308 logger = logger.WithJob(jobID).WithData(inputs) 309 310 // subStats is still valid even on an error, merge those in before proceeding 311 subStats, subRecovered, err := processDatum(driver, logger, index, inputs, data.OutputCommit, datumCache, statsCache, status) 312 313 statsMutex.Lock() 314 defer statsMutex.Unlock() 315 statsErr := mergeStats(data.Stats, subStats) 316 if err != nil { 317 return err 318 } 319 recoveredDatums = append(recoveredDatums, subRecovered...) 320 if len(subRecovered) == 0 { 321 atomic.AddInt64(&dataProcessed, 1) 322 } 323 atomic.AddInt64(&dataRecovered, int64(len(recoveredDatums))) 324 return statsErr 325 }) 326 return nil 327 }); err != nil { 328 cancel() 329 eg.Wait() 330 return err 331 } 332 333 return eg.Wait() 334 }) 335 }); err != nil { 336 return err 337 } 338 339 if data.Stats.DatumsFailed == 0 && !driver.PipelineInfo().S3Out { 340 if len(recoveredDatums) > 0 { 341 recoveredDatumsObject := jobArtifactRecoveredDatums(logger.JobID(), subtaskID) 342 if err := uploadRecoveredDatums(driver, logger, recoveredDatums, recoveredDatumsObject); err != nil { 343 return err 344 } 345 data.RecoveredDatumsObject = recoveredDatumsObject 346 } 347 348 chunkCache, err := driver.ChunkCaches().GetOrCreateCache(logger.JobID()) 349 if err != nil { 350 return err 351 } 352 353 chunkObject := jobArtifactChunk(logger.JobID(), subtaskID) 354 if err := uploadChunk(driver, logger, datumCache, chunkCache, chunkObject, subtaskID); err != nil { 355 return err 356 } 357 358 data.ChunkHashtree = &HashtreeInfo{Address: os.Getenv(client.PPSWorkerIPEnv), Object: chunkObject, SubtaskID: subtaskID} 359 } 360 361 if driver.PipelineInfo().EnableStats { 362 chunkStatsCache, err := driver.ChunkStatsCaches().GetOrCreateCache(logger.JobID()) 363 if err != nil { 364 return err 365 } 366 367 chunkStatsObject := jobArtifactChunkStats(logger.JobID(), subtaskID) 368 if err := uploadChunk(driver, logger, statsCache, chunkStatsCache, chunkStatsObject, subtaskID); err != nil { 369 return err 370 } 371 data.StatsHashtree = &HashtreeInfo{Address: os.Getenv(client.PPSWorkerIPEnv), Object: chunkStatsObject, SubtaskID: subtaskID} 372 } 373 374 return nil 375 }) 376 } 377 378 func processDatum( 379 driver driver.Driver, 380 logger logs.TaggedLogger, 381 datumIndex int64, 382 inputs []*common.Input, 383 outputCommit *pfs.Commit, 384 datumCache *hashtree.MergeCache, 385 datumStatsCache *hashtree.MergeCache, 386 status *Status, 387 ) (_ *DatumStats, _ []string, retErr error) { 388 recoveredDatums := []string{} 389 stats := &DatumStats{} 390 tag := common.HashDatum(driver.PipelineInfo().Pipeline.Name, driver.PipelineInfo().Salt, inputs) 391 datumID := common.DatumID(inputs) 392 393 if driver.PipelineInfo().ReprocessSpec != client.ReprocessSpecEveryJob { 394 if _, err := driver.PachClient().InspectTag(driver.PachClient().Ctx(), client.NewTag(tag)); err == nil { 395 buf := &bytes.Buffer{} 396 if err := driver.PachClient().GetTag(tag, buf); err != nil { 397 return stats, recoveredDatums, err 398 } 399 if err := datumCache.Put(uuid.NewWithoutDashes(), buf); err != nil { 400 return stats, recoveredDatums, err 401 } 402 if driver.PipelineInfo().EnableStats { 403 buf.Reset() 404 if err := driver.PachClient().GetTag(tag+statsTagSuffix, buf); err != nil { 405 // We are okay with not finding the stats hashtree. This allows users to 406 // enable stats on a pipeline with pre-existing jobs. 407 return stats, recoveredDatums, nil 408 } 409 if err := datumStatsCache.Put(uuid.NewWithoutDashes(), buf); err != nil { 410 return stats, recoveredDatums, err 411 } 412 } 413 stats.DatumsSkipped++ 414 return stats, recoveredDatums, nil 415 } 416 } 417 418 statsRoot := path.Join("/", datumID) 419 var inputTree, outputTree *hashtree.Ordered 420 var statsTree *hashtree.Unordered 421 if driver.PipelineInfo().EnableStats { 422 inputTree = hashtree.NewOrdered(path.Join(statsRoot, "pfs")) 423 outputTree = hashtree.NewOrdered(path.Join(statsRoot, "pfs", "out")) 424 statsTree = hashtree.NewUnordered(statsRoot) 425 // Write job id to stats tree 426 statsTree.PutFile(fmt.Sprintf("job:%s", logger.JobID()), nil, 0) 427 // Write index in datum factory to stats tree 428 object, size, err := driver.PachClient().PutObject(strings.NewReader(fmt.Sprint(int(datumIndex)))) 429 if err != nil { 430 return stats, recoveredDatums, err 431 } 432 objectInfo, err := driver.PachClient().InspectObject(object.Hash) 433 if err != nil { 434 return stats, recoveredDatums, err 435 } 436 h, err := pfs.DecodeHash(object.Hash) 437 if err != nil { 438 return stats, recoveredDatums, err 439 } 440 statsTree.PutFile("index", h, size, objectInfo.BlockRef) 441 defer func() { 442 logger.Logf("writing stats for datum: %s, current err: %v", tag, retErr) 443 if err := writeStats(driver, logger, stats.ProcessStats, inputTree, outputTree, statsTree, tag, datumStatsCache); err != nil && retErr == nil { 444 retErr = err 445 } 446 }() 447 } 448 449 var failures int64 450 if err := backoff.RetryUntilCancel(driver.PachClient().Ctx(), func() error { 451 var err error 452 453 // WithData will download the inputs for this datum 454 stats.ProcessStats, err = driver.WithData(inputs, inputTree, logger, func(dir string, processStats *pps.ProcessStats) error { 455 456 // WithActiveData acquires a mutex so that we don't run this section concurrently 457 if err := driver.WithActiveData(inputs, dir, func() error { 458 ctx, cancel := context.WithCancel(driver.PachClient().Ctx()) 459 defer cancel() 460 461 driver := driver.WithContext(ctx) 462 463 return status.withDatum(inputs, cancel, func() error { 464 env := driver.UserCodeEnv(logger.JobID(), outputCommit, inputs) 465 if err := driver.RunUserCode(logger, env, processStats, driver.PipelineInfo().DatumTimeout); err != nil { 466 if driver.PipelineInfo().Transform.ErrCmd != nil && failures == driver.PipelineInfo().DatumTries-1 { 467 if err = driver.RunUserErrorHandlingCode(logger, env, processStats, driver.PipelineInfo().DatumTimeout); err != nil { 468 return errors.Wrap(err, "RunUserErrorHandlingCode") 469 } 470 return errDatumRecovered 471 } 472 return err 473 } 474 return nil 475 }) 476 }); err != nil { 477 return err 478 } 479 480 if driver.PipelineInfo().S3Out { 481 return nil // S3Out pipelines do not store data in worker hashtrees 482 } 483 484 hashtreeBytes, err := driver.UploadOutput(dir, tag, logger, inputs, processStats, outputTree) 485 if err != nil { 486 return err 487 } 488 489 // Cache datum hashtree locally 490 return datumCache.Put(uuid.NewWithoutDashes(), bytes.NewReader(hashtreeBytes)) 491 }) 492 return err 493 }, &backoff.ZeroBackOff{}, func(err error, d time.Duration) error { 494 failures++ 495 if failures >= driver.PipelineInfo().DatumTries { 496 logger.Logf("failed to process datum with error: %+v", err) 497 if statsTree != nil { 498 object, size, err := driver.PachClient().PutObject(strings.NewReader(err.Error())) 499 if err != nil { 500 logger.Errf("could not put error object: %s\n", err) 501 } else { 502 objectInfo, err := driver.PachClient().InspectObject(object.Hash) 503 if err != nil { 504 return err 505 } 506 h, err := pfs.DecodeHash(object.Hash) 507 if err != nil { 508 return err 509 } 510 statsTree.PutFile("failure", h, size, objectInfo.BlockRef) 511 } 512 } 513 return err 514 } 515 // If stats is enabled, reset input and output tree on retry. 516 if statsTree != nil { 517 inputTree = hashtree.NewOrdered(path.Join(statsRoot, "pfs")) 518 outputTree = hashtree.NewOrdered(path.Join(statsRoot, "pfs", "out")) 519 } 520 logger.Logf("failed processing datum: %v, retrying in %v", err, d) 521 return nil 522 }); errors.Is(err, errDatumRecovered) { 523 // keep track of the recovered datums 524 recoveredDatums = []string{tag} 525 stats.DatumsRecovered++ 526 } else if err != nil { 527 stats.FailedDatumID = datumID 528 stats.DatumsFailed++ 529 } else { 530 stats.DatumsProcessed++ 531 } 532 return stats, recoveredDatums, nil 533 } 534 535 func writeStats( 536 driver driver.Driver, 537 logger logs.TaggedLogger, 538 stats *pps.ProcessStats, 539 inputTree *hashtree.Ordered, 540 outputTree *hashtree.Ordered, 541 statsTree *hashtree.Unordered, 542 tag string, 543 datumStatsCache *hashtree.MergeCache, 544 ) (retErr error) { 545 // Store stats and add stats file 546 marshaler := &jsonpb.Marshaler{} 547 statsString, err := marshaler.MarshalToString(stats) 548 if err != nil { 549 logger.Errf("could not serialize stats: %s\n", err) 550 return err 551 } 552 object, size, err := driver.PachClient().PutObject(strings.NewReader(statsString)) 553 if err != nil { 554 logger.Errf("could not put stats object: %s\n", err) 555 return err 556 } 557 objectInfo, err := driver.PachClient().InspectObject(object.Hash) 558 if err != nil { 559 return err 560 } 561 h, err := pfs.DecodeHash(object.Hash) 562 if err != nil { 563 return err 564 } 565 statsTree.PutFile("stats", h, size, objectInfo.BlockRef) 566 // Store logs and add logs file 567 object, size, err = logger.Close() 568 if err != nil { 569 return err 570 } 571 if object != nil { 572 objectInfo, err := driver.PachClient().InspectObject(object.Hash) 573 if err != nil { 574 return err 575 } 576 h, err := pfs.DecodeHash(object.Hash) 577 if err != nil { 578 return err 579 } 580 statsTree.PutFile("logs", h, size, objectInfo.BlockRef) 581 } 582 // Merge stats trees (input, output, stats) and write out 583 inputBuf := &bytes.Buffer{} 584 inputTree.Serialize(inputBuf) 585 outputBuf := &bytes.Buffer{} 586 outputTree.Serialize(outputBuf) 587 statsBuf := &bytes.Buffer{} 588 statsTree.Ordered().Serialize(statsBuf) 589 // Merge datum stats hashtree 590 buf := &bytes.Buffer{} 591 if err := hashtree.Merge(hashtree.NewWriter(buf), []*hashtree.Reader{ 592 hashtree.NewReader(inputBuf, nil), 593 hashtree.NewReader(outputBuf, nil), 594 hashtree.NewReader(statsBuf, nil), 595 }); err != nil { 596 return err 597 } 598 // Write datum stats hashtree to object storage 599 objW, err := driver.PachClient().PutObjectAsync([]*pfs.Tag{client.NewTag(tag + statsTagSuffix)}) 600 if err != nil { 601 return err 602 } 603 defer func() { 604 if err := objW.Close(); err != nil && retErr == nil { 605 retErr = err 606 } 607 }() 608 if _, err := objW.Write(buf.Bytes()); err != nil { 609 return err 610 } 611 // Cache datum stats hashtree locally 612 return datumStatsCache.Put(tag, bytes.NewReader(buf.Bytes())) 613 } 614 615 func fetchChunkFromWorker(driver driver.Driver, logger logs.TaggedLogger, address string, subtaskID string, shard int64, stats bool) (io.ReadCloser, error) { 616 // TODO: cache cross-worker clients at the driver level 617 client, err := server.NewClient(address) 618 if err != nil { 619 return nil, err 620 } 621 622 ctx, cancel := context.WithCancel(driver.PachClient().Ctx()) 623 getChunkClient, err := client.GetChunk(ctx, &server.GetChunkRequest{JobID: logger.JobID(), ChunkID: hashtreeChunkID(subtaskID), Shard: shard, Stats: stats}) 624 if err != nil { 625 cancel() 626 return nil, grpcutil.ScrubGRPC(err) 627 } 628 629 return grpcutil.NewStreamingBytesReader(getChunkClient, cancel), nil 630 } 631 632 func fetchChunk(driver driver.Driver, logger logs.TaggedLogger, cache *hashtree.MergeCache, chunkID string, info *HashtreeInfo, shard int64, stats bool) (retErr error) { 633 if info.Address != "" { 634 err := func() (retErr error) { 635 reader, err := fetchChunkFromWorker(driver, logger, info.Address, info.SubtaskID, shard, stats) 636 if err != nil { 637 return err 638 } 639 defer func() { 640 if err := reader.Close(); retErr == nil { 641 retErr = err 642 } 643 }() 644 return cache.Put(chunkID, reader) 645 }() 646 if err == nil { 647 return nil 648 } 649 logger.Logf("error when fetching cached chunk (%s) from worker (%s) - fetching from object store instead: %v", info.Object, info.Address, err) 650 } 651 652 reader, err := driver.PachClient().DirectObjReader(info.Object) 653 if err != nil { 654 return err 655 } 656 defer func() { 657 if err := reader.Close(); retErr == nil { 658 retErr = err 659 } 660 }() 661 return cache.Put(chunkID, reader) 662 } 663 664 func handleMergeTask(driver driver.Driver, logger logs.TaggedLogger, data *MergeData) (retErr error) { 665 var cache *hashtree.MergeCache 666 var err error 667 if data.Stats { 668 cache, err = driver.ChunkStatsCaches().GetOrCreateCache(logger.JobID()) 669 } else { 670 cache, err = driver.ChunkCaches().GetOrCreateCache(logger.JobID()) 671 } 672 if err != nil { 673 return err 674 } 675 676 var parentReader io.ReadCloser 677 defer func() { 678 if parentReader != nil { 679 if err := parentReader.Close(); retErr == nil { 680 retErr = err 681 } 682 } 683 }() 684 685 if err := logger.LogStep("downloading hashtree chunks", func() error { 686 eg, _ := errgroup.WithContext(driver.PachClient().Ctx()) 687 limiter := limit.New(20) // TODO: base this off of configuration 688 689 cachedIDs := cache.Keys() 690 usedIDs := make(map[string]struct{}) 691 var keptChunks, droppedChunks, downloadedChunks int 692 693 for _, hashtreeInfo := range data.Hashtrees { 694 chunkID := hashtreeChunkID(hashtreeInfo.SubtaskID) 695 usedIDs[chunkID] = struct{}{} 696 697 if !cache.Has(chunkID) { 698 limiter.Acquire() 699 hashtreeInfo := hashtreeInfo 700 eg.Go(func() (retErr error) { 701 defer limiter.Release() 702 return errors.EnsureStack(fetchChunk(driver, logger, cache, chunkID, hashtreeInfo, data.Shard, data.Stats)) 703 }) 704 downloadedChunks++ 705 } else { 706 keptChunks++ 707 } 708 } 709 710 // There may be cached trees from a failed run - drop them 711 for _, id := range cachedIDs { 712 if _, ok := usedIDs[id]; !ok { 713 cache.Delete(id) 714 droppedChunks++ 715 } 716 } 717 718 logger.Logf("all hashtree chunks accounted for: %d kept, %d dropped, %d downloading", keptChunks, droppedChunks, downloadedChunks) 719 720 if data.Parent != nil { 721 eg.Go(func() error { 722 var err error 723 parentReader, err = driver.PachClient().GetObjectReader(data.Parent.Hash) 724 return errors.EnsureStack(err) 725 }) 726 } 727 728 return errors.EnsureStack(eg.Wait()) 729 }); err != nil { 730 return err 731 } 732 733 return logger.LogStep("merging hashtree chunks", func() error { 734 tree, size, err := merge(driver, parentReader, cache, data.Shard) 735 if err != nil { 736 return err 737 } 738 739 data.Tree = tree 740 data.TreeSize = size 741 return nil 742 }) 743 } 744 745 func merge(driver driver.Driver, parent io.Reader, cache *hashtree.MergeCache, shard int64) (*pfs.Object, uint64, error) { 746 var tree *pfs.Object 747 var size uint64 748 if err := func() (retErr error) { 749 objW, err := driver.PachClient().PutObjectAsync(nil) 750 if err != nil { 751 return errors.EnsureStack(err) 752 } 753 754 w := hashtree.NewWriter(objW) 755 filter := hashtree.NewFilter(driver.NumShards(), shard) 756 err = cache.Merge(w, parent, filter) 757 size = w.Size() 758 if err != nil { 759 objW.Close() 760 return errors.EnsureStack(err) 761 } 762 // Get object hash for hashtree 763 if err := objW.Close(); err != nil { 764 return errors.EnsureStack(err) 765 } 766 tree, err = objW.Object() 767 if err != nil { 768 return errors.EnsureStack(err) 769 } 770 // Get index and write it out 771 indexData, err := w.Index() 772 if err != nil { 773 return errors.EnsureStack(err) 774 } 775 return writeIndex(driver, tree, indexData) 776 }(); err != nil { 777 return nil, 0, err 778 } 779 return tree, size, nil 780 } 781 782 func writeIndex(driver driver.Driver, tree *pfs.Object, indexData []byte) (retErr error) { 783 info, err := driver.PachClient().InspectObject(tree.Hash) 784 if err != nil { 785 return errors.EnsureStack(err) 786 } 787 path, err := pfsserver.BlockPathFromEnv(info.BlockRef.Block) 788 if err != nil { 789 return errors.EnsureStack(err) 790 } 791 indexWriter, err := driver.PachClient().DirectObjWriter(path + hashtree.IndexPath) 792 if err != nil { 793 return errors.EnsureStack(err) 794 } 795 defer func() { 796 if err := indexWriter.Close(); err != nil && retErr == nil { 797 retErr = errors.EnsureStack(err) 798 } 799 }() 800 _, err = indexWriter.Write(indexData) 801 return errors.EnsureStack(err) 802 }