github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/message.go (about)

     1  package p2p
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/neatio-net/neatio/network/p2p/discover"
    13  	"github.com/neatio-net/neatio/utilities/event"
    14  	"github.com/neatio-net/neatio/utilities/rlp"
    15  )
    16  
    17  type Msg struct {
    18  	Code       uint64
    19  	Size       uint32
    20  	Payload    io.Reader
    21  	ReceivedAt time.Time
    22  }
    23  
    24  func (msg Msg) Decode(val interface{}) error {
    25  	s := rlp.NewStream(msg.Payload, uint64(msg.Size))
    26  	if err := s.Decode(val); err != nil {
    27  		return newPeerError(errInvalidMsg, "(code %x) (size %d) %v", msg.Code, msg.Size, err)
    28  	}
    29  	return nil
    30  }
    31  
    32  func (msg Msg) String() string {
    33  	return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size)
    34  }
    35  
    36  func (msg Msg) Discard() error {
    37  	_, err := io.Copy(ioutil.Discard, msg.Payload)
    38  	return err
    39  }
    40  
    41  type MsgReader interface {
    42  	ReadMsg() (Msg, error)
    43  }
    44  
    45  type MsgWriter interface {
    46  	WriteMsg(Msg) error
    47  }
    48  
    49  type MsgReadWriter interface {
    50  	MsgReader
    51  	MsgWriter
    52  }
    53  
    54  func Send(w MsgWriter, msgcode uint64, data interface{}) error {
    55  	size, r, err := rlp.EncodeToReader(data)
    56  	if err != nil {
    57  		return err
    58  	}
    59  	return w.WriteMsg(Msg{Code: msgcode, Size: uint32(size), Payload: r})
    60  }
    61  
    62  func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error {
    63  	return Send(w, msgcode, elems)
    64  }
    65  
    66  type eofSignal struct {
    67  	wrapped io.Reader
    68  	count   uint32
    69  	eof     chan<- struct{}
    70  }
    71  
    72  func (r *eofSignal) Read(buf []byte) (int, error) {
    73  	if r.count == 0 {
    74  		if r.eof != nil {
    75  			r.eof <- struct{}{}
    76  			r.eof = nil
    77  		}
    78  		return 0, io.EOF
    79  	}
    80  
    81  	max := len(buf)
    82  	if int(r.count) < len(buf) {
    83  		max = int(r.count)
    84  	}
    85  	n, err := r.wrapped.Read(buf[:max])
    86  	r.count -= uint32(n)
    87  	if (err != nil || r.count == 0) && r.eof != nil {
    88  		r.eof <- struct{}{}
    89  		r.eof = nil
    90  	}
    91  	return n, err
    92  }
    93  
    94  func MsgPipe() (*MsgPipeRW, *MsgPipeRW) {
    95  	var (
    96  		c1, c2  = make(chan Msg), make(chan Msg)
    97  		closing = make(chan struct{})
    98  		closed  = new(int32)
    99  		rw1     = &MsgPipeRW{c1, c2, closing, closed}
   100  		rw2     = &MsgPipeRW{c2, c1, closing, closed}
   101  	)
   102  	return rw1, rw2
   103  }
   104  
   105  var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe")
   106  
   107  type MsgPipeRW struct {
   108  	w       chan<- Msg
   109  	r       <-chan Msg
   110  	closing chan struct{}
   111  	closed  *int32
   112  }
   113  
   114  func (p *MsgPipeRW) WriteMsg(msg Msg) error {
   115  	if atomic.LoadInt32(p.closed) == 0 {
   116  		consumed := make(chan struct{}, 1)
   117  		msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
   118  		select {
   119  		case p.w <- msg:
   120  			if msg.Size > 0 {
   121  
   122  				select {
   123  				case <-consumed:
   124  				case <-p.closing:
   125  				}
   126  			}
   127  			return nil
   128  		case <-p.closing:
   129  		}
   130  	}
   131  	return ErrPipeClosed
   132  }
   133  
   134  func (p *MsgPipeRW) ReadMsg() (Msg, error) {
   135  	if atomic.LoadInt32(p.closed) == 0 {
   136  		select {
   137  		case msg := <-p.r:
   138  			return msg, nil
   139  		case <-p.closing:
   140  		}
   141  	}
   142  	return Msg{}, ErrPipeClosed
   143  }
   144  
   145  func (p *MsgPipeRW) Close() error {
   146  	if atomic.AddInt32(p.closed, 1) != 1 {
   147  
   148  		atomic.StoreInt32(p.closed, 1)
   149  		return nil
   150  	}
   151  	close(p.closing)
   152  	return nil
   153  }
   154  
   155  func ExpectMsg(r MsgReader, code uint64, content interface{}) error {
   156  	msg, err := r.ReadMsg()
   157  	if err != nil {
   158  		return err
   159  	}
   160  	if msg.Code != code {
   161  		return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
   162  	}
   163  	if content == nil {
   164  		return msg.Discard()
   165  	} else {
   166  		contentEnc, err := rlp.EncodeToBytes(content)
   167  		if err != nil {
   168  			panic("content encode error: " + err.Error())
   169  		}
   170  		if int(msg.Size) != len(contentEnc) {
   171  			return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc))
   172  		}
   173  		actualContent, err := ioutil.ReadAll(msg.Payload)
   174  		if err != nil {
   175  			return err
   176  		}
   177  		if !bytes.Equal(actualContent, contentEnc) {
   178  			return fmt.Errorf("message payload mismatch:\ngot:  %x\nwant: %x", actualContent, contentEnc)
   179  		}
   180  	}
   181  	return nil
   182  }
   183  
   184  type msgEventer struct {
   185  	MsgReadWriter
   186  
   187  	feed     *event.Feed
   188  	peerID   discover.NodeID
   189  	Protocol string
   190  }
   191  
   192  func newMsgEventer(rw MsgReadWriter, feed *event.Feed, peerID discover.NodeID, proto string) *msgEventer {
   193  	return &msgEventer{
   194  		MsgReadWriter: rw,
   195  		feed:          feed,
   196  		peerID:        peerID,
   197  		Protocol:      proto,
   198  	}
   199  }
   200  
   201  func (self *msgEventer) ReadMsg() (Msg, error) {
   202  	msg, err := self.MsgReadWriter.ReadMsg()
   203  	if err != nil {
   204  		return msg, err
   205  	}
   206  	self.feed.Send(&PeerEvent{
   207  		Type:     PeerEventTypeMsgRecv,
   208  		Peer:     self.peerID,
   209  		Protocol: self.Protocol,
   210  		MsgCode:  &msg.Code,
   211  		MsgSize:  &msg.Size,
   212  	})
   213  	return msg, nil
   214  }
   215  
   216  func (self *msgEventer) WriteMsg(msg Msg) error {
   217  	err := self.MsgReadWriter.WriteMsg(msg)
   218  	if err != nil {
   219  		return err
   220  	}
   221  	self.feed.Send(&PeerEvent{
   222  		Type:     PeerEventTypeMsgSend,
   223  		Peer:     self.peerID,
   224  		Protocol: self.Protocol,
   225  		MsgCode:  &msg.Code,
   226  		MsgSize:  &msg.Size,
   227  	})
   228  	return nil
   229  }
   230  
   231  func (self *msgEventer) Close() error {
   232  	if v, ok := self.MsgReadWriter.(io.Closer); ok {
   233  		return v.Close()
   234  	}
   235  	return nil
   236  }