github.com/vipernet-xyz/tm@v0.34.24/p2p/conn/secret_connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/hex"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"os"
    11  	"path/filepath"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  
    20  	"github.com/vipernet-xyz/tm/crypto"
    21  	"github.com/vipernet-xyz/tm/crypto/ed25519"
    22  	"github.com/vipernet-xyz/tm/crypto/sr25519"
    23  	"github.com/vipernet-xyz/tm/libs/async"
    24  	tmos "github.com/vipernet-xyz/tm/libs/os"
    25  	tmrand "github.com/vipernet-xyz/tm/libs/rand"
    26  )
    27  
    28  // Run go test -update from within this module
    29  // to update the golden test vector file
    30  var update = flag.Bool("update", false, "update .golden files")
    31  
    32  type kvstoreConn struct {
    33  	*io.PipeReader
    34  	*io.PipeWriter
    35  }
    36  
    37  func (drw kvstoreConn) Close() (err error) {
    38  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    39  	err1 := drw.PipeReader.Close()
    40  	if err2 != nil {
    41  		return err
    42  	}
    43  	return err1
    44  }
    45  
    46  type privKeyWithNilPubKey struct {
    47  	orig crypto.PrivKey
    48  }
    49  
    50  func (pk privKeyWithNilPubKey) Bytes() []byte                   { return pk.orig.Bytes() }
    51  func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
    52  func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey           { return nil }
    53  func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool  { return pk.orig.Equals(pk2) }
    54  func (pk privKeyWithNilPubKey) Type() string                    { return "privKeyWithNilPubKey" }
    55  
    56  func TestSecretConnectionHandshake(t *testing.T) {
    57  	fooSecConn, barSecConn := makeSecretConnPair(t)
    58  	if err := fooSecConn.Close(); err != nil {
    59  		t.Error(err)
    60  	}
    61  	if err := barSecConn.Close(); err != nil {
    62  		t.Error(err)
    63  	}
    64  }
    65  
    66  func TestConcurrentWrite(t *testing.T) {
    67  	fooSecConn, barSecConn := makeSecretConnPair(t)
    68  	fooWriteText := tmrand.Str(dataMaxSize)
    69  
    70  	// write from two routines.
    71  	// should be safe from race according to net.Conn:
    72  	// https://golang.org/pkg/net/#Conn
    73  	n := 100
    74  	wg := new(sync.WaitGroup)
    75  	wg.Add(3)
    76  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
    77  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
    78  
    79  	// Consume reads from bar's reader
    80  	readLots(t, wg, barSecConn, n*2)
    81  	wg.Wait()
    82  
    83  	if err := fooSecConn.Close(); err != nil {
    84  		t.Error(err)
    85  	}
    86  }
    87  
    88  func TestConcurrentRead(t *testing.T) {
    89  	fooSecConn, barSecConn := makeSecretConnPair(t)
    90  	fooWriteText := tmrand.Str(dataMaxSize)
    91  	n := 100
    92  
    93  	// read from two routines.
    94  	// should be safe from race according to net.Conn:
    95  	// https://golang.org/pkg/net/#Conn
    96  	wg := new(sync.WaitGroup)
    97  	wg.Add(3)
    98  	go readLots(t, wg, fooSecConn, n/2)
    99  	go readLots(t, wg, fooSecConn, n/2)
   100  
   101  	// write to bar
   102  	writeLots(t, wg, barSecConn, fooWriteText, n)
   103  	wg.Wait()
   104  
   105  	if err := fooSecConn.Close(); err != nil {
   106  		t.Error(err)
   107  	}
   108  }
   109  
   110  func TestSecretConnectionReadWrite(t *testing.T) {
   111  	fooConn, barConn := makeKVStoreConnPair()
   112  	fooWrites, barWrites := []string{}, []string{}
   113  	fooReads, barReads := []string{}, []string{}
   114  
   115  	// Pre-generate the things to write (for foo & bar)
   116  	for i := 0; i < 100; i++ {
   117  		fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   118  		barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   119  	}
   120  
   121  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
   122  	genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
   123  		return func(_ int) (interface{}, bool, error) {
   124  			// Initiate cryptographic private key and secret connection trhough nodeConn.
   125  			nodePrvKey := ed25519.GenPrivKey()
   126  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
   127  			if err != nil {
   128  				t.Errorf("failed to establish SecretConnection for node: %v", err)
   129  				return nil, true, err
   130  			}
   131  			// In parallel, handle some reads and writes.
   132  			var trs, ok = async.Parallel(
   133  				func(_ int) (interface{}, bool, error) {
   134  					// Node writes:
   135  					for _, nodeWrite := range nodeWrites {
   136  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   137  						if err != nil {
   138  							t.Errorf("failed to write to nodeSecretConn: %v", err)
   139  							return nil, true, err
   140  						}
   141  						if n != len(nodeWrite) {
   142  							err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   143  							t.Error(err)
   144  							return nil, true, err
   145  						}
   146  					}
   147  					if err := nodeConn.PipeWriter.Close(); err != nil {
   148  						t.Error(err)
   149  						return nil, true, err
   150  					}
   151  					return nil, false, nil
   152  				},
   153  				func(_ int) (interface{}, bool, error) {
   154  					// Node reads:
   155  					readBuffer := make([]byte, dataMaxSize)
   156  					for {
   157  						n, err := nodeSecretConn.Read(readBuffer)
   158  						if err == io.EOF {
   159  							if err := nodeConn.PipeReader.Close(); err != nil {
   160  								t.Error(err)
   161  								return nil, true, err
   162  							}
   163  							return nil, false, nil
   164  						} else if err != nil {
   165  							t.Errorf("failed to read from nodeSecretConn: %v", err)
   166  							return nil, true, err
   167  						}
   168  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   169  					}
   170  				},
   171  			)
   172  			assert.True(t, ok, "Unexpected task abortion")
   173  
   174  			// If error:
   175  			if trs.FirstError() != nil {
   176  				return nil, true, trs.FirstError()
   177  			}
   178  
   179  			// Otherwise:
   180  			return nil, false, nil
   181  		}
   182  	}
   183  
   184  	// Run foo & bar in parallel
   185  	var trs, ok = async.Parallel(
   186  		genNodeRunner("foo", fooConn, fooWrites, &fooReads),
   187  		genNodeRunner("bar", barConn, barWrites, &barReads),
   188  	)
   189  	require.Nil(t, trs.FirstError())
   190  	require.True(t, ok, "unexpected task abortion")
   191  
   192  	// A helper to ensure that the writes and reads match.
   193  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   194  	compareWritesReads := func(writes []string, reads []string) {
   195  		for {
   196  			// Pop next write & corresponding reads
   197  			var read = ""
   198  			var write = writes[0]
   199  			var readCount = 0
   200  			for _, readChunk := range reads {
   201  				read += readChunk
   202  				readCount++
   203  				if len(write) <= len(read) {
   204  					break
   205  				}
   206  				if len(write) <= dataMaxSize {
   207  					break // atomicity of small writes
   208  				}
   209  			}
   210  			// Compare
   211  			if write != read {
   212  				t.Errorf("expected to read %X, got %X", write, read)
   213  			}
   214  			// Iterate
   215  			writes = writes[1:]
   216  			reads = reads[readCount:]
   217  			if len(writes) == 0 {
   218  				break
   219  			}
   220  		}
   221  	}
   222  
   223  	compareWritesReads(fooWrites, barReads)
   224  	compareWritesReads(barWrites, fooReads)
   225  }
   226  
   227  func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
   228  	goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
   229  	if *update {
   230  		t.Logf("Updating golden test vector file %s", goldenFilepath)
   231  		data := createGoldenTestVectors(t)
   232  		err := tmos.WriteFile(goldenFilepath, []byte(data), 0644)
   233  		require.NoError(t, err)
   234  	}
   235  	f, err := os.Open(goldenFilepath)
   236  	if err != nil {
   237  		log.Fatal(err)
   238  	}
   239  	defer f.Close()
   240  	scanner := bufio.NewScanner(f)
   241  	for scanner.Scan() {
   242  		line := scanner.Text()
   243  		params := strings.Split(line, ",")
   244  		randSecretVector, err := hex.DecodeString(params[0])
   245  		require.Nil(t, err)
   246  		randSecret := new([32]byte)
   247  		copy((*randSecret)[:], randSecretVector)
   248  		locIsLeast, err := strconv.ParseBool(params[1])
   249  		require.Nil(t, err)
   250  		expectedRecvSecret, err := hex.DecodeString(params[2])
   251  		require.Nil(t, err)
   252  		expectedSendSecret, err := hex.DecodeString(params[3])
   253  		require.Nil(t, err)
   254  
   255  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   256  		require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
   257  		require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
   258  	}
   259  }
   260  
   261  func TestNilPubkey(t *testing.T) {
   262  	var fooConn, barConn = makeKVStoreConnPair()
   263  	defer fooConn.Close()
   264  	defer barConn.Close()
   265  	var fooPrvKey = ed25519.GenPrivKey()
   266  	var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
   267  
   268  	go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
   269  
   270  	_, err := MakeSecretConnection(barConn, barPrvKey)
   271  	require.Error(t, err)
   272  	assert.Equal(t, "toproto: key type <nil> is not supported", err.Error())
   273  }
   274  
   275  func TestNonEd25519Pubkey(t *testing.T) {
   276  	var fooConn, barConn = makeKVStoreConnPair()
   277  	defer fooConn.Close()
   278  	defer barConn.Close()
   279  	var fooPrvKey = ed25519.GenPrivKey()
   280  	var barPrvKey = sr25519.GenPrivKey()
   281  
   282  	go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
   283  
   284  	_, err := MakeSecretConnection(barConn, barPrvKey)
   285  	require.Error(t, err)
   286  	assert.Contains(t, err.Error(), "is not supported")
   287  }
   288  
   289  func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
   290  	defer wg.Done()
   291  	for i := 0; i < n; i++ {
   292  		_, err := conn.Write([]byte(txt))
   293  		if err != nil {
   294  			t.Errorf("failed to write to fooSecConn: %v", err)
   295  			return
   296  		}
   297  	}
   298  }
   299  
   300  func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
   301  	readBuffer := make([]byte, dataMaxSize)
   302  	for i := 0; i < n; i++ {
   303  		_, err := conn.Read(readBuffer)
   304  		assert.NoError(t, err)
   305  	}
   306  	wg.Done()
   307  }
   308  
   309  // Creates the data for a test vector file.
   310  // The file format is:
   311  // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
   312  func createGoldenTestVectors(t *testing.T) string {
   313  	data := ""
   314  	for i := 0; i < 32; i++ {
   315  		randSecretVector := tmrand.Bytes(32)
   316  		randSecret := new([32]byte)
   317  		copy((*randSecret)[:], randSecretVector)
   318  		data += hex.EncodeToString((*randSecret)[:]) + ","
   319  		locIsLeast := tmrand.Bool()
   320  		data += strconv.FormatBool(locIsLeast) + ","
   321  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   322  		data += hex.EncodeToString((*recvSecret)[:]) + ","
   323  		data += hex.EncodeToString((*sendSecret)[:]) + ","
   324  	}
   325  	return data
   326  }
   327  
   328  // Each returned ReadWriteCloser is akin to a net.Connection
   329  func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
   330  	barReader, fooWriter := io.Pipe()
   331  	fooReader, barWriter := io.Pipe()
   332  	return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
   333  }
   334  
   335  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
   336  	var (
   337  		fooConn, barConn = makeKVStoreConnPair()
   338  		fooPrvKey        = ed25519.GenPrivKey()
   339  		fooPubKey        = fooPrvKey.PubKey()
   340  		barPrvKey        = ed25519.GenPrivKey()
   341  		barPubKey        = barPrvKey.PubKey()
   342  	)
   343  
   344  	// Make connections from both sides in parallel.
   345  	var trs, ok = async.Parallel(
   346  		func(_ int) (val interface{}, abort bool, err error) {
   347  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
   348  			if err != nil {
   349  				tb.Errorf("failed to establish SecretConnection for foo: %v", err)
   350  				return nil, true, err
   351  			}
   352  			remotePubBytes := fooSecConn.RemotePubKey()
   353  			if !remotePubBytes.Equals(barPubKey) {
   354  				err = fmt.Errorf("unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
   355  					barPubKey, fooSecConn.RemotePubKey())
   356  				tb.Error(err)
   357  				return nil, true, err
   358  			}
   359  			return nil, false, nil
   360  		},
   361  		func(_ int) (val interface{}, abort bool, err error) {
   362  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
   363  			if barSecConn == nil {
   364  				tb.Errorf("failed to establish SecretConnection for bar: %v", err)
   365  				return nil, true, err
   366  			}
   367  			remotePubBytes := barSecConn.RemotePubKey()
   368  			if !remotePubBytes.Equals(fooPubKey) {
   369  				err = fmt.Errorf("unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
   370  					fooPubKey, barSecConn.RemotePubKey())
   371  				tb.Error(err)
   372  				return nil, true, err
   373  			}
   374  			return nil, false, nil
   375  		},
   376  	)
   377  
   378  	require.Nil(tb, trs.FirstError())
   379  	require.True(tb, ok, "Unexpected task abortion")
   380  
   381  	return fooSecConn, barSecConn
   382  }
   383  
   384  // Benchmarks
   385  
   386  func BenchmarkWriteSecretConnection(b *testing.B) {
   387  	b.StopTimer()
   388  	b.ReportAllocs()
   389  	fooSecConn, barSecConn := makeSecretConnPair(b)
   390  	randomMsgSizes := []int{
   391  		dataMaxSize / 10,
   392  		dataMaxSize / 3,
   393  		dataMaxSize / 2,
   394  		dataMaxSize,
   395  		dataMaxSize * 3 / 2,
   396  		dataMaxSize * 2,
   397  		dataMaxSize * 7 / 2,
   398  	}
   399  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   400  	for _, size := range randomMsgSizes {
   401  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   402  	}
   403  	// Consume reads from bar's reader
   404  	go func() {
   405  		readBuffer := make([]byte, dataMaxSize)
   406  		for {
   407  			_, err := barSecConn.Read(readBuffer)
   408  			if err == io.EOF {
   409  				return
   410  			} else if err != nil {
   411  				b.Errorf("failed to read from barSecConn: %v", err)
   412  				return
   413  			}
   414  		}
   415  	}()
   416  
   417  	b.StartTimer()
   418  	for i := 0; i < b.N; i++ {
   419  		idx := tmrand.Intn(len(fooWriteBytes))
   420  		_, err := fooSecConn.Write(fooWriteBytes[idx])
   421  		if err != nil {
   422  			b.Errorf("failed to write to fooSecConn: %v", err)
   423  			return
   424  		}
   425  	}
   426  	b.StopTimer()
   427  
   428  	if err := fooSecConn.Close(); err != nil {
   429  		b.Error(err)
   430  	}
   431  	// barSecConn.Close() race condition
   432  }
   433  
   434  func BenchmarkReadSecretConnection(b *testing.B) {
   435  	b.StopTimer()
   436  	b.ReportAllocs()
   437  	fooSecConn, barSecConn := makeSecretConnPair(b)
   438  	randomMsgSizes := []int{
   439  		dataMaxSize / 10,
   440  		dataMaxSize / 3,
   441  		dataMaxSize / 2,
   442  		dataMaxSize,
   443  		dataMaxSize * 3 / 2,
   444  		dataMaxSize * 2,
   445  		dataMaxSize * 7 / 2,
   446  	}
   447  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   448  	for _, size := range randomMsgSizes {
   449  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   450  	}
   451  	go func() {
   452  		for i := 0; i < b.N; i++ {
   453  			idx := tmrand.Intn(len(fooWriteBytes))
   454  			_, err := fooSecConn.Write(fooWriteBytes[idx])
   455  			if err != nil {
   456  				b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
   457  				return
   458  			}
   459  		}
   460  	}()
   461  
   462  	b.StartTimer()
   463  	for i := 0; i < b.N; i++ {
   464  		readBuffer := make([]byte, dataMaxSize)
   465  		_, err := barSecConn.Read(readBuffer)
   466  
   467  		if err == io.EOF {
   468  			return
   469  		} else if err != nil {
   470  			b.Fatalf("Failed to read from barSecConn: %v", err)
   471  		}
   472  	}
   473  	b.StopTimer()
   474  }