github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/p2p/protocols/protocol.go (about)

     1  package protocols
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"sync"
     8  
     9  	"github.com/neatlab/neatio/network/p2p"
    10  )
    11  
    12  const (
    13  	ErrMsgTooLong = iota
    14  	ErrDecode
    15  	ErrWrite
    16  	ErrInvalidMsgCode
    17  	ErrInvalidMsgType
    18  	ErrHandshake
    19  	ErrNoHandler
    20  	ErrHandler
    21  )
    22  
    23  var errorToString = map[int]string{
    24  	ErrMsgTooLong:     "Message too long",
    25  	ErrDecode:         "Invalid message (RLP error)",
    26  	ErrWrite:          "Error sending message",
    27  	ErrInvalidMsgCode: "Invalid message code",
    28  	ErrInvalidMsgType: "Invalid message type",
    29  	ErrHandshake:      "Handshake error",
    30  	ErrNoHandler:      "No handler registered error",
    31  	ErrHandler:        "Message handler error",
    32  }
    33  
    34  type Error struct {
    35  	Code    int
    36  	message string
    37  	format  string
    38  	params  []interface{}
    39  }
    40  
    41  func (e Error) Error() (message string) {
    42  	if len(e.message) == 0 {
    43  		name, ok := errorToString[e.Code]
    44  		if !ok {
    45  			panic("invalid message code")
    46  		}
    47  		e.message = name
    48  		if e.format != "" {
    49  			e.message += ": " + fmt.Sprintf(e.format, e.params...)
    50  		}
    51  	}
    52  	return e.message
    53  }
    54  
    55  func errorf(code int, format string, params ...interface{}) *Error {
    56  	return &Error{
    57  		Code:   code,
    58  		format: format,
    59  		params: params,
    60  	}
    61  }
    62  
    63  type Spec struct {
    64  	Name string
    65  
    66  	Version uint
    67  
    68  	MaxMsgSize uint32
    69  
    70  	Messages []interface{}
    71  
    72  	initOnce sync.Once
    73  	codes    map[reflect.Type]uint64
    74  	types    map[uint64]reflect.Type
    75  }
    76  
    77  func (s *Spec) init() {
    78  	s.initOnce.Do(func() {
    79  		s.codes = make(map[reflect.Type]uint64, len(s.Messages))
    80  		s.types = make(map[uint64]reflect.Type, len(s.Messages))
    81  		for i, msg := range s.Messages {
    82  			code := uint64(i)
    83  			typ := reflect.TypeOf(msg)
    84  			if typ.Kind() == reflect.Ptr {
    85  				typ = typ.Elem()
    86  			}
    87  			s.codes[typ] = code
    88  			s.types[code] = typ
    89  		}
    90  	})
    91  }
    92  
    93  func (s *Spec) Length() uint64 {
    94  	return uint64(len(s.Messages))
    95  }
    96  
    97  func (s *Spec) GetCode(msg interface{}) (uint64, bool) {
    98  	s.init()
    99  	typ := reflect.TypeOf(msg)
   100  	if typ.Kind() == reflect.Ptr {
   101  		typ = typ.Elem()
   102  	}
   103  	code, ok := s.codes[typ]
   104  	return code, ok
   105  }
   106  
   107  func (s *Spec) NewMsg(code uint64) (interface{}, bool) {
   108  	s.init()
   109  	typ, ok := s.types[code]
   110  	if !ok {
   111  		return nil, false
   112  	}
   113  	return reflect.New(typ).Interface(), true
   114  }
   115  
   116  type Peer struct {
   117  	*p2p.Peer
   118  	rw   p2p.MsgReadWriter
   119  	spec *Spec
   120  }
   121  
   122  func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
   123  	return &Peer{
   124  		Peer: p,
   125  		rw:   rw,
   126  		spec: spec,
   127  	}
   128  }
   129  
   130  func (p *Peer) Run(handler func(msg interface{}) error) error {
   131  	for {
   132  		if err := p.handleIncoming(handler); err != nil {
   133  			return err
   134  		}
   135  	}
   136  }
   137  
   138  func (p *Peer) Drop(err error) {
   139  	p.Disconnect(p2p.DiscSubprotocolError)
   140  }
   141  
   142  func (p *Peer) Send(msg interface{}) error {
   143  	code, found := p.spec.GetCode(msg)
   144  	if !found {
   145  		return errorf(ErrInvalidMsgType, "%v", code)
   146  	}
   147  	return p2p.Send(p.rw, code, msg)
   148  }
   149  
   150  func (p *Peer) handleIncoming(handle func(msg interface{}) error) error {
   151  	msg, err := p.rw.ReadMsg()
   152  	if err != nil {
   153  		return err
   154  	}
   155  
   156  	defer msg.Discard()
   157  
   158  	if msg.Size > p.spec.MaxMsgSize {
   159  		return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize)
   160  	}
   161  
   162  	val, ok := p.spec.NewMsg(msg.Code)
   163  	if !ok {
   164  		return errorf(ErrInvalidMsgCode, "%v", msg.Code)
   165  	}
   166  	if err := msg.Decode(val); err != nil {
   167  		return errorf(ErrDecode, "<= %v: %v", msg, err)
   168  	}
   169  
   170  	if err := handle(val); err != nil {
   171  		return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err)
   172  	}
   173  	return nil
   174  }
   175  
   176  func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) {
   177  	if _, ok := p.spec.GetCode(hs); !ok {
   178  		return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs)
   179  	}
   180  	errc := make(chan error, 2)
   181  	handle := func(msg interface{}) error {
   182  		rhs = msg
   183  		if verify != nil {
   184  			return verify(rhs)
   185  		}
   186  		return nil
   187  	}
   188  	send := func() { errc <- p.Send(hs) }
   189  	receive := func() { errc <- p.handleIncoming(handle) }
   190  
   191  	go func() {
   192  		if p.Inbound() {
   193  			receive()
   194  			send()
   195  		} else {
   196  			send()
   197  			receive()
   198  		}
   199  	}()
   200  
   201  	for i := 0; i < 2; i++ {
   202  		select {
   203  		case err = <-errc:
   204  		case <-ctx.Done():
   205  			err = ctx.Err()
   206  		}
   207  		if err != nil {
   208  			return nil, errorf(ErrHandshake, err.Error())
   209  		}
   210  	}
   211  	return rhs, nil
   212  }