github.com/celestiaorg/celestia-node@v0.15.0-beta.1/share/p2p/shrexnd/server.go (about) 1 package shrexnd 2 3 import ( 4 "context" 5 "crypto/sha256" 6 "errors" 7 "fmt" 8 "time" 9 10 "github.com/libp2p/go-libp2p/core/host" 11 "github.com/libp2p/go-libp2p/core/network" 12 "github.com/libp2p/go-libp2p/core/protocol" 13 "go.uber.org/zap" 14 15 "github.com/celestiaorg/go-libp2p-messenger/serde" 16 nmt_pb "github.com/celestiaorg/nmt/pb" 17 18 "github.com/celestiaorg/celestia-node/share" 19 "github.com/celestiaorg/celestia-node/share/eds" 20 "github.com/celestiaorg/celestia-node/share/p2p" 21 pb "github.com/celestiaorg/celestia-node/share/p2p/shrexnd/pb" 22 ) 23 24 // Server implements server side of shrex/nd protocol to serve namespaced share to remote 25 // peers. 26 type Server struct { 27 cancel context.CancelFunc 28 29 host host.Host 30 protocolID protocol.ID 31 32 handler network.StreamHandler 33 store *eds.Store 34 35 params *Parameters 36 middleware *p2p.Middleware 37 metrics *p2p.Metrics 38 } 39 40 // NewServer creates new Server 41 func NewServer(params *Parameters, host host.Host, store *eds.Store) (*Server, error) { 42 if err := params.Validate(); err != nil { 43 return nil, fmt.Errorf("shrex-nd: server creation failed: %w", err) 44 } 45 46 srv := &Server{ 47 store: store, 48 host: host, 49 params: params, 50 protocolID: p2p.ProtocolID(params.NetworkID(), protocolString), 51 middleware: p2p.NewMiddleware(params.ConcurrencyLimit), 52 } 53 54 ctx, cancel := context.WithCancel(context.Background()) 55 srv.cancel = cancel 56 57 srv.handler = srv.middleware.RateLimitHandler(srv.streamHandler(ctx)) 58 return srv, nil 59 } 60 61 // Start starts the server 62 func (srv *Server) Start(context.Context) error { 63 srv.host.SetStreamHandler(srv.protocolID, srv.handler) 64 return nil 65 } 66 67 // Stop stops the server 68 func (srv *Server) Stop(context.Context) error { 69 srv.cancel() 70 srv.host.RemoveStreamHandler(srv.protocolID) 71 return nil 72 } 73 74 func (srv *Server) streamHandler(ctx context.Context) network.StreamHandler { 75 return func(s network.Stream) { 76 err := srv.handleNamespacedData(ctx, s) 77 if err != nil { 78 s.Reset() //nolint:errcheck 79 return 80 } 81 if err = s.Close(); err != nil { 82 log.Debugw("server: closing stream", "err", err) 83 } 84 } 85 } 86 87 // SetHandler sets server handler 88 func (srv *Server) SetHandler(handler network.StreamHandler) { 89 srv.handler = handler 90 } 91 92 func (srv *Server) observeRateLimitedRequests() { 93 numRateLimited := srv.middleware.DrainCounter() 94 if numRateLimited > 0 { 95 srv.metrics.ObserveRequests(context.Background(), numRateLimited, p2p.StatusRateLimited) 96 } 97 } 98 99 func (srv *Server) handleNamespacedData(ctx context.Context, stream network.Stream) error { 100 logger := log.With("source", "server", "peer", stream.Conn().RemotePeer().String()) 101 logger.Debug("handling nd request") 102 103 srv.observeRateLimitedRequests() 104 req, err := srv.readRequest(logger, stream) 105 if err != nil { 106 logger.Warnw("read request", "err", err) 107 srv.metrics.ObserveRequests(ctx, 1, p2p.StatusBadRequest) 108 return err 109 } 110 111 logger = logger.With("namespace", share.Namespace(req.Namespace).String(), 112 "hash", share.DataHash(req.RootHash).String()) 113 114 ctx, cancel := context.WithTimeout(ctx, srv.params.HandleRequestTimeout) 115 defer cancel() 116 117 shares, status, err := srv.getNamespaceData(ctx, req.RootHash, req.Namespace) 118 if err != nil { 119 // server should respond with status regardless if there was an error getting data 120 sendErr := srv.respondStatus(ctx, logger, stream, status) 121 if sendErr != nil { 122 logger.Errorw("sending response", "err", sendErr) 123 srv.metrics.ObserveRequests(ctx, 1, p2p.StatusSendRespErr) 124 } 125 logger.Errorw("handling request", "err", err) 126 return errors.Join(err, sendErr) 127 } 128 129 err = srv.respondStatus(ctx, logger, stream, status) 130 if err != nil { 131 logger.Errorw("sending response", "err", err) 132 srv.metrics.ObserveRequests(ctx, 1, p2p.StatusSendRespErr) 133 return err 134 } 135 136 err = srv.sendNamespacedShares(shares, stream) 137 if err != nil { 138 logger.Errorw("send nd data", "err", err) 139 srv.metrics.ObserveRequests(ctx, 1, p2p.StatusSendRespErr) 140 return err 141 } 142 return nil 143 } 144 145 func (srv *Server) readRequest( 146 logger *zap.SugaredLogger, 147 stream network.Stream, 148 ) (*pb.GetSharesByNamespaceRequest, error) { 149 err := stream.SetReadDeadline(time.Now().Add(srv.params.ServerReadTimeout)) 150 if err != nil { 151 logger.Debugw("setting read deadline", "err", err) 152 } 153 154 var req pb.GetSharesByNamespaceRequest 155 _, err = serde.Read(stream, &req) 156 if err != nil { 157 return nil, fmt.Errorf("reading request: %w", err) 158 159 } 160 161 logger.Debugw("new request") 162 err = stream.CloseRead() 163 if err != nil { 164 logger.Debugw("closing read side of the stream", "err", err) 165 } 166 167 err = validateRequest(req) 168 if err != nil { 169 return nil, fmt.Errorf("invalid request: %w", err) 170 } 171 return &req, nil 172 } 173 174 func (srv *Server) getNamespaceData(ctx context.Context, 175 hash share.DataHash, namespace share.Namespace) (share.NamespacedShares, pb.StatusCode, error) { 176 dah, err := srv.store.GetDAH(ctx, hash) 177 if err != nil { 178 if errors.Is(err, eds.ErrNotFound) { 179 return nil, pb.StatusCode_NOT_FOUND, nil 180 } 181 return nil, pb.StatusCode_INTERNAL, fmt.Errorf("retrieving DAH: %w", err) 182 } 183 184 shares, err := eds.RetrieveNamespaceFromStore(ctx, srv.store, dah, namespace) 185 if err != nil { 186 return nil, pb.StatusCode_INTERNAL, fmt.Errorf("retrieving shares: %w", err) 187 } 188 189 return shares, pb.StatusCode_OK, nil 190 } 191 192 func (srv *Server) respondStatus( 193 ctx context.Context, 194 logger *zap.SugaredLogger, 195 stream network.Stream, 196 status pb.StatusCode, 197 ) error { 198 srv.observeStatus(ctx, status) 199 200 err := stream.SetWriteDeadline(time.Now().Add(srv.params.ServerWriteTimeout)) 201 if err != nil { 202 logger.Debugw("setting write deadline", "err", err) 203 } 204 205 _, err = serde.Write(stream, &pb.GetSharesByNamespaceStatusResponse{Status: status}) 206 if err != nil { 207 return fmt.Errorf("writing response: %w", err) 208 } 209 210 return nil 211 } 212 213 // sendNamespacedShares encodes shares into proto messages and sends it to client 214 func (srv *Server) sendNamespacedShares(shares share.NamespacedShares, stream network.Stream) error { 215 for _, row := range shares { 216 row := &pb.NamespaceRowResponse{ 217 Shares: row.Shares, 218 Proof: &nmt_pb.Proof{ 219 Start: int64(row.Proof.Start()), 220 End: int64(row.Proof.End()), 221 Nodes: row.Proof.Nodes(), 222 LeafHash: row.Proof.LeafHash(), 223 IsMaxNamespaceIgnored: row.Proof.IsMaxNamespaceIDIgnored(), 224 }, 225 } 226 _, err := serde.Write(stream, row) 227 if err != nil { 228 return fmt.Errorf("writing nd data to stream: %w", err) 229 } 230 } 231 return nil 232 } 233 234 func (srv *Server) observeStatus(ctx context.Context, status pb.StatusCode) { 235 switch { 236 case status == pb.StatusCode_OK: 237 srv.metrics.ObserveRequests(ctx, 1, p2p.StatusSuccess) 238 case status == pb.StatusCode_NOT_FOUND: 239 srv.metrics.ObserveRequests(ctx, 1, p2p.StatusNotFound) 240 case status == pb.StatusCode_INTERNAL: 241 srv.metrics.ObserveRequests(ctx, 1, p2p.StatusInternalErr) 242 } 243 } 244 245 // validateRequest checks correctness of the request 246 func validateRequest(req pb.GetSharesByNamespaceRequest) error { 247 if err := share.Namespace(req.Namespace).ValidateForData(); err != nil { 248 return err 249 } 250 if len(req.RootHash) != sha256.Size { 251 return fmt.Errorf("incorrect root hash length: %v", len(req.RootHash)) 252 } 253 return nil 254 }