github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/datas/pull/pull_chunk_tracker.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pull
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"sync"
    21  
    22  	"github.com/dolthub/dolt/go/store/hash"
    23  )
    24  
    25  type HasManyer interface {
    26  	HasMany(context.Context, hash.HashSet) (hash.HashSet, error)
    27  }
    28  
    29  type TrackerConfig struct {
    30  	BatchSize int
    31  
    32  	HasManyer HasManyer
    33  }
    34  
    35  const hasManyThreadCount = 3
    36  
    37  // A PullChunkTracker keeps track of seen chunk addresses and returns every
    38  // seen chunk address which is not already in the destination database exactly
    39  // once. A Puller instantiantes one of these with the initial set of addresses
    40  // to pull, and repeatedly calls |GetChunksToFetch|. It passes in all
    41  // references it finds in the fetched chunks to |Seen|, and continues to call
    42  // |GetChunksToFetch| and deliver new addresses to |Seen| until
    43  // |GetChunksToFetch| returns |false| from its |more| return boolean.
    44  //
    45  // PullChunkTracker is able to call |HasMany| on the destination database in
    46  // parallel with other work the Puller does and abstracts out the logic for
    47  // keeping track of seen, unchecked and to pull hcunk addresses.
    48  type PullChunkTracker struct {
    49  	ctx  context.Context
    50  	seen hash.HashSet
    51  	cfg  TrackerConfig
    52  	wg   sync.WaitGroup
    53  
    54  	uncheckedCh chan hash.Hash
    55  	processedCh chan struct{}
    56  	reqCh       chan *trackerGetAbsentReq
    57  }
    58  
    59  func NewPullChunkTracker(ctx context.Context, initial hash.HashSet, cfg TrackerConfig) *PullChunkTracker {
    60  	ret := &PullChunkTracker{
    61  		ctx:         ctx,
    62  		seen:        make(hash.HashSet),
    63  		cfg:         cfg,
    64  		uncheckedCh: make(chan hash.Hash),
    65  		processedCh: make(chan struct{}),
    66  		reqCh:       make(chan *trackerGetAbsentReq),
    67  	}
    68  	ret.seen.InsertAll(initial)
    69  	ret.wg.Add(1)
    70  	go func() {
    71  		defer ret.wg.Done()
    72  		ret.reqRespThread(initial)
    73  	}()
    74  	return ret
    75  }
    76  
    77  func (t *PullChunkTracker) Seen(h hash.Hash) {
    78  	if !t.seen.Has(h) {
    79  		t.seen.Insert(h)
    80  		t.addUnchecked(h)
    81  	}
    82  }
    83  
    84  // Call this for every returned hash that has been successfully processed.
    85  //
    86  // GetChunksToFetch() requires a matching |TickProcessed| call for each
    87  // returned Hash before it will return |hasMany == false|.
    88  func (t *PullChunkTracker) TickProcessed() {
    89  	select {
    90  	case t.processedCh <- struct{}{}:
    91  	case <-t.ctx.Done():
    92  	}
    93  }
    94  
    95  func (t *PullChunkTracker) Close() {
    96  	close(t.uncheckedCh)
    97  	t.wg.Wait()
    98  }
    99  
   100  func (t *PullChunkTracker) addUnchecked(h hash.Hash) {
   101  	select {
   102  	case t.uncheckedCh <- h:
   103  	case <-t.ctx.Done():
   104  	}
   105  }
   106  
   107  func (t *PullChunkTracker) GetChunksToFetch() (hash.HashSet, bool, error) {
   108  	var req trackerGetAbsentReq
   109  	req.ready = make(chan struct{})
   110  
   111  	select {
   112  	case t.reqCh <- &req:
   113  	case <-t.ctx.Done():
   114  		return nil, false, context.Cause(t.ctx)
   115  	}
   116  
   117  	select {
   118  	case <-req.ready:
   119  	case <-t.ctx.Done():
   120  		return nil, false, context.Cause(t.ctx)
   121  	}
   122  
   123  	return req.hs, req.ok, req.err
   124  }
   125  
   126  // The main logic of the PullChunkTracker, receives requests from other threads
   127  // and responds to them.
   128  func (t *PullChunkTracker) reqRespThread(initial hash.HashSet) {
   129  	doneCh := make(chan struct{})
   130  	hasManyReqCh := make(chan trackerHasManyReq)
   131  	hasManyRespCh := make(chan trackerHasManyResp)
   132  
   133  	var wg sync.WaitGroup
   134  	wg.Add(hasManyThreadCount)
   135  
   136  	for i := 0; i < hasManyThreadCount; i++ {
   137  		go func() {
   138  			defer wg.Done()
   139  			hasManyThread(t.ctx, t.cfg.HasManyer, hasManyReqCh, hasManyRespCh, doneCh)
   140  		}()
   141  	}
   142  
   143  	defer func() {
   144  		close(doneCh)
   145  		wg.Wait()
   146  	}()
   147  
   148  	unchecked := make([]hash.HashSet, 0)
   149  	absent := make([]hash.HashSet, 0)
   150  
   151  	var err error
   152  	outstanding := 0
   153  	unprocessed := 0
   154  
   155  	if len(initial) > 0 {
   156  		unchecked = append(unchecked, initial)
   157  		outstanding += 1
   158  	}
   159  
   160  	for {
   161  		var thisReqCh = t.reqCh
   162  		if len(absent) == 0 && (outstanding != 0 || unprocessed != 0) {
   163  			// If we are waiting for a HasMany response and we don't currently have any
   164  			// absent addresses to return, block any absent requests.
   165  			thisReqCh = nil
   166  		}
   167  
   168  		var thisHasManyReqCh chan trackerHasManyReq
   169  		var hasManyReq trackerHasManyReq
   170  		if len(unchecked) > 0 {
   171  			hasManyReq.hs = unchecked[0]
   172  			thisHasManyReqCh = hasManyReqCh
   173  		}
   174  
   175  		select {
   176  		case h, ok := <-t.uncheckedCh:
   177  			if !ok {
   178  				return
   179  			}
   180  			if len(unchecked) == 0 || len(unchecked[len(unchecked)-1]) >= t.cfg.BatchSize {
   181  				outstanding += 1
   182  				unchecked = append(unchecked, make(hash.HashSet))
   183  			}
   184  			unchecked[len(unchecked)-1].Insert(h)
   185  		case resp := <-hasManyRespCh:
   186  			outstanding -= 1
   187  			if resp.err != nil {
   188  				err = errors.Join(err, resp.err)
   189  			} else if len(resp.hs) > 0 {
   190  				absent = append(absent, resp.hs)
   191  			}
   192  		case thisHasManyReqCh <- hasManyReq:
   193  			copy(unchecked[:], unchecked[1:])
   194  			if len(unchecked) > 1 {
   195  				unchecked[len(unchecked)-1] = nil
   196  			}
   197  			unchecked = unchecked[:len(unchecked)-1]
   198  		case <-t.processedCh:
   199  			unprocessed -= 1
   200  		case req := <-thisReqCh:
   201  			if err != nil {
   202  				req.err = err
   203  				close(req.ready)
   204  				err = nil
   205  			} else if len(absent) == 0 {
   206  				req.ok = false
   207  				close(req.ready)
   208  			} else {
   209  				req.ok = true
   210  				req.hs = absent[0]
   211  				var i int
   212  				for i = 1; i < len(absent); i++ {
   213  					l := len(absent[i])
   214  					if len(req.hs)+l < t.cfg.BatchSize {
   215  						req.hs.InsertAll(absent[i])
   216  					} else {
   217  						for h := range absent[i] {
   218  							if len(req.hs) >= t.cfg.BatchSize {
   219  								break
   220  							}
   221  							req.hs.Insert(h)
   222  							absent[i].Remove(h)
   223  						}
   224  						break
   225  					}
   226  				}
   227  				copy(absent[:], absent[i:])
   228  				for j := len(absent) - i; j < len(absent); j++ {
   229  					absent[j] = nil
   230  				}
   231  				absent = absent[:len(absent)-i]
   232  				unprocessed += len(req.hs)
   233  				close(req.ready)
   234  			}
   235  		case <-t.ctx.Done():
   236  			return
   237  		}
   238  	}
   239  }
   240  
   241  // Run by a PullChunkTracker, calls HasMany on a batch of addresses and delivers the results.
   242  func hasManyThread(ctx context.Context, hasManyer HasManyer, reqCh <-chan trackerHasManyReq, respCh chan<- trackerHasManyResp, doneCh <-chan struct{}) {
   243  	for {
   244  		select {
   245  		case req := <-reqCh:
   246  			hs, err := hasManyer.HasMany(ctx, req.hs)
   247  			if err != nil {
   248  				select {
   249  				case respCh <- trackerHasManyResp{err: err}:
   250  				case <-ctx.Done():
   251  					return
   252  				case <-doneCh:
   253  					return
   254  				}
   255  			} else {
   256  				select {
   257  				case respCh <- trackerHasManyResp{hs: hs}:
   258  				case <-ctx.Done():
   259  					return
   260  				case <-doneCh:
   261  					return
   262  				}
   263  			}
   264  		case <-doneCh:
   265  			return
   266  		case <-ctx.Done():
   267  			return
   268  		}
   269  	}
   270  }
   271  
   272  // Sent by the tracker thread to a HasMany thread, includes a batch of
   273  // addresses to HasMany. The response comes back to the tracker thread on a
   274  // separate channel as a |trackerHasManyResp|.
   275  type trackerHasManyReq struct {
   276  	hs hash.HashSet
   277  }
   278  
   279  // Sent by the HasMany thread back to the tracker thread.
   280  // If HasMany returned an error, it will be returned here.
   281  type trackerHasManyResp struct {
   282  	hs  hash.HashSet
   283  	err error
   284  }
   285  
   286  // Sent by a client calling |GetChunksToFetch| to the tracker thread. The
   287  // tracker thread will return a batch of chunk addresses that need to be
   288  // fetched from source and added to destination.
   289  //
   290  // This will block until HasMany requests are completed.
   291  //
   292  // If |ok| is |false|, then the Tracker is closing because every absent address
   293  // has been delivered.
   294  type trackerGetAbsentReq struct {
   295  	hs    hash.HashSet
   296  	err   error
   297  	ok    bool
   298  	ready chan struct{}
   299  }