github.com/aergoio/aergo@v1.3.1/p2p/handshakev2.go (about) 1 /* 2 * @file 3 * @copyright defined in aergo/LICENSE.txt 4 */ 5 6 package p2p 7 8 import ( 9 "context" 10 "encoding/binary" 11 "fmt" 12 "github.com/aergoio/aergo-lib/log" 13 "github.com/aergoio/aergo/p2p/p2pcommon" 14 "github.com/aergoio/aergo/p2p/p2putil" 15 "github.com/aergoio/aergo/types" 16 "io" 17 "time" 18 ) 19 20 // AcceptedInboundVersions is list of versions this aergosvr supports. The first is the best recommended version. 21 var AcceptedInboundVersions = []p2pcommon.P2PVersion{p2pcommon.P2PVersion032, p2pcommon.P2PVersion031, p2pcommon.P2PVersion030} 22 var AttemptingOutboundVersions = []p2pcommon.P2PVersion{p2pcommon.P2PVersion032, p2pcommon.P2PVersion031} 23 24 // baseWireHandshaker works to handshake to just connected peer, it detect chain networks 25 // and protocol versions, and then select InnerHandshaker for that protocol version. 26 type baseWireHandshaker struct { 27 pm p2pcommon.PeerManager 28 actor p2pcommon.ActorService 29 verM p2pcommon.VersionedManager 30 logger *log.Logger 31 peerID types.PeerID 32 // check if is it ad hoc 33 localChainID *types.ChainID 34 35 remoteStatus *types.Status 36 } 37 38 type InboundWireHandshaker struct { 39 baseWireHandshaker 40 } 41 42 func NewInboundHSHandler(pm p2pcommon.PeerManager, actor p2pcommon.ActorService, verManager p2pcommon.VersionedManager, log *log.Logger, chainID *types.ChainID, peerID types.PeerID) p2pcommon.HSHandler { 43 return &InboundWireHandshaker{baseWireHandshaker{pm: pm, actor: actor, verM:verManager, logger: log, localChainID: chainID, peerID: peerID}} 44 } 45 46 func (h *InboundWireHandshaker) Handle(s io.ReadWriteCloser, ttl time.Duration) (p2pcommon.MsgReadWriter, *types.Status, error) { 47 ctx, cancel := context.WithTimeout(context.Background(), ttl) 48 defer cancel() 49 return h.handleInboundPeer(ctx, s) 50 } 51 52 func (h *InboundWireHandshaker) handleInboundPeer(ctx context.Context, rwc io.ReadWriteCloser) (p2pcommon.MsgReadWriter, *types.Status, error) { 53 // wait initial hs message 54 hsReq, err := h.readWireHSRequest(rwc) 55 select { 56 case <-ctx.Done(): 57 return nil, nil, ctx.Err() 58 default: 59 // go on 60 } 61 if err != nil { 62 return h.writeErrAndReturn(err, p2pcommon.HSCodeWrongHSReq, rwc) 63 } 64 // check magic 65 if hsReq.Magic != p2pcommon.MAGICMain { 66 return h.writeErrAndReturn(fmt.Errorf("wrong magic %v",hsReq.Magic), p2pcommon.HSCodeWrongHSReq, rwc) 67 } 68 69 // continue to handshake with VersionedHandshaker 70 bestVer := h.verM.FindBestP2PVersion(hsReq.Versions) 71 if bestVer == p2pcommon.P2PVersionUnknown { 72 return h.writeErrAndReturn(fmt.Errorf("no matchied p2p version for %v", hsReq.Versions), p2pcommon.HSCodeNoMatchedVersion,rwc) 73 } else { 74 h.logger.Debug().Str(p2putil.LogPeerID, p2putil.ShortForm(h.peerID)).Str("version",bestVer.String()).Msg("Responding best p2p version") 75 resp := p2pcommon.HSHeadResp{hsReq.Magic, bestVer.Uint32()} 76 err = h.writeWireHSResponse(resp, rwc) 77 select { 78 case <-ctx.Done(): 79 return nil, nil, ctx.Err() 80 default: 81 // go on 82 } 83 if err != nil { 84 return nil, nil, err 85 } 86 } 87 innerHS, err := h.verM.GetVersionedHandshaker(bestVer, h.peerID, rwc) 88 if err != nil { 89 return nil, nil, err 90 } 91 status, err := innerHS.DoForInbound(ctx) 92 // send hs response 93 h.remoteStatus = status 94 return innerHS.GetMsgRW(), status, err 95 } 96 97 type OutboundWireHandshaker struct { 98 baseWireHandshaker 99 } 100 101 func NewOutboundHSHandler(pm p2pcommon.PeerManager, actor p2pcommon.ActorService, verManager p2pcommon.VersionedManager, log *log.Logger, chainID *types.ChainID, peerID types.PeerID) p2pcommon.HSHandler { 102 return &OutboundWireHandshaker{baseWireHandshaker{pm: pm, actor: actor, verM:verManager, logger: log, localChainID: chainID, peerID: peerID}} 103 } 104 105 func (h *OutboundWireHandshaker) Handle(s io.ReadWriteCloser, ttl time.Duration) (p2pcommon.MsgReadWriter, *types.Status, error) { 106 ctx, cancel := context.WithTimeout(context.Background(), ttl) 107 defer cancel() 108 return h.handleOutboundPeer(ctx, s) 109 } 110 111 func (h *OutboundWireHandshaker) handleOutboundPeer(ctx context.Context, rwc io.ReadWriteCloser) (p2pcommon.MsgReadWriter, *types.Status, error) { 112 // send initial hs message 113 versions := AttemptingOutboundVersions 114 115 hsHeader := p2pcommon.HSHeadReq{Magic: p2pcommon.MAGICMain, Versions: versions} 116 err := h.writeWireHSRequest(hsHeader, rwc) 117 select { 118 case <-ctx.Done(): 119 return nil, nil, ctx.Err() 120 default: 121 // go on 122 } 123 if err != nil { 124 return nil, nil, err 125 } 126 127 // read response 128 respHeader, err := h.readWireHSResp(rwc) 129 select { 130 case <-ctx.Done(): 131 return nil, nil, ctx.Err() 132 default: 133 // go on 134 } 135 if err != nil { 136 return nil, nil, err 137 } 138 // check response 139 if respHeader.Magic != hsHeader.Magic { 140 return nil, nil, fmt.Errorf("remote peer failed: %v", respHeader.RespCode) 141 } 142 bestVersion := p2pcommon.P2PVersion(respHeader.RespCode) 143 h.logger.Debug().Str(p2putil.LogPeerID, p2putil.ShortForm(h.peerID)).Str("version",bestVersion.String()).Msg("Responded best p2p version") 144 // continue to handshake with VersionedHandshaker 145 innerHS, err := h.verM.GetVersionedHandshaker(bestVersion, h.peerID, rwc) 146 if err != nil { 147 return nil, nil, err 148 } 149 status, err := innerHS.DoForOutbound(ctx) 150 h.remoteStatus = status 151 return innerHS.GetMsgRW(), status, err 152 } 153 154 func (h *baseWireHandshaker) writeWireHSRequest(hsHeader p2pcommon.HSHeadReq, wr io.Writer) (err error) { 155 bytes := hsHeader.Marshal() 156 sent, err := wr.Write(bytes) 157 if err != nil { 158 return 159 } 160 if sent != len(bytes) { 161 return fmt.Errorf("wrong sent size") 162 } 163 return 164 } 165 166 func (h *baseWireHandshaker) readWireHSRequest(rd io.Reader) (header p2pcommon.HSHeadReq, err error) { 167 buf := make([]byte, p2pcommon.HSMagicLength) 168 readn, err := p2putil.ReadToLen(rd, buf[:p2pcommon.HSMagicLength]) 169 if err != nil { 170 return 171 } 172 if readn != p2pcommon.HSMagicLength { 173 err = fmt.Errorf("transport error") 174 return 175 } 176 header.Magic = binary.BigEndian.Uint32(buf) 177 readn, err = p2putil.ReadToLen(rd, buf[:p2pcommon.HSVerCntLength]) 178 if err != nil { 179 return 180 } 181 if readn != p2pcommon.HSVerCntLength { 182 err = fmt.Errorf("transport error") 183 return 184 } 185 verCount := int(binary.BigEndian.Uint32(buf)) 186 if verCount <= 0 || verCount > p2pcommon.HSMaxVersionCnt { 187 err = fmt.Errorf("invalid version count: %d", verCount) 188 return 189 } 190 versions := make([]p2pcommon.P2PVersion, verCount) 191 for i := 0; i < verCount; i++ { 192 readn, err = p2putil.ReadToLen(rd, buf[:p2pcommon.HSVersionLength]) 193 if err != nil { 194 return 195 } 196 if readn != p2pcommon.HSVersionLength { 197 err = fmt.Errorf("transport error") 198 return 199 } 200 versions[i] = p2pcommon.P2PVersion(binary.BigEndian.Uint32(buf)) 201 } 202 header.Versions = versions 203 return 204 } 205 206 func (h *baseWireHandshaker) writeWireHSResponse(hsHeader p2pcommon.HSHeadResp, wr io.Writer) (err error) { 207 bytes := hsHeader.Marshal() 208 sent, err := wr.Write(bytes) 209 if err != nil { 210 return 211 } 212 if sent != len(bytes) { 213 return fmt.Errorf("wrong sent size") 214 } 215 return 216 } 217 218 func (h *baseWireHandshaker) writeErrAndReturn(err error, errCode uint32, wr io.Writer) (p2pcommon.MsgReadWriter, *types.Status, error) { 219 errResp := p2pcommon.HSHeadResp{p2pcommon.HSError, errCode} 220 _ = h.writeWireHSResponse(errResp, wr) 221 return nil, nil, err 222 } 223 func (h *baseWireHandshaker) readWireHSResp(rd io.Reader) (header p2pcommon.HSHeadResp, err error) { 224 bytebuf := make([]byte, p2pcommon.HSMagicLength) 225 readn, err := p2putil.ReadToLen(rd, bytebuf[:p2pcommon.HSMagicLength]) 226 if err != nil { 227 return 228 } 229 if readn != p2pcommon.HSMagicLength { 230 err = fmt.Errorf("transport error") 231 return 232 } 233 header.Magic = binary.BigEndian.Uint32(bytebuf) 234 readn, err = p2putil.ReadToLen(rd, bytebuf[:p2pcommon.HSVersionLength]) 235 if err != nil { 236 return 237 } 238 if readn != p2pcommon.HSVersionLength { 239 err = fmt.Errorf("transport error") 240 return 241 } 242 header.RespCode = binary.BigEndian.Uint32(bytebuf) 243 return 244 }