github.com/celestiaorg/celestia-node@v0.15.0-beta.1/share/p2p/shrexeds/server.go (about)

     1  package shrexeds
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"time"
     9  
    10  	"github.com/libp2p/go-libp2p/core/host"
    11  	"github.com/libp2p/go-libp2p/core/network"
    12  	"github.com/libp2p/go-libp2p/core/protocol"
    13  	"go.uber.org/zap"
    14  
    15  	"github.com/celestiaorg/go-libp2p-messenger/serde"
    16  
    17  	"github.com/celestiaorg/celestia-node/share"
    18  	"github.com/celestiaorg/celestia-node/share/eds"
    19  	"github.com/celestiaorg/celestia-node/share/p2p"
    20  	p2p_pb "github.com/celestiaorg/celestia-node/share/p2p/shrexeds/pb"
    21  )
    22  
    23  // Server is responsible for serving ODSs for blocksync over the ShrEx/EDS protocol.
    24  type Server struct {
    25  	ctx    context.Context
    26  	cancel context.CancelFunc
    27  
    28  	host       host.Host
    29  	protocolID protocol.ID
    30  
    31  	store *eds.Store
    32  
    33  	params     *Parameters
    34  	middleware *p2p.Middleware
    35  	metrics    *p2p.Metrics
    36  }
    37  
    38  // NewServer creates a new ShrEx/EDS server.
    39  func NewServer(params *Parameters, host host.Host, store *eds.Store) (*Server, error) {
    40  	if err := params.Validate(); err != nil {
    41  		return nil, fmt.Errorf("shrex-eds: server creation failed: %w", err)
    42  	}
    43  
    44  	return &Server{
    45  		host:       host,
    46  		store:      store,
    47  		protocolID: p2p.ProtocolID(params.NetworkID(), protocolString),
    48  		params:     params,
    49  		middleware: p2p.NewMiddleware(params.ConcurrencyLimit),
    50  	}, nil
    51  }
    52  
    53  func (s *Server) Start(context.Context) error {
    54  	s.ctx, s.cancel = context.WithCancel(context.Background())
    55  	s.host.SetStreamHandler(s.protocolID, s.middleware.RateLimitHandler(s.handleStream))
    56  	return nil
    57  }
    58  
    59  func (s *Server) Stop(context.Context) error {
    60  	defer s.cancel()
    61  	s.host.RemoveStreamHandler(s.protocolID)
    62  	return nil
    63  }
    64  
    65  func (s *Server) observeRateLimitedRequests() {
    66  	numRateLimited := s.middleware.DrainCounter()
    67  	if numRateLimited > 0 {
    68  		s.metrics.ObserveRequests(context.Background(), numRateLimited, p2p.StatusRateLimited)
    69  	}
    70  }
    71  
    72  func (s *Server) handleStream(stream network.Stream) {
    73  	logger := log.With("peer", stream.Conn().RemotePeer().String())
    74  	logger.Debug("server: handling eds request")
    75  
    76  	s.observeRateLimitedRequests()
    77  
    78  	// read request from stream to get the dataHash for store lookup
    79  	req, err := s.readRequest(logger, stream)
    80  	if err != nil {
    81  		logger.Warnw("server: reading request from stream", "err", err)
    82  		stream.Reset() //nolint:errcheck
    83  		return
    84  	}
    85  
    86  	// ensure the requested dataHash is a valid root
    87  	hash := share.DataHash(req.Hash)
    88  	err = hash.Validate()
    89  	if err != nil {
    90  		logger.Warnw("server: invalid request", "err", err)
    91  		stream.Reset() //nolint:errcheck
    92  		return
    93  	}
    94  	logger = logger.With("hash", hash.String())
    95  
    96  	ctx, cancel := context.WithTimeout(s.ctx, s.params.HandleRequestTimeout)
    97  	defer cancel()
    98  
    99  	// determine whether the EDS is available in our store
   100  	// we do not close the reader, so that other requests will not need to re-open the file.
   101  	// closing is handled by the LRU cache.
   102  	edsReader, err := s.store.GetCAR(ctx, hash)
   103  	var status p2p_pb.Status
   104  	switch {
   105  	case err == nil:
   106  		defer func() {
   107  			if err := edsReader.Close(); err != nil {
   108  				log.Warnw("closing car reader", "err", err)
   109  			}
   110  		}()
   111  		status = p2p_pb.Status_OK
   112  	case errors.Is(err, eds.ErrNotFound):
   113  		logger.Warnw("server: request hash not found")
   114  		s.metrics.ObserveRequests(ctx, 1, p2p.StatusNotFound)
   115  		status = p2p_pb.Status_NOT_FOUND
   116  	case err != nil:
   117  		logger.Errorw("server: get CAR", "err", err)
   118  		status = p2p_pb.Status_INTERNAL
   119  	}
   120  
   121  	// inform the client of our status
   122  	err = s.writeStatus(logger, status, stream)
   123  	if err != nil {
   124  		logger.Warnw("server: writing status to stream", "err", err)
   125  		stream.Reset() //nolint:errcheck
   126  		return
   127  	}
   128  	// if we cannot serve the EDS, we are already done
   129  	if status != p2p_pb.Status_OK {
   130  		err = stream.Close()
   131  		if err != nil {
   132  			logger.Debugw("server: closing stream", "err", err)
   133  		}
   134  		return
   135  	}
   136  
   137  	// start streaming the ODS to the client
   138  	err = s.writeODS(logger, edsReader, stream)
   139  	if err != nil {
   140  		logger.Warnw("server: writing ods to stream", "err", err)
   141  		stream.Reset() //nolint:errcheck
   142  		return
   143  	}
   144  
   145  	s.metrics.ObserveRequests(ctx, 1, p2p.StatusSuccess)
   146  	err = stream.Close()
   147  	if err != nil {
   148  		logger.Debugw("server: closing stream", "err", err)
   149  	}
   150  }
   151  
   152  func (s *Server) readRequest(logger *zap.SugaredLogger, stream network.Stream) (*p2p_pb.EDSRequest, error) {
   153  	err := stream.SetReadDeadline(time.Now().Add(s.params.ServerReadTimeout))
   154  	if err != nil {
   155  		logger.Debugw("server: set read deadline", "err", err)
   156  	}
   157  
   158  	req := new(p2p_pb.EDSRequest)
   159  	_, err = serde.Read(stream, req)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	err = stream.CloseRead()
   164  	if err != nil {
   165  		logger.Debugw("server: closing read", "err", err)
   166  	}
   167  
   168  	return req, nil
   169  }
   170  
   171  func (s *Server) writeStatus(logger *zap.SugaredLogger, status p2p_pb.Status, stream network.Stream) error {
   172  	err := stream.SetWriteDeadline(time.Now().Add(s.params.ServerWriteTimeout))
   173  	if err != nil {
   174  		logger.Debugw("server: set write deadline", "err", err)
   175  	}
   176  
   177  	resp := &p2p_pb.EDSResponse{Status: status}
   178  	_, err = serde.Write(stream, resp)
   179  	return err
   180  }
   181  
   182  func (s *Server) writeODS(logger *zap.SugaredLogger, edsReader io.Reader, stream network.Stream) error {
   183  	err := stream.SetWriteDeadline(time.Now().Add(s.params.ServerWriteTimeout))
   184  	if err != nil {
   185  		logger.Debugw("server: set read deadline", "err", err)
   186  	}
   187  
   188  	odsReader, err := eds.ODSReader(edsReader)
   189  	if err != nil {
   190  		return fmt.Errorf("creating ODS reader: %w", err)
   191  	}
   192  	buf := make([]byte, s.params.BufferSize)
   193  	_, err = io.CopyBuffer(stream, odsReader, buf)
   194  	if err != nil {
   195  		return fmt.Errorf("writing ODS bytes: %w", err)
   196  	}
   197  
   198  	return nil
   199  }