github.com/ethereum/go-ethereum@v1.16.1/cmd/devp2p/internal/ethtest/conn.go (about) 1 // Copyright 2023 The go-ethereum Authors 2 // This file is part of go-ethereum. 3 // 4 // go-ethereum is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // go-ethereum is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU General Public License for more details. 13 // 14 // You should have received a copy of the GNU General Public License 15 // along with go-ethereum. If not, see <http://www.gnu.org/licenses/>. 16 17 package ethtest 18 19 import ( 20 "crypto/ecdsa" 21 "errors" 22 "fmt" 23 "net" 24 "reflect" 25 "time" 26 27 "github.com/davecgh/go-spew/spew" 28 "github.com/ethereum/go-ethereum/crypto" 29 "github.com/ethereum/go-ethereum/eth/protocols/eth" 30 "github.com/ethereum/go-ethereum/eth/protocols/snap" 31 "github.com/ethereum/go-ethereum/p2p" 32 "github.com/ethereum/go-ethereum/p2p/rlpx" 33 "github.com/ethereum/go-ethereum/rlp" 34 ) 35 36 var ( 37 pretty = spew.ConfigState{ 38 Indent: " ", 39 DisableCapacities: true, 40 DisablePointerAddresses: true, 41 SortKeys: true, 42 } 43 timeout = 2 * time.Second 44 ) 45 46 // dial attempts to dial the given node and perform a handshake, returning the 47 // created Conn if successful. 48 func (s *Suite) dial() (*Conn, error) { 49 key, _ := crypto.GenerateKey() 50 return s.dialAs(key) 51 } 52 53 // dialAs attempts to dial a given node and perform a handshake using the given 54 // private key. 55 func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) { 56 tcpEndpoint, _ := s.Dest.TCPEndpoint() 57 fd, err := net.Dial("tcp", tcpEndpoint.String()) 58 if err != nil { 59 return nil, err 60 } 61 conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())} 62 conn.ourKey = key 63 _, err = conn.Handshake(conn.ourKey) 64 if err != nil { 65 conn.Close() 66 return nil, err 67 } 68 conn.caps = []p2p.Cap{ 69 {Name: "eth", Version: 69}, 70 } 71 conn.ourHighestProtoVersion = 69 72 return &conn, nil 73 } 74 75 // dialSnap creates a connection with snap/1 capability. 76 func (s *Suite) dialSnap() (*Conn, error) { 77 conn, err := s.dial() 78 if err != nil { 79 return nil, fmt.Errorf("dial failed: %v", err) 80 } 81 conn.caps = append(conn.caps, p2p.Cap{Name: "snap", Version: 1}) 82 conn.ourHighestSnapProtoVersion = 1 83 return conn, nil 84 } 85 86 // Conn represents an individual connection with a peer 87 type Conn struct { 88 *rlpx.Conn 89 ourKey *ecdsa.PrivateKey 90 negotiatedProtoVersion uint 91 negotiatedSnapProtoVersion uint 92 ourHighestProtoVersion uint 93 ourHighestSnapProtoVersion uint 94 caps []p2p.Cap 95 } 96 97 // Read reads a packet from the connection. 98 func (c *Conn) Read() (uint64, []byte, error) { 99 c.SetReadDeadline(time.Now().Add(timeout)) 100 code, data, _, err := c.Conn.Read() 101 if err != nil { 102 return 0, nil, err 103 } 104 return code, data, nil 105 } 106 107 // ReadMsg attempts to read a devp2p message with a specific code. 108 func (c *Conn) ReadMsg(proto Proto, code uint64, msg any) error { 109 c.SetReadDeadline(time.Now().Add(timeout)) 110 for { 111 got, data, err := c.Read() 112 if err != nil { 113 return err 114 } 115 if protoOffset(proto)+code == got { 116 return rlp.DecodeBytes(data, msg) 117 } 118 } 119 } 120 121 // Write writes a eth packet to the connection. 122 func (c *Conn) Write(proto Proto, code uint64, msg any) error { 123 c.SetWriteDeadline(time.Now().Add(timeout)) 124 payload, err := rlp.EncodeToBytes(msg) 125 if err != nil { 126 return err 127 } 128 _, err = c.Conn.Write(protoOffset(proto)+code, payload) 129 return err 130 } 131 132 var errDisc error = fmt.Errorf("disconnect") 133 134 // ReadEth reads an Eth sub-protocol wire message. 135 func (c *Conn) ReadEth() (any, error) { 136 c.SetReadDeadline(time.Now().Add(timeout)) 137 for { 138 code, data, _, err := c.Conn.Read() 139 if code == discMsg { 140 return nil, errDisc 141 } 142 if err != nil { 143 return nil, err 144 } 145 if code == pingMsg { 146 c.Write(baseProto, pongMsg, []byte{}) 147 continue 148 } 149 if getProto(code) != ethProto { 150 // Read until eth message. 151 continue 152 } 153 code -= baseProtoLen 154 155 var msg any 156 switch int(code) { 157 case eth.StatusMsg: 158 msg = new(eth.StatusPacket69) 159 case eth.GetBlockHeadersMsg: 160 msg = new(eth.GetBlockHeadersPacket) 161 case eth.BlockHeadersMsg: 162 msg = new(eth.BlockHeadersPacket) 163 case eth.GetBlockBodiesMsg: 164 msg = new(eth.GetBlockBodiesPacket) 165 case eth.BlockBodiesMsg: 166 msg = new(eth.BlockBodiesPacket) 167 case eth.NewBlockMsg: 168 msg = new(eth.NewBlockPacket) 169 case eth.NewBlockHashesMsg: 170 msg = new(eth.NewBlockHashesPacket) 171 case eth.TransactionsMsg: 172 msg = new(eth.TransactionsPacket) 173 case eth.NewPooledTransactionHashesMsg: 174 msg = new(eth.NewPooledTransactionHashesPacket) 175 case eth.GetPooledTransactionsMsg: 176 msg = new(eth.GetPooledTransactionsPacket) 177 case eth.PooledTransactionsMsg: 178 msg = new(eth.PooledTransactionsPacket) 179 default: 180 panic(fmt.Sprintf("unhandled eth msg code %d", code)) 181 } 182 if err := rlp.DecodeBytes(data, msg); err != nil { 183 return nil, fmt.Errorf("unable to decode eth msg: %v", err) 184 } 185 return msg, nil 186 } 187 } 188 189 // ReadSnap reads a snap/1 response with the given id from the connection. 190 func (c *Conn) ReadSnap() (any, error) { 191 c.SetReadDeadline(time.Now().Add(timeout)) 192 for { 193 code, data, _, err := c.Conn.Read() 194 if err != nil { 195 return nil, err 196 } 197 if getProto(code) != snapProto { 198 // Read until snap message. 199 continue 200 } 201 code -= baseProtoLen + ethProtoLen 202 203 var msg any 204 switch int(code) { 205 case snap.GetAccountRangeMsg: 206 msg = new(snap.GetAccountRangePacket) 207 case snap.AccountRangeMsg: 208 msg = new(snap.AccountRangePacket) 209 case snap.GetStorageRangesMsg: 210 msg = new(snap.GetStorageRangesPacket) 211 case snap.StorageRangesMsg: 212 msg = new(snap.StorageRangesPacket) 213 case snap.GetByteCodesMsg: 214 msg = new(snap.GetByteCodesPacket) 215 case snap.ByteCodesMsg: 216 msg = new(snap.ByteCodesPacket) 217 case snap.GetTrieNodesMsg: 218 msg = new(snap.GetTrieNodesPacket) 219 case snap.TrieNodesMsg: 220 msg = new(snap.TrieNodesPacket) 221 default: 222 panic(fmt.Errorf("unhandled snap code: %d", code)) 223 } 224 if err := rlp.DecodeBytes(data, msg); err != nil { 225 return nil, fmt.Errorf("could not rlp decode message: %v", err) 226 } 227 return msg, nil 228 } 229 } 230 231 // dialAndPeer creates a peer connection and runs the handshake. 232 func (s *Suite) dialAndPeer(status *eth.StatusPacket69) (*Conn, error) { 233 c, err := s.dial() 234 if err != nil { 235 return nil, err 236 } 237 if err = c.peer(s.chain, status); err != nil { 238 c.Close() 239 } 240 return c, err 241 } 242 243 // peer performs both the protocol handshake and the status message 244 // exchange with the node in order to peer with it. 245 func (c *Conn) peer(chain *Chain, status *eth.StatusPacket69) error { 246 if err := c.handshake(); err != nil { 247 return fmt.Errorf("handshake failed: %v", err) 248 } 249 if err := c.statusExchange(chain, status); err != nil { 250 return fmt.Errorf("status exchange failed: %v", err) 251 } 252 return nil 253 } 254 255 // handshake performs a protocol handshake with the node. 256 func (c *Conn) handshake() error { 257 // Write hello to client. 258 pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:] 259 ourHandshake := &protoHandshake{ 260 Version: 5, 261 Caps: c.caps, 262 ID: pub0, 263 } 264 if err := c.Write(baseProto, handshakeMsg, ourHandshake); err != nil { 265 return fmt.Errorf("write to connection failed: %v", err) 266 } 267 // Read hello from client. 268 code, data, err := c.Read() 269 if err != nil { 270 return fmt.Errorf("erroring reading handshake: %v", err) 271 } 272 switch code { 273 case handshakeMsg: 274 msg := new(protoHandshake) 275 if err := rlp.DecodeBytes(data, &msg); err != nil { 276 return fmt.Errorf("error decoding handshake msg: %v", err) 277 } 278 // Set snappy if version is at least 5. 279 if msg.Version >= 5 { 280 c.SetSnappy(true) 281 } 282 c.negotiateEthProtocol(msg.Caps) 283 if c.negotiatedProtoVersion == 0 { 284 return fmt.Errorf("could not negotiate eth protocol (remote caps: %v, local eth version: %v)", msg.Caps, c.ourHighestProtoVersion) 285 } 286 // If we require snap, verify that it was negotiated. 287 if c.ourHighestSnapProtoVersion != c.negotiatedSnapProtoVersion { 288 return fmt.Errorf("could not negotiate snap protocol (remote caps: %v, local snap version: %v)", msg.Caps, c.ourHighestSnapProtoVersion) 289 } 290 return nil 291 default: 292 return fmt.Errorf("bad handshake: got msg code %d", code) 293 } 294 } 295 296 // negotiateEthProtocol sets the Conn's eth protocol version to highest 297 // advertised capability from peer. 298 func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) { 299 var highestEthVersion uint 300 var highestSnapVersion uint 301 for _, capability := range caps { 302 switch capability.Name { 303 case "eth": 304 if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion { 305 highestEthVersion = capability.Version 306 } 307 case "snap": 308 if capability.Version > highestSnapVersion && capability.Version <= c.ourHighestSnapProtoVersion { 309 highestSnapVersion = capability.Version 310 } 311 } 312 } 313 c.negotiatedProtoVersion = highestEthVersion 314 c.negotiatedSnapProtoVersion = highestSnapVersion 315 } 316 317 // statusExchange performs a `Status` message exchange with the given node. 318 func (c *Conn) statusExchange(chain *Chain, status *eth.StatusPacket69) error { 319 loop: 320 for { 321 code, data, err := c.Read() 322 if err != nil { 323 return fmt.Errorf("failed to read from connection: %w", err) 324 } 325 switch code { 326 case eth.StatusMsg + protoOffset(ethProto): 327 msg := new(eth.StatusPacket69) 328 if err := rlp.DecodeBytes(data, &msg); err != nil { 329 return fmt.Errorf("error decoding status packet: %w", err) 330 } 331 if have, want := msg.LatestBlock, chain.blocks[chain.Len()-1].NumberU64(); have != want { 332 return fmt.Errorf("wrong head block in status, want: %d, have %d", 333 want, have) 334 } 335 if have, want := msg.LatestBlockHash, chain.blocks[chain.Len()-1].Hash(); have != want { 336 return fmt.Errorf("wrong head block in status, want: %#x (block %d) have %#x", 337 want, chain.blocks[chain.Len()-1].NumberU64(), have) 338 } 339 if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) { 340 return fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want) 341 } 342 if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) { 343 return fmt.Errorf("wrong protocol version: have %v, want %v", have, want) 344 } 345 break loop 346 case discMsg: 347 var msg []p2p.DiscReason 348 if rlp.DecodeBytes(data, &msg); len(msg) == 0 { 349 return errors.New("invalid disconnect message") 350 } 351 return fmt.Errorf("disconnect received: %v", pretty.Sdump(msg)) 352 case pingMsg: 353 // TODO (renaynay): in the future, this should be an error 354 // (PINGs should not be a response upon fresh connection) 355 c.Write(baseProto, pongMsg, nil) 356 default: 357 return fmt.Errorf("bad status message: code %d", code) 358 } 359 } 360 // make sure eth protocol version is set for negotiation 361 if c.negotiatedProtoVersion == 0 { 362 return errors.New("eth protocol version must be set in Conn") 363 } 364 if status == nil { 365 // default status message 366 status = ð.StatusPacket69{ 367 ProtocolVersion: uint32(c.negotiatedProtoVersion), 368 NetworkID: chain.config.ChainID.Uint64(), 369 Genesis: chain.blocks[0].Hash(), 370 ForkID: chain.ForkID(), 371 EarliestBlock: 0, 372 LatestBlock: chain.blocks[chain.Len()-1].NumberU64(), 373 LatestBlockHash: chain.blocks[chain.Len()-1].Hash(), 374 } 375 } 376 if err := c.Write(ethProto, eth.StatusMsg, status); err != nil { 377 return fmt.Errorf("write to connection failed: %v", err) 378 } 379 return nil 380 }