github.com/lazyledger/lazyledger-core@v0.35.0-dev.0.20210613111200-4c651f053571/p2p/conn/secret_connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/hex"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	mrand "math/rand"
    12  	"os"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/lazyledger/lazyledger-core/crypto"
    23  	"github.com/lazyledger/lazyledger-core/crypto/ed25519"
    24  	"github.com/lazyledger/lazyledger-core/crypto/sr25519"
    25  	"github.com/lazyledger/lazyledger-core/libs/async"
    26  	tmrand "github.com/lazyledger/lazyledger-core/libs/rand"
    27  )
    28  
    29  // Run go test -update from within this module
    30  // to update the golden test vector file
    31  var update = flag.Bool("update", false, "update .golden files")
    32  
    33  type kvstoreConn struct {
    34  	*io.PipeReader
    35  	*io.PipeWriter
    36  }
    37  
    38  func (drw kvstoreConn) Close() (err error) {
    39  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    40  	err1 := drw.PipeReader.Close()
    41  	if err2 != nil {
    42  		return err
    43  	}
    44  	return err1
    45  }
    46  
    47  type privKeyWithNilPubKey struct {
    48  	orig crypto.PrivKey
    49  }
    50  
    51  func (pk privKeyWithNilPubKey) Bytes() []byte                   { return pk.orig.Bytes() }
    52  func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
    53  func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey           { return nil }
    54  func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool  { return pk.orig.Equals(pk2) }
    55  func (pk privKeyWithNilPubKey) Type() string                    { return "privKeyWithNilPubKey" }
    56  
    57  func TestSecretConnectionHandshake(t *testing.T) {
    58  	fooSecConn, barSecConn := makeSecretConnPair(t)
    59  	if err := fooSecConn.Close(); err != nil {
    60  		t.Error(err)
    61  	}
    62  	if err := barSecConn.Close(); err != nil {
    63  		t.Error(err)
    64  	}
    65  }
    66  
    67  func TestConcurrentWrite(t *testing.T) {
    68  	fooSecConn, barSecConn := makeSecretConnPair(t)
    69  	fooWriteText := tmrand.Str(dataMaxSize)
    70  
    71  	// write from two routines.
    72  	// should be safe from race according to net.Conn:
    73  	// https://golang.org/pkg/net/#Conn
    74  	n := 100
    75  	wg := new(sync.WaitGroup)
    76  	wg.Add(3)
    77  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
    78  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
    79  
    80  	// Consume reads from bar's reader
    81  	readLots(t, wg, barSecConn, n*2)
    82  	wg.Wait()
    83  
    84  	if err := fooSecConn.Close(); err != nil {
    85  		t.Error(err)
    86  	}
    87  }
    88  
    89  func TestConcurrentRead(t *testing.T) {
    90  	fooSecConn, barSecConn := makeSecretConnPair(t)
    91  	fooWriteText := tmrand.Str(dataMaxSize)
    92  	n := 100
    93  
    94  	// read from two routines.
    95  	// should be safe from race according to net.Conn:
    96  	// https://golang.org/pkg/net/#Conn
    97  	wg := new(sync.WaitGroup)
    98  	wg.Add(3)
    99  	go readLots(t, wg, fooSecConn, n/2)
   100  	go readLots(t, wg, fooSecConn, n/2)
   101  
   102  	// write to bar
   103  	writeLots(t, wg, barSecConn, fooWriteText, n)
   104  	wg.Wait()
   105  
   106  	if err := fooSecConn.Close(); err != nil {
   107  		t.Error(err)
   108  	}
   109  }
   110  
   111  func TestSecretConnectionReadWrite(t *testing.T) {
   112  	fooConn, barConn := makeKVStoreConnPair()
   113  	fooWrites, barWrites := []string{}, []string{}
   114  	fooReads, barReads := []string{}, []string{}
   115  
   116  	// Pre-generate the things to write (for foo & bar)
   117  	for i := 0; i < 100; i++ {
   118  		fooWrites = append(fooWrites, tmrand.Str((mrand.Int()%(dataMaxSize*5))+1))
   119  		barWrites = append(barWrites, tmrand.Str((mrand.Int()%(dataMaxSize*5))+1))
   120  	}
   121  
   122  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
   123  	genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
   124  		return func(_ int) (interface{}, bool, error) {
   125  			// Initiate cryptographic private key and secret connection trhough nodeConn.
   126  			nodePrvKey := ed25519.GenPrivKey()
   127  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
   128  			if err != nil {
   129  				t.Errorf("failed to establish SecretConnection for node: %v", err)
   130  				return nil, true, err
   131  			}
   132  			// In parallel, handle some reads and writes.
   133  			var trs, ok = async.Parallel(
   134  				func(_ int) (interface{}, bool, error) {
   135  					// Node writes:
   136  					for _, nodeWrite := range nodeWrites {
   137  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   138  						if err != nil {
   139  							t.Errorf("failed to write to nodeSecretConn: %v", err)
   140  							return nil, true, err
   141  						}
   142  						if n != len(nodeWrite) {
   143  							err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   144  							t.Error(err)
   145  							return nil, true, err
   146  						}
   147  					}
   148  					if err := nodeConn.PipeWriter.Close(); err != nil {
   149  						t.Error(err)
   150  						return nil, true, err
   151  					}
   152  					return nil, false, nil
   153  				},
   154  				func(_ int) (interface{}, bool, error) {
   155  					// Node reads:
   156  					readBuffer := make([]byte, dataMaxSize)
   157  					for {
   158  						n, err := nodeSecretConn.Read(readBuffer)
   159  						if err == io.EOF {
   160  							if err := nodeConn.PipeReader.Close(); err != nil {
   161  								t.Error(err)
   162  								return nil, true, err
   163  							}
   164  							return nil, false, nil
   165  						} else if err != nil {
   166  							t.Errorf("failed to read from nodeSecretConn: %v", err)
   167  							return nil, true, err
   168  						}
   169  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   170  					}
   171  				},
   172  			)
   173  			assert.True(t, ok, "Unexpected task abortion")
   174  
   175  			// If error:
   176  			if trs.FirstError() != nil {
   177  				return nil, true, trs.FirstError()
   178  			}
   179  
   180  			// Otherwise:
   181  			return nil, false, nil
   182  		}
   183  	}
   184  
   185  	// Run foo & bar in parallel
   186  	var trs, ok = async.Parallel(
   187  		genNodeRunner("foo", fooConn, fooWrites, &fooReads),
   188  		genNodeRunner("bar", barConn, barWrites, &barReads),
   189  	)
   190  	require.Nil(t, trs.FirstError())
   191  	require.True(t, ok, "unexpected task abortion")
   192  
   193  	// A helper to ensure that the writes and reads match.
   194  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   195  	compareWritesReads := func(writes []string, reads []string) {
   196  		for {
   197  			// Pop next write & corresponding reads
   198  			var read, write string = "", writes[0]
   199  			var readCount = 0
   200  			for _, readChunk := range reads {
   201  				read += readChunk
   202  				readCount++
   203  				if len(write) <= len(read) {
   204  					break
   205  				}
   206  				if len(write) <= dataMaxSize {
   207  					break // atomicity of small writes
   208  				}
   209  			}
   210  			// Compare
   211  			if write != read {
   212  				t.Errorf("expected to read %X, got %X", write, read)
   213  			}
   214  			// Iterate
   215  			writes = writes[1:]
   216  			reads = reads[readCount:]
   217  			if len(writes) == 0 {
   218  				break
   219  			}
   220  		}
   221  	}
   222  
   223  	compareWritesReads(fooWrites, barReads)
   224  	compareWritesReads(barWrites, fooReads)
   225  }
   226  
   227  func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
   228  	goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
   229  	if *update {
   230  		t.Logf("Updating golden test vector file %s", goldenFilepath)
   231  		data := createGoldenTestVectors(t)
   232  		require.NoError(t, ioutil.WriteFile(goldenFilepath, []byte(data), 0644))
   233  	}
   234  	f, err := os.Open(goldenFilepath)
   235  	if err != nil {
   236  		log.Fatal(err)
   237  	}
   238  	t.Cleanup(closeAll(t, f))
   239  	scanner := bufio.NewScanner(f)
   240  	for scanner.Scan() {
   241  		line := scanner.Text()
   242  		params := strings.Split(line, ",")
   243  		randSecretVector, err := hex.DecodeString(params[0])
   244  		require.Nil(t, err)
   245  		randSecret := new([32]byte)
   246  		copy((*randSecret)[:], randSecretVector)
   247  		locIsLeast, err := strconv.ParseBool(params[1])
   248  		require.Nil(t, err)
   249  		expectedRecvSecret, err := hex.DecodeString(params[2])
   250  		require.Nil(t, err)
   251  		expectedSendSecret, err := hex.DecodeString(params[3])
   252  		require.Nil(t, err)
   253  
   254  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   255  		require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
   256  		require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
   257  	}
   258  }
   259  
   260  func TestNilPubkey(t *testing.T) {
   261  	var fooConn, barConn = makeKVStoreConnPair()
   262  	t.Cleanup(closeAll(t, fooConn, barConn))
   263  	var fooPrvKey = ed25519.GenPrivKey()
   264  	var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
   265  
   266  	go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
   267  
   268  	_, err := MakeSecretConnection(barConn, barPrvKey)
   269  	require.Error(t, err)
   270  	assert.Equal(t, "toproto: key type <nil> is not supported", err.Error())
   271  }
   272  
   273  func TestNonEd25519Pubkey(t *testing.T) {
   274  	var fooConn, barConn = makeKVStoreConnPair()
   275  	t.Cleanup(closeAll(t, fooConn, barConn))
   276  
   277  	var fooPrvKey = ed25519.GenPrivKey()
   278  	var barPrvKey = sr25519.GenPrivKey()
   279  
   280  	go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
   281  
   282  	_, err := MakeSecretConnection(barConn, barPrvKey)
   283  	require.Error(t, err)
   284  	assert.Contains(t, err.Error(), "is not supported")
   285  }
   286  
   287  func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
   288  	defer wg.Done()
   289  	for i := 0; i < n; i++ {
   290  		_, err := conn.Write([]byte(txt))
   291  		if err != nil {
   292  			t.Errorf("failed to write to fooSecConn: %v", err)
   293  			return
   294  		}
   295  	}
   296  }
   297  
   298  func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
   299  	readBuffer := make([]byte, dataMaxSize)
   300  	for i := 0; i < n; i++ {
   301  		_, err := conn.Read(readBuffer)
   302  		assert.NoError(t, err)
   303  	}
   304  	wg.Done()
   305  }
   306  
   307  // Creates the data for a test vector file.
   308  // The file format is:
   309  // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
   310  func createGoldenTestVectors(t *testing.T) string {
   311  	data := ""
   312  	for i := 0; i < 32; i++ {
   313  		randSecretVector := tmrand.Bytes(32)
   314  		randSecret := new([32]byte)
   315  		copy((*randSecret)[:], randSecretVector)
   316  		data += hex.EncodeToString((*randSecret)[:]) + ","
   317  		locIsLeast := mrand.Int63()%2 == 0
   318  		data += strconv.FormatBool(locIsLeast) + ","
   319  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   320  		data += hex.EncodeToString((*recvSecret)[:]) + ","
   321  		data += hex.EncodeToString((*sendSecret)[:]) + ","
   322  	}
   323  	return data
   324  }
   325  
   326  // Each returned ReadWriteCloser is akin to a net.Connection
   327  func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
   328  	barReader, fooWriter := io.Pipe()
   329  	fooReader, barWriter := io.Pipe()
   330  	return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
   331  }
   332  
   333  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
   334  	var (
   335  		fooConn, barConn = makeKVStoreConnPair()
   336  		fooPrvKey        = ed25519.GenPrivKey()
   337  		fooPubKey        = fooPrvKey.PubKey()
   338  		barPrvKey        = ed25519.GenPrivKey()
   339  		barPubKey        = barPrvKey.PubKey()
   340  	)
   341  
   342  	// Make connections from both sides in parallel.
   343  	var trs, ok = async.Parallel(
   344  		func(_ int) (val interface{}, abort bool, err error) {
   345  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
   346  			if err != nil {
   347  				tb.Errorf("failed to establish SecretConnection for foo: %v", err)
   348  				return nil, true, err
   349  			}
   350  			remotePubBytes := fooSecConn.RemotePubKey()
   351  			if !remotePubBytes.Equals(barPubKey) {
   352  				err = fmt.Errorf("unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
   353  					barPubKey, fooSecConn.RemotePubKey())
   354  				tb.Error(err)
   355  				return nil, true, err
   356  			}
   357  			return nil, false, nil
   358  		},
   359  		func(_ int) (val interface{}, abort bool, err error) {
   360  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
   361  			if barSecConn == nil {
   362  				tb.Errorf("failed to establish SecretConnection for bar: %v", err)
   363  				return nil, true, err
   364  			}
   365  			remotePubBytes := barSecConn.RemotePubKey()
   366  			if !remotePubBytes.Equals(fooPubKey) {
   367  				err = fmt.Errorf("unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
   368  					fooPubKey, barSecConn.RemotePubKey())
   369  				tb.Error(err)
   370  				return nil, true, err
   371  			}
   372  			return nil, false, nil
   373  		},
   374  	)
   375  
   376  	require.Nil(tb, trs.FirstError())
   377  	require.True(tb, ok, "Unexpected task abortion")
   378  
   379  	return fooSecConn, barSecConn
   380  }
   381  
   382  // Benchmarks
   383  
   384  func BenchmarkWriteSecretConnection(b *testing.B) {
   385  	b.StopTimer()
   386  	b.ReportAllocs()
   387  	fooSecConn, barSecConn := makeSecretConnPair(b)
   388  	randomMsgSizes := []int{
   389  		dataMaxSize / 10,
   390  		dataMaxSize / 3,
   391  		dataMaxSize / 2,
   392  		dataMaxSize,
   393  		dataMaxSize * 3 / 2,
   394  		dataMaxSize * 2,
   395  		dataMaxSize * 7 / 2,
   396  	}
   397  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   398  	for _, size := range randomMsgSizes {
   399  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   400  	}
   401  	// Consume reads from bar's reader
   402  	go func() {
   403  		readBuffer := make([]byte, dataMaxSize)
   404  		for {
   405  			_, err := barSecConn.Read(readBuffer)
   406  			if err == io.EOF {
   407  				return
   408  			} else if err != nil {
   409  				b.Errorf("failed to read from barSecConn: %v", err)
   410  				return
   411  			}
   412  		}
   413  	}()
   414  
   415  	b.StartTimer()
   416  	for i := 0; i < b.N; i++ {
   417  		idx := mrand.Intn(len(fooWriteBytes))
   418  		_, err := fooSecConn.Write(fooWriteBytes[idx])
   419  		if err != nil {
   420  			b.Errorf("failed to write to fooSecConn: %v", err)
   421  			return
   422  		}
   423  	}
   424  	b.StopTimer()
   425  
   426  	if err := fooSecConn.Close(); err != nil {
   427  		b.Error(err)
   428  	}
   429  	// barSecConn.Close() race condition
   430  }
   431  
   432  func BenchmarkReadSecretConnection(b *testing.B) {
   433  	b.StopTimer()
   434  	b.ReportAllocs()
   435  	fooSecConn, barSecConn := makeSecretConnPair(b)
   436  	randomMsgSizes := []int{
   437  		dataMaxSize / 10,
   438  		dataMaxSize / 3,
   439  		dataMaxSize / 2,
   440  		dataMaxSize,
   441  		dataMaxSize * 3 / 2,
   442  		dataMaxSize * 2,
   443  		dataMaxSize * 7 / 2,
   444  	}
   445  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   446  	for _, size := range randomMsgSizes {
   447  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   448  	}
   449  	go func() {
   450  		for i := 0; i < b.N; i++ {
   451  			idx := mrand.Intn(len(fooWriteBytes))
   452  			_, err := fooSecConn.Write(fooWriteBytes[idx])
   453  			if err != nil {
   454  				b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
   455  				return
   456  			}
   457  		}
   458  	}()
   459  
   460  	b.StartTimer()
   461  	for i := 0; i < b.N; i++ {
   462  		readBuffer := make([]byte, dataMaxSize)
   463  		_, err := barSecConn.Read(readBuffer)
   464  
   465  		if err == io.EOF {
   466  			return
   467  		} else if err != nil {
   468  			b.Fatalf("Failed to read from barSecConn: %v", err)
   469  		}
   470  	}
   471  	b.StopTimer()
   472  }