github.1485827954.workers.dev/ethereum/go-ethereum@v1.14.3/cmd/devp2p/internal/ethtest/conn.go (about) 1 // Copyright 2023 The go-ethereum Authors 2 // This file is part of go-ethereum. 3 // 4 // go-ethereum is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // go-ethereum is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU General Public License for more details. 13 // 14 // You should have received a copy of the GNU General Public License 15 // along with go-ethereum. If not, see <http://www.gnu.org/licenses/>. 16 17 package ethtest 18 19 import ( 20 "crypto/ecdsa" 21 "errors" 22 "fmt" 23 "net" 24 "reflect" 25 "time" 26 27 "github.com/davecgh/go-spew/spew" 28 "github.com/ethereum/go-ethereum/crypto" 29 "github.com/ethereum/go-ethereum/eth/protocols/eth" 30 "github.com/ethereum/go-ethereum/eth/protocols/snap" 31 "github.com/ethereum/go-ethereum/p2p" 32 "github.com/ethereum/go-ethereum/p2p/rlpx" 33 "github.com/ethereum/go-ethereum/rlp" 34 ) 35 36 var ( 37 pretty = spew.ConfigState{ 38 Indent: " ", 39 DisableCapacities: true, 40 DisablePointerAddresses: true, 41 SortKeys: true, 42 } 43 timeout = 2 * time.Second 44 ) 45 46 // dial attempts to dial the given node and perform a handshake, returning the 47 // created Conn if successful. 48 func (s *Suite) dial() (*Conn, error) { 49 key, _ := crypto.GenerateKey() 50 return s.dialAs(key) 51 } 52 53 // dialAs attempts to dial a given node and perform a handshake using the given 54 // private key. 55 func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) { 56 fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP())) 57 if err != nil { 58 return nil, err 59 } 60 conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())} 61 conn.ourKey = key 62 _, err = conn.Handshake(conn.ourKey) 63 if err != nil { 64 conn.Close() 65 return nil, err 66 } 67 conn.caps = []p2p.Cap{ 68 {Name: "eth", Version: 67}, 69 {Name: "eth", Version: 68}, 70 } 71 conn.ourHighestProtoVersion = 68 72 return &conn, nil 73 } 74 75 // dialSnap creates a connection with snap/1 capability. 76 func (s *Suite) dialSnap() (*Conn, error) { 77 conn, err := s.dial() 78 if err != nil { 79 return nil, fmt.Errorf("dial failed: %v", err) 80 } 81 conn.caps = append(conn.caps, p2p.Cap{Name: "snap", Version: 1}) 82 conn.ourHighestSnapProtoVersion = 1 83 return conn, nil 84 } 85 86 // Conn represents an individual connection with a peer 87 type Conn struct { 88 *rlpx.Conn 89 ourKey *ecdsa.PrivateKey 90 negotiatedProtoVersion uint 91 negotiatedSnapProtoVersion uint 92 ourHighestProtoVersion uint 93 ourHighestSnapProtoVersion uint 94 caps []p2p.Cap 95 } 96 97 // Read reads a packet from the connection. 98 func (c *Conn) Read() (uint64, []byte, error) { 99 c.SetReadDeadline(time.Now().Add(timeout)) 100 code, data, _, err := c.Conn.Read() 101 if err != nil { 102 return 0, nil, err 103 } 104 return code, data, nil 105 } 106 107 // ReadMsg attempts to read a devp2p message with a specific code. 108 func (c *Conn) ReadMsg(proto Proto, code uint64, msg any) error { 109 c.SetReadDeadline(time.Now().Add(timeout)) 110 for { 111 got, data, err := c.Read() 112 if err != nil { 113 return err 114 } 115 if protoOffset(proto)+code == got { 116 return rlp.DecodeBytes(data, msg) 117 } 118 } 119 } 120 121 // Write writes a eth packet to the connection. 122 func (c *Conn) Write(proto Proto, code uint64, msg any) error { 123 c.SetWriteDeadline(time.Now().Add(timeout)) 124 payload, err := rlp.EncodeToBytes(msg) 125 if err != nil { 126 return err 127 } 128 _, err = c.Conn.Write(protoOffset(proto)+code, payload) 129 return err 130 } 131 132 // ReadEth reads an Eth sub-protocol wire message. 133 func (c *Conn) ReadEth() (any, error) { 134 c.SetReadDeadline(time.Now().Add(timeout)) 135 for { 136 code, data, _, err := c.Conn.Read() 137 if err != nil { 138 return nil, err 139 } 140 if code == pingMsg { 141 c.Write(baseProto, pongMsg, []byte{}) 142 continue 143 } 144 if getProto(code) != ethProto { 145 // Read until eth message. 146 continue 147 } 148 code -= baseProtoLen 149 150 var msg any 151 switch int(code) { 152 case eth.StatusMsg: 153 msg = new(eth.StatusPacket) 154 case eth.GetBlockHeadersMsg: 155 msg = new(eth.GetBlockHeadersPacket) 156 case eth.BlockHeadersMsg: 157 msg = new(eth.BlockHeadersPacket) 158 case eth.GetBlockBodiesMsg: 159 msg = new(eth.GetBlockBodiesPacket) 160 case eth.BlockBodiesMsg: 161 msg = new(eth.BlockBodiesPacket) 162 case eth.NewBlockMsg: 163 msg = new(eth.NewBlockPacket) 164 case eth.NewBlockHashesMsg: 165 msg = new(eth.NewBlockHashesPacket) 166 case eth.TransactionsMsg: 167 msg = new(eth.TransactionsPacket) 168 case eth.NewPooledTransactionHashesMsg: 169 msg = new(eth.NewPooledTransactionHashesPacket) 170 case eth.GetPooledTransactionsMsg: 171 msg = new(eth.GetPooledTransactionsPacket) 172 case eth.PooledTransactionsMsg: 173 msg = new(eth.PooledTransactionsPacket) 174 default: 175 panic(fmt.Sprintf("unhandled eth msg code %d", code)) 176 } 177 if err := rlp.DecodeBytes(data, msg); err != nil { 178 return nil, fmt.Errorf("unable to decode eth msg: %v", err) 179 } 180 return msg, nil 181 } 182 } 183 184 // ReadSnap reads a snap/1 response with the given id from the connection. 185 func (c *Conn) ReadSnap() (any, error) { 186 c.SetReadDeadline(time.Now().Add(timeout)) 187 for { 188 code, data, _, err := c.Conn.Read() 189 if err != nil { 190 return nil, err 191 } 192 if getProto(code) != snapProto { 193 // Read until snap message. 194 continue 195 } 196 code -= baseProtoLen + ethProtoLen 197 198 var msg any 199 switch int(code) { 200 case snap.GetAccountRangeMsg: 201 msg = new(snap.GetAccountRangePacket) 202 case snap.AccountRangeMsg: 203 msg = new(snap.AccountRangePacket) 204 case snap.GetStorageRangesMsg: 205 msg = new(snap.GetStorageRangesPacket) 206 case snap.StorageRangesMsg: 207 msg = new(snap.StorageRangesPacket) 208 case snap.GetByteCodesMsg: 209 msg = new(snap.GetByteCodesPacket) 210 case snap.ByteCodesMsg: 211 msg = new(snap.ByteCodesPacket) 212 case snap.GetTrieNodesMsg: 213 msg = new(snap.GetTrieNodesPacket) 214 case snap.TrieNodesMsg: 215 msg = new(snap.TrieNodesPacket) 216 default: 217 panic(fmt.Errorf("unhandled snap code: %d", code)) 218 } 219 if err := rlp.DecodeBytes(data, msg); err != nil { 220 return nil, fmt.Errorf("could not rlp decode message: %v", err) 221 } 222 return msg, nil 223 } 224 } 225 226 // peer performs both the protocol handshake and the status message 227 // exchange with the node in order to peer with it. 228 func (c *Conn) peer(chain *Chain, status *eth.StatusPacket) error { 229 if err := c.handshake(); err != nil { 230 return fmt.Errorf("handshake failed: %v", err) 231 } 232 if err := c.statusExchange(chain, status); err != nil { 233 return fmt.Errorf("status exchange failed: %v", err) 234 } 235 return nil 236 } 237 238 // handshake performs a protocol handshake with the node. 239 func (c *Conn) handshake() error { 240 // Write hello to client. 241 pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:] 242 ourHandshake := &protoHandshake{ 243 Version: 5, 244 Caps: c.caps, 245 ID: pub0, 246 } 247 if err := c.Write(baseProto, handshakeMsg, ourHandshake); err != nil { 248 return fmt.Errorf("write to connection failed: %v", err) 249 } 250 // Read hello from client. 251 code, data, err := c.Read() 252 if err != nil { 253 return fmt.Errorf("erroring reading handshake: %v", err) 254 } 255 switch code { 256 case handshakeMsg: 257 msg := new(protoHandshake) 258 if err := rlp.DecodeBytes(data, &msg); err != nil { 259 return fmt.Errorf("error decoding handshake msg: %v", err) 260 } 261 // Set snappy if version is at least 5. 262 if msg.Version >= 5 { 263 c.SetSnappy(true) 264 } 265 c.negotiateEthProtocol(msg.Caps) 266 if c.negotiatedProtoVersion == 0 { 267 return fmt.Errorf("could not negotiate eth protocol (remote caps: %v, local eth version: %v)", msg.Caps, c.ourHighestProtoVersion) 268 } 269 // If we require snap, verify that it was negotiated. 270 if c.ourHighestSnapProtoVersion != c.negotiatedSnapProtoVersion { 271 return fmt.Errorf("could not negotiate snap protocol (remote caps: %v, local snap version: %v)", msg.Caps, c.ourHighestSnapProtoVersion) 272 } 273 return nil 274 default: 275 return fmt.Errorf("bad handshake: got msg code %d", code) 276 } 277 } 278 279 // negotiateEthProtocol sets the Conn's eth protocol version to highest 280 // advertised capability from peer. 281 func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) { 282 var highestEthVersion uint 283 var highestSnapVersion uint 284 for _, capability := range caps { 285 switch capability.Name { 286 case "eth": 287 if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion { 288 highestEthVersion = capability.Version 289 } 290 case "snap": 291 if capability.Version > highestSnapVersion && capability.Version <= c.ourHighestSnapProtoVersion { 292 highestSnapVersion = capability.Version 293 } 294 } 295 } 296 c.negotiatedProtoVersion = highestEthVersion 297 c.negotiatedSnapProtoVersion = highestSnapVersion 298 } 299 300 // statusExchange performs a `Status` message exchange with the given node. 301 func (c *Conn) statusExchange(chain *Chain, status *eth.StatusPacket) error { 302 loop: 303 for { 304 code, data, err := c.Read() 305 if err != nil { 306 return fmt.Errorf("failed to read from connection: %w", err) 307 } 308 switch code { 309 case eth.StatusMsg + protoOffset(ethProto): 310 msg := new(eth.StatusPacket) 311 if err := rlp.DecodeBytes(data, &msg); err != nil { 312 return fmt.Errorf("error decoding status packet: %w", err) 313 } 314 if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want { 315 return fmt.Errorf("wrong head block in status, want: %#x (block %d) have %#x", 316 want, chain.blocks[chain.Len()-1].NumberU64(), have) 317 } 318 if have, want := msg.TD.Cmp(chain.TD()), 0; have != want { 319 return fmt.Errorf("wrong TD in status: have %v want %v", have, want) 320 } 321 if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) { 322 return fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want) 323 } 324 if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) { 325 return fmt.Errorf("wrong protocol version: have %v, want %v", have, want) 326 } 327 break loop 328 case discMsg: 329 var msg []p2p.DiscReason 330 if rlp.DecodeBytes(data, &msg); len(msg) == 0 { 331 return errors.New("invalid disconnect message") 332 } 333 return fmt.Errorf("disconnect received: %v", pretty.Sdump(msg)) 334 case pingMsg: 335 // TODO (renaynay): in the future, this should be an error 336 // (PINGs should not be a response upon fresh connection) 337 c.Write(baseProto, pongMsg, nil) 338 default: 339 return fmt.Errorf("bad status message: code %d", code) 340 } 341 } 342 // make sure eth protocol version is set for negotiation 343 if c.negotiatedProtoVersion == 0 { 344 return errors.New("eth protocol version must be set in Conn") 345 } 346 if status == nil { 347 // default status message 348 status = ð.StatusPacket{ 349 ProtocolVersion: uint32(c.negotiatedProtoVersion), 350 NetworkID: chain.config.ChainID.Uint64(), 351 TD: chain.TD(), 352 Head: chain.blocks[chain.Len()-1].Hash(), 353 Genesis: chain.blocks[0].Hash(), 354 ForkID: chain.ForkID(), 355 } 356 } 357 if err := c.Write(ethProto, eth.StatusMsg, status); err != nil { 358 return fmt.Errorf("write to connection failed: %v", err) 359 } 360 return nil 361 }