github.com/ethersphere/bee/v2@v2.2.0/pkg/p2p/streamtest/streamtest.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package streamtest
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"io"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/ethersphere/bee/v2/pkg/p2p"
    16  	"github.com/ethersphere/bee/v2/pkg/spinlock"
    17  	"github.com/ethersphere/bee/v2/pkg/swarm"
    18  	ma "github.com/multiformats/go-multiaddr"
    19  )
    20  
    21  var (
    22  	ErrRecordsNotFound    = errors.New("records not found")
    23  	ErrStreamNotSupported = errors.New("stream not supported")
    24  	ErrStreamClosed       = errors.New("stream closed")
    25  
    26  	noopMiddleware = func(f p2p.HandlerFunc) p2p.HandlerFunc {
    27  		return f
    28  	}
    29  )
    30  
    31  type Recorder struct {
    32  	base               swarm.Address
    33  	fullNode           bool
    34  	records            map[string][]*Record
    35  	recordsMu          sync.Mutex
    36  	protocols          []p2p.ProtocolSpec
    37  	middlewares        []p2p.HandlerMiddleware
    38  	streamErr          func(swarm.Address, string, string, string) error
    39  	pingErr            func(ma.Multiaddr) (time.Duration, error)
    40  	protocolsWithPeers map[string]p2p.ProtocolSpec
    41  }
    42  
    43  func WithProtocols(protocols ...p2p.ProtocolSpec) Option {
    44  	return optionFunc(func(r *Recorder) {
    45  		r.protocols = append(r.protocols, protocols...)
    46  	})
    47  }
    48  
    49  func WithPeerProtocols(protocolsWithPeers map[string]p2p.ProtocolSpec) Option {
    50  	return optionFunc(func(r *Recorder) {
    51  		r.protocolsWithPeers = protocolsWithPeers
    52  	})
    53  }
    54  
    55  func WithMiddlewares(middlewares ...p2p.HandlerMiddleware) Option {
    56  	return optionFunc(func(r *Recorder) {
    57  		r.middlewares = append(r.middlewares, middlewares...)
    58  	})
    59  }
    60  
    61  func WithBaseAddr(a swarm.Address) Option {
    62  	return optionFunc(func(r *Recorder) {
    63  		r.base = a
    64  	})
    65  }
    66  
    67  func WithLightNode() Option {
    68  	return optionFunc(func(r *Recorder) {
    69  		r.fullNode = false
    70  	})
    71  }
    72  
    73  func WithStreamError(streamErr func(swarm.Address, string, string, string) error) Option {
    74  	return optionFunc(func(r *Recorder) {
    75  		r.streamErr = streamErr
    76  	})
    77  }
    78  
    79  func WithPingErr(pingErr func(ma.Multiaddr) (time.Duration, error)) Option {
    80  	return optionFunc(func(r *Recorder) {
    81  		r.pingErr = pingErr
    82  	})
    83  }
    84  
    85  func New(opts ...Option) *Recorder {
    86  	r := &Recorder{
    87  		records:  make(map[string][]*Record),
    88  		fullNode: true,
    89  	}
    90  
    91  	r.middlewares = append(r.middlewares, noopMiddleware)
    92  
    93  	for _, o := range opts {
    94  		o.apply(r)
    95  	}
    96  	return r
    97  }
    98  
    99  func (r *Recorder) SetProtocols(protocols ...p2p.ProtocolSpec) {
   100  	r.protocols = append(r.protocols, protocols...)
   101  }
   102  
   103  func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Headers, protocolName, protocolVersion, streamName string) (p2p.Stream, error) {
   104  	if r.streamErr != nil {
   105  		err := r.streamErr(addr, protocolName, protocolVersion, streamName)
   106  		if err != nil {
   107  			return nil, err
   108  		}
   109  	}
   110  
   111  	recordIn := newRecord()
   112  	recordOut := newRecord()
   113  	streamOut := newStream(recordIn, recordOut)
   114  	streamIn := newStream(recordOut, recordIn)
   115  
   116  	var handler p2p.HandlerFunc
   117  	var headler p2p.HeadlerFunc
   118  	peerHandlers, ok := r.protocolsWithPeers[addr.String()]
   119  	if !ok {
   120  		for _, p := range r.protocols {
   121  			if p.Name == protocolName && p.Version == protocolVersion {
   122  				peerHandlers = p
   123  			}
   124  		}
   125  	}
   126  	for _, s := range peerHandlers.StreamSpecs {
   127  		if s.Name == streamName {
   128  			handler = s.Handler
   129  			headler = s.Headler
   130  		}
   131  	}
   132  	if handler == nil {
   133  		return nil, ErrStreamNotSupported
   134  	}
   135  	for i := len(r.middlewares) - 1; i >= 0; i-- {
   136  		handler = r.middlewares[i](handler)
   137  	}
   138  	if headler != nil {
   139  		streamOut.headers = headler(h, addr)
   140  	}
   141  	record := &Record{in: recordIn, out: recordOut, done: make(chan struct{})}
   142  	go func() {
   143  		defer close(record.done)
   144  
   145  		// pass a new context to handler,
   146  		streamIn.responseHeaders = streamOut.headers
   147  		// do not cancel it with the client stream context
   148  		err := handler(context.Background(), p2p.Peer{Address: r.base, FullNode: r.fullNode}, streamIn)
   149  		if err != nil && !errors.Is(err, io.EOF) {
   150  			record.setErr(err)
   151  		}
   152  	}()
   153  
   154  	id := addr.String() + p2p.NewSwarmStreamName(protocolName, protocolVersion, streamName)
   155  
   156  	r.recordsMu.Lock()
   157  	defer r.recordsMu.Unlock()
   158  
   159  	r.records[id] = append(r.records[id], record)
   160  	return streamOut, nil
   161  }
   162  
   163  func (r *Recorder) Ping(ctx context.Context, addr ma.Multiaddr) (rtt time.Duration, err error) {
   164  	if r.pingErr != nil {
   165  		return r.pingErr(addr)
   166  	}
   167  	return rtt, err
   168  }
   169  
   170  func (r *Recorder) Records(addr swarm.Address, protocolName, protocolVersio, streamName string) ([]*Record, error) {
   171  	id := addr.String() + p2p.NewSwarmStreamName(protocolName, protocolVersio, streamName)
   172  
   173  	r.recordsMu.Lock()
   174  	defer r.recordsMu.Unlock()
   175  
   176  	records, ok := r.records[id]
   177  	if !ok {
   178  		return nil, ErrRecordsNotFound
   179  	}
   180  	// wait for all records goroutines to terminate
   181  	for _, r := range records {
   182  		<-r.done
   183  	}
   184  	return records, nil
   185  }
   186  
   187  // WaitRecords waits for some time for records to come into the recorder. If msgs is 0, the timeoutSec period is waited to verify
   188  // that _no_ messages arrive during this time period.
   189  func (r *Recorder) WaitRecords(t *testing.T, addr swarm.Address, proto, version, stream string, msgs, timeoutSec int) []*Record {
   190  	t.Helper()
   191  
   192  	var recs []*Record
   193  	err := spinlock.Wait(time.Second*time.Duration(timeoutSec), func() bool {
   194  		recs, _ = r.Records(addr, proto, version, stream)
   195  		if l := len(recs); l > msgs {
   196  			t.Fatalf("too many records. want %d got %d", msgs, l)
   197  		} else if msgs > 0 && l == msgs {
   198  			return true
   199  		}
   200  		return false
   201  		// we can be here if msgs == 0 && l == 0
   202  		// or msgs = x && l < x, both cases are fine
   203  		// and we should continue waiting
   204  	})
   205  	if err != nil && msgs > 0 {
   206  		t.Fatal("timed out while waiting for records")
   207  	}
   208  
   209  	return recs
   210  }
   211  
   212  type Record struct {
   213  	in    *record
   214  	out   *record
   215  	err   error
   216  	errMu sync.Mutex
   217  	done  chan struct{}
   218  }
   219  
   220  func (r *Record) In() []byte {
   221  	return r.in.bytes()
   222  }
   223  
   224  func (r *Record) Out() []byte {
   225  	return r.out.bytes()
   226  }
   227  
   228  func (r *Record) Err() error {
   229  	r.errMu.Lock()
   230  	defer r.errMu.Unlock()
   231  
   232  	return r.err
   233  }
   234  
   235  func (r *Record) setErr(err error) {
   236  	r.errMu.Lock()
   237  	defer r.errMu.Unlock()
   238  
   239  	r.err = err
   240  }
   241  
   242  type stream struct {
   243  	in              *record
   244  	out             *record
   245  	headers         p2p.Headers
   246  	responseHeaders p2p.Headers
   247  	closed          bool
   248  	lock            sync.Mutex
   249  }
   250  
   251  func newStream(in, out *record) *stream {
   252  	return &stream{in: in, out: out}
   253  }
   254  
   255  func (s *stream) Read(p []byte) (int, error) {
   256  	if s.Closed() {
   257  		return 0, ErrStreamClosed
   258  	}
   259  
   260  	return s.out.Read(p)
   261  }
   262  
   263  func (s *stream) Write(p []byte) (int, error) {
   264  	if s.Closed() {
   265  		return 0, ErrStreamClosed
   266  	}
   267  
   268  	return s.in.Write(p)
   269  }
   270  
   271  func (s *stream) Headers() p2p.Headers {
   272  	return s.headers
   273  }
   274  
   275  func (s *stream) ResponseHeaders() p2p.Headers {
   276  	return s.responseHeaders
   277  }
   278  
   279  func (s *stream) Close() error {
   280  	s.lock.Lock()
   281  	defer s.lock.Unlock()
   282  
   283  	if s.closed {
   284  		return ErrStreamClosed
   285  	}
   286  
   287  	s.closed = true
   288  	s.in.close()
   289  
   290  	return nil
   291  }
   292  
   293  func (s *stream) Closed() bool {
   294  	s.lock.Lock()
   295  	defer s.lock.Unlock()
   296  
   297  	return s.closed
   298  }
   299  
   300  func (s *stream) FullClose() error {
   301  	s.lock.Lock()
   302  	defer s.lock.Unlock()
   303  
   304  	if s.closed {
   305  		return ErrStreamClosed
   306  	}
   307  
   308  	s.closed = true
   309  	s.in.close()
   310  	s.out.close()
   311  
   312  	return nil
   313  }
   314  
   315  func (s *stream) Reset() (err error) {
   316  	return s.FullClose()
   317  }
   318  
   319  type record struct {
   320  	b        []byte
   321  	c        int
   322  	lock     sync.Mutex
   323  	dataSigC chan struct{}
   324  	closed   bool
   325  }
   326  
   327  func newRecord() *record {
   328  	return &record{
   329  		dataSigC: make(chan struct{}, 16),
   330  	}
   331  }
   332  
   333  func (r *record) Read(p []byte) (n int, err error) {
   334  	for r.c == r.bytesSize() {
   335  		_, ok := <-r.dataSigC
   336  		if !ok {
   337  			return 0, io.EOF
   338  		}
   339  	}
   340  
   341  	r.lock.Lock()
   342  	defer r.lock.Unlock()
   343  
   344  	end := r.c + len(p)
   345  	if end > len(r.b) {
   346  		end = len(r.b)
   347  	}
   348  	n = copy(p, r.b[r.c:end])
   349  	r.c += n
   350  
   351  	return n, nil
   352  }
   353  
   354  func (r *record) Write(p []byte) (int, error) {
   355  	r.lock.Lock()
   356  	defer r.lock.Unlock()
   357  
   358  	if r.closed {
   359  		return 0, ErrStreamClosed
   360  	}
   361  
   362  	r.b = append(r.b, p...)
   363  	r.dataSigC <- struct{}{}
   364  
   365  	return len(p), nil
   366  }
   367  
   368  func (r *record) close() {
   369  	r.lock.Lock()
   370  	defer r.lock.Unlock()
   371  
   372  	if r.closed {
   373  		return
   374  	}
   375  
   376  	r.closed = true
   377  	close(r.dataSigC)
   378  }
   379  
   380  func (r *record) bytes() []byte {
   381  	return r.b
   382  }
   383  
   384  func (r *record) bytesSize() int {
   385  	r.lock.Lock()
   386  	defer r.lock.Unlock()
   387  	return len(r.b)
   388  }
   389  
   390  type Option interface {
   391  	apply(*Recorder)
   392  }
   393  type optionFunc func(*Recorder)
   394  
   395  func (f optionFunc) apply(r *Recorder) { f(r) }
   396  
   397  var _ p2p.StreamerDisconnecter = (*RecorderDisconnecter)(nil)
   398  
   399  type RecorderDisconnecter struct {
   400  	*Recorder
   401  	disconnected map[string]struct{}
   402  	blocklisted  map[string]time.Duration
   403  	mu           sync.RWMutex
   404  }
   405  
   406  func NewRecorderDisconnecter(r *Recorder) *RecorderDisconnecter {
   407  	return &RecorderDisconnecter{
   408  		Recorder:     r,
   409  		disconnected: make(map[string]struct{}),
   410  		blocklisted:  make(map[string]time.Duration),
   411  	}
   412  }
   413  
   414  func (r *RecorderDisconnecter) Disconnect(overlay swarm.Address, _ string) error {
   415  	r.mu.Lock()
   416  	defer r.mu.Unlock()
   417  
   418  	r.disconnected[overlay.String()] = struct{}{}
   419  	return nil
   420  }
   421  
   422  func (r *RecorderDisconnecter) Blocklist(overlay swarm.Address, d time.Duration, _ string) error {
   423  	r.mu.Lock()
   424  	defer r.mu.Unlock()
   425  
   426  	r.blocklisted[overlay.String()] = d
   427  	return nil
   428  }
   429  
   430  func (r *RecorderDisconnecter) IsDisconnected(overlay swarm.Address) bool {
   431  	r.mu.RLock()
   432  	defer r.mu.RUnlock()
   433  
   434  	_, yes := r.disconnected[overlay.String()]
   435  	return yes
   436  }
   437  
   438  func (r *RecorderDisconnecter) IsBlocklisted(overlay swarm.Address) (bool, time.Duration) {
   439  	r.mu.RLock()
   440  	defer r.mu.RUnlock()
   441  
   442  	d, yes := r.blocklisted[overlay.String()]
   443  	return yes, d
   444  }
   445  
   446  // NetworkStatus implements p2p.NetworkStatuser interface.
   447  // It always returns p2p.NetworkStatusAvailable.
   448  func (r *RecorderDisconnecter) NetworkStatus() p2p.NetworkStatus {
   449  	return p2p.NetworkStatusAvailable
   450  }