github.com/pachyderm/pachyderm@v1.13.4/src/server/worker/pipeline/transform/chain/chain.go (about)

     1  package chain
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  	"sync"
     7  
     8  	"github.com/pachyderm/pachyderm/src/client/pkg/errors"
     9  	"github.com/pachyderm/pachyderm/src/server/worker/common"
    10  	"github.com/pachyderm/pachyderm/src/server/worker/datum"
    11  )
    12  
    13  // DatumHasher is an interface to provide datum hashing without any external
    14  // dependencies (such as on the pipelineInfo).
    15  type DatumHasher interface {
    16  	// Hash should essentially wrap the common.HashDatum function, but other
    17  	// implementations may be useful in tests.
    18  	Hash([]*common.Input) string
    19  }
    20  
    21  // JobData is an interface which is used as a key to refer to a job within the
    22  // JobChain. It must provide a constructor for the datum iterator used by the
    23  // chain to produce the JobDatumIterator.
    24  type JobData interface {
    25  	// Iterator constructs the datum.Iterator associated with the job
    26  	Iterator() (datum.Iterator, error)
    27  }
    28  
    29  // JobDatumIterator is the interface returned by the JobChain corresponding to a
    30  // JobData. This acts similarly to a datum.Iterator, but has slightly different
    31  // semantics. This iterator works in batches, although iteration is still
    32  // performed one datum at a time, the batch sizes let the user know how many
    33  // datums can be consumed without blocking on upstream jobs. This iterator does
    34  // not support random access, although it can be reset if the user needs to
    35  // reiterate over the datums.
    36  type JobDatumIterator interface {
    37  	// NextBatch blocks until the next batch of datums are available
    38  	// (corresponding to an upstream job finishing in some way), and returns the
    39  	// number of datums that can now be iterated through.
    40  	NextBatch(context.Context) (int64, error)
    41  
    42  	// NextDatum advances the iterator and returns the next available datum. If no
    43  	// such datum is immediately available, nil will be returned.
    44  	NextDatum() ([]*common.Input, int64)
    45  
    46  	// AdditiveOnly indicates if the job output can be merged with the parent
    47  	// job's output commit. If this is true, the iterator will not provide all
    48  	// datums for the output commit, but rather the set of datums that have been
    49  	// added.
    50  	AdditiveOnly() bool
    51  
    52  	// DatumSet returns the set of datums that have been produced for this job. If
    53  	// any datums have been recovered (see JobChain.RecoveredDatums), they will be
    54  	// excluded from this set.
    55  	DatumSet() DatumSet
    56  
    57  	// MaxLen returns the length of the underlying datum.Iterator. This kinda
    58  	// sucks but is necessary to know how many datums were skipped in the case of
    59  	// AdditiveOnly=true.
    60  	MaxLen() int64
    61  
    62  	// Reset will reset the underlying data structures so that iteration can be
    63  	// performed from the start again. There is no guarantee that the datums will
    64  	// be provided in the same order, as some datums may no longer be blocked on
    65  	// subsequent iterations.
    66  	Reset()
    67  }
    68  
    69  // JobChain is an for coordinating concurrency between jobs. It tracks multiple
    70  // jobs via their JobData interface, and provides a JobDatumIterator for them to
    71  // safely process datums without worrying about work being duplicated or
    72  // invalidated. Dependencies between jobs are based on the order in which they
    73  // are added, so care should be taken to not introduce race conditions when
    74  // starting jobs.
    75  type JobChain interface {
    76  	// Start adds a new job to the chain and returns the corresponding
    77  	// JobDatumIterator
    78  	Start(jd JobData) (JobDatumIterator, error)
    79  
    80  	// RecoveredDatums indicates the set of recovered datums for the job. This can
    81  	// be called multiple times.
    82  	RecoveredDatums(jd JobData, recoveredDatums DatumSet) error
    83  
    84  	// Succeed indicates that the job has finished successfully
    85  	Succeed(jd JobData) error
    86  
    87  	// Fail indicates that the job has finished unsuccessfully
    88  	Fail(jd JobData) error
    89  }
    90  
    91  // DatumSet is a data structure used to track the set of datums in a job.
    92  // Multiple identical datums may be present in a job (so this is more of a
    93  // Multiset), but w/e.
    94  type DatumSet map[string]int64
    95  
    96  type jobDatumIterator struct {
    97  	data JobData
    98  	jc   *jobChain
    99  
   100  	// TODO: lower memory consumption - all these datumsets might result in a
   101  	// really large memory footprint. See if we can do a streaming interface to
   102  	// replace these - will likely require the new storage layer, as additive-only
   103  	// jobs need this stuff the most.
   104  	yielding  DatumSet // Datums that may be yielded as the iterator progresses
   105  	yielded   DatumSet // Datums that have been yielded
   106  	allDatums DatumSet // All datum hashes from the datum iterator
   107  
   108  	ancestors []*jobDatumIterator
   109  	dit       datum.Iterator
   110  	ditIndex  int
   111  
   112  	finished     bool
   113  	additiveOnly bool
   114  	done         chan struct{}
   115  }
   116  
   117  type jobChain struct {
   118  	mutex  sync.Mutex
   119  	hasher DatumHasher
   120  	jobs   []*jobDatumIterator
   121  }
   122  
   123  // NewJobChain constructs a JobChain
   124  func NewJobChain(hasher DatumHasher, baseDatums DatumSet) JobChain {
   125  	jc := &jobChain{
   126  		hasher: hasher,
   127  	}
   128  
   129  	// Insert a dummy job representing the given base datum set
   130  	jdi := &jobDatumIterator{
   131  		data:      nil,
   132  		jc:        jc,
   133  		allDatums: baseDatums,
   134  		finished:  true,
   135  		done:      make(chan struct{}),
   136  		ditIndex:  -1,
   137  	}
   138  	close(jdi.done)
   139  
   140  	jc.jobs = []*jobDatumIterator{jdi}
   141  	return jc
   142  }
   143  
   144  // recalculate is called whenever jdi.yielding is empty (either at init or when
   145  // a blocking ancestor job has finished), to repopulate it.
   146  func (jdi *jobDatumIterator) recalculate(allAncestors []*jobDatumIterator) {
   147  	jdi.ancestors = []*jobDatumIterator{}
   148  	interestingAncestors := map[*jobDatumIterator]struct{}{}
   149  	for hash, count := range jdi.allDatums {
   150  		if yieldedCount, ok := jdi.yielded[hash]; ok {
   151  			if count-yieldedCount > 0 {
   152  				jdi.yielding[hash] = count - yieldedCount
   153  			}
   154  			continue
   155  		}
   156  
   157  		safeToProcess := true
   158  		// interestingAncestors should be _all_ unfinished previous jobs which have
   159  		// _any_ datum overlap with this job
   160  		for _, ancestor := range allAncestors {
   161  			if !ancestor.finished {
   162  				if _, ok := ancestor.allDatums[hash]; ok {
   163  					interestingAncestors[ancestor] = struct{}{}
   164  					safeToProcess = false
   165  				}
   166  			}
   167  		}
   168  
   169  		if safeToProcess {
   170  			jdi.yielding[hash] = count
   171  		}
   172  	}
   173  
   174  	var parentJob *jobDatumIterator
   175  	for i := len(allAncestors) - 1; i >= 0; i-- {
   176  		// Skip all failed jobs
   177  		if allAncestors[i].allDatums != nil {
   178  			parentJob = allAncestors[i]
   179  			break
   180  		}
   181  	}
   182  
   183  	// If this job is additive-only from the parent job, we should mark it now -
   184  	// loop over parent datums to see if they are all present
   185  	jdi.additiveOnly = true
   186  	for hash, parentCount := range parentJob.allDatums {
   187  		if count, ok := jdi.allDatums[hash]; !ok || count < parentCount {
   188  			jdi.additiveOnly = false
   189  			break
   190  		}
   191  	}
   192  
   193  	if jdi.additiveOnly {
   194  		// If this is additive-only, we only need to enqueue new datums (since the parent job)
   195  		for hash, count := range jdi.yielding {
   196  			if parentCount, ok := parentJob.allDatums[hash]; ok {
   197  				if count == parentCount {
   198  					delete(jdi.yielding, hash)
   199  				} else {
   200  					jdi.yielding[hash] = count - parentCount
   201  				}
   202  			}
   203  		}
   204  		// An additive-only job can only progress once its parent job has finished.
   205  		// At that point it will re-evaluate what datums to process in case of a
   206  		// failed job or recovered datums.
   207  		if !parentJob.finished {
   208  			jdi.ancestors = append(jdi.ancestors, parentJob)
   209  		}
   210  	} else {
   211  		for ancestor := range interestingAncestors {
   212  			jdi.ancestors = append(jdi.ancestors, ancestor)
   213  		}
   214  	}
   215  }
   216  
   217  func (jc *jobChain) Start(jd JobData) (JobDatumIterator, error) {
   218  	dit, err := jd.Iterator()
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  
   223  	jdi := &jobDatumIterator{
   224  		data:      jd,
   225  		jc:        jc,
   226  		yielding:  make(DatumSet),
   227  		yielded:   make(DatumSet),
   228  		allDatums: make(DatumSet),
   229  		ancestors: []*jobDatumIterator{},
   230  		dit:       dit,
   231  		ditIndex:  -1,
   232  		done:      make(chan struct{}),
   233  	}
   234  
   235  	jdi.dit.Reset()
   236  	for jdi.dit.Next() {
   237  		inputs := jdi.dit.Datum()
   238  		hash := jc.hasher.Hash(inputs)
   239  		jdi.allDatums[hash]++
   240  	}
   241  	jdi.dit.Reset()
   242  
   243  	jc.mutex.Lock()
   244  	defer jc.mutex.Unlock()
   245  
   246  	jdi.recalculate(jc.jobs)
   247  
   248  	jc.jobs = append(jc.jobs, jdi)
   249  	return jdi, nil
   250  }
   251  
   252  func (jc *jobChain) indexOf(jd JobData) (int, error) {
   253  	for i, x := range jc.jobs {
   254  		if x.data == jd {
   255  			return i, nil
   256  		}
   257  	}
   258  	return 0, errors.New("job not found in job chain")
   259  }
   260  
   261  func (jc *jobChain) cleanFinishedJobs() {
   262  	for len(jc.jobs) > 1 && jc.jobs[1].finished {
   263  		if jc.jobs[1].allDatums != nil {
   264  			jc.jobs[0].allDatums = jc.jobs[1].allDatums
   265  		}
   266  		jc.jobs = append(jc.jobs[:1], jc.jobs[2:]...)
   267  	}
   268  }
   269  
   270  func (jc *jobChain) Fail(jd JobData) error {
   271  	jc.mutex.Lock()
   272  	defer jc.mutex.Unlock()
   273  
   274  	index, err := jc.indexOf(jd)
   275  	if err != nil {
   276  		return err
   277  	}
   278  
   279  	jdi := jc.jobs[index]
   280  
   281  	if jdi.finished {
   282  		return errors.New("cannot fail a job that is already finished")
   283  	}
   284  
   285  	jdi.allDatums = nil
   286  	jdi.finished = true
   287  	close(jdi.done)
   288  
   289  	jc.cleanFinishedJobs()
   290  
   291  	return nil
   292  }
   293  
   294  func (jc *jobChain) RecoveredDatums(jd JobData, recoveredDatums DatumSet) error {
   295  	jc.mutex.Lock()
   296  	defer jc.mutex.Unlock()
   297  
   298  	index, err := jc.indexOf(jd)
   299  	if err != nil {
   300  		return err
   301  	}
   302  
   303  	jdi := jc.jobs[index]
   304  
   305  	for hash := range recoveredDatums {
   306  		delete(jdi.allDatums, hash)
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  func (jc *jobChain) Succeed(jd JobData) error {
   313  	jc.mutex.Lock()
   314  	defer jc.mutex.Unlock()
   315  
   316  	index, err := jc.indexOf(jd)
   317  	if err != nil {
   318  		return err
   319  	}
   320  
   321  	jdi := jc.jobs[index]
   322  
   323  	if jdi.finished {
   324  		return errors.New("cannot succeed a job that is already finished")
   325  	}
   326  
   327  	if len(jdi.yielding) != 0 || len(jdi.ancestors) > 0 {
   328  		return errors.Errorf(
   329  			"cannot succeed a job with items remaining on the iterator: %d datums and %d ancestor jobs",
   330  			len(jdi.yielding), len(jdi.ancestors),
   331  		)
   332  	}
   333  
   334  	jdi.finished = true
   335  	jc.cleanFinishedJobs()
   336  	close(jdi.done)
   337  	return nil
   338  }
   339  
   340  // TODO: iteration should return a chunk of 'known' new datums before other
   341  // datums (to optimize for distributing processing across workers). This should
   342  // still be true even after resetting the iterator.
   343  func (jdi *jobDatumIterator) NextBatch(ctx context.Context) (int64, error) {
   344  	for len(jdi.yielding) == 0 {
   345  		if len(jdi.ancestors) == 0 {
   346  			return 0, nil
   347  		}
   348  
   349  		// Wait on an ancestor job
   350  		cases := make([]reflect.SelectCase, 0, len(jdi.ancestors)+1)
   351  		for _, x := range jdi.ancestors {
   352  			cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(x.done)})
   353  		}
   354  		cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())})
   355  
   356  		// Wait for an ancestor job to finish, then remove it from our dependencies
   357  		selectIndex, _, _ := reflect.Select(cases)
   358  		if selectIndex == len(cases)-1 {
   359  			return 0, ctx.Err()
   360  		}
   361  
   362  		if err := func() error {
   363  			jdi.jc.mutex.Lock()
   364  			defer jdi.jc.mutex.Unlock()
   365  
   366  			if jdi.finished {
   367  				return errors.New("stopping datum iteration because job failed")
   368  			}
   369  
   370  			index, err := jdi.jc.indexOf(jdi.data)
   371  			if err != nil {
   372  				return err
   373  			}
   374  
   375  			jdi.recalculate(jdi.jc.jobs[:index])
   376  			return nil
   377  		}(); err != nil {
   378  			return 0, err
   379  		}
   380  
   381  		jdi.ditIndex = -1
   382  	}
   383  
   384  	batchSize := int64(0)
   385  	for _, count := range jdi.yielding {
   386  		batchSize += count
   387  	}
   388  
   389  	return batchSize, nil
   390  }
   391  
   392  func (jdi *jobDatumIterator) NextDatum() ([]*common.Input, int64) {
   393  	jdi.ditIndex++
   394  	for jdi.ditIndex < jdi.dit.Len() {
   395  		inputs := jdi.dit.DatumN(jdi.ditIndex)
   396  		hash := jdi.jc.hasher.Hash(inputs)
   397  		if count, ok := jdi.yielding[hash]; ok {
   398  			if count == 1 {
   399  				delete(jdi.yielding, hash)
   400  			} else {
   401  				jdi.yielding[hash]--
   402  			}
   403  			jdi.yielded[hash]++
   404  			return inputs, int64(jdi.ditIndex)
   405  		}
   406  		jdi.ditIndex++
   407  	}
   408  
   409  	return nil, 0
   410  }
   411  
   412  func (jdi *jobDatumIterator) Reset() {
   413  	jdi.ditIndex = -1
   414  	for hash, count := range jdi.yielded {
   415  		delete(jdi.yielded, hash)
   416  		jdi.yielding[hash] += count
   417  	}
   418  }
   419  
   420  func (jdi *jobDatumIterator) MaxLen() int64 {
   421  	return int64(jdi.dit.Len())
   422  }
   423  
   424  func (jdi *jobDatumIterator) DatumSet() DatumSet {
   425  	return jdi.allDatums
   426  }
   427  
   428  func (jdi *jobDatumIterator) AdditiveOnly() bool {
   429  	return jdi.additiveOnly
   430  }