github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/message.go (about) 1 package p2p 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "sync/atomic" 10 "time" 11 12 "github.com/neatio-net/neatio/network/p2p/discover" 13 "github.com/neatio-net/neatio/utilities/event" 14 "github.com/neatio-net/neatio/utilities/rlp" 15 ) 16 17 type Msg struct { 18 Code uint64 19 Size uint32 20 Payload io.Reader 21 ReceivedAt time.Time 22 } 23 24 func (msg Msg) Decode(val interface{}) error { 25 s := rlp.NewStream(msg.Payload, uint64(msg.Size)) 26 if err := s.Decode(val); err != nil { 27 return newPeerError(errInvalidMsg, "(code %x) (size %d) %v", msg.Code, msg.Size, err) 28 } 29 return nil 30 } 31 32 func (msg Msg) String() string { 33 return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size) 34 } 35 36 func (msg Msg) Discard() error { 37 _, err := io.Copy(ioutil.Discard, msg.Payload) 38 return err 39 } 40 41 type MsgReader interface { 42 ReadMsg() (Msg, error) 43 } 44 45 type MsgWriter interface { 46 WriteMsg(Msg) error 47 } 48 49 type MsgReadWriter interface { 50 MsgReader 51 MsgWriter 52 } 53 54 func Send(w MsgWriter, msgcode uint64, data interface{}) error { 55 size, r, err := rlp.EncodeToReader(data) 56 if err != nil { 57 return err 58 } 59 return w.WriteMsg(Msg{Code: msgcode, Size: uint32(size), Payload: r}) 60 } 61 62 func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error { 63 return Send(w, msgcode, elems) 64 } 65 66 type eofSignal struct { 67 wrapped io.Reader 68 count uint32 69 eof chan<- struct{} 70 } 71 72 func (r *eofSignal) Read(buf []byte) (int, error) { 73 if r.count == 0 { 74 if r.eof != nil { 75 r.eof <- struct{}{} 76 r.eof = nil 77 } 78 return 0, io.EOF 79 } 80 81 max := len(buf) 82 if int(r.count) < len(buf) { 83 max = int(r.count) 84 } 85 n, err := r.wrapped.Read(buf[:max]) 86 r.count -= uint32(n) 87 if (err != nil || r.count == 0) && r.eof != nil { 88 r.eof <- struct{}{} 89 r.eof = nil 90 } 91 return n, err 92 } 93 94 func MsgPipe() (*MsgPipeRW, *MsgPipeRW) { 95 var ( 96 c1, c2 = make(chan Msg), make(chan Msg) 97 closing = make(chan struct{}) 98 closed = new(int32) 99 rw1 = &MsgPipeRW{c1, c2, closing, closed} 100 rw2 = &MsgPipeRW{c2, c1, closing, closed} 101 ) 102 return rw1, rw2 103 } 104 105 var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe") 106 107 type MsgPipeRW struct { 108 w chan<- Msg 109 r <-chan Msg 110 closing chan struct{} 111 closed *int32 112 } 113 114 func (p *MsgPipeRW) WriteMsg(msg Msg) error { 115 if atomic.LoadInt32(p.closed) == 0 { 116 consumed := make(chan struct{}, 1) 117 msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed} 118 select { 119 case p.w <- msg: 120 if msg.Size > 0 { 121 122 select { 123 case <-consumed: 124 case <-p.closing: 125 } 126 } 127 return nil 128 case <-p.closing: 129 } 130 } 131 return ErrPipeClosed 132 } 133 134 func (p *MsgPipeRW) ReadMsg() (Msg, error) { 135 if atomic.LoadInt32(p.closed) == 0 { 136 select { 137 case msg := <-p.r: 138 return msg, nil 139 case <-p.closing: 140 } 141 } 142 return Msg{}, ErrPipeClosed 143 } 144 145 func (p *MsgPipeRW) Close() error { 146 if atomic.AddInt32(p.closed, 1) != 1 { 147 148 atomic.StoreInt32(p.closed, 1) 149 return nil 150 } 151 close(p.closing) 152 return nil 153 } 154 155 func ExpectMsg(r MsgReader, code uint64, content interface{}) error { 156 msg, err := r.ReadMsg() 157 if err != nil { 158 return err 159 } 160 if msg.Code != code { 161 return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code) 162 } 163 if content == nil { 164 return msg.Discard() 165 } else { 166 contentEnc, err := rlp.EncodeToBytes(content) 167 if err != nil { 168 panic("content encode error: " + err.Error()) 169 } 170 if int(msg.Size) != len(contentEnc) { 171 return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc)) 172 } 173 actualContent, err := ioutil.ReadAll(msg.Payload) 174 if err != nil { 175 return err 176 } 177 if !bytes.Equal(actualContent, contentEnc) { 178 return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc) 179 } 180 } 181 return nil 182 } 183 184 type msgEventer struct { 185 MsgReadWriter 186 187 feed *event.Feed 188 peerID discover.NodeID 189 Protocol string 190 } 191 192 func newMsgEventer(rw MsgReadWriter, feed *event.Feed, peerID discover.NodeID, proto string) *msgEventer { 193 return &msgEventer{ 194 MsgReadWriter: rw, 195 feed: feed, 196 peerID: peerID, 197 Protocol: proto, 198 } 199 } 200 201 func (self *msgEventer) ReadMsg() (Msg, error) { 202 msg, err := self.MsgReadWriter.ReadMsg() 203 if err != nil { 204 return msg, err 205 } 206 self.feed.Send(&PeerEvent{ 207 Type: PeerEventTypeMsgRecv, 208 Peer: self.peerID, 209 Protocol: self.Protocol, 210 MsgCode: &msg.Code, 211 MsgSize: &msg.Size, 212 }) 213 return msg, nil 214 } 215 216 func (self *msgEventer) WriteMsg(msg Msg) error { 217 err := self.MsgReadWriter.WriteMsg(msg) 218 if err != nil { 219 return err 220 } 221 self.feed.Send(&PeerEvent{ 222 Type: PeerEventTypeMsgSend, 223 Peer: self.peerID, 224 Protocol: self.Protocol, 225 MsgCode: &msg.Code, 226 MsgSize: &msg.Size, 227 }) 228 return nil 229 } 230 231 func (self *msgEventer) Close() error { 232 if v, ok := self.MsgReadWriter.(io.Closer); ok { 233 return v.Close() 234 } 235 return nil 236 }