github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/protocols/protocol.go (about) 1 package protocols 2 3 import ( 4 "context" 5 "fmt" 6 "reflect" 7 "sync" 8 9 "github.com/neatio-net/neatio/network/p2p" 10 ) 11 12 const ( 13 ErrMsgTooLong = iota 14 ErrDecode 15 ErrWrite 16 ErrInvalidMsgCode 17 ErrInvalidMsgType 18 ErrHandshake 19 ErrNoHandler 20 ErrHandler 21 ) 22 23 var errorToString = map[int]string{ 24 ErrMsgTooLong: "Message too long", 25 ErrDecode: "Invalid message (RLP error)", 26 ErrWrite: "Error sending message", 27 ErrInvalidMsgCode: "Invalid message code", 28 ErrInvalidMsgType: "Invalid message type", 29 ErrHandshake: "Handshake error", 30 ErrNoHandler: "No handler registered error", 31 ErrHandler: "Message handler error", 32 } 33 34 type Error struct { 35 Code int 36 message string 37 format string 38 params []interface{} 39 } 40 41 func (e Error) Error() (message string) { 42 if len(e.message) == 0 { 43 name, ok := errorToString[e.Code] 44 if !ok { 45 panic("invalid message code") 46 } 47 e.message = name 48 if e.format != "" { 49 e.message += ": " + fmt.Sprintf(e.format, e.params...) 50 } 51 } 52 return e.message 53 } 54 55 func errorf(code int, format string, params ...interface{}) *Error { 56 return &Error{ 57 Code: code, 58 format: format, 59 params: params, 60 } 61 } 62 63 type Spec struct { 64 Name string 65 66 Version uint 67 68 MaxMsgSize uint32 69 70 Messages []interface{} 71 72 initOnce sync.Once 73 codes map[reflect.Type]uint64 74 types map[uint64]reflect.Type 75 } 76 77 func (s *Spec) init() { 78 s.initOnce.Do(func() { 79 s.codes = make(map[reflect.Type]uint64, len(s.Messages)) 80 s.types = make(map[uint64]reflect.Type, len(s.Messages)) 81 for i, msg := range s.Messages { 82 code := uint64(i) 83 typ := reflect.TypeOf(msg) 84 if typ.Kind() == reflect.Ptr { 85 typ = typ.Elem() 86 } 87 s.codes[typ] = code 88 s.types[code] = typ 89 } 90 }) 91 } 92 93 func (s *Spec) Length() uint64 { 94 return uint64(len(s.Messages)) 95 } 96 97 func (s *Spec) GetCode(msg interface{}) (uint64, bool) { 98 s.init() 99 typ := reflect.TypeOf(msg) 100 if typ.Kind() == reflect.Ptr { 101 typ = typ.Elem() 102 } 103 code, ok := s.codes[typ] 104 return code, ok 105 } 106 107 func (s *Spec) NewMsg(code uint64) (interface{}, bool) { 108 s.init() 109 typ, ok := s.types[code] 110 if !ok { 111 return nil, false 112 } 113 return reflect.New(typ).Interface(), true 114 } 115 116 type Peer struct { 117 *p2p.Peer 118 rw p2p.MsgReadWriter 119 spec *Spec 120 } 121 122 func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { 123 return &Peer{ 124 Peer: p, 125 rw: rw, 126 spec: spec, 127 } 128 } 129 130 func (p *Peer) Run(handler func(msg interface{}) error) error { 131 for { 132 if err := p.handleIncoming(handler); err != nil { 133 return err 134 } 135 } 136 } 137 138 func (p *Peer) Drop(err error) { 139 p.Disconnect(p2p.DiscSubprotocolError) 140 } 141 142 func (p *Peer) Send(msg interface{}) error { 143 code, found := p.spec.GetCode(msg) 144 if !found { 145 return errorf(ErrInvalidMsgType, "%v", code) 146 } 147 return p2p.Send(p.rw, code, msg) 148 } 149 150 func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { 151 msg, err := p.rw.ReadMsg() 152 if err != nil { 153 return err 154 } 155 156 defer msg.Discard() 157 158 if msg.Size > p.spec.MaxMsgSize { 159 return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) 160 } 161 162 val, ok := p.spec.NewMsg(msg.Code) 163 if !ok { 164 return errorf(ErrInvalidMsgCode, "%v", msg.Code) 165 } 166 if err := msg.Decode(val); err != nil { 167 return errorf(ErrDecode, "<= %v: %v", msg, err) 168 } 169 170 if err := handle(val); err != nil { 171 return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err) 172 } 173 return nil 174 } 175 176 func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) { 177 if _, ok := p.spec.GetCode(hs); !ok { 178 return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs) 179 } 180 errc := make(chan error, 2) 181 handle := func(msg interface{}) error { 182 rhs = msg 183 if verify != nil { 184 return verify(rhs) 185 } 186 return nil 187 } 188 send := func() { errc <- p.Send(hs) } 189 receive := func() { errc <- p.handleIncoming(handle) } 190 191 go func() { 192 if p.Inbound() { 193 receive() 194 send() 195 } else { 196 send() 197 receive() 198 } 199 }() 200 201 for i := 0; i < 2; i++ { 202 select { 203 case err = <-errc: 204 case <-ctx.Done(): 205 err = ctx.Err() 206 } 207 if err != nil { 208 return nil, errorf(ErrHandshake, err.Error()) 209 } 210 } 211 return rhs, nil 212 }