github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/p2p/protocols/protocol.go (about) 1 /* 2 Package protocols is an extension to p2p. It offers a user friendly simple way to define 3 devp2p subprotocols by abstracting away code standardly shared by protocols. 4 5 * automate assigments of code indexes to messages 6 * automate RLP decoding/encoding based on reflecting 7 * provide the forever loop to read incoming messages 8 * standardise error handling related to communication 9 * standardised handshake negotiation 10 * TODO: automatic generation of wire protocol specification for peers 11 12 */ 13 package protocols 14 15 import ( 16 "context" 17 "fmt" 18 "reflect" 19 "sync" 20 21 "github.com/quickchainproject/quickchain/p2p" 22 ) 23 24 // error codes used by this protocol scheme 25 const ( 26 ErrMsgTooLong = iota 27 ErrDecode 28 ErrWrite 29 ErrInvalidMsgCode 30 ErrInvalidMsgType 31 ErrHandshake 32 ErrNoHandler 33 ErrHandler 34 ) 35 36 // error description strings associated with the codes 37 var errorToString = map[int]string{ 38 ErrMsgTooLong: "Message too long", 39 ErrDecode: "Invalid message (RLP error)", 40 ErrWrite: "Error sending message", 41 ErrInvalidMsgCode: "Invalid message code", 42 ErrInvalidMsgType: "Invalid message type", 43 ErrHandshake: "Handshake error", 44 ErrNoHandler: "No handler registered error", 45 ErrHandler: "Message handler error", 46 } 47 48 /* 49 Error implements the standard go error interface. 50 Use: 51 52 errorf(code, format, params ...interface{}) 53 54 Prints as: 55 56 <description>: <details> 57 58 where description is given by code in errorToString 59 and details is fmt.Sprintf(format, params...) 60 61 exported field Code can be checked 62 */ 63 type Error struct { 64 Code int 65 message string 66 format string 67 params []interface{} 68 } 69 70 func (e Error) Error() (message string) { 71 if len(e.message) == 0 { 72 name, ok := errorToString[e.Code] 73 if !ok { 74 panic("invalid message code") 75 } 76 e.message = name 77 if e.format != "" { 78 e.message += ": " + fmt.Sprintf(e.format, e.params...) 79 } 80 } 81 return e.message 82 } 83 84 func errorf(code int, format string, params ...interface{}) *Error { 85 return &Error{ 86 Code: code, 87 format: format, 88 params: params, 89 } 90 } 91 92 // Spec is a protocol specification including its name and version as well as 93 // the types of messages which are exchanged 94 type Spec struct { 95 // Name is the name of the protocol, often a three-letter word 96 Name string 97 98 // Version is the version number of the protocol 99 Version uint 100 101 // MaxMsgSize is the maximum accepted length of the message payload 102 MaxMsgSize uint32 103 104 // Messages is a list of message data types which this protocol uses, with 105 // each message type being sent with its array index as the code (so 106 // [&foo{}, &bar{}, &baz{}] would send foo, bar and baz with codes 107 // 0, 1 and 2 respectively) 108 // each message must have a single unique data type 109 Messages []interface{} 110 111 initOnce sync.Once 112 codes map[reflect.Type]uint64 113 types map[uint64]reflect.Type 114 } 115 116 func (s *Spec) init() { 117 s.initOnce.Do(func() { 118 s.codes = make(map[reflect.Type]uint64, len(s.Messages)) 119 s.types = make(map[uint64]reflect.Type, len(s.Messages)) 120 for i, msg := range s.Messages { 121 code := uint64(i) 122 typ := reflect.TypeOf(msg) 123 if typ.Kind() == reflect.Ptr { 124 typ = typ.Elem() 125 } 126 s.codes[typ] = code 127 s.types[code] = typ 128 } 129 }) 130 } 131 132 // Length returns the number of message types in the protocol 133 func (s *Spec) Length() uint64 { 134 return uint64(len(s.Messages)) 135 } 136 137 // GetCode returns the message code of a type, and boolean second argument is 138 // false if the message type is not found 139 func (s *Spec) GetCode(msg interface{}) (uint64, bool) { 140 s.init() 141 typ := reflect.TypeOf(msg) 142 if typ.Kind() == reflect.Ptr { 143 typ = typ.Elem() 144 } 145 code, ok := s.codes[typ] 146 return code, ok 147 } 148 149 // NewMsg construct a new message type given the code 150 func (s *Spec) NewMsg(code uint64) (interface{}, bool) { 151 s.init() 152 typ, ok := s.types[code] 153 if !ok { 154 return nil, false 155 } 156 return reflect.New(typ).Interface(), true 157 } 158 159 // Peer represents a remote peer or protocol instance that is running on a peer connection with 160 // a remote peer 161 type Peer struct { 162 *p2p.Peer // the p2p.Peer object representing the remote 163 rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from 164 spec *Spec 165 } 166 167 // NewPeer constructs a new peer 168 // this constructor is called by the p2p.Protocol#Run function 169 // the first two arguments are the arguments passed to p2p.Protocol.Run function 170 // the third argument is the Spec describing the protocol 171 func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { 172 return &Peer{ 173 Peer: p, 174 rw: rw, 175 spec: spec, 176 } 177 } 178 179 // Run starts the forever loop that handles incoming messages 180 // called within the p2p.Protocol#Run function 181 // the handler argument is a function which is called for each message received 182 // from the remote peer, a returned error causes the loop to exit 183 // resulting in disconnection 184 func (p *Peer) Run(handler func(msg interface{}) error) error { 185 for { 186 if err := p.handleIncoming(handler); err != nil { 187 return err 188 } 189 } 190 } 191 192 // Drop disconnects a peer. 193 // TODO: may need to implement protocol drop only? don't want to kick off the peer 194 // if they are useful for other protocols 195 func (p *Peer) Drop(err error) { 196 p.Disconnect(p2p.DiscSubprotocolError) 197 } 198 199 // Send takes a message, encodes it in RLP, finds the right message code and sends the 200 // message off to the peer 201 // this low level call will be wrapped by libraries providing routed or broadcast sends 202 // but often just used to forward and push messages to directly connected peers 203 func (p *Peer) Send(msg interface{}) error { 204 code, found := p.spec.GetCode(msg) 205 if !found { 206 return errorf(ErrInvalidMsgType, "%v", code) 207 } 208 return p2p.Send(p.rw, code, msg) 209 } 210 211 // handleIncoming(code) 212 // is called each cycle of the main forever loop that dispatches incoming messages 213 // if this returns an error the loop returns and the peer is disconnected with the error 214 // this generic handler 215 // * checks message size, 216 // * checks for out-of-range message codes, 217 // * handles decoding with reflection, 218 // * call handlers as callbacks 219 func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { 220 msg, err := p.rw.ReadMsg() 221 if err != nil { 222 return err 223 } 224 // make sure that the payload has been fully consumed 225 defer msg.Discard() 226 227 if msg.Size > p.spec.MaxMsgSize { 228 return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) 229 } 230 231 val, ok := p.spec.NewMsg(msg.Code) 232 if !ok { 233 return errorf(ErrInvalidMsgCode, "%v", msg.Code) 234 } 235 if err := msg.Decode(val); err != nil { 236 return errorf(ErrDecode, "<= %v: %v", msg, err) 237 } 238 239 // call the registered handler callbacks 240 // a registered callback take the decoded message as argument as an interface 241 // which the handler is supposed to cast to the appropriate type 242 // it is entirely safe not to check the cast in the handler since the handler is 243 // chosen based on the proper type in the first place 244 if err := handle(val); err != nil { 245 return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err) 246 } 247 return nil 248 } 249 250 // Handshake negotiates a handshake on the peer connection 251 // * arguments 252 // * context 253 // * the local handshake to be sent to the remote peer 254 // * funcion to be called on the remote handshake (can be nil) 255 // * expects a remote handshake back of the same type 256 // * the dialing peer needs to send the handshake first and then waits for remote 257 // * the listening peer waits for the remote handshake and then sends it 258 // returns the remote handshake and an error 259 func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) { 260 if _, ok := p.spec.GetCode(hs); !ok { 261 return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs) 262 } 263 errc := make(chan error, 2) 264 handle := func(msg interface{}) error { 265 rhs = msg 266 if verify != nil { 267 return verify(rhs) 268 } 269 return nil 270 } 271 send := func() { errc <- p.Send(hs) } 272 receive := func() { errc <- p.handleIncoming(handle) } 273 274 go func() { 275 if p.Inbound() { 276 receive() 277 send() 278 } else { 279 send() 280 receive() 281 } 282 }() 283 284 for i := 0; i < 2; i++ { 285 select { 286 case err = <-errc: 287 case <-ctx.Done(): 288 err = ctx.Err() 289 } 290 if err != nil { 291 return nil, errorf(ErrHandshake, err.Error()) 292 } 293 } 294 return rhs, nil 295 }