github.com/ethereum/go-ethereum@v1.16.1/p2p/transport_test.go (about) 1 // Copyright 2020 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package p2p 18 19 import ( 20 "errors" 21 "reflect" 22 "sync" 23 "testing" 24 25 "github.com/davecgh/go-spew/spew" 26 "github.com/ethereum/go-ethereum/crypto" 27 "github.com/ethereum/go-ethereum/p2p/pipes" 28 ) 29 30 func TestProtocolHandshake(t *testing.T) { 31 var ( 32 prv0, _ = crypto.GenerateKey() 33 pub0 = crypto.FromECDSAPub(&prv0.PublicKey)[1:] 34 hs0 = &protoHandshake{Version: 3, ID: pub0, Caps: []Cap{{"a", 0}, {"b", 2}}} 35 36 prv1, _ = crypto.GenerateKey() 37 pub1 = crypto.FromECDSAPub(&prv1.PublicKey)[1:] 38 hs1 = &protoHandshake{Version: 3, ID: pub1, Caps: []Cap{{"c", 1}, {"d", 3}}} 39 40 wg sync.WaitGroup 41 ) 42 43 fd0, fd1, err := pipes.TCPPipe() 44 if err != nil { 45 t.Fatal(err) 46 } 47 48 wg.Add(2) 49 go func() { 50 defer wg.Done() 51 defer fd0.Close() 52 frame := newRLPX(fd0, &prv1.PublicKey) 53 rpubkey, err := frame.doEncHandshake(prv0) 54 if err != nil { 55 t.Errorf("dial side enc handshake failed: %v", err) 56 return 57 } 58 if !reflect.DeepEqual(rpubkey, &prv1.PublicKey) { 59 t.Errorf("dial side remote pubkey mismatch: got %v, want %v", rpubkey, &prv1.PublicKey) 60 return 61 } 62 63 phs, err := frame.doProtoHandshake(hs0) 64 if err != nil { 65 t.Errorf("dial side proto handshake error: %v", err) 66 return 67 } 68 phs.Rest = nil 69 if !reflect.DeepEqual(phs, hs1) { 70 t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1)) 71 return 72 } 73 frame.close(DiscQuitting) 74 }() 75 go func() { 76 defer wg.Done() 77 defer fd1.Close() 78 rlpx := newRLPX(fd1, nil) 79 rpubkey, err := rlpx.doEncHandshake(prv1) 80 if err != nil { 81 t.Errorf("listen side enc handshake failed: %v", err) 82 return 83 } 84 if !reflect.DeepEqual(rpubkey, &prv0.PublicKey) { 85 t.Errorf("listen side remote pubkey mismatch: got %v, want %v", rpubkey, &prv0.PublicKey) 86 return 87 } 88 89 phs, err := rlpx.doProtoHandshake(hs1) 90 if err != nil { 91 t.Errorf("listen side proto handshake error: %v", err) 92 return 93 } 94 phs.Rest = nil 95 if !reflect.DeepEqual(phs, hs0) { 96 t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0)) 97 return 98 } 99 100 if err := ExpectMsg(rlpx, discMsg, []any{DiscQuitting}); err != nil { 101 t.Errorf("error receiving disconnect: %v", err) 102 } 103 }() 104 wg.Wait() 105 } 106 107 func TestProtocolHandshakeErrors(t *testing.T) { 108 tests := []struct { 109 code uint64 110 msg interface{} 111 err error 112 }{ 113 { 114 code: discMsg, 115 msg: []any{DiscQuitting}, 116 err: DiscQuitting, 117 }, 118 { 119 // legacy disconnect encoding as byte array 120 code: discMsg, 121 msg: []byte{byte(DiscQuitting)}, 122 err: DiscQuitting, 123 }, 124 { 125 code: 0x989898, 126 msg: []byte{1}, 127 err: errors.New("expected handshake, got 989898"), 128 }, 129 { 130 code: handshakeMsg, 131 msg: make([]byte, baseProtocolMaxMsgSize+2), 132 err: errors.New("message too big"), 133 }, 134 { 135 code: handshakeMsg, 136 msg: []byte{1, 2, 3}, 137 err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"), 138 }, 139 { 140 code: handshakeMsg, 141 msg: &protoHandshake{Version: 3}, 142 err: DiscInvalidIdentity, 143 }, 144 } 145 146 for i, test := range tests { 147 p1, p2 := MsgPipe() 148 go Send(p1, test.code, test.msg) 149 _, err := readProtocolHandshake(p2) 150 if !reflect.DeepEqual(err, test.err) { 151 t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) 152 } 153 } 154 }