github.com/DFWallet/tendermint-cosmos@v0.0.2/p2p/conn/secret_connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/hex"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"os"
    11  	"path/filepath"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  
    20  	"github.com/DFWallet/tendermint-cosmos/crypto"
    21  	"github.com/DFWallet/tendermint-cosmos/crypto/ed25519"
    22  	"github.com/DFWallet/tendermint-cosmos/crypto/sr25519"
    23  	"github.com/DFWallet/tendermint-cosmos/libs/async"
    24  	tmos "github.com/DFWallet/tendermint-cosmos/libs/os"
    25  	tmrand "github.com/DFWallet/tendermint-cosmos/libs/rand"
    26  )
    27  
    28  // Run go test -update from within this module
    29  // to update the golden test vector file
    30  var update = flag.Bool("update", false, "update .golden files")
    31  
    32  type kvstoreConn struct {
    33  	*io.PipeReader
    34  	*io.PipeWriter
    35  }
    36  
    37  func (drw kvstoreConn) Close() (err error) {
    38  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    39  	err1 := drw.PipeReader.Close()
    40  	if err2 != nil {
    41  		return err
    42  	}
    43  	return err1
    44  }
    45  
    46  type privKeyWithNilPubKey struct {
    47  	orig crypto.PrivKey
    48  }
    49  
    50  func (pk privKeyWithNilPubKey) Bytes() []byte                   { return pk.orig.Bytes() }
    51  func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
    52  func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey           { return nil }
    53  func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool  { return pk.orig.Equals(pk2) }
    54  func (pk privKeyWithNilPubKey) Type() string                    { return "privKeyWithNilPubKey" }
    55  
    56  func TestSecretConnectionHandshake(t *testing.T) {
    57  	fooSecConn, barSecConn := makeSecretConnPair(t)
    58  	if err := fooSecConn.Close(); err != nil {
    59  		t.Error(err)
    60  	}
    61  	if err := barSecConn.Close(); err != nil {
    62  		t.Error(err)
    63  	}
    64  }
    65  
    66  func TestConcurrentWrite(t *testing.T) {
    67  	fooSecConn, barSecConn := makeSecretConnPair(t)
    68  	fooWriteText := tmrand.Str(dataMaxSize)
    69  
    70  	// write from two routines.
    71  	// should be safe from race according to net.Conn:
    72  	// https://golang.org/pkg/net/#Conn
    73  	n := 100
    74  	wg := new(sync.WaitGroup)
    75  	wg.Add(3)
    76  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
    77  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
    78  
    79  	// Consume reads from bar's reader
    80  	readLots(t, wg, barSecConn, n*2)
    81  	wg.Wait()
    82  
    83  	if err := fooSecConn.Close(); err != nil {
    84  		t.Error(err)
    85  	}
    86  }
    87  
    88  func TestConcurrentRead(t *testing.T) {
    89  	fooSecConn, barSecConn := makeSecretConnPair(t)
    90  	fooWriteText := tmrand.Str(dataMaxSize)
    91  	n := 100
    92  
    93  	// read from two routines.
    94  	// should be safe from race according to net.Conn:
    95  	// https://golang.org/pkg/net/#Conn
    96  	wg := new(sync.WaitGroup)
    97  	wg.Add(3)
    98  	go readLots(t, wg, fooSecConn, n/2)
    99  	go readLots(t, wg, fooSecConn, n/2)
   100  
   101  	// write to bar
   102  	writeLots(t, wg, barSecConn, fooWriteText, n)
   103  	wg.Wait()
   104  
   105  	if err := fooSecConn.Close(); err != nil {
   106  		t.Error(err)
   107  	}
   108  }
   109  
   110  func TestSecretConnectionReadWrite(t *testing.T) {
   111  	fooConn, barConn := makeKVStoreConnPair()
   112  	fooWrites, barWrites := []string{}, []string{}
   113  	fooReads, barReads := []string{}, []string{}
   114  
   115  	// Pre-generate the things to write (for foo & bar)
   116  	for i := 0; i < 100; i++ {
   117  		fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   118  		barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   119  	}
   120  
   121  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
   122  	genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
   123  		return func(_ int) (interface{}, bool, error) {
   124  			// Initiate cryptographic private key and secret connection trhough nodeConn.
   125  			nodePrvKey := ed25519.GenPrivKey()
   126  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
   127  			if err != nil {
   128  				t.Errorf("failed to establish SecretConnection for node: %v", err)
   129  				return nil, true, err
   130  			}
   131  			// In parallel, handle some reads and writes.
   132  			var trs, ok = async.Parallel(
   133  				func(_ int) (interface{}, bool, error) {
   134  					// Node writes:
   135  					for _, nodeWrite := range nodeWrites {
   136  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   137  						if err != nil {
   138  							t.Errorf("failed to write to nodeSecretConn: %v", err)
   139  							return nil, true, err
   140  						}
   141  						if n != len(nodeWrite) {
   142  							err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   143  							t.Error(err)
   144  							return nil, true, err
   145  						}
   146  					}
   147  					if err := nodeConn.PipeWriter.Close(); err != nil {
   148  						t.Error(err)
   149  						return nil, true, err
   150  					}
   151  					return nil, false, nil
   152  				},
   153  				func(_ int) (interface{}, bool, error) {
   154  					// Node reads:
   155  					readBuffer := make([]byte, dataMaxSize)
   156  					for {
   157  						n, err := nodeSecretConn.Read(readBuffer)
   158  						if err == io.EOF {
   159  							if err := nodeConn.PipeReader.Close(); err != nil {
   160  								t.Error(err)
   161  								return nil, true, err
   162  							}
   163  							return nil, false, nil
   164  						} else if err != nil {
   165  							t.Errorf("failed to read from nodeSecretConn: %v", err)
   166  							return nil, true, err
   167  						}
   168  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   169  					}
   170  				},
   171  			)
   172  			assert.True(t, ok, "Unexpected task abortion")
   173  
   174  			// If error:
   175  			if trs.FirstError() != nil {
   176  				return nil, true, trs.FirstError()
   177  			}
   178  
   179  			// Otherwise:
   180  			return nil, false, nil
   181  		}
   182  	}
   183  
   184  	// Run foo & bar in parallel
   185  	var trs, ok = async.Parallel(
   186  		genNodeRunner("foo", fooConn, fooWrites, &fooReads),
   187  		genNodeRunner("bar", barConn, barWrites, &barReads),
   188  	)
   189  	require.Nil(t, trs.FirstError())
   190  	require.True(t, ok, "unexpected task abortion")
   191  
   192  	// A helper to ensure that the writes and reads match.
   193  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   194  	compareWritesReads := func(writes []string, reads []string) {
   195  		for {
   196  			// Pop next write & corresponding reads
   197  			var read, write string = "", writes[0]
   198  			var readCount = 0
   199  			for _, readChunk := range reads {
   200  				read += readChunk
   201  				readCount++
   202  				if len(write) <= len(read) {
   203  					break
   204  				}
   205  				if len(write) <= dataMaxSize {
   206  					break // atomicity of small writes
   207  				}
   208  			}
   209  			// Compare
   210  			if write != read {
   211  				t.Errorf("expected to read %X, got %X", write, read)
   212  			}
   213  			// Iterate
   214  			writes = writes[1:]
   215  			reads = reads[readCount:]
   216  			if len(writes) == 0 {
   217  				break
   218  			}
   219  		}
   220  	}
   221  
   222  	compareWritesReads(fooWrites, barReads)
   223  	compareWritesReads(barWrites, fooReads)
   224  }
   225  
   226  func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
   227  	goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
   228  	if *update {
   229  		t.Logf("Updating golden test vector file %s", goldenFilepath)
   230  		data := createGoldenTestVectors(t)
   231  		err := tmos.WriteFile(goldenFilepath, []byte(data), 0644)
   232  		require.NoError(t, err)
   233  	}
   234  	f, err := os.Open(goldenFilepath)
   235  	if err != nil {
   236  		log.Fatal(err)
   237  	}
   238  	defer f.Close()
   239  	scanner := bufio.NewScanner(f)
   240  	for scanner.Scan() {
   241  		line := scanner.Text()
   242  		params := strings.Split(line, ",")
   243  		randSecretVector, err := hex.DecodeString(params[0])
   244  		require.Nil(t, err)
   245  		randSecret := new([32]byte)
   246  		copy((*randSecret)[:], randSecretVector)
   247  		locIsLeast, err := strconv.ParseBool(params[1])
   248  		require.Nil(t, err)
   249  		expectedRecvSecret, err := hex.DecodeString(params[2])
   250  		require.Nil(t, err)
   251  		expectedSendSecret, err := hex.DecodeString(params[3])
   252  		require.Nil(t, err)
   253  
   254  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   255  		require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
   256  		require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
   257  	}
   258  }
   259  
   260  func TestNilPubkey(t *testing.T) {
   261  	var fooConn, barConn = makeKVStoreConnPair()
   262  	defer fooConn.Close()
   263  	defer barConn.Close()
   264  	var fooPrvKey = ed25519.GenPrivKey()
   265  	var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
   266  
   267  	go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
   268  
   269  	_, err := MakeSecretConnection(barConn, barPrvKey)
   270  	require.Error(t, err)
   271  	assert.Equal(t, "toproto: key type <nil> is not supported", err.Error())
   272  }
   273  
   274  func TestNonEd25519Pubkey(t *testing.T) {
   275  	var fooConn, barConn = makeKVStoreConnPair()
   276  	defer fooConn.Close()
   277  	defer barConn.Close()
   278  	var fooPrvKey = ed25519.GenPrivKey()
   279  	var barPrvKey = sr25519.GenPrivKey()
   280  
   281  	go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
   282  
   283  	_, err := MakeSecretConnection(barConn, barPrvKey)
   284  	require.Error(t, err)
   285  	assert.Contains(t, err.Error(), "is not supported")
   286  }
   287  
   288  func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
   289  	defer wg.Done()
   290  	for i := 0; i < n; i++ {
   291  		_, err := conn.Write([]byte(txt))
   292  		if err != nil {
   293  			t.Errorf("failed to write to fooSecConn: %v", err)
   294  			return
   295  		}
   296  	}
   297  }
   298  
   299  func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
   300  	readBuffer := make([]byte, dataMaxSize)
   301  	for i := 0; i < n; i++ {
   302  		_, err := conn.Read(readBuffer)
   303  		assert.NoError(t, err)
   304  	}
   305  	wg.Done()
   306  }
   307  
   308  // Creates the data for a test vector file.
   309  // The file format is:
   310  // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
   311  func createGoldenTestVectors(t *testing.T) string {
   312  	data := ""
   313  	for i := 0; i < 32; i++ {
   314  		randSecretVector := tmrand.Bytes(32)
   315  		randSecret := new([32]byte)
   316  		copy((*randSecret)[:], randSecretVector)
   317  		data += hex.EncodeToString((*randSecret)[:]) + ","
   318  		locIsLeast := tmrand.Bool()
   319  		data += strconv.FormatBool(locIsLeast) + ","
   320  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   321  		data += hex.EncodeToString((*recvSecret)[:]) + ","
   322  		data += hex.EncodeToString((*sendSecret)[:]) + ","
   323  	}
   324  	return data
   325  }
   326  
   327  // Each returned ReadWriteCloser is akin to a net.Connection
   328  func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
   329  	barReader, fooWriter := io.Pipe()
   330  	fooReader, barWriter := io.Pipe()
   331  	return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
   332  }
   333  
   334  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
   335  	var (
   336  		fooConn, barConn = makeKVStoreConnPair()
   337  		fooPrvKey        = ed25519.GenPrivKey()
   338  		fooPubKey        = fooPrvKey.PubKey()
   339  		barPrvKey        = ed25519.GenPrivKey()
   340  		barPubKey        = barPrvKey.PubKey()
   341  	)
   342  
   343  	// Make connections from both sides in parallel.
   344  	var trs, ok = async.Parallel(
   345  		func(_ int) (val interface{}, abort bool, err error) {
   346  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
   347  			if err != nil {
   348  				tb.Errorf("failed to establish SecretConnection for foo: %v", err)
   349  				return nil, true, err
   350  			}
   351  			remotePubBytes := fooSecConn.RemotePubKey()
   352  			if !remotePubBytes.Equals(barPubKey) {
   353  				err = fmt.Errorf("unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
   354  					barPubKey, fooSecConn.RemotePubKey())
   355  				tb.Error(err)
   356  				return nil, true, err
   357  			}
   358  			return nil, false, nil
   359  		},
   360  		func(_ int) (val interface{}, abort bool, err error) {
   361  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
   362  			if barSecConn == nil {
   363  				tb.Errorf("failed to establish SecretConnection for bar: %v", err)
   364  				return nil, true, err
   365  			}
   366  			remotePubBytes := barSecConn.RemotePubKey()
   367  			if !remotePubBytes.Equals(fooPubKey) {
   368  				err = fmt.Errorf("unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
   369  					fooPubKey, barSecConn.RemotePubKey())
   370  				tb.Error(err)
   371  				return nil, true, err
   372  			}
   373  			return nil, false, nil
   374  		},
   375  	)
   376  
   377  	require.Nil(tb, trs.FirstError())
   378  	require.True(tb, ok, "Unexpected task abortion")
   379  
   380  	return fooSecConn, barSecConn
   381  }
   382  
   383  // Benchmarks
   384  
   385  func BenchmarkWriteSecretConnection(b *testing.B) {
   386  	b.StopTimer()
   387  	b.ReportAllocs()
   388  	fooSecConn, barSecConn := makeSecretConnPair(b)
   389  	randomMsgSizes := []int{
   390  		dataMaxSize / 10,
   391  		dataMaxSize / 3,
   392  		dataMaxSize / 2,
   393  		dataMaxSize,
   394  		dataMaxSize * 3 / 2,
   395  		dataMaxSize * 2,
   396  		dataMaxSize * 7 / 2,
   397  	}
   398  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   399  	for _, size := range randomMsgSizes {
   400  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   401  	}
   402  	// Consume reads from bar's reader
   403  	go func() {
   404  		readBuffer := make([]byte, dataMaxSize)
   405  		for {
   406  			_, err := barSecConn.Read(readBuffer)
   407  			if err == io.EOF {
   408  				return
   409  			} else if err != nil {
   410  				b.Errorf("failed to read from barSecConn: %v", err)
   411  				return
   412  			}
   413  		}
   414  	}()
   415  
   416  	b.StartTimer()
   417  	for i := 0; i < b.N; i++ {
   418  		idx := tmrand.Intn(len(fooWriteBytes))
   419  		_, err := fooSecConn.Write(fooWriteBytes[idx])
   420  		if err != nil {
   421  			b.Errorf("failed to write to fooSecConn: %v", err)
   422  			return
   423  		}
   424  	}
   425  	b.StopTimer()
   426  
   427  	if err := fooSecConn.Close(); err != nil {
   428  		b.Error(err)
   429  	}
   430  	// barSecConn.Close() race condition
   431  }
   432  
   433  func BenchmarkReadSecretConnection(b *testing.B) {
   434  	b.StopTimer()
   435  	b.ReportAllocs()
   436  	fooSecConn, barSecConn := makeSecretConnPair(b)
   437  	randomMsgSizes := []int{
   438  		dataMaxSize / 10,
   439  		dataMaxSize / 3,
   440  		dataMaxSize / 2,
   441  		dataMaxSize,
   442  		dataMaxSize * 3 / 2,
   443  		dataMaxSize * 2,
   444  		dataMaxSize * 7 / 2,
   445  	}
   446  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   447  	for _, size := range randomMsgSizes {
   448  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   449  	}
   450  	go func() {
   451  		for i := 0; i < b.N; i++ {
   452  			idx := tmrand.Intn(len(fooWriteBytes))
   453  			_, err := fooSecConn.Write(fooWriteBytes[idx])
   454  			if err != nil {
   455  				b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
   456  				return
   457  			}
   458  		}
   459  	}()
   460  
   461  	b.StartTimer()
   462  	for i := 0; i < b.N; i++ {
   463  		readBuffer := make([]byte, dataMaxSize)
   464  		_, err := barSecConn.Read(readBuffer)
   465  
   466  		if err == io.EOF {
   467  			return
   468  		} else if err != nil {
   469  			b.Fatalf("Failed to read from barSecConn: %v", err)
   470  		}
   471  	}
   472  	b.StopTimer()
   473  }