github.com/oskarth/go-ethereum@v1.6.8-0.20191013093314-dac24a9d3494/p2p/protocols/protocol.go (about)

     1  // Copyright 2017 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  /*
    18  Package protocols is an extension to p2p. It offers a user friendly simple way to define
    19  devp2p subprotocols by abstracting away code standardly shared by protocols.
    20  
    21  * automate assigments of code indexes to messages
    22  * automate RLP decoding/encoding based on reflecting
    23  * provide the forever loop to read incoming messages
    24  * standardise error handling related to communication
    25  * standardised	handshake negotiation
    26  * TODO: automatic generation of wire protocol specification for peers
    27  
    28  */
    29  package protocols
    30  
    31  import (
    32  	"bufio"
    33  	"bytes"
    34  	"context"
    35  	"fmt"
    36  	"io"
    37  	"reflect"
    38  	"sync"
    39  	"time"
    40  
    41  	"github.com/ethereum/go-ethereum/log"
    42  	"github.com/ethereum/go-ethereum/metrics"
    43  	"github.com/ethereum/go-ethereum/p2p"
    44  	"github.com/ethereum/go-ethereum/rlp"
    45  	"github.com/ethereum/go-ethereum/swarm/spancontext"
    46  	"github.com/ethereum/go-ethereum/swarm/tracing"
    47  	opentracing "github.com/opentracing/opentracing-go"
    48  )
    49  
    50  // error codes used by this  protocol scheme
    51  const (
    52  	ErrMsgTooLong = iota
    53  	ErrDecode
    54  	ErrWrite
    55  	ErrInvalidMsgCode
    56  	ErrInvalidMsgType
    57  	ErrHandshake
    58  	ErrNoHandler
    59  	ErrHandler
    60  )
    61  
    62  // error description strings associated with the codes
    63  var errorToString = map[int]string{
    64  	ErrMsgTooLong:     "Message too long",
    65  	ErrDecode:         "Invalid message (RLP error)",
    66  	ErrWrite:          "Error sending message",
    67  	ErrInvalidMsgCode: "Invalid message code",
    68  	ErrInvalidMsgType: "Invalid message type",
    69  	ErrHandshake:      "Handshake error",
    70  	ErrNoHandler:      "No handler registered error",
    71  	ErrHandler:        "Message handler error",
    72  }
    73  
    74  /*
    75  Error implements the standard go error interface.
    76  Use:
    77  
    78    errorf(code, format, params ...interface{})
    79  
    80  Prints as:
    81  
    82   <description>: <details>
    83  
    84  where description is given by code in errorToString
    85  and details is fmt.Sprintf(format, params...)
    86  
    87  exported field Code can be checked
    88  */
    89  type Error struct {
    90  	Code    int
    91  	message string
    92  	format  string
    93  	params  []interface{}
    94  }
    95  
    96  func (e Error) Error() (message string) {
    97  	if len(e.message) == 0 {
    98  		name, ok := errorToString[e.Code]
    99  		if !ok {
   100  			panic("invalid message code")
   101  		}
   102  		e.message = name
   103  		if e.format != "" {
   104  			e.message += ": " + fmt.Sprintf(e.format, e.params...)
   105  		}
   106  	}
   107  	return e.message
   108  }
   109  
   110  func errorf(code int, format string, params ...interface{}) *Error {
   111  	return &Error{
   112  		Code:   code,
   113  		format: format,
   114  		params: params,
   115  	}
   116  }
   117  
   118  // WrappedMsg is used to propagate marshalled context alongside message payloads
   119  type WrappedMsg struct {
   120  	Context []byte
   121  	Size    uint32
   122  	Payload []byte
   123  }
   124  
   125  //For accounting, the design is to allow the Spec to describe which and how its messages are priced
   126  //To access this functionality, we provide a Hook interface which will call accounting methods
   127  //NOTE: there could be more such (horizontal) hooks in the future
   128  type Hook interface {
   129  	//A hook for sending messages
   130  	Send(peer *Peer, size uint32, msg interface{}) error
   131  	//A hook for receiving messages
   132  	Receive(peer *Peer, size uint32, msg interface{}) error
   133  }
   134  
   135  // Spec is a protocol specification including its name and version as well as
   136  // the types of messages which are exchanged
   137  type Spec struct {
   138  	// Name is the name of the protocol, often a three-letter word
   139  	Name string
   140  
   141  	// Version is the version number of the protocol
   142  	Version uint
   143  
   144  	// MaxMsgSize is the maximum accepted length of the message payload
   145  	MaxMsgSize uint32
   146  
   147  	// Messages is a list of message data types which this protocol uses, with
   148  	// each message type being sent with its array index as the code (so
   149  	// [&foo{}, &bar{}, &baz{}] would send foo, bar and baz with codes
   150  	// 0, 1 and 2 respectively)
   151  	// each message must have a single unique data type
   152  	Messages []interface{}
   153  
   154  	//hook for accounting (could be extended to multiple hooks in the future)
   155  	Hook Hook
   156  
   157  	initOnce sync.Once
   158  	codes    map[reflect.Type]uint64
   159  	types    map[uint64]reflect.Type
   160  }
   161  
   162  func (s *Spec) init() {
   163  	s.initOnce.Do(func() {
   164  		s.codes = make(map[reflect.Type]uint64, len(s.Messages))
   165  		s.types = make(map[uint64]reflect.Type, len(s.Messages))
   166  		for i, msg := range s.Messages {
   167  			code := uint64(i)
   168  			typ := reflect.TypeOf(msg)
   169  			if typ.Kind() == reflect.Ptr {
   170  				typ = typ.Elem()
   171  			}
   172  			s.codes[typ] = code
   173  			s.types[code] = typ
   174  		}
   175  	})
   176  }
   177  
   178  // Length returns the number of message types in the protocol
   179  func (s *Spec) Length() uint64 {
   180  	return uint64(len(s.Messages))
   181  }
   182  
   183  // GetCode returns the message code of a type, and boolean second argument is
   184  // false if the message type is not found
   185  func (s *Spec) GetCode(msg interface{}) (uint64, bool) {
   186  	s.init()
   187  	typ := reflect.TypeOf(msg)
   188  	if typ.Kind() == reflect.Ptr {
   189  		typ = typ.Elem()
   190  	}
   191  	code, ok := s.codes[typ]
   192  	return code, ok
   193  }
   194  
   195  // NewMsg construct a new message type given the code
   196  func (s *Spec) NewMsg(code uint64) (interface{}, bool) {
   197  	s.init()
   198  	typ, ok := s.types[code]
   199  	if !ok {
   200  		return nil, false
   201  	}
   202  	return reflect.New(typ).Interface(), true
   203  }
   204  
   205  // Peer represents a remote peer or protocol instance that is running on a peer connection with
   206  // a remote peer
   207  type Peer struct {
   208  	*p2p.Peer                   // the p2p.Peer object representing the remote
   209  	rw        p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from
   210  	spec      *Spec
   211  }
   212  
   213  // NewPeer constructs a new peer
   214  // this constructor is called by the p2p.Protocol#Run function
   215  // the first two arguments are the arguments passed to p2p.Protocol.Run function
   216  // the third argument is the Spec describing the protocol
   217  func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
   218  	return &Peer{
   219  		Peer: p,
   220  		rw:   rw,
   221  		spec: spec,
   222  	}
   223  }
   224  
   225  // Run starts the forever loop that handles incoming messages
   226  // called within the p2p.Protocol#Run function
   227  // the handler argument is a function which is called for each message received
   228  // from the remote peer, a returned error causes the loop to exit
   229  // resulting in disconnection
   230  func (p *Peer) Run(handler func(ctx context.Context, msg interface{}) error) error {
   231  	for {
   232  		if err := p.handleIncoming(handler); err != nil {
   233  			if err != io.EOF {
   234  				metrics.GetOrRegisterCounter("peer.handleincoming.error", nil).Inc(1)
   235  				log.Error("peer.handleIncoming", "err", err)
   236  			}
   237  
   238  			return err
   239  		}
   240  	}
   241  }
   242  
   243  // Drop disconnects a peer.
   244  // TODO: may need to implement protocol drop only? don't want to kick off the peer
   245  // if they are useful for other protocols
   246  func (p *Peer) Drop(err error) {
   247  	p.Disconnect(p2p.DiscSubprotocolError)
   248  }
   249  
   250  // Send takes a message, encodes it in RLP, finds the right message code and sends the
   251  // message off to the peer
   252  // this low level call will be wrapped by libraries providing routed or broadcast sends
   253  // but often just used to forward and push messages to directly connected peers
   254  func (p *Peer) Send(ctx context.Context, msg interface{}) error {
   255  	defer metrics.GetOrRegisterResettingTimer("peer.send_t", nil).UpdateSince(time.Now())
   256  	metrics.GetOrRegisterCounter("peer.send", nil).Inc(1)
   257  
   258  	var b bytes.Buffer
   259  	if tracing.Enabled {
   260  		writer := bufio.NewWriter(&b)
   261  
   262  		tracer := opentracing.GlobalTracer()
   263  
   264  		sctx := spancontext.FromContext(ctx)
   265  
   266  		if sctx != nil {
   267  			err := tracer.Inject(
   268  				sctx,
   269  				opentracing.Binary,
   270  				writer)
   271  			if err != nil {
   272  				return err
   273  			}
   274  		}
   275  
   276  		writer.Flush()
   277  	}
   278  
   279  	r, err := rlp.EncodeToBytes(msg)
   280  	if err != nil {
   281  		return err
   282  	}
   283  
   284  	wmsg := WrappedMsg{
   285  		Context: b.Bytes(),
   286  		Size:    uint32(len(r)),
   287  		Payload: r,
   288  	}
   289  
   290  	//if the accounting hook is set, call it
   291  	if p.spec.Hook != nil {
   292  		err := p.spec.Hook.Send(p, wmsg.Size, msg)
   293  		if err != nil {
   294  			p.Drop(err)
   295  			return err
   296  		}
   297  	}
   298  
   299  	code, found := p.spec.GetCode(msg)
   300  	if !found {
   301  		return errorf(ErrInvalidMsgType, "%v", code)
   302  	}
   303  	return p2p.Send(p.rw, code, wmsg)
   304  }
   305  
   306  // handleIncoming(code)
   307  // is called each cycle of the main forever loop that dispatches incoming messages
   308  // if this returns an error the loop returns and the peer is disconnected with the error
   309  // this generic handler
   310  // * checks message size,
   311  // * checks for out-of-range message codes,
   312  // * handles decoding with reflection,
   313  // * call handlers as callbacks
   314  func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{}) error) error {
   315  	msg, err := p.rw.ReadMsg()
   316  	if err != nil {
   317  		return err
   318  	}
   319  	// make sure that the payload has been fully consumed
   320  	defer msg.Discard()
   321  
   322  	if msg.Size > p.spec.MaxMsgSize {
   323  		return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize)
   324  	}
   325  
   326  	// unmarshal wrapped msg, which might contain context
   327  	var wmsg WrappedMsg
   328  	err = msg.Decode(&wmsg)
   329  	if err != nil {
   330  		log.Error(err.Error())
   331  		return err
   332  	}
   333  
   334  	ctx := context.Background()
   335  
   336  	// if tracing is enabled and the context coming within the request is
   337  	// not empty, try to unmarshal it
   338  	if tracing.Enabled && len(wmsg.Context) > 0 {
   339  		var sctx opentracing.SpanContext
   340  
   341  		tracer := opentracing.GlobalTracer()
   342  		sctx, err = tracer.Extract(
   343  			opentracing.Binary,
   344  			bytes.NewReader(wmsg.Context))
   345  		if err != nil {
   346  			log.Error(err.Error())
   347  			return err
   348  		}
   349  
   350  		ctx = spancontext.WithContext(ctx, sctx)
   351  	}
   352  
   353  	val, ok := p.spec.NewMsg(msg.Code)
   354  	if !ok {
   355  		return errorf(ErrInvalidMsgCode, "%v", msg.Code)
   356  	}
   357  	if err := rlp.DecodeBytes(wmsg.Payload, val); err != nil {
   358  		return errorf(ErrDecode, "<= %v: %v", msg, err)
   359  	}
   360  
   361  	//if the accounting hook is set, call it
   362  	if p.spec.Hook != nil {
   363  		err := p.spec.Hook.Receive(p, wmsg.Size, val)
   364  		if err != nil {
   365  			return err
   366  		}
   367  	}
   368  
   369  	// call the registered handler callbacks
   370  	// a registered callback take the decoded message as argument as an interface
   371  	// which the handler is supposed to cast to the appropriate type
   372  	// it is entirely safe not to check the cast in the handler since the handler is
   373  	// chosen based on the proper type in the first place
   374  	if err := handle(ctx, val); err != nil {
   375  		return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err)
   376  	}
   377  	return nil
   378  }
   379  
   380  // Handshake negotiates a handshake on the peer connection
   381  // * arguments
   382  //   * context
   383  //   * the local handshake to be sent to the remote peer
   384  //   * funcion to be called on the remote handshake (can be nil)
   385  // * expects a remote handshake back of the same type
   386  // * the dialing peer needs to send the handshake first and then waits for remote
   387  // * the listening peer waits for the remote handshake and then sends it
   388  // returns the remote handshake and an error
   389  func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) {
   390  	if _, ok := p.spec.GetCode(hs); !ok {
   391  		return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs)
   392  	}
   393  	errc := make(chan error, 2)
   394  	handle := func(ctx context.Context, msg interface{}) error {
   395  		rhs = msg
   396  		if verify != nil {
   397  			return verify(rhs)
   398  		}
   399  		return nil
   400  	}
   401  	send := func() { errc <- p.Send(ctx, hs) }
   402  	receive := func() { errc <- p.handleIncoming(handle) }
   403  
   404  	go func() {
   405  		if p.Inbound() {
   406  			receive()
   407  			send()
   408  		} else {
   409  			send()
   410  			receive()
   411  		}
   412  	}()
   413  
   414  	for i := 0; i < 2; i++ {
   415  		select {
   416  		case err = <-errc:
   417  		case <-ctx.Done():
   418  			err = ctx.Err()
   419  		}
   420  		if err != nil {
   421  			return nil, errorf(ErrHandshake, err.Error())
   422  		}
   423  	}
   424  	return rhs, nil
   425  }