github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/p2p/protocols/protocol.go (about)

     1  /*
     2  Package protocols is an extension to p2p. It offers a user friendly simple way to define
     3  devp2p subprotocols by abstracting away code standardly shared by protocols.
     4  
     5  * automate assigments of code indexes to messages
     6  * automate RLP decoding/encoding based on reflecting
     7  * provide the forever loop to read incoming messages
     8  * standardise error handling related to communication
     9  * standardised	handshake negotiation
    10  * TODO: automatic generation of wire protocol specification for peers
    11  
    12  */
    13  package protocols
    14  
    15  import (
    16  	"context"
    17  	"fmt"
    18  	"reflect"
    19  	"sync"
    20  
    21  	"github.com/quickchainproject/quickchain/p2p"
    22  )
    23  
    24  // error codes used by this  protocol scheme
    25  const (
    26  	ErrMsgTooLong = iota
    27  	ErrDecode
    28  	ErrWrite
    29  	ErrInvalidMsgCode
    30  	ErrInvalidMsgType
    31  	ErrHandshake
    32  	ErrNoHandler
    33  	ErrHandler
    34  )
    35  
    36  // error description strings associated with the codes
    37  var errorToString = map[int]string{
    38  	ErrMsgTooLong:     "Message too long",
    39  	ErrDecode:         "Invalid message (RLP error)",
    40  	ErrWrite:          "Error sending message",
    41  	ErrInvalidMsgCode: "Invalid message code",
    42  	ErrInvalidMsgType: "Invalid message type",
    43  	ErrHandshake:      "Handshake error",
    44  	ErrNoHandler:      "No handler registered error",
    45  	ErrHandler:        "Message handler error",
    46  }
    47  
    48  /*
    49  Error implements the standard go error interface.
    50  Use:
    51  
    52    errorf(code, format, params ...interface{})
    53  
    54  Prints as:
    55  
    56   <description>: <details>
    57  
    58  where description is given by code in errorToString
    59  and details is fmt.Sprintf(format, params...)
    60  
    61  exported field Code can be checked
    62  */
    63  type Error struct {
    64  	Code    int
    65  	message string
    66  	format  string
    67  	params  []interface{}
    68  }
    69  
    70  func (e Error) Error() (message string) {
    71  	if len(e.message) == 0 {
    72  		name, ok := errorToString[e.Code]
    73  		if !ok {
    74  			panic("invalid message code")
    75  		}
    76  		e.message = name
    77  		if e.format != "" {
    78  			e.message += ": " + fmt.Sprintf(e.format, e.params...)
    79  		}
    80  	}
    81  	return e.message
    82  }
    83  
    84  func errorf(code int, format string, params ...interface{}) *Error {
    85  	return &Error{
    86  		Code:   code,
    87  		format: format,
    88  		params: params,
    89  	}
    90  }
    91  
    92  // Spec is a protocol specification including its name and version as well as
    93  // the types of messages which are exchanged
    94  type Spec struct {
    95  	// Name is the name of the protocol, often a three-letter word
    96  	Name string
    97  
    98  	// Version is the version number of the protocol
    99  	Version uint
   100  
   101  	// MaxMsgSize is the maximum accepted length of the message payload
   102  	MaxMsgSize uint32
   103  
   104  	// Messages is a list of message data types which this protocol uses, with
   105  	// each message type being sent with its array index as the code (so
   106  	// [&foo{}, &bar{}, &baz{}] would send foo, bar and baz with codes
   107  	// 0, 1 and 2 respectively)
   108  	// each message must have a single unique data type
   109  	Messages []interface{}
   110  
   111  	initOnce sync.Once
   112  	codes    map[reflect.Type]uint64
   113  	types    map[uint64]reflect.Type
   114  }
   115  
   116  func (s *Spec) init() {
   117  	s.initOnce.Do(func() {
   118  		s.codes = make(map[reflect.Type]uint64, len(s.Messages))
   119  		s.types = make(map[uint64]reflect.Type, len(s.Messages))
   120  		for i, msg := range s.Messages {
   121  			code := uint64(i)
   122  			typ := reflect.TypeOf(msg)
   123  			if typ.Kind() == reflect.Ptr {
   124  				typ = typ.Elem()
   125  			}
   126  			s.codes[typ] = code
   127  			s.types[code] = typ
   128  		}
   129  	})
   130  }
   131  
   132  // Length returns the number of message types in the protocol
   133  func (s *Spec) Length() uint64 {
   134  	return uint64(len(s.Messages))
   135  }
   136  
   137  // GetCode returns the message code of a type, and boolean second argument is
   138  // false if the message type is not found
   139  func (s *Spec) GetCode(msg interface{}) (uint64, bool) {
   140  	s.init()
   141  	typ := reflect.TypeOf(msg)
   142  	if typ.Kind() == reflect.Ptr {
   143  		typ = typ.Elem()
   144  	}
   145  	code, ok := s.codes[typ]
   146  	return code, ok
   147  }
   148  
   149  // NewMsg construct a new message type given the code
   150  func (s *Spec) NewMsg(code uint64) (interface{}, bool) {
   151  	s.init()
   152  	typ, ok := s.types[code]
   153  	if !ok {
   154  		return nil, false
   155  	}
   156  	return reflect.New(typ).Interface(), true
   157  }
   158  
   159  // Peer represents a remote peer or protocol instance that is running on a peer connection with
   160  // a remote peer
   161  type Peer struct {
   162  	*p2p.Peer                   // the p2p.Peer object representing the remote
   163  	rw        p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from
   164  	spec      *Spec
   165  }
   166  
   167  // NewPeer constructs a new peer
   168  // this constructor is called by the p2p.Protocol#Run function
   169  // the first two arguments are the arguments passed to p2p.Protocol.Run function
   170  // the third argument is the Spec describing the protocol
   171  func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
   172  	return &Peer{
   173  		Peer: p,
   174  		rw:   rw,
   175  		spec: spec,
   176  	}
   177  }
   178  
   179  // Run starts the forever loop that handles incoming messages
   180  // called within the p2p.Protocol#Run function
   181  // the handler argument is a function which is called for each message received
   182  // from the remote peer, a returned error causes the loop to exit
   183  // resulting in disconnection
   184  func (p *Peer) Run(handler func(msg interface{}) error) error {
   185  	for {
   186  		if err := p.handleIncoming(handler); err != nil {
   187  			return err
   188  		}
   189  	}
   190  }
   191  
   192  // Drop disconnects a peer.
   193  // TODO: may need to implement protocol drop only? don't want to kick off the peer
   194  // if they are useful for other protocols
   195  func (p *Peer) Drop(err error) {
   196  	p.Disconnect(p2p.DiscSubprotocolError)
   197  }
   198  
   199  // Send takes a message, encodes it in RLP, finds the right message code and sends the
   200  // message off to the peer
   201  // this low level call will be wrapped by libraries providing routed or broadcast sends
   202  // but often just used to forward and push messages to directly connected peers
   203  func (p *Peer) Send(msg interface{}) error {
   204  	code, found := p.spec.GetCode(msg)
   205  	if !found {
   206  		return errorf(ErrInvalidMsgType, "%v", code)
   207  	}
   208  	return p2p.Send(p.rw, code, msg)
   209  }
   210  
   211  // handleIncoming(code)
   212  // is called each cycle of the main forever loop that dispatches incoming messages
   213  // if this returns an error the loop returns and the peer is disconnected with the error
   214  // this generic handler
   215  // * checks message size,
   216  // * checks for out-of-range message codes,
   217  // * handles decoding with reflection,
   218  // * call handlers as callbacks
   219  func (p *Peer) handleIncoming(handle func(msg interface{}) error) error {
   220  	msg, err := p.rw.ReadMsg()
   221  	if err != nil {
   222  		return err
   223  	}
   224  	// make sure that the payload has been fully consumed
   225  	defer msg.Discard()
   226  
   227  	if msg.Size > p.spec.MaxMsgSize {
   228  		return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize)
   229  	}
   230  
   231  	val, ok := p.spec.NewMsg(msg.Code)
   232  	if !ok {
   233  		return errorf(ErrInvalidMsgCode, "%v", msg.Code)
   234  	}
   235  	if err := msg.Decode(val); err != nil {
   236  		return errorf(ErrDecode, "<= %v: %v", msg, err)
   237  	}
   238  
   239  	// call the registered handler callbacks
   240  	// a registered callback take the decoded message as argument as an interface
   241  	// which the handler is supposed to cast to the appropriate type
   242  	// it is entirely safe not to check the cast in the handler since the handler is
   243  	// chosen based on the proper type in the first place
   244  	if err := handle(val); err != nil {
   245  		return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err)
   246  	}
   247  	return nil
   248  }
   249  
   250  // Handshake negotiates a handshake on the peer connection
   251  // * arguments
   252  //   * context
   253  //   * the local handshake to be sent to the remote peer
   254  //   * funcion to be called on the remote handshake (can be nil)
   255  // * expects a remote handshake back of the same type
   256  // * the dialing peer needs to send the handshake first and then waits for remote
   257  // * the listening peer waits for the remote handshake and then sends it
   258  // returns the remote handshake and an error
   259  func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) {
   260  	if _, ok := p.spec.GetCode(hs); !ok {
   261  		return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs)
   262  	}
   263  	errc := make(chan error, 2)
   264  	handle := func(msg interface{}) error {
   265  		rhs = msg
   266  		if verify != nil {
   267  			return verify(rhs)
   268  		}
   269  		return nil
   270  	}
   271  	send := func() { errc <- p.Send(hs) }
   272  	receive := func() { errc <- p.handleIncoming(handle) }
   273  
   274  	go func() {
   275  		if p.Inbound() {
   276  			receive()
   277  			send()
   278  		} else {
   279  			send()
   280  			receive()
   281  		}
   282  	}()
   283  
   284  	for i := 0; i < 2; i++ {
   285  		select {
   286  		case err = <-errc:
   287  		case <-ctx.Done():
   288  			err = ctx.Err()
   289  		}
   290  		if err != nil {
   291  			return nil, errorf(ErrHandshake, err.Error())
   292  		}
   293  	}
   294  	return rhs, nil
   295  }