github.com/vipernet-xyz/tendermint-core@v0.32.0/p2p/conn/secret_connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/hex"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"os"
    11  	"path/filepath"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  
    20  	"github.com/tendermint/tendermint/crypto"
    21  	"github.com/tendermint/tendermint/crypto/ed25519"
    22  	"github.com/tendermint/tendermint/crypto/secp256k1"
    23  	"github.com/tendermint/tendermint/libs/async"
    24  	tmos "github.com/tendermint/tendermint/libs/os"
    25  	tmrand "github.com/tendermint/tendermint/libs/rand"
    26  )
    27  
    28  type kvstoreConn struct {
    29  	*io.PipeReader
    30  	*io.PipeWriter
    31  }
    32  
    33  func (drw kvstoreConn) Close() (err error) {
    34  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    35  	err1 := drw.PipeReader.Close()
    36  	if err2 != nil {
    37  		return err
    38  	}
    39  	return err1
    40  }
    41  
    42  // Each returned ReadWriteCloser is akin to a net.Connection
    43  func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
    44  	barReader, fooWriter := io.Pipe()
    45  	fooReader, barWriter := io.Pipe()
    46  	return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
    47  }
    48  
    49  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
    50  
    51  	var fooConn, barConn = makeKVStoreConnPair()
    52  	var fooPrvKey = ed25519.GenPrivKey()
    53  	var fooPubKey = fooPrvKey.PubKey()
    54  	var barPrvKey = ed25519.GenPrivKey()
    55  	var barPubKey = barPrvKey.PubKey()
    56  
    57  	// Make connections from both sides in parallel.
    58  	var trs, ok = async.Parallel(
    59  		func(_ int) (val interface{}, abort bool, err error) {
    60  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
    61  			if err != nil {
    62  				tb.Errorf("failed to establish SecretConnection for foo: %v", err)
    63  				return nil, true, err
    64  			}
    65  			remotePubBytes := fooSecConn.RemotePubKey()
    66  			if !remotePubBytes.Equals(barPubKey) {
    67  				err = fmt.Errorf("unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
    68  					barPubKey, fooSecConn.RemotePubKey())
    69  				tb.Error(err)
    70  				return nil, false, err
    71  			}
    72  			return nil, false, nil
    73  		},
    74  		func(_ int) (val interface{}, abort bool, err error) {
    75  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
    76  			if barSecConn == nil {
    77  				tb.Errorf("failed to establish SecretConnection for bar: %v", err)
    78  				return nil, true, err
    79  			}
    80  			remotePubBytes := barSecConn.RemotePubKey()
    81  			if !remotePubBytes.Equals(fooPubKey) {
    82  				err = fmt.Errorf("unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
    83  					fooPubKey, barSecConn.RemotePubKey())
    84  				tb.Error(err)
    85  				return nil, false, nil
    86  			}
    87  			return nil, false, nil
    88  		},
    89  	)
    90  
    91  	require.Nil(tb, trs.FirstError())
    92  	require.True(tb, ok, "Unexpected task abortion")
    93  
    94  	return fooSecConn, barSecConn
    95  }
    96  
    97  func TestSecretConnectionHandshake(t *testing.T) {
    98  	fooSecConn, barSecConn := makeSecretConnPair(t)
    99  	if err := fooSecConn.Close(); err != nil {
   100  		t.Error(err)
   101  	}
   102  	if err := barSecConn.Close(); err != nil {
   103  		t.Error(err)
   104  	}
   105  }
   106  
   107  func TestConcurrentWrite(t *testing.T) {
   108  	fooSecConn, barSecConn := makeSecretConnPair(t)
   109  	fooWriteText := tmrand.Str(dataMaxSize)
   110  
   111  	// write from two routines.
   112  	// should be safe from race according to net.Conn:
   113  	// https://golang.org/pkg/net/#Conn
   114  	n := 100
   115  	wg := new(sync.WaitGroup)
   116  	wg.Add(3)
   117  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
   118  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
   119  
   120  	// Consume reads from bar's reader
   121  	readLots(t, wg, barSecConn, n*2)
   122  	wg.Wait()
   123  
   124  	if err := fooSecConn.Close(); err != nil {
   125  		t.Error(err)
   126  	}
   127  }
   128  
   129  func TestConcurrentRead(t *testing.T) {
   130  	fooSecConn, barSecConn := makeSecretConnPair(t)
   131  	fooWriteText := tmrand.Str(dataMaxSize)
   132  	n := 100
   133  
   134  	// read from two routines.
   135  	// should be safe from race according to net.Conn:
   136  	// https://golang.org/pkg/net/#Conn
   137  	wg := new(sync.WaitGroup)
   138  	wg.Add(3)
   139  	go readLots(t, wg, fooSecConn, n/2)
   140  	go readLots(t, wg, fooSecConn, n/2)
   141  
   142  	// write to bar
   143  	writeLots(t, wg, barSecConn, fooWriteText, n)
   144  	wg.Wait()
   145  
   146  	if err := fooSecConn.Close(); err != nil {
   147  		t.Error(err)
   148  	}
   149  }
   150  
   151  func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
   152  	defer wg.Done()
   153  	for i := 0; i < n; i++ {
   154  		_, err := conn.Write([]byte(txt))
   155  		if err != nil {
   156  			t.Errorf("failed to write to fooSecConn: %v", err)
   157  			return
   158  		}
   159  	}
   160  }
   161  
   162  func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
   163  	readBuffer := make([]byte, dataMaxSize)
   164  	for i := 0; i < n; i++ {
   165  		_, err := conn.Read(readBuffer)
   166  		assert.NoError(t, err)
   167  	}
   168  	wg.Done()
   169  }
   170  
   171  func TestSecretConnectionReadWrite(t *testing.T) {
   172  	fooConn, barConn := makeKVStoreConnPair()
   173  	fooWrites, barWrites := []string{}, []string{}
   174  	fooReads, barReads := []string{}, []string{}
   175  
   176  	// Pre-generate the things to write (for foo & bar)
   177  	for i := 0; i < 100; i++ {
   178  		fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   179  		barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   180  	}
   181  
   182  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
   183  	genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
   184  		return func(_ int) (interface{}, bool, error) {
   185  			// Initiate cryptographic private key and secret connection trhough nodeConn.
   186  			nodePrvKey := ed25519.GenPrivKey()
   187  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
   188  			if err != nil {
   189  				t.Errorf("failed to establish SecretConnection for node: %v", err)
   190  				return nil, true, err
   191  			}
   192  			// In parallel, handle some reads and writes.
   193  			var trs, ok = async.Parallel(
   194  				func(_ int) (interface{}, bool, error) {
   195  					// Node writes:
   196  					for _, nodeWrite := range nodeWrites {
   197  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   198  						if err != nil {
   199  							t.Errorf("failed to write to nodeSecretConn: %v", err)
   200  							return nil, true, err
   201  						}
   202  						if n != len(nodeWrite) {
   203  							err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   204  							t.Error(err)
   205  							return nil, true, err
   206  						}
   207  					}
   208  					if err := nodeConn.PipeWriter.Close(); err != nil {
   209  						t.Error(err)
   210  						return nil, true, err
   211  					}
   212  					return nil, false, nil
   213  				},
   214  				func(_ int) (interface{}, bool, error) {
   215  					// Node reads:
   216  					readBuffer := make([]byte, dataMaxSize)
   217  					for {
   218  						n, err := nodeSecretConn.Read(readBuffer)
   219  						if err == io.EOF {
   220  							if err := nodeConn.PipeReader.Close(); err != nil {
   221  								t.Error(err)
   222  								return nil, true, err
   223  							}
   224  							return nil, false, nil
   225  						} else if err != nil {
   226  							t.Errorf("failed to read from nodeSecretConn: %v", err)
   227  							return nil, true, err
   228  						}
   229  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   230  					}
   231  				},
   232  			)
   233  			assert.True(t, ok, "Unexpected task abortion")
   234  
   235  			// If error:
   236  			if trs.FirstError() != nil {
   237  				return nil, true, trs.FirstError()
   238  			}
   239  
   240  			// Otherwise:
   241  			return nil, false, nil
   242  		}
   243  	}
   244  
   245  	// Run foo & bar in parallel
   246  	var trs, ok = async.Parallel(
   247  		genNodeRunner("foo", fooConn, fooWrites, &fooReads),
   248  		genNodeRunner("bar", barConn, barWrites, &barReads),
   249  	)
   250  	require.Nil(t, trs.FirstError())
   251  	require.True(t, ok, "unexpected task abortion")
   252  
   253  	// A helper to ensure that the writes and reads match.
   254  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   255  	compareWritesReads := func(writes []string, reads []string) {
   256  		for {
   257  			// Pop next write & corresponding reads
   258  			var read, write string = "", writes[0]
   259  			var readCount = 0
   260  			for _, readChunk := range reads {
   261  				read += readChunk
   262  				readCount++
   263  				if len(write) <= len(read) {
   264  					break
   265  				}
   266  				if len(write) <= dataMaxSize {
   267  					break // atomicity of small writes
   268  				}
   269  			}
   270  			// Compare
   271  			if write != read {
   272  				t.Errorf("expected to read %X, got %X", write, read)
   273  			}
   274  			// Iterate
   275  			writes = writes[1:]
   276  			reads = reads[readCount:]
   277  			if len(writes) == 0 {
   278  				break
   279  			}
   280  		}
   281  	}
   282  
   283  	compareWritesReads(fooWrites, barReads)
   284  	compareWritesReads(barWrites, fooReads)
   285  
   286  }
   287  
   288  // Run go test -update from within this module
   289  // to update the golden test vector file
   290  var update = flag.Bool("update", false, "update .golden files")
   291  
   292  func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
   293  	goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
   294  	if *update {
   295  		t.Logf("Updating golden test vector file %s", goldenFilepath)
   296  		data := createGoldenTestVectors(t)
   297  		tmos.WriteFile(goldenFilepath, []byte(data), 0644)
   298  	}
   299  	f, err := os.Open(goldenFilepath)
   300  	if err != nil {
   301  		log.Fatal(err)
   302  	}
   303  	defer f.Close()
   304  	scanner := bufio.NewScanner(f)
   305  	for scanner.Scan() {
   306  		line := scanner.Text()
   307  		params := strings.Split(line, ",")
   308  		randSecretVector, err := hex.DecodeString(params[0])
   309  		require.Nil(t, err)
   310  		randSecret := new([32]byte)
   311  		copy((*randSecret)[:], randSecretVector)
   312  		locIsLeast, err := strconv.ParseBool(params[1])
   313  		require.Nil(t, err)
   314  		expectedRecvSecret, err := hex.DecodeString(params[2])
   315  		require.Nil(t, err)
   316  		expectedSendSecret, err := hex.DecodeString(params[3])
   317  		require.Nil(t, err)
   318  		expectedChallenge, err := hex.DecodeString(params[4])
   319  		require.Nil(t, err)
   320  
   321  		recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
   322  		require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
   323  		require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
   324  		require.Equal(t, expectedChallenge, (*challenge)[:], "challenges aren't equal")
   325  	}
   326  }
   327  
   328  type privKeyWithNilPubKey struct {
   329  	orig crypto.PrivKey
   330  }
   331  
   332  func (pk privKeyWithNilPubKey) Bytes() []byte                   { return pk.orig.Bytes() }
   333  func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
   334  func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey           { return nil }
   335  func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool  { return pk.orig.Equals(pk2) }
   336  
   337  func TestNilPubkey(t *testing.T) {
   338  	var fooConn, barConn = makeKVStoreConnPair()
   339  	var fooPrvKey = ed25519.GenPrivKey()
   340  	var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
   341  
   342  	go func() {
   343  		_, err := MakeSecretConnection(barConn, barPrvKey)
   344  		assert.NoError(t, err)
   345  	}()
   346  
   347  	assert.NotPanics(t, func() {
   348  		_, err := MakeSecretConnection(fooConn, fooPrvKey)
   349  		if assert.Error(t, err) {
   350  			assert.Equal(t, "expected ed25519 pubkey, got <nil>", err.Error())
   351  		}
   352  	})
   353  }
   354  
   355  func TestNonEd25519Pubkey(t *testing.T) {
   356  	var fooConn, barConn = makeKVStoreConnPair()
   357  	var fooPrvKey = ed25519.GenPrivKey()
   358  	var barPrvKey = secp256k1.GenPrivKey()
   359  
   360  	go func() {
   361  		_, err := MakeSecretConnection(barConn, barPrvKey)
   362  		assert.NoError(t, err)
   363  	}()
   364  
   365  	assert.NotPanics(t, func() {
   366  		_, err := MakeSecretConnection(fooConn, fooPrvKey)
   367  		if assert.Error(t, err) {
   368  			assert.Equal(t, "expected ed25519 pubkey, got secp256k1.PubKeySecp256k1", err.Error())
   369  		}
   370  	})
   371  }
   372  
   373  // Creates the data for a test vector file.
   374  // The file format is:
   375  // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
   376  func createGoldenTestVectors(t *testing.T) string {
   377  	data := ""
   378  	for i := 0; i < 32; i++ {
   379  		randSecretVector := tmrand.Bytes(32)
   380  		randSecret := new([32]byte)
   381  		copy((*randSecret)[:], randSecretVector)
   382  		data += hex.EncodeToString((*randSecret)[:]) + ","
   383  		locIsLeast := tmrand.Bool()
   384  		data += strconv.FormatBool(locIsLeast) + ","
   385  		recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
   386  		data += hex.EncodeToString((*recvSecret)[:]) + ","
   387  		data += hex.EncodeToString((*sendSecret)[:]) + ","
   388  		data += hex.EncodeToString((*challenge)[:]) + "\n"
   389  	}
   390  	return data
   391  }
   392  
   393  func BenchmarkWriteSecretConnection(b *testing.B) {
   394  	b.StopTimer()
   395  	b.ReportAllocs()
   396  	fooSecConn, barSecConn := makeSecretConnPair(b)
   397  	randomMsgSizes := []int{
   398  		dataMaxSize / 10,
   399  		dataMaxSize / 3,
   400  		dataMaxSize / 2,
   401  		dataMaxSize,
   402  		dataMaxSize * 3 / 2,
   403  		dataMaxSize * 2,
   404  		dataMaxSize * 7 / 2,
   405  	}
   406  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   407  	for _, size := range randomMsgSizes {
   408  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   409  	}
   410  	// Consume reads from bar's reader
   411  	go func() {
   412  		readBuffer := make([]byte, dataMaxSize)
   413  		for {
   414  			_, err := barSecConn.Read(readBuffer)
   415  			if err == io.EOF {
   416  				return
   417  			} else if err != nil {
   418  				b.Errorf("failed to read from barSecConn: %v", err)
   419  				return
   420  			}
   421  		}
   422  	}()
   423  
   424  	b.StartTimer()
   425  	for i := 0; i < b.N; i++ {
   426  		idx := tmrand.Intn(len(fooWriteBytes))
   427  		_, err := fooSecConn.Write(fooWriteBytes[idx])
   428  		if err != nil {
   429  			b.Errorf("failed to write to fooSecConn: %v", err)
   430  			return
   431  		}
   432  	}
   433  	b.StopTimer()
   434  
   435  	if err := fooSecConn.Close(); err != nil {
   436  		b.Error(err)
   437  	}
   438  	//barSecConn.Close() race condition
   439  }
   440  
   441  func BenchmarkReadSecretConnection(b *testing.B) {
   442  	b.StopTimer()
   443  	b.ReportAllocs()
   444  	fooSecConn, barSecConn := makeSecretConnPair(b)
   445  	randomMsgSizes := []int{
   446  		dataMaxSize / 10,
   447  		dataMaxSize / 3,
   448  		dataMaxSize / 2,
   449  		dataMaxSize,
   450  		dataMaxSize * 3 / 2,
   451  		dataMaxSize * 2,
   452  		dataMaxSize * 7 / 2,
   453  	}
   454  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   455  	for _, size := range randomMsgSizes {
   456  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   457  	}
   458  	go func() {
   459  		for i := 0; i < b.N; i++ {
   460  			idx := tmrand.Intn(len(fooWriteBytes))
   461  			_, err := fooSecConn.Write(fooWriteBytes[idx])
   462  			if err != nil {
   463  				b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
   464  				return
   465  			}
   466  		}
   467  	}()
   468  
   469  	b.StartTimer()
   470  	for i := 0; i < b.N; i++ {
   471  		readBuffer := make([]byte, dataMaxSize)
   472  		_, err := barSecConn.Read(readBuffer)
   473  
   474  		if err == io.EOF {
   475  			return
   476  		} else if err != nil {
   477  			b.Fatalf("Failed to read from barSecConn: %v", err)
   478  		}
   479  	}
   480  	b.StopTimer()
   481  }