github.com/ethersphere/bee/v2@v2.2.0/pkg/pullsync/pullsync.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package pullsync provides the pullsync protocol
     6  // implementation.
     7  package pullsync
     8  
     9  import (
    10  	"context"
    11  	"encoding/hex"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"math"
    16  	"sync/atomic"
    17  	"time"
    18  
    19  	"github.com/ethersphere/bee/v2/pkg/bitvector"
    20  	"github.com/ethersphere/bee/v2/pkg/cac"
    21  	"github.com/ethersphere/bee/v2/pkg/log"
    22  	"github.com/ethersphere/bee/v2/pkg/p2p"
    23  	"github.com/ethersphere/bee/v2/pkg/p2p/protobuf"
    24  	"github.com/ethersphere/bee/v2/pkg/postage"
    25  	"github.com/ethersphere/bee/v2/pkg/pullsync/pb"
    26  	"github.com/ethersphere/bee/v2/pkg/ratelimit"
    27  	"github.com/ethersphere/bee/v2/pkg/soc"
    28  	"github.com/ethersphere/bee/v2/pkg/storage"
    29  	"github.com/ethersphere/bee/v2/pkg/storer"
    30  	"github.com/ethersphere/bee/v2/pkg/swarm"
    31  	"resenje.org/singleflight"
    32  )
    33  
    34  // loggerName is the tree path name of the logger for this package.
    35  const loggerName = "pullsync"
    36  
    37  const (
    38  	protocolName     = "pullsync"
    39  	protocolVersion  = "1.4.0"
    40  	streamName       = "pullsync"
    41  	cursorStreamName = "cursors"
    42  )
    43  
    44  var (
    45  	ErrUnsolicitedChunk = errors.New("peer sent unsolicited chunk")
    46  )
    47  
    48  const (
    49  	MaxCursor                       = math.MaxUint64
    50  	DefaultMaxPage           uint64 = 250
    51  	pageTimeout                     = time.Second
    52  	makeOfferTimeout                = 15 * time.Minute
    53  	handleMaxChunksPerSecond        = 250
    54  	handleRequestsLimitRate         = time.Second / handleMaxChunksPerSecond // handle max 100 chunks per second per peer
    55  )
    56  
    57  // Interface is the PullSync interface.
    58  type Interface interface {
    59  	// Sync syncs a batch of chunks starting at a start BinID.
    60  	// It returns the BinID of highest chunk that was synced from the given
    61  	// batch and the total number of chunks the downstream peer has sent.
    62  	Sync(ctx context.Context, peer swarm.Address, bin uint8, start uint64) (topmost uint64, count int, err error)
    63  	// GetCursors retrieves all cursors from a downstream peer.
    64  	GetCursors(ctx context.Context, peer swarm.Address) ([]uint64, uint64, error)
    65  }
    66  
    67  type Syncer struct {
    68  	streamer       p2p.Streamer
    69  	metrics        metrics
    70  	logger         log.Logger
    71  	store          storer.Reserve
    72  	quit           chan struct{}
    73  	unwrap         func(swarm.Chunk)
    74  	validStamp     postage.ValidStampFn
    75  	intervalsSF    singleflight.Group[string, *collectAddrsResult]
    76  	syncInProgress atomic.Int32
    77  
    78  	maxPage uint64
    79  
    80  	limiter *ratelimit.Limiter
    81  
    82  	Interface
    83  	io.Closer
    84  }
    85  
    86  func New(
    87  	streamer p2p.Streamer,
    88  	store storer.Reserve,
    89  	unwrap func(swarm.Chunk),
    90  	validStamp postage.ValidStampFn,
    91  	logger log.Logger,
    92  	maxPage uint64,
    93  ) *Syncer {
    94  
    95  	return &Syncer{
    96  		streamer:   streamer,
    97  		store:      store,
    98  		metrics:    newMetrics(),
    99  		unwrap:     unwrap,
   100  		validStamp: validStamp,
   101  		logger:     logger.WithName(loggerName).Register(),
   102  		quit:       make(chan struct{}),
   103  		maxPage:    maxPage,
   104  		limiter:    ratelimit.New(handleRequestsLimitRate, int(maxPage)),
   105  	}
   106  }
   107  
   108  func (s *Syncer) Protocol() p2p.ProtocolSpec {
   109  	return p2p.ProtocolSpec{
   110  		Name:    protocolName,
   111  		Version: protocolVersion,
   112  		StreamSpecs: []p2p.StreamSpec{
   113  			{
   114  				Name:    streamName,
   115  				Handler: s.handler,
   116  			},
   117  			{
   118  				Name:    cursorStreamName,
   119  				Handler: s.cursorHandler,
   120  			},
   121  		},
   122  		DisconnectIn:  s.disconnect,
   123  		DisconnectOut: s.disconnect,
   124  	}
   125  }
   126  
   127  // handler handles an incoming request to sync an interval
   128  func (s *Syncer) handler(streamCtx context.Context, p p2p.Peer, stream p2p.Stream) (err error) {
   129  
   130  	select {
   131  	case <-s.quit:
   132  		return nil
   133  	default:
   134  		s.syncInProgress.Add(1)
   135  		defer s.syncInProgress.Add(-1)
   136  	}
   137  
   138  	r := protobuf.NewReader(stream)
   139  	defer func() {
   140  		if err != nil {
   141  			_ = stream.Reset()
   142  		} else {
   143  			_ = stream.FullClose()
   144  		}
   145  	}()
   146  
   147  	ctx, cancel := context.WithCancel(streamCtx)
   148  	defer cancel()
   149  
   150  	go func() {
   151  		select {
   152  		case <-s.quit:
   153  			cancel()
   154  		case <-ctx.Done():
   155  			return
   156  		}
   157  	}()
   158  
   159  	var rn pb.Get
   160  	if err := r.ReadMsgWithContext(ctx, &rn); err != nil {
   161  		return fmt.Errorf("read get range: %w", err)
   162  	}
   163  
   164  	// recreate the reader to allow the first one to be garbage collected
   165  	// before the makeOffer function call, to reduce the total memory allocated
   166  	// while makeOffer is executing (waiting for the new chunks)
   167  	w, r := protobuf.NewWriterAndReader(stream)
   168  
   169  	// make an offer to the upstream peer in return for the requested range
   170  	offer, err := s.makeOffer(ctx, rn)
   171  	if err != nil {
   172  		return fmt.Errorf("make offer: %w", err)
   173  	}
   174  
   175  	if err := w.WriteMsgWithContext(ctx, offer); err != nil {
   176  		return fmt.Errorf("write offer: %w", err)
   177  	}
   178  
   179  	// we don't have any hashes to offer in this range (the
   180  	// interval is empty). nothing more to do
   181  	if len(offer.Chunks) == 0 {
   182  		return nil
   183  	}
   184  
   185  	s.metrics.SentOffered.Add(float64(len(offer.Chunks)))
   186  
   187  	var want pb.Want
   188  	if err := r.ReadMsgWithContext(ctx, &want); err != nil {
   189  		return fmt.Errorf("read want: %w", err)
   190  	}
   191  
   192  	chs, err := s.processWant(ctx, offer, &want)
   193  	if err != nil {
   194  		return fmt.Errorf("process want: %w", err)
   195  	}
   196  
   197  	// slow down future requests
   198  	waitDur, err := s.limiter.Wait(streamCtx, p.Address.ByteString(), max(1, len(chs)))
   199  	if err != nil {
   200  		return fmt.Errorf("rate limiter: %w", err)
   201  	}
   202  	if waitDur > 0 {
   203  		s.logger.Debug("rate limited peer", "wait_duration", waitDur, "peer_address", p.Address)
   204  	}
   205  
   206  	for _, c := range chs {
   207  		var stamp []byte
   208  		if c.Stamp() != nil {
   209  			stamp, err = c.Stamp().MarshalBinary()
   210  			if err != nil {
   211  				return fmt.Errorf("serialise stamp: %w", err)
   212  			}
   213  		}
   214  
   215  		deliver := pb.Delivery{Address: c.Address().Bytes(), Data: c.Data(), Stamp: stamp}
   216  		if err := w.WriteMsgWithContext(ctx, &deliver); err != nil {
   217  			return fmt.Errorf("write delivery: %w", err)
   218  		}
   219  		s.metrics.Sent.Inc()
   220  	}
   221  
   222  	return nil
   223  }
   224  
   225  // Sync syncs a batch of chunks starting at a start BinID.
   226  // It returns the BinID of highest chunk that was synced from the given
   227  // batch and the total number of chunks the downstream peer has sent.
   228  func (s *Syncer) Sync(ctx context.Context, peer swarm.Address, bin uint8, start uint64) (uint64, int, error) {
   229  
   230  	stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
   231  	if err != nil {
   232  		return 0, 0, fmt.Errorf("new stream: %w", err)
   233  	}
   234  	defer func() {
   235  		if err != nil {
   236  			_ = stream.Reset()
   237  			s.logger.Debug("error syncing peer", "peer_address", peer, "bin", bin, "start", start, "error", err)
   238  		} else {
   239  			stream.FullClose()
   240  		}
   241  	}()
   242  
   243  	w, r := protobuf.NewWriterAndReader(stream)
   244  
   245  	rangeMsg := &pb.Get{Bin: int32(bin), Start: start}
   246  	if err = w.WriteMsgWithContext(ctx, rangeMsg); err != nil {
   247  		return 0, 0, fmt.Errorf("write get range: %w", err)
   248  	}
   249  
   250  	var offer pb.Offer
   251  	if err = r.ReadMsgWithContext(ctx, &offer); err != nil {
   252  		return 0, 0, fmt.Errorf("read offer: %w", err)
   253  	}
   254  
   255  	// empty interval (no chunks present in interval).
   256  	// return the end of the requested range as topmost.
   257  	if len(offer.Chunks) == 0 {
   258  		return offer.Topmost, 0, nil
   259  	}
   260  
   261  	topmost := offer.Topmost
   262  
   263  	var (
   264  		bvLen      = len(offer.Chunks)
   265  		wantChunks = make(map[string]struct{}, bvLen)
   266  		ctr        = 0
   267  		have       bool
   268  	)
   269  
   270  	bv, err := bitvector.New(bvLen)
   271  	if err != nil {
   272  		return 0, 0, fmt.Errorf("new bitvector: %w", err)
   273  	}
   274  
   275  	for i := 0; i < len(offer.Chunks); i++ {
   276  
   277  		addr := offer.Chunks[i].Address
   278  		batchID := offer.Chunks[i].BatchID
   279  		stampHash := offer.Chunks[i].StampHash
   280  		if len(addr) != swarm.HashSize {
   281  			return 0, 0, fmt.Errorf("inconsistent hash length")
   282  		}
   283  
   284  		a := swarm.NewAddress(addr)
   285  		if a.Equal(swarm.ZeroAddress) {
   286  			// i'd like to have this around to see we don't see any of these in the logs
   287  			s.logger.Debug("syncer got a zero address hash on offer", "peer_address", peer)
   288  			continue
   289  		}
   290  		s.metrics.Offered.Inc()
   291  		if s.store.IsWithinStorageRadius(a) {
   292  			have, err = s.store.ReserveHas(a, batchID, stampHash)
   293  			if err != nil {
   294  				s.logger.Debug("storage has", "error", err)
   295  				return 0, 0, err
   296  			}
   297  
   298  			if !have {
   299  				wantChunks[a.ByteString()+string(batchID)+string(stampHash)] = struct{}{}
   300  				ctr++
   301  				s.metrics.Wanted.Inc()
   302  				bv.Set(i)
   303  			}
   304  		}
   305  	}
   306  
   307  	wantMsg := &pb.Want{BitVector: bv.Bytes()}
   308  	if err = w.WriteMsgWithContext(ctx, wantMsg); err != nil {
   309  		return 0, 0, fmt.Errorf("write want: %w", err)
   310  	}
   311  
   312  	chunksToPut := make([]swarm.Chunk, 0, ctr)
   313  
   314  	var chunkErr error
   315  	for ; ctr > 0; ctr-- {
   316  		var delivery pb.Delivery
   317  		if err = r.ReadMsgWithContext(ctx, &delivery); err != nil {
   318  			return 0, 0, errors.Join(chunkErr, fmt.Errorf("read delivery: %w", err))
   319  		}
   320  
   321  		addr := swarm.NewAddress(delivery.Address)
   322  		if addr.Equal(swarm.ZeroAddress) {
   323  			s.logger.Debug("received zero address chunk", "peer_address", peer)
   324  			s.metrics.ReceivedZeroAddress.Inc()
   325  			continue
   326  		}
   327  
   328  		newChunk := swarm.NewChunk(addr, delivery.Data)
   329  
   330  		stamp := new(postage.Stamp)
   331  		if err = stamp.UnmarshalBinary(delivery.Stamp); err != nil {
   332  			chunkErr = errors.Join(chunkErr, err)
   333  			continue
   334  		}
   335  		stampHash, err := stamp.Hash()
   336  		if err != nil {
   337  			chunkErr = errors.Join(chunkErr, err)
   338  			continue
   339  		}
   340  
   341  		wantChunkID := addr.ByteString() + string(stamp.BatchID()) + string(stampHash)
   342  		if _, ok := wantChunks[wantChunkID]; !ok {
   343  			s.logger.Debug("want chunks", "error", ErrUnsolicitedChunk, "peer_address", peer, "chunk_address", addr)
   344  			chunkErr = errors.Join(chunkErr, ErrUnsolicitedChunk)
   345  			continue
   346  		}
   347  
   348  		delete(wantChunks, wantChunkID)
   349  
   350  		chunk, err := s.validStamp(newChunk.WithStamp(stamp))
   351  		if err != nil {
   352  			s.logger.Debug("unverified stamp", "error", err, "peer_address", peer, "chunk_address", newChunk)
   353  			chunkErr = errors.Join(chunkErr, err)
   354  			continue
   355  		}
   356  
   357  		if cac.Valid(chunk) {
   358  			go s.unwrap(chunk)
   359  		} else if !soc.Valid(chunk) {
   360  			s.logger.Debug("invalid cac/soc chunk", "error", swarm.ErrInvalidChunk, "peer_address", peer, "chunk", chunk)
   361  			chunkErr = errors.Join(chunkErr, swarm.ErrInvalidChunk)
   362  			s.metrics.ReceivedInvalidChunk.Inc()
   363  			continue
   364  		}
   365  		chunksToPut = append(chunksToPut, chunk)
   366  	}
   367  
   368  	chunksPut := 0
   369  	if len(chunksToPut) > 0 {
   370  
   371  		s.metrics.Delivered.Add(float64(len(chunksToPut)))
   372  		s.metrics.LastReceived.WithLabelValues(fmt.Sprintf("%d", bin)).Add(float64(len(chunksToPut)))
   373  
   374  		for _, c := range chunksToPut {
   375  			if err := s.store.ReservePutter().Put(ctx, c); err != nil {
   376  				// in case of these errors, no new items are added to the storage, so it
   377  				// is safe to continue with the next chunk
   378  				if errors.Is(err, storage.ErrOverwriteNewerChunk) {
   379  					s.logger.Debug("overwrite newer chunk", "error", err, "peer_address", peer, "chunk", c)
   380  					chunkErr = errors.Join(chunkErr, err)
   381  					continue
   382  				}
   383  				return 0, 0, errors.Join(chunkErr, err)
   384  			}
   385  			chunksPut++
   386  		}
   387  	}
   388  
   389  	return topmost, chunksPut, chunkErr
   390  }
   391  
   392  // makeOffer tries to assemble an offer for a given requested interval.
   393  func (s *Syncer) makeOffer(ctx context.Context, rn pb.Get) (*pb.Offer, error) {
   394  
   395  	ctx, cancel := context.WithTimeout(ctx, makeOfferTimeout)
   396  	defer cancel()
   397  
   398  	addrs, top, err := s.collectAddrs(ctx, uint8(rn.Bin), rn.Start)
   399  	if err != nil {
   400  		return nil, err
   401  	}
   402  
   403  	o := new(pb.Offer)
   404  	o.Topmost = top
   405  	o.Chunks = make([]*pb.Chunk, 0, len(addrs))
   406  	for _, v := range addrs {
   407  		o.Chunks = append(o.Chunks, &pb.Chunk{Address: v.Address.Bytes(), BatchID: v.BatchID, StampHash: v.StampHash})
   408  	}
   409  	return o, nil
   410  }
   411  
   412  type collectAddrsResult struct {
   413  	chs     []*storer.BinC
   414  	topmost uint64
   415  }
   416  
   417  // collectAddrs collects chunk addresses at a bin starting at some start BinID until a limit is reached.
   418  // The function waits for an unbounded amount of time for the first chunk to arrive.
   419  // After the arrival of the first chunk, the subsequent chunks have a limited amount of time to arrive,
   420  // after which the function returns the collected slice of chunks.
   421  func (s *Syncer) collectAddrs(ctx context.Context, bin uint8, start uint64) ([]*storer.BinC, uint64, error) {
   422  	loggerV2 := s.logger.V(2).Register()
   423  
   424  	v, _, err := s.intervalsSF.Do(ctx, sfKey(bin, start), func(ctx context.Context) (*collectAddrsResult, error) {
   425  		var (
   426  			chs     []*storer.BinC
   427  			topmost uint64
   428  			timer   *time.Timer
   429  			timerC  <-chan time.Time
   430  		)
   431  		chC, unsub, errC := s.store.SubscribeBin(ctx, bin, start)
   432  		defer func() {
   433  			unsub()
   434  			if timer != nil {
   435  				timer.Stop()
   436  			}
   437  		}()
   438  
   439  		limit := s.maxPage
   440  
   441  	LOOP:
   442  		for limit > 0 {
   443  			select {
   444  			case c, ok := <-chC:
   445  				if !ok {
   446  					break LOOP // The stream has been closed.
   447  				}
   448  
   449  				chs = append(chs, &storer.BinC{Address: c.Address, BatchID: c.BatchID, StampHash: c.StampHash})
   450  				if c.BinID > topmost {
   451  					topmost = c.BinID
   452  				}
   453  				limit--
   454  				if timer == nil {
   455  					timer = time.NewTimer(pageTimeout)
   456  				} else {
   457  					if !timer.Stop() {
   458  						<-timer.C
   459  					}
   460  					timer.Reset(pageTimeout)
   461  				}
   462  				timerC = timer.C
   463  			case err := <-errC:
   464  				return nil, err
   465  			case <-ctx.Done():
   466  				return nil, ctx.Err()
   467  			case <-timerC:
   468  				loggerV2.Debug("batch timeout timer triggered")
   469  				// return batch if new chunks are not received after some time
   470  				break LOOP
   471  			}
   472  		}
   473  
   474  		return &collectAddrsResult{chs: chs, topmost: topmost}, nil
   475  	})
   476  	if err != nil {
   477  		return nil, 0, err
   478  	}
   479  	return v.chs, v.topmost, nil
   480  }
   481  
   482  // processWant compares a received Want to a sent Offer and returns
   483  // the appropriate chunks from the local store.
   484  func (s *Syncer) processWant(ctx context.Context, o *pb.Offer, w *pb.Want) ([]swarm.Chunk, error) {
   485  	bv, err := bitvector.NewFromBytes(w.BitVector, len(o.Chunks))
   486  	if err != nil {
   487  		return nil, err
   488  	}
   489  
   490  	chunks := make([]swarm.Chunk, 0, len(o.Chunks))
   491  	for i := 0; i < len(o.Chunks); i++ {
   492  		if bv.Get(i) {
   493  			ch := o.Chunks[i]
   494  			addr := swarm.NewAddress(ch.Address)
   495  			s.metrics.SentWanted.Inc()
   496  			c, err := s.store.ReserveGet(ctx, addr, ch.BatchID, ch.StampHash)
   497  			if err != nil {
   498  				s.logger.Debug("processing want: unable to find chunk", "chunk_address", addr, "batch_id", hex.EncodeToString(ch.BatchID))
   499  				chunks = append(chunks, swarm.NewChunk(swarm.ZeroAddress, nil))
   500  				s.metrics.MissingChunks.Inc()
   501  				continue
   502  			}
   503  			chunks = append(chunks, c)
   504  		}
   505  	}
   506  	return chunks, nil
   507  }
   508  
   509  func (s *Syncer) GetCursors(ctx context.Context, peer swarm.Address) (retr []uint64, epoch uint64, err error) {
   510  	loggerV2 := s.logger.V(2).Register()
   511  
   512  	stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, cursorStreamName)
   513  	if err != nil {
   514  		return nil, 0, fmt.Errorf("new stream: %w", err)
   515  	}
   516  	loggerV2.Debug("getting cursors from peer", "peer_address", peer)
   517  	defer func() {
   518  		if err != nil {
   519  			_ = stream.Reset()
   520  			loggerV2.Debug("error getting cursors from peer", "peer_address", peer, "error", err)
   521  		} else {
   522  			stream.FullClose()
   523  		}
   524  	}()
   525  
   526  	w, r := protobuf.NewWriterAndReader(stream)
   527  	syn := &pb.Syn{}
   528  	if err = w.WriteMsgWithContext(ctx, syn); err != nil {
   529  		return nil, 0, fmt.Errorf("write syn: %w", err)
   530  	}
   531  
   532  	var ack pb.Ack
   533  	if err = r.ReadMsgWithContext(ctx, &ack); err != nil {
   534  		return nil, 0, fmt.Errorf("read ack: %w", err)
   535  	}
   536  
   537  	return ack.Cursors, ack.Epoch, nil
   538  }
   539  
   540  func (s *Syncer) cursorHandler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (err error) {
   541  	loggerV2 := s.logger.V(2).Register()
   542  
   543  	w, r := protobuf.NewWriterAndReader(stream)
   544  	loggerV2.Debug("peer wants cursors", "peer_address", p.Address)
   545  	defer func() {
   546  		if err != nil {
   547  			_ = stream.Reset()
   548  			loggerV2.Debug("error getting cursors for peer", "peer_address", p.Address, "error", err)
   549  		} else {
   550  			_ = stream.FullClose()
   551  		}
   552  	}()
   553  
   554  	var syn pb.Syn
   555  	if err := r.ReadMsgWithContext(ctx, &syn); err != nil {
   556  		return fmt.Errorf("read syn: %w", err)
   557  	}
   558  
   559  	var ack pb.Ack
   560  	ints, epoch, err := s.store.ReserveLastBinIDs()
   561  	if err != nil {
   562  		return err
   563  	}
   564  	ack.Cursors = ints
   565  	ack.Epoch = epoch
   566  	if err = w.WriteMsgWithContext(ctx, &ack); err != nil {
   567  		return fmt.Errorf("write ack: %w", err)
   568  	}
   569  
   570  	return nil
   571  }
   572  
   573  func (s *Syncer) disconnect(peer p2p.Peer) error {
   574  	s.limiter.Clear(peer.Address.ByteString())
   575  	return nil
   576  }
   577  
   578  func (s *Syncer) Close() error {
   579  	s.logger.Info("pull syncer shutting down")
   580  	close(s.quit)
   581  	cc := make(chan struct{})
   582  	go func() {
   583  		defer close(cc)
   584  		for {
   585  			if s.syncInProgress.Load() > 0 {
   586  				time.Sleep(100 * time.Millisecond)
   587  				continue
   588  			}
   589  			break
   590  		}
   591  	}()
   592  
   593  	select {
   594  	case <-cc:
   595  	case <-time.After(5 * time.Second):
   596  		s.logger.Warning("pull syncer shutting down with running goroutines")
   597  	}
   598  	return nil
   599  }
   600  
   601  // singleflight key for intervals
   602  func sfKey(bin uint8, start uint64) string {
   603  	return fmt.Sprintf("%d-%d", bin, start)
   604  }