github.com/celestiaorg/celestia-node@v0.15.0-beta.1/share/p2p/shrexnd/server.go (about)

     1  package shrexnd
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"errors"
     7  	"fmt"
     8  	"time"
     9  
    10  	"github.com/libp2p/go-libp2p/core/host"
    11  	"github.com/libp2p/go-libp2p/core/network"
    12  	"github.com/libp2p/go-libp2p/core/protocol"
    13  	"go.uber.org/zap"
    14  
    15  	"github.com/celestiaorg/go-libp2p-messenger/serde"
    16  	nmt_pb "github.com/celestiaorg/nmt/pb"
    17  
    18  	"github.com/celestiaorg/celestia-node/share"
    19  	"github.com/celestiaorg/celestia-node/share/eds"
    20  	"github.com/celestiaorg/celestia-node/share/p2p"
    21  	pb "github.com/celestiaorg/celestia-node/share/p2p/shrexnd/pb"
    22  )
    23  
    24  // Server implements server side of shrex/nd protocol to serve namespaced share to remote
    25  // peers.
    26  type Server struct {
    27  	cancel context.CancelFunc
    28  
    29  	host       host.Host
    30  	protocolID protocol.ID
    31  
    32  	handler network.StreamHandler
    33  	store   *eds.Store
    34  
    35  	params     *Parameters
    36  	middleware *p2p.Middleware
    37  	metrics    *p2p.Metrics
    38  }
    39  
    40  // NewServer creates new Server
    41  func NewServer(params *Parameters, host host.Host, store *eds.Store) (*Server, error) {
    42  	if err := params.Validate(); err != nil {
    43  		return nil, fmt.Errorf("shrex-nd: server creation failed: %w", err)
    44  	}
    45  
    46  	srv := &Server{
    47  		store:      store,
    48  		host:       host,
    49  		params:     params,
    50  		protocolID: p2p.ProtocolID(params.NetworkID(), protocolString),
    51  		middleware: p2p.NewMiddleware(params.ConcurrencyLimit),
    52  	}
    53  
    54  	ctx, cancel := context.WithCancel(context.Background())
    55  	srv.cancel = cancel
    56  
    57  	srv.handler = srv.middleware.RateLimitHandler(srv.streamHandler(ctx))
    58  	return srv, nil
    59  }
    60  
    61  // Start starts the server
    62  func (srv *Server) Start(context.Context) error {
    63  	srv.host.SetStreamHandler(srv.protocolID, srv.handler)
    64  	return nil
    65  }
    66  
    67  // Stop stops the server
    68  func (srv *Server) Stop(context.Context) error {
    69  	srv.cancel()
    70  	srv.host.RemoveStreamHandler(srv.protocolID)
    71  	return nil
    72  }
    73  
    74  func (srv *Server) streamHandler(ctx context.Context) network.StreamHandler {
    75  	return func(s network.Stream) {
    76  		err := srv.handleNamespacedData(ctx, s)
    77  		if err != nil {
    78  			s.Reset() //nolint:errcheck
    79  			return
    80  		}
    81  		if err = s.Close(); err != nil {
    82  			log.Debugw("server: closing stream", "err", err)
    83  		}
    84  	}
    85  }
    86  
    87  // SetHandler sets server handler
    88  func (srv *Server) SetHandler(handler network.StreamHandler) {
    89  	srv.handler = handler
    90  }
    91  
    92  func (srv *Server) observeRateLimitedRequests() {
    93  	numRateLimited := srv.middleware.DrainCounter()
    94  	if numRateLimited > 0 {
    95  		srv.metrics.ObserveRequests(context.Background(), numRateLimited, p2p.StatusRateLimited)
    96  	}
    97  }
    98  
    99  func (srv *Server) handleNamespacedData(ctx context.Context, stream network.Stream) error {
   100  	logger := log.With("source", "server", "peer", stream.Conn().RemotePeer().String())
   101  	logger.Debug("handling nd request")
   102  
   103  	srv.observeRateLimitedRequests()
   104  	req, err := srv.readRequest(logger, stream)
   105  	if err != nil {
   106  		logger.Warnw("read request", "err", err)
   107  		srv.metrics.ObserveRequests(ctx, 1, p2p.StatusBadRequest)
   108  		return err
   109  	}
   110  
   111  	logger = logger.With("namespace", share.Namespace(req.Namespace).String(),
   112  		"hash", share.DataHash(req.RootHash).String())
   113  
   114  	ctx, cancel := context.WithTimeout(ctx, srv.params.HandleRequestTimeout)
   115  	defer cancel()
   116  
   117  	shares, status, err := srv.getNamespaceData(ctx, req.RootHash, req.Namespace)
   118  	if err != nil {
   119  		// server should respond with status regardless if there was an error getting data
   120  		sendErr := srv.respondStatus(ctx, logger, stream, status)
   121  		if sendErr != nil {
   122  			logger.Errorw("sending response", "err", sendErr)
   123  			srv.metrics.ObserveRequests(ctx, 1, p2p.StatusSendRespErr)
   124  		}
   125  		logger.Errorw("handling request", "err", err)
   126  		return errors.Join(err, sendErr)
   127  	}
   128  
   129  	err = srv.respondStatus(ctx, logger, stream, status)
   130  	if err != nil {
   131  		logger.Errorw("sending response", "err", err)
   132  		srv.metrics.ObserveRequests(ctx, 1, p2p.StatusSendRespErr)
   133  		return err
   134  	}
   135  
   136  	err = srv.sendNamespacedShares(shares, stream)
   137  	if err != nil {
   138  		logger.Errorw("send nd data", "err", err)
   139  		srv.metrics.ObserveRequests(ctx, 1, p2p.StatusSendRespErr)
   140  		return err
   141  	}
   142  	return nil
   143  }
   144  
   145  func (srv *Server) readRequest(
   146  	logger *zap.SugaredLogger,
   147  	stream network.Stream,
   148  ) (*pb.GetSharesByNamespaceRequest, error) {
   149  	err := stream.SetReadDeadline(time.Now().Add(srv.params.ServerReadTimeout))
   150  	if err != nil {
   151  		logger.Debugw("setting read deadline", "err", err)
   152  	}
   153  
   154  	var req pb.GetSharesByNamespaceRequest
   155  	_, err = serde.Read(stream, &req)
   156  	if err != nil {
   157  		return nil, fmt.Errorf("reading request: %w", err)
   158  
   159  	}
   160  
   161  	logger.Debugw("new request")
   162  	err = stream.CloseRead()
   163  	if err != nil {
   164  		logger.Debugw("closing read side of the stream", "err", err)
   165  	}
   166  
   167  	err = validateRequest(req)
   168  	if err != nil {
   169  		return nil, fmt.Errorf("invalid request: %w", err)
   170  	}
   171  	return &req, nil
   172  }
   173  
   174  func (srv *Server) getNamespaceData(ctx context.Context,
   175  	hash share.DataHash, namespace share.Namespace) (share.NamespacedShares, pb.StatusCode, error) {
   176  	dah, err := srv.store.GetDAH(ctx, hash)
   177  	if err != nil {
   178  		if errors.Is(err, eds.ErrNotFound) {
   179  			return nil, pb.StatusCode_NOT_FOUND, nil
   180  		}
   181  		return nil, pb.StatusCode_INTERNAL, fmt.Errorf("retrieving DAH: %w", err)
   182  	}
   183  
   184  	shares, err := eds.RetrieveNamespaceFromStore(ctx, srv.store, dah, namespace)
   185  	if err != nil {
   186  		return nil, pb.StatusCode_INTERNAL, fmt.Errorf("retrieving shares: %w", err)
   187  	}
   188  
   189  	return shares, pb.StatusCode_OK, nil
   190  }
   191  
   192  func (srv *Server) respondStatus(
   193  	ctx context.Context,
   194  	logger *zap.SugaredLogger,
   195  	stream network.Stream,
   196  	status pb.StatusCode,
   197  ) error {
   198  	srv.observeStatus(ctx, status)
   199  
   200  	err := stream.SetWriteDeadline(time.Now().Add(srv.params.ServerWriteTimeout))
   201  	if err != nil {
   202  		logger.Debugw("setting write deadline", "err", err)
   203  	}
   204  
   205  	_, err = serde.Write(stream, &pb.GetSharesByNamespaceStatusResponse{Status: status})
   206  	if err != nil {
   207  		return fmt.Errorf("writing response: %w", err)
   208  	}
   209  
   210  	return nil
   211  }
   212  
   213  // sendNamespacedShares encodes shares into proto messages and sends it to client
   214  func (srv *Server) sendNamespacedShares(shares share.NamespacedShares, stream network.Stream) error {
   215  	for _, row := range shares {
   216  		row := &pb.NamespaceRowResponse{
   217  			Shares: row.Shares,
   218  			Proof: &nmt_pb.Proof{
   219  				Start:                 int64(row.Proof.Start()),
   220  				End:                   int64(row.Proof.End()),
   221  				Nodes:                 row.Proof.Nodes(),
   222  				LeafHash:              row.Proof.LeafHash(),
   223  				IsMaxNamespaceIgnored: row.Proof.IsMaxNamespaceIDIgnored(),
   224  			},
   225  		}
   226  		_, err := serde.Write(stream, row)
   227  		if err != nil {
   228  			return fmt.Errorf("writing nd data to stream: %w", err)
   229  		}
   230  	}
   231  	return nil
   232  }
   233  
   234  func (srv *Server) observeStatus(ctx context.Context, status pb.StatusCode) {
   235  	switch {
   236  	case status == pb.StatusCode_OK:
   237  		srv.metrics.ObserveRequests(ctx, 1, p2p.StatusSuccess)
   238  	case status == pb.StatusCode_NOT_FOUND:
   239  		srv.metrics.ObserveRequests(ctx, 1, p2p.StatusNotFound)
   240  	case status == pb.StatusCode_INTERNAL:
   241  		srv.metrics.ObserveRequests(ctx, 1, p2p.StatusInternalErr)
   242  	}
   243  }
   244  
   245  // validateRequest checks correctness of the request
   246  func validateRequest(req pb.GetSharesByNamespaceRequest) error {
   247  	if err := share.Namespace(req.Namespace).ValidateForData(); err != nil {
   248  		return err
   249  	}
   250  	if len(req.RootHash) != sha256.Size {
   251  		return fmt.Errorf("incorrect root hash length: %v", len(req.RootHash))
   252  	}
   253  	return nil
   254  }