github.com/Bytom/bytom@v1.1.2-0.20210127130405-ae40204c0b09/p2p/connection/secret_connection_test.go (about)

     1  package connection
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"testing"
     7  
     8  	"github.com/tendermint/go-crypto"
     9  	cmn "github.com/tendermint/tmlibs/common"
    10  )
    11  
    12  type dummyConn struct {
    13  	*io.PipeReader
    14  	*io.PipeWriter
    15  }
    16  
    17  func (drw dummyConn) Close() (err error) {
    18  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    19  	err1 := drw.PipeReader.Close()
    20  	if err2 != nil {
    21  		return err
    22  	}
    23  	return err1
    24  }
    25  
    26  // Each returned ReadWriteCloser is akin to a net.Connection
    27  func makeDummyConnPair() (fooConn, barConn dummyConn) {
    28  	barReader, fooWriter := io.Pipe()
    29  	fooReader, barWriter := io.Pipe()
    30  	return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter}
    31  }
    32  
    33  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
    34  	fooConn, barConn := makeDummyConnPair()
    35  	fooPrvKey := crypto.GenPrivKeyEd25519()
    36  	fooPubKey := fooPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
    37  	barPrvKey := crypto.GenPrivKeyEd25519()
    38  	barPubKey := barPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
    39  
    40  	cmn.Parallel(
    41  		func() {
    42  			var err error
    43  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
    44  			if err != nil {
    45  				tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
    46  				return
    47  			}
    48  			remotePubBytes := fooSecConn.RemotePubKey()
    49  			if !bytes.Equal(remotePubBytes[:], barPubKey[:]) {
    50  				tb.Errorf("Unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
    51  					barPubKey, fooSecConn.RemotePubKey())
    52  			}
    53  		},
    54  		func() {
    55  			var err error
    56  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
    57  			if barSecConn == nil {
    58  				tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
    59  				return
    60  			}
    61  			remotePubBytes := barSecConn.RemotePubKey()
    62  			if !bytes.Equal(remotePubBytes[:], fooPubKey[:]) {
    63  				tb.Errorf("Unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
    64  					fooPubKey, barSecConn.RemotePubKey())
    65  			}
    66  		})
    67  
    68  	return
    69  }
    70  
    71  func TestSecretConnectionHandshake(t *testing.T) {
    72  	fooSecConn, barSecConn := makeSecretConnPair(t)
    73  	fooSecConn.Close()
    74  	barSecConn.Close()
    75  }
    76  
    77  func TestSecretConnectionReadWrite(t *testing.T) {
    78  	fooConn, barConn := makeDummyConnPair()
    79  	fooWrites, barWrites := []string{}, []string{}
    80  	fooReads, barReads := []string{}, []string{}
    81  
    82  	// Pre-generate the things to write (for foo & bar)
    83  	for i := 0; i < 100; i++ {
    84  		fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
    85  		barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
    86  	}
    87  
    88  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
    89  	genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() {
    90  		return func() {
    91  			// Node handskae
    92  			nodePrvKey := crypto.GenPrivKeyEd25519()
    93  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
    94  			if err != nil {
    95  				t.Errorf("Failed to establish SecretConnection for node: %v", err)
    96  				return
    97  			}
    98  			// In parallel, handle reads and writes
    99  			cmn.Parallel(
   100  				func() {
   101  					// Node writes
   102  					for _, nodeWrite := range nodeWrites {
   103  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   104  						if err != nil {
   105  							t.Errorf("Failed to write to nodeSecretConn: %v", err)
   106  							return
   107  						}
   108  						if n != len(nodeWrite) {
   109  							t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   110  							return
   111  						}
   112  					}
   113  					nodeConn.PipeWriter.Close()
   114  				},
   115  				func() {
   116  					// Node reads
   117  					readBuffer := make([]byte, dataMaxSize)
   118  					for {
   119  						n, err := nodeSecretConn.Read(readBuffer)
   120  						if err == io.EOF {
   121  							return
   122  						} else if err != nil {
   123  							t.Errorf("Failed to read from nodeSecretConn: %v", err)
   124  							return
   125  						}
   126  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   127  					}
   128  					nodeConn.PipeReader.Close()
   129  				})
   130  		}
   131  	}
   132  
   133  	// Run foo & bar in parallel
   134  	cmn.Parallel(
   135  		genNodeRunner(fooConn, fooWrites, &fooReads),
   136  		genNodeRunner(barConn, barWrites, &barReads),
   137  	)
   138  
   139  	// A helper to ensure that the writes and reads match.
   140  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   141  	compareWritesReads := func(writes []string, reads []string) {
   142  		for {
   143  			// Pop next write & corresponding reads
   144  			var read, write string = "", writes[0]
   145  			var readCount = 0
   146  			for _, readChunk := range reads {
   147  				read += readChunk
   148  				readCount++
   149  				if len(write) <= len(read) {
   150  					break
   151  				}
   152  				if len(write) <= dataMaxSize {
   153  					break // atomicity of small writes
   154  				}
   155  			}
   156  			// Compare
   157  			if write != read {
   158  				t.Errorf("Expected to read %X, got %X", write, read)
   159  			}
   160  			// Iterate
   161  			writes = writes[1:]
   162  			reads = reads[readCount:]
   163  			if len(writes) == 0 {
   164  				break
   165  			}
   166  		}
   167  	}
   168  
   169  	compareWritesReads(fooWrites, barReads)
   170  	compareWritesReads(barWrites, fooReads)
   171  
   172  }
   173  
   174  func BenchmarkSecretConnection(b *testing.B) {
   175  	b.StopTimer()
   176  	fooSecConn, barSecConn := makeSecretConnPair(b)
   177  	fooWriteText := cmn.RandStr(dataMaxSize)
   178  	// Consume reads from bar's reader
   179  	go func() {
   180  		readBuffer := make([]byte, dataMaxSize)
   181  		for {
   182  			_, err := barSecConn.Read(readBuffer)
   183  			if err == io.EOF {
   184  				return
   185  			} else if err != nil {
   186  				b.Fatalf("Failed to read from barSecConn: %v", err)
   187  			}
   188  		}
   189  	}()
   190  
   191  	b.StartTimer()
   192  	for i := 0; i < b.N; i++ {
   193  		_, err := fooSecConn.Write([]byte(fooWriteText))
   194  		if err != nil {
   195  			b.Fatalf("Failed to write to fooSecConn: %v", err)
   196  		}
   197  	}
   198  	b.StopTimer()
   199  
   200  	fooSecConn.Close()
   201  	//barSecConn.Close() race condition
   202  }