github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/datas/pull/puller.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pull
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"log"
    23  	"math"
    24  	"os"
    25  	"path/filepath"
    26  	"strings"
    27  	"sync"
    28  	"sync/atomic"
    29  	"time"
    30  
    31  	"golang.org/x/sync/errgroup"
    32  
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/dconfig"
    34  	"github.com/dolthub/dolt/go/store/chunks"
    35  	"github.com/dolthub/dolt/go/store/hash"
    36  	"github.com/dolthub/dolt/go/store/nbs"
    37  )
    38  
    39  // ErrDBUpToDate is the error code returned from NewPuller in the event that there is no work to do.
    40  var ErrDBUpToDate = errors.New("the database does not need to be pulled as it's already up to date")
    41  
    42  // ErrIncompatibleSourceChunkStore is the error code returned from NewPuller in
    43  // the event that the source ChunkStore does not implement `NBSCompressedChunkStore`.
    44  var ErrIncompatibleSourceChunkStore = errors.New("the chunk store of the source database does not implement NBSCompressedChunkStore.")
    45  
    46  type WalkAddrs func(chunks.Chunk, func(hash.Hash, bool) error) error
    47  
    48  // Puller is used to sync data between to Databases
    49  type Puller struct {
    50  	waf WalkAddrs
    51  
    52  	srcChunkStore nbs.NBSCompressedChunkStore
    53  	sinkDBCS      chunks.ChunkStore
    54  	hashes        hash.HashSet
    55  
    56  	wr *PullTableFileWriter
    57  	rd nbs.ChunkFetcher
    58  
    59  	pushLog *log.Logger
    60  
    61  	statsCh chan Stats
    62  	stats   *stats
    63  }
    64  
    65  // NewPuller creates a new Puller instance to do the syncing.  If a nil puller is returned without error that means
    66  // that there is nothing to pull and the sinkDB is already up to date.
    67  func NewPuller(
    68  	ctx context.Context,
    69  	tempDir string,
    70  	chunksPerTF int,
    71  	srcCS, sinkCS chunks.ChunkStore,
    72  	walkAddrs WalkAddrs,
    73  	hashes []hash.Hash,
    74  	statsCh chan Stats,
    75  ) (*Puller, error) {
    76  	// Sanity Check
    77  	hs := hash.NewHashSet(hashes...)
    78  	missing, err := srcCS.HasMany(ctx, hs)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  	if missing.Size() != 0 {
    83  		return nil, errors.New("not found")
    84  	}
    85  
    86  	hs = hash.NewHashSet(hashes...)
    87  	missing, err = sinkCS.HasMany(ctx, hs)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	if missing.Size() == 0 {
    92  		return nil, ErrDBUpToDate
    93  	}
    94  
    95  	if srcCS.Version() != sinkCS.Version() {
    96  		return nil, fmt.Errorf("cannot pull from src to sink; src version is %v and sink version is %v", srcCS.Version(), sinkCS.Version())
    97  	}
    98  
    99  	srcChunkStore, ok := srcCS.(nbs.NBSCompressedChunkStore)
   100  	if !ok {
   101  		return nil, ErrIncompatibleSourceChunkStore
   102  	}
   103  
   104  	wr := NewPullTableFileWriter(ctx, PullTableFileWriterConfig{
   105  		ConcurrentUploads:    2,
   106  		ChunksPerFile:        chunksPerTF,
   107  		MaximumBufferedFiles: 8,
   108  		TempDir:              tempDir,
   109  		DestStore:            sinkCS.(chunks.TableFileStore),
   110  	})
   111  
   112  	rd := GetChunkFetcher(ctx, srcChunkStore)
   113  
   114  	var pushLogger *log.Logger
   115  	if dbg, ok := os.LookupEnv(dconfig.EnvPushLog); ok && strings.ToLower(dbg) == "true" {
   116  		logFilePath := filepath.Join(tempDir, "push.log")
   117  		f, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm)
   118  
   119  		if err == nil {
   120  			pushLogger = log.New(f, "", log.Lmicroseconds)
   121  		}
   122  	}
   123  
   124  	p := &Puller{
   125  		waf:           walkAddrs,
   126  		srcChunkStore: srcChunkStore,
   127  		sinkDBCS:      sinkCS,
   128  		hashes:        hash.NewHashSet(hashes...),
   129  		wr:            wr,
   130  		rd:            rd,
   131  		pushLog:       pushLogger,
   132  		statsCh:       statsCh,
   133  		stats: &stats{
   134  			wrStatsGetter: wr.GetStats,
   135  		},
   136  	}
   137  
   138  	if lcs, ok := sinkCS.(chunks.LoggingChunkStore); ok {
   139  		lcs.SetLogger(p)
   140  	}
   141  
   142  	return p, nil
   143  }
   144  
   145  func (p *Puller) Logf(fmt string, args ...interface{}) {
   146  	if p.pushLog != nil {
   147  		p.pushLog.Printf(fmt, args...)
   148  	}
   149  }
   150  
   151  type readable interface {
   152  	Reader() (io.ReadCloser, error)
   153  	Remove() error
   154  }
   155  
   156  type tempTblFile struct {
   157  	id          string
   158  	read        readable
   159  	numChunks   int
   160  	chunksLen   uint64
   161  	contentLen  uint64
   162  	contentHash []byte
   163  }
   164  
   165  type countingReader struct {
   166  	io.ReadCloser
   167  	cnt *uint64
   168  }
   169  
   170  func (c countingReader) Read(p []byte) (int, error) {
   171  	n, err := c.ReadCloser.Read(p)
   172  	atomic.AddUint64(c.cnt, uint64(n))
   173  	return n, err
   174  }
   175  
   176  func emitStats(s *stats, ch chan Stats) (cancel func()) {
   177  	done := make(chan struct{})
   178  	var wg sync.WaitGroup
   179  	wg.Add(2)
   180  	cancel = func() {
   181  		close(done)
   182  		wg.Wait()
   183  	}
   184  
   185  	go func() {
   186  		defer wg.Done()
   187  		sampleduration := 100 * time.Millisecond
   188  		samplesinsec := uint64((1 * time.Second) / sampleduration)
   189  		weight := 0.1
   190  		ticker := time.NewTicker(sampleduration)
   191  		defer ticker.Stop()
   192  		var lastSendBytes, lastFetchedBytes uint64
   193  		for {
   194  			select {
   195  			case <-ticker.C:
   196  				wrStats := s.wrStatsGetter()
   197  				newSendBytes := wrStats.FinishedSendBytes
   198  				newFetchedBytes := atomic.LoadUint64(&s.fetchedSourceBytes)
   199  				sendBytesDiff := newSendBytes - lastSendBytes
   200  				fetchedBytesDiff := newFetchedBytes - lastFetchedBytes
   201  
   202  				newSendBPS := float64(sendBytesDiff * samplesinsec)
   203  				newFetchedBPS := float64(fetchedBytesDiff * samplesinsec)
   204  
   205  				curSendBPS := math.Float64frombits(atomic.LoadUint64(&s.sendBytesPerSec))
   206  				curFetchedBPS := math.Float64frombits(atomic.LoadUint64(&s.fetchedSourceBytesPerSec))
   207  
   208  				smoothedSendBPS := newSendBPS
   209  				if curSendBPS != 0 {
   210  					smoothedSendBPS = curSendBPS + weight*(newSendBPS-curSendBPS)
   211  				}
   212  
   213  				smoothedFetchBPS := newFetchedBPS
   214  				if curFetchedBPS != 0 {
   215  					smoothedFetchBPS = curFetchedBPS + weight*(newFetchedBPS-curFetchedBPS)
   216  				}
   217  
   218  				if smoothedSendBPS < 1 {
   219  					smoothedSendBPS = 0
   220  				}
   221  				if smoothedFetchBPS < 1 {
   222  					smoothedFetchBPS = 0
   223  				}
   224  
   225  				atomic.StoreUint64(&s.sendBytesPerSec, math.Float64bits(smoothedSendBPS))
   226  				atomic.StoreUint64(&s.fetchedSourceBytesPerSec, math.Float64bits(smoothedFetchBPS))
   227  
   228  				lastSendBytes = newSendBytes
   229  				lastFetchedBytes = newFetchedBytes
   230  			case <-done:
   231  				return
   232  			}
   233  		}
   234  	}()
   235  
   236  	go func() {
   237  		defer wg.Done()
   238  		updateduration := 1 * time.Second
   239  		ticker := time.NewTicker(updateduration)
   240  		defer ticker.Stop()
   241  		for {
   242  			select {
   243  			case <-ticker.C:
   244  				ch <- s.read()
   245  			case <-done:
   246  				ch <- s.read()
   247  				return
   248  			}
   249  		}
   250  	}()
   251  
   252  	return cancel
   253  }
   254  
   255  type stats struct {
   256  	sendBytesPerSec uint64
   257  
   258  	totalSourceChunks        uint64
   259  	fetchedSourceChunks      uint64
   260  	fetchedSourceBytes       uint64
   261  	fetchedSourceBytesPerSec uint64
   262  
   263  	sendBytesPerSecF          float64
   264  	fetchedSourceBytesPerSecF float64
   265  
   266  	wrStatsGetter func() PullTableFileWriterStats
   267  }
   268  
   269  type Stats struct {
   270  	FinishedSendBytes uint64
   271  	BufferedSendBytes uint64
   272  	SendBytesPerSec   float64
   273  
   274  	TotalSourceChunks        uint64
   275  	FetchedSourceChunks      uint64
   276  	FetchedSourceBytes       uint64
   277  	FetchedSourceBytesPerSec float64
   278  }
   279  
   280  func (s *stats) read() Stats {
   281  	wrStats := s.wrStatsGetter()
   282  
   283  	var ret Stats
   284  	ret.FinishedSendBytes = wrStats.FinishedSendBytes
   285  	ret.BufferedSendBytes = wrStats.BufferedSendBytes
   286  	ret.SendBytesPerSec = math.Float64frombits(atomic.LoadUint64(&s.sendBytesPerSec))
   287  	ret.TotalSourceChunks = atomic.LoadUint64(&s.totalSourceChunks)
   288  	ret.FetchedSourceChunks = atomic.LoadUint64(&s.fetchedSourceChunks)
   289  	ret.FetchedSourceBytes = atomic.LoadUint64(&s.fetchedSourceBytes)
   290  	ret.FetchedSourceBytesPerSec = math.Float64frombits(atomic.LoadUint64(&s.fetchedSourceBytesPerSec))
   291  	return ret
   292  }
   293  
   294  // Pull executes the sync operation
   295  func (p *Puller) Pull(ctx context.Context) error {
   296  	if p.statsCh != nil {
   297  		c := emitStats(p.stats, p.statsCh)
   298  		defer c()
   299  	}
   300  
   301  	eg, ctx := errgroup.WithContext(ctx)
   302  
   303  	const batchSize = 64 * 1024
   304  	tracker := NewPullChunkTracker(ctx, p.hashes, TrackerConfig{
   305  		BatchSize: batchSize,
   306  		HasManyer: p.sinkDBCS,
   307  	})
   308  
   309  	// One thread calls ChunkFetcher.Get on each batch.
   310  	eg.Go(func() error {
   311  		for {
   312  			toFetch, hasMore, err := tracker.GetChunksToFetch()
   313  			if err != nil {
   314  				return err
   315  			}
   316  			if !hasMore {
   317  				return p.rd.CloseSend()
   318  			}
   319  
   320  			atomic.AddUint64(&p.stats.totalSourceChunks, uint64(len(toFetch)))
   321  			err = p.rd.Get(ctx, toFetch)
   322  			if err != nil {
   323  				return err
   324  			}
   325  		}
   326  	})
   327  
   328  	// One thread reads the received chunks, walks their addresses and writes them to the table file.
   329  	eg.Go(func() error {
   330  		for {
   331  			cChk, err := p.rd.Recv(ctx)
   332  			if err == io.EOF {
   333  				// This means the requesting thread
   334  				// successfully saw all chunk addresses and
   335  				// called CloseSend and that all requested
   336  				// chunks were successfully delivered to this
   337  				// thread. Calling wr.Close() here will block
   338  				// on uploading any table files and will write
   339  				// the new table files to the destination's
   340  				// manifest.
   341  				return p.wr.Close()
   342  			}
   343  			if err != nil {
   344  				return err
   345  			}
   346  			if len(cChk.FullCompressedChunk) == 0 {
   347  				return errors.New("failed to get all chunks.")
   348  			}
   349  
   350  			atomic.AddUint64(&p.stats.fetchedSourceBytes, uint64(len(cChk.FullCompressedChunk)))
   351  			atomic.AddUint64(&p.stats.fetchedSourceChunks, uint64(1))
   352  
   353  			chnk, err := cChk.ToChunk()
   354  			if err != nil {
   355  				return err
   356  			}
   357  			err = p.waf(chnk, func(h hash.Hash, _ bool) error {
   358  				tracker.Seen(h)
   359  				return nil
   360  			})
   361  			if err != nil {
   362  				return err
   363  			}
   364  			tracker.TickProcessed()
   365  
   366  			err = p.wr.AddCompressedChunk(ctx, cChk)
   367  			if err != nil {
   368  				return err
   369  			}
   370  		}
   371  	})
   372  
   373  	// Always close the reader outside of the errgroup threads above.
   374  	// Closing the reader will cause Get() to start returning errors, and
   375  	// we don't need to deliver that error to the Get thread. Both threads
   376  	// should exit and return any errors they encounter, after which the
   377  	// errgroup will report the error.
   378  	wErr := eg.Wait()
   379  	rErr := p.rd.Close()
   380  	return errors.Join(wErr, rErr)
   381  }