github.com/bytom/bytom@v1.1.2-0.20221014091027-bbcba3df6075/p2p/connection/secret_connection_test.go (about) 1 package connection 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "testing" 8 9 cmn "github.com/tendermint/tmlibs/common" 10 11 "github.com/bytom/bytom/crypto/ed25519/chainkd" 12 ) 13 14 type dummyConn struct { 15 *io.PipeReader 16 *io.PipeWriter 17 } 18 19 func (drw dummyConn) Close() (err error) { 20 err2 := drw.PipeWriter.CloseWithError(io.EOF) 21 err1 := drw.PipeReader.Close() 22 if err2 != nil { 23 return err 24 } 25 return err1 26 } 27 28 // Each returned ReadWriteCloser is akin to a net.Connection 29 func makeDummyConnPair() (fooConn, barConn dummyConn) { 30 barReader, fooWriter := io.Pipe() 31 fooReader, barWriter := io.Pipe() 32 return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter} 33 } 34 35 func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { 36 fooConn, barConn := makeDummyConnPair() 37 fooPrvKey, _ := chainkd.NewXPrv(nil) 38 fooPubKey := fooPrvKey.XPub() 39 barPrvKey, _ := chainkd.NewXPrv(nil) 40 barPubKey := barPrvKey.XPub() 41 42 fooSecConnTask := func(i int) (val interface{}, err error, abort bool) { 43 fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) 44 if err != nil { 45 return nil, err, false 46 } 47 48 remotePubBytes := fooSecConn.RemotePubKey() 49 if !bytes.Equal(remotePubBytes[:], barPubKey[:]) { 50 return nil, fmt.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v", barPubKey, remotePubBytes), false 51 } 52 53 return nil, nil, false 54 } 55 56 barSecConnTask := func(i int) (val interface{}, err error, about bool) { 57 barSecConn, err = MakeSecretConnection(barConn, barPrvKey) 58 if err != nil { 59 return nil, err, false 60 } 61 62 remotePubBytes := barSecConn.RemotePubKey() 63 if !bytes.Equal(remotePubBytes[:], fooPubKey[:]) { 64 return nil, fmt.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v", fooPubKey, remotePubBytes), false 65 } 66 return nil, nil, false 67 } 68 69 _, ok := cmn.Parallel(fooSecConnTask, barSecConnTask) 70 if !ok { 71 tb.Errorf("Parallel task run failed") 72 } 73 74 return 75 } 76 77 func TestSecretConnectionHandshake(t *testing.T) { 78 fooSecConn, barSecConn := makeSecretConnPair(t) 79 fooSecConn.Close() 80 barSecConn.Close() 81 } 82 83 func TestSecretConnectionReadWrite(t *testing.T) { 84 fooConn, barConn := makeDummyConnPair() 85 fooWrites, barWrites := []string{}, []string{} 86 fooReads, barReads := []string{}, []string{} 87 88 // Pre-generate the things to write (for foo & bar) 89 for i := 0; i < 100; i++ { 90 fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1)) 91 barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1)) 92 } 93 94 // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa 95 genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func(int) (interface{}, error, bool) { 96 return func(i int) (val interface{}, err error, about bool) { 97 // Node handshake 98 nodePrvKey, _ := chainkd.NewXPrv(nil) 99 nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) 100 if err != nil { 101 return nil, err, false 102 } 103 104 nodeWriteTask := func(i int) (val interface{}, err error, about bool) { 105 // Node writes 106 for _, nodeWrite := range nodeWrites { 107 n, err := nodeSecretConn.Write([]byte(nodeWrite)) 108 if err != nil { 109 t.Errorf("Failed to write to nodeSecretConn: %v", err) 110 return nil, err, false 111 } 112 if n != len(nodeWrite) { 113 t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n) 114 115 return nil, err, false 116 } 117 } 118 nodeConn.PipeWriter.Close() 119 return nil, nil, false 120 } 121 122 nodeReadsTask := func(i int) (val interface{}, err error, about bool) { 123 // Node reads 124 defer nodeConn.PipeReader.Close() 125 readBuffer := make([]byte, dataMaxSize) 126 for { 127 n, err := nodeSecretConn.Read(readBuffer) 128 if err == io.EOF { 129 return nil, nil, false 130 } else if err != nil { 131 return nil, err, false 132 } 133 *nodeReads = append(*nodeReads, string(readBuffer[:n])) 134 } 135 } 136 137 // In parallel, handle reads and writes 138 trs, ok := cmn.Parallel(nodeWriteTask, nodeReadsTask) 139 if !ok { 140 t.Errorf("Parallel task run failed") 141 } 142 for i := 0; i < 2; i++ { 143 res, ok := trs.LatestResult(i) 144 if !ok { 145 t.Errorf("Task %d did not complete", i) 146 } 147 148 if res.Error != nil { 149 t.Errorf("Task %d should not has error but god %v", i, res.Error) 150 } 151 } 152 return 153 } 154 } 155 // Run foo & bar in parallel 156 cmn.Parallel( 157 genNodeRunner(fooConn, fooWrites, &fooReads), 158 genNodeRunner(barConn, barWrites, &barReads), 159 ) 160 161 // A helper to ensure that the writes and reads match. 162 // Additionally, small writes (<= dataMaxSize) must be atomically read. 163 compareWritesReads := func(writes []string, reads []string) { 164 for { 165 // Pop next write & corresponding reads 166 var read, write string = "", writes[0] 167 var readCount = 0 168 for _, readChunk := range reads { 169 read += readChunk 170 readCount++ 171 if len(write) <= len(read) { 172 break 173 } 174 if len(write) <= dataMaxSize { 175 break // atomicity of small writes 176 } 177 } 178 // Compare 179 if write != read { 180 t.Errorf("Expected to read %X, got %X", write, read) 181 } 182 // Iterate 183 writes = writes[1:] 184 reads = reads[readCount:] 185 if len(writes) == 0 { 186 break 187 } 188 } 189 } 190 191 compareWritesReads(fooWrites, barReads) 192 compareWritesReads(barWrites, fooReads) 193 194 } 195 196 func BenchmarkSecretConnection(b *testing.B) { 197 b.StopTimer() 198 fooSecConn, barSecConn := makeSecretConnPair(b) 199 fooWriteText := cmn.RandStr(dataMaxSize) 200 // Consume reads from bar's reader 201 go func() { 202 readBuffer := make([]byte, dataMaxSize) 203 for { 204 _, err := barSecConn.Read(readBuffer) 205 if err == io.EOF { 206 return 207 } else if err != nil { 208 b.Fatalf("Failed to read from barSecConn: %v", err) 209 } 210 } 211 }() 212 213 b.StartTimer() 214 for i := 0; i < b.N; i++ { 215 _, err := fooSecConn.Write([]byte(fooWriteText)) 216 if err != nil { 217 b.Fatalf("Failed to write to fooSecConn: %v", err) 218 } 219 } 220 b.StopTimer() 221 222 fooSecConn.Close() 223 //barSecConn.Close() race condition 224 }