github.com/celestiaorg/celestia-node@v0.15.0-beta.1/share/p2p/shrexeds/client.go (about)

     1  package shrexeds
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/libp2p/go-libp2p/core/host"
    12  	"github.com/libp2p/go-libp2p/core/network"
    13  	"github.com/libp2p/go-libp2p/core/peer"
    14  	"github.com/libp2p/go-libp2p/core/protocol"
    15  
    16  	"github.com/celestiaorg/go-libp2p-messenger/serde"
    17  	"github.com/celestiaorg/rsmt2d"
    18  
    19  	"github.com/celestiaorg/celestia-node/share"
    20  	"github.com/celestiaorg/celestia-node/share/eds"
    21  	"github.com/celestiaorg/celestia-node/share/p2p"
    22  	pb "github.com/celestiaorg/celestia-node/share/p2p/shrexeds/pb"
    23  )
    24  
    25  // Client is responsible for requesting EDSs for blocksync over the ShrEx/EDS protocol.
    26  type Client struct {
    27  	params     *Parameters
    28  	protocolID protocol.ID
    29  	host       host.Host
    30  
    31  	metrics *p2p.Metrics
    32  }
    33  
    34  // NewClient creates a new ShrEx/EDS client.
    35  func NewClient(params *Parameters, host host.Host) (*Client, error) {
    36  	if err := params.Validate(); err != nil {
    37  		return nil, fmt.Errorf("shrex-eds: client creation failed: %w", err)
    38  	}
    39  
    40  	return &Client{
    41  		params:     params,
    42  		host:       host,
    43  		protocolID: p2p.ProtocolID(params.NetworkID(), protocolString),
    44  	}, nil
    45  }
    46  
    47  // RequestEDS requests the ODS from the given peers and returns the EDS upon success.
    48  func (c *Client) RequestEDS(
    49  	ctx context.Context,
    50  	dataHash share.DataHash,
    51  	peer peer.ID,
    52  ) (*rsmt2d.ExtendedDataSquare, error) {
    53  	eds, err := c.doRequest(ctx, dataHash, peer)
    54  	if err == nil {
    55  		return eds, nil
    56  	}
    57  	log.Debugw("client: eds request to peer failed", "peer", peer.String(), "hash", dataHash.String(), "error", err)
    58  	if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
    59  		c.metrics.ObserveRequests(ctx, 1, p2p.StatusTimeout)
    60  		return nil, err
    61  	}
    62  	// some net.Errors also mean the context deadline was exceeded, but yamux/mocknet do not
    63  	// unwrap to a ctx err
    64  	var ne net.Error
    65  	if errors.As(err, &ne) && ne.Timeout() {
    66  		if deadline, _ := ctx.Deadline(); deadline.Before(time.Now()) {
    67  			c.metrics.ObserveRequests(ctx, 1, p2p.StatusTimeout)
    68  			return nil, context.DeadlineExceeded
    69  		}
    70  	}
    71  	if err != p2p.ErrNotFound {
    72  		log.Warnw("client: eds request to peer failed",
    73  			"peer", peer.String(),
    74  			"hash", dataHash.String(),
    75  			"err", err)
    76  	}
    77  
    78  	return nil, err
    79  }
    80  
    81  func (c *Client) doRequest(
    82  	ctx context.Context,
    83  	dataHash share.DataHash,
    84  	to peer.ID,
    85  ) (*rsmt2d.ExtendedDataSquare, error) {
    86  	streamOpenCtx, cancel := context.WithTimeout(ctx, c.params.ServerReadTimeout)
    87  	defer cancel()
    88  	stream, err := c.host.NewStream(streamOpenCtx, to, c.protocolID)
    89  	if err != nil {
    90  		return nil, fmt.Errorf("failed to open stream: %w", err)
    91  	}
    92  	defer stream.Close()
    93  
    94  	c.setStreamDeadlines(ctx, stream)
    95  
    96  	req := &pb.EDSRequest{Hash: dataHash}
    97  
    98  	// request ODS
    99  	log.Debugw("client: requesting ods", "hash", dataHash.String(), "peer", to.String())
   100  	_, err = serde.Write(stream, req)
   101  	if err != nil {
   102  		stream.Reset() //nolint:errcheck
   103  		return nil, fmt.Errorf("failed to write request to stream: %w", err)
   104  	}
   105  	err = stream.CloseWrite()
   106  	if err != nil {
   107  		log.Debugw("client: error closing write", "err", err)
   108  	}
   109  
   110  	// read and parse status from peer
   111  	resp := new(pb.EDSResponse)
   112  	err = stream.SetReadDeadline(time.Now().Add(c.params.ServerReadTimeout))
   113  	if err != nil {
   114  		log.Debugw("client: failed to set read deadline for reading status", "err", err)
   115  	}
   116  	_, err = serde.Read(stream, resp)
   117  	if err != nil {
   118  		// server closes the stream here if we are rate limited
   119  		if errors.Is(err, io.EOF) {
   120  			c.metrics.ObserveRequests(ctx, 1, p2p.StatusRateLimited)
   121  			return nil, p2p.ErrNotFound
   122  		}
   123  		stream.Reset() //nolint:errcheck
   124  		return nil, fmt.Errorf("failed to read status from stream: %w", err)
   125  	}
   126  
   127  	switch resp.Status {
   128  	case pb.Status_OK:
   129  		// reset stream deadlines to original values, since read deadline was changed during status read
   130  		c.setStreamDeadlines(ctx, stream)
   131  		// use header and ODS bytes to construct EDS and verify it against dataHash
   132  		eds, err := eds.ReadEDS(ctx, stream, dataHash)
   133  		if err != nil {
   134  			return nil, fmt.Errorf("failed to read eds from ods bytes: %w", err)
   135  		}
   136  		c.metrics.ObserveRequests(ctx, 1, p2p.StatusSuccess)
   137  		return eds, nil
   138  	case pb.Status_NOT_FOUND:
   139  		c.metrics.ObserveRequests(ctx, 1, p2p.StatusNotFound)
   140  		return nil, p2p.ErrNotFound
   141  	case pb.Status_INVALID:
   142  		log.Debug("client: invalid request")
   143  		fallthrough
   144  	case pb.Status_INTERNAL:
   145  		fallthrough
   146  	default:
   147  		c.metrics.ObserveRequests(ctx, 1, p2p.StatusInternalErr)
   148  		return nil, p2p.ErrInvalidResponse
   149  	}
   150  }
   151  
   152  func (c *Client) setStreamDeadlines(ctx context.Context, stream network.Stream) {
   153  	// set read/write deadline to use context deadline if it exists
   154  	if dl, ok := ctx.Deadline(); ok {
   155  		err := stream.SetDeadline(dl)
   156  		if err == nil {
   157  			return
   158  		}
   159  		log.Debugw("client: setting deadline: %s", "err", err)
   160  	}
   161  
   162  	// if deadline not set, client read deadline defaults to server write deadline
   163  	if c.params.ServerWriteTimeout != 0 {
   164  		err := stream.SetReadDeadline(time.Now().Add(c.params.ServerWriteTimeout))
   165  		if err != nil {
   166  			log.Debugw("client: setting read deadline", "err", err)
   167  		}
   168  	}
   169  
   170  	// if deadline not set, client write deadline defaults to server read deadline
   171  	if c.params.ServerReadTimeout != 0 {
   172  		err := stream.SetWriteDeadline(time.Now().Add(c.params.ServerReadTimeout))
   173  		if err != nil {
   174  			log.Debugw("client: setting write deadline", "err", err)
   175  		}
   176  	}
   177  }