github.com/devwanda/aphelion-staking@v0.33.9/p2p/conn/secret_connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/hex"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"os"
    11  	"path/filepath"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  
    20  	"github.com/devwanda/aphelion-staking/crypto"
    21  	"github.com/devwanda/aphelion-staking/crypto/ed25519"
    22  	"github.com/devwanda/aphelion-staking/crypto/secp256k1"
    23  	"github.com/devwanda/aphelion-staking/libs/async"
    24  	tmos "github.com/devwanda/aphelion-staking/libs/os"
    25  	tmrand "github.com/devwanda/aphelion-staking/libs/rand"
    26  )
    27  
    28  type kvstoreConn struct {
    29  	*io.PipeReader
    30  	*io.PipeWriter
    31  }
    32  
    33  func (drw kvstoreConn) Close() (err error) {
    34  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    35  	err1 := drw.PipeReader.Close()
    36  	if err2 != nil {
    37  		return err
    38  	}
    39  	return err1
    40  }
    41  
    42  // Each returned ReadWriteCloser is akin to a net.Connection
    43  func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
    44  	barReader, fooWriter := io.Pipe()
    45  	fooReader, barWriter := io.Pipe()
    46  	return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
    47  }
    48  
    49  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
    50  
    51  	var fooConn, barConn = makeKVStoreConnPair()
    52  	var fooPrvKey = ed25519.GenPrivKey()
    53  	var fooPubKey = fooPrvKey.PubKey()
    54  	var barPrvKey = ed25519.GenPrivKey()
    55  	var barPubKey = barPrvKey.PubKey()
    56  
    57  	// Make connections from both sides in parallel.
    58  	var trs, ok = async.Parallel(
    59  		func(_ int) (val interface{}, abort bool, err error) {
    60  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
    61  			if err != nil {
    62  				tb.Errorf("failed to establish SecretConnection for foo: %v", err)
    63  				return nil, true, err
    64  			}
    65  			remotePubBytes := fooSecConn.RemotePubKey()
    66  			if !remotePubBytes.Equals(barPubKey) {
    67  				err = fmt.Errorf("unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
    68  					barPubKey, fooSecConn.RemotePubKey())
    69  				tb.Error(err)
    70  				return nil, false, err
    71  			}
    72  			return nil, false, nil
    73  		},
    74  		func(_ int) (val interface{}, abort bool, err error) {
    75  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
    76  			if barSecConn == nil {
    77  				tb.Errorf("failed to establish SecretConnection for bar: %v", err)
    78  				return nil, true, err
    79  			}
    80  			remotePubBytes := barSecConn.RemotePubKey()
    81  			if !remotePubBytes.Equals(fooPubKey) {
    82  				err = fmt.Errorf("unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
    83  					fooPubKey, barSecConn.RemotePubKey())
    84  				tb.Error(err)
    85  				return nil, false, nil
    86  			}
    87  			return nil, false, nil
    88  		},
    89  	)
    90  
    91  	require.Nil(tb, trs.FirstError())
    92  	require.True(tb, ok, "Unexpected task abortion")
    93  
    94  	return fooSecConn, barSecConn
    95  }
    96  
    97  func TestSecretConnectionHandshake(t *testing.T) {
    98  	fooSecConn, barSecConn := makeSecretConnPair(t)
    99  	if err := fooSecConn.Close(); err != nil {
   100  		t.Error(err)
   101  	}
   102  	if err := barSecConn.Close(); err != nil {
   103  		t.Error(err)
   104  	}
   105  }
   106  
   107  func TestConcurrentWrite(t *testing.T) {
   108  	fooSecConn, barSecConn := makeSecretConnPair(t)
   109  	fooWriteText := tmrand.Str(dataMaxSize)
   110  
   111  	// write from two routines.
   112  	// should be safe from race according to net.Conn:
   113  	// https://golang.org/pkg/net/#Conn
   114  	n := 100
   115  	wg := new(sync.WaitGroup)
   116  	wg.Add(3)
   117  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
   118  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
   119  
   120  	// Consume reads from bar's reader
   121  	readLots(t, wg, barSecConn, n*2)
   122  	wg.Wait()
   123  
   124  	if err := fooSecConn.Close(); err != nil {
   125  		t.Error(err)
   126  	}
   127  }
   128  
   129  func TestConcurrentRead(t *testing.T) {
   130  	fooSecConn, barSecConn := makeSecretConnPair(t)
   131  	fooWriteText := tmrand.Str(dataMaxSize)
   132  	n := 100
   133  
   134  	// read from two routines.
   135  	// should be safe from race according to net.Conn:
   136  	// https://golang.org/pkg/net/#Conn
   137  	wg := new(sync.WaitGroup)
   138  	wg.Add(3)
   139  	go readLots(t, wg, fooSecConn, n/2)
   140  	go readLots(t, wg, fooSecConn, n/2)
   141  
   142  	// write to bar
   143  	writeLots(t, wg, barSecConn, fooWriteText, n)
   144  	wg.Wait()
   145  
   146  	if err := fooSecConn.Close(); err != nil {
   147  		t.Error(err)
   148  	}
   149  }
   150  
   151  func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
   152  	defer wg.Done()
   153  	for i := 0; i < n; i++ {
   154  		_, err := conn.Write([]byte(txt))
   155  		if err != nil {
   156  			t.Errorf("failed to write to fooSecConn: %v", err)
   157  			return
   158  		}
   159  	}
   160  }
   161  
   162  func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
   163  	readBuffer := make([]byte, dataMaxSize)
   164  	for i := 0; i < n; i++ {
   165  		_, err := conn.Read(readBuffer)
   166  		assert.NoError(t, err)
   167  	}
   168  	wg.Done()
   169  }
   170  
   171  func TestSecretConnectionReadWrite(t *testing.T) {
   172  	fooConn, barConn := makeKVStoreConnPair()
   173  	fooWrites, barWrites := []string{}, []string{}
   174  	fooReads, barReads := []string{}, []string{}
   175  
   176  	// Pre-generate the things to write (for foo & bar)
   177  	for i := 0; i < 100; i++ {
   178  		fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   179  		barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   180  	}
   181  
   182  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
   183  	genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
   184  		return func(_ int) (interface{}, bool, error) {
   185  			// Initiate cryptographic private key and secret connection trhough nodeConn.
   186  			nodePrvKey := ed25519.GenPrivKey()
   187  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
   188  			if err != nil {
   189  				t.Errorf("failed to establish SecretConnection for node: %v", err)
   190  				return nil, true, err
   191  			}
   192  			// In parallel, handle some reads and writes.
   193  			var trs, ok = async.Parallel(
   194  				func(_ int) (interface{}, bool, error) {
   195  					// Node writes:
   196  					for _, nodeWrite := range nodeWrites {
   197  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   198  						if err != nil {
   199  							t.Errorf("failed to write to nodeSecretConn: %v", err)
   200  							return nil, true, err
   201  						}
   202  						if n != len(nodeWrite) {
   203  							err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   204  							t.Error(err)
   205  							return nil, true, err
   206  						}
   207  					}
   208  					if err := nodeConn.PipeWriter.Close(); err != nil {
   209  						t.Error(err)
   210  						return nil, true, err
   211  					}
   212  					return nil, false, nil
   213  				},
   214  				func(_ int) (interface{}, bool, error) {
   215  					// Node reads:
   216  					readBuffer := make([]byte, dataMaxSize)
   217  					for {
   218  						n, err := nodeSecretConn.Read(readBuffer)
   219  						if err == io.EOF {
   220  							if err := nodeConn.PipeReader.Close(); err != nil {
   221  								t.Error(err)
   222  								return nil, true, err
   223  							}
   224  							return nil, false, nil
   225  						} else if err != nil {
   226  							t.Errorf("failed to read from nodeSecretConn: %v", err)
   227  							return nil, true, err
   228  						}
   229  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   230  					}
   231  				},
   232  			)
   233  			assert.True(t, ok, "Unexpected task abortion")
   234  
   235  			// If error:
   236  			if trs.FirstError() != nil {
   237  				return nil, true, trs.FirstError()
   238  			}
   239  
   240  			// Otherwise:
   241  			return nil, false, nil
   242  		}
   243  	}
   244  
   245  	// Run foo & bar in parallel
   246  	var trs, ok = async.Parallel(
   247  		genNodeRunner("foo", fooConn, fooWrites, &fooReads),
   248  		genNodeRunner("bar", barConn, barWrites, &barReads),
   249  	)
   250  	require.Nil(t, trs.FirstError())
   251  	require.True(t, ok, "unexpected task abortion")
   252  
   253  	// A helper to ensure that the writes and reads match.
   254  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   255  	compareWritesReads := func(writes []string, reads []string) {
   256  		for {
   257  			// Pop next write & corresponding reads
   258  			var read, write string = "", writes[0]
   259  			var readCount = 0
   260  			for _, readChunk := range reads {
   261  				read += readChunk
   262  				readCount++
   263  				if len(write) <= len(read) {
   264  					break
   265  				}
   266  				if len(write) <= dataMaxSize {
   267  					break // atomicity of small writes
   268  				}
   269  			}
   270  			// Compare
   271  			if write != read {
   272  				t.Errorf("expected to read %X, got %X", write, read)
   273  			}
   274  			// Iterate
   275  			writes = writes[1:]
   276  			reads = reads[readCount:]
   277  			if len(writes) == 0 {
   278  				break
   279  			}
   280  		}
   281  	}
   282  
   283  	compareWritesReads(fooWrites, barReads)
   284  	compareWritesReads(barWrites, fooReads)
   285  
   286  }
   287  
   288  // Run go test -update from within this module
   289  // to update the golden test vector file
   290  var update = flag.Bool("update", false, "update .golden files")
   291  
   292  func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
   293  	goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
   294  	if *update {
   295  		t.Logf("Updating golden test vector file %s", goldenFilepath)
   296  		data := createGoldenTestVectors(t)
   297  		tmos.WriteFile(goldenFilepath, []byte(data), 0644)
   298  	}
   299  	f, err := os.Open(goldenFilepath)
   300  	if err != nil {
   301  		log.Fatal(err)
   302  	}
   303  	defer f.Close()
   304  	scanner := bufio.NewScanner(f)
   305  	for scanner.Scan() {
   306  		line := scanner.Text()
   307  		params := strings.Split(line, ",")
   308  		randSecretVector, err := hex.DecodeString(params[0])
   309  		require.Nil(t, err)
   310  		randSecret := new([32]byte)
   311  		copy((*randSecret)[:], randSecretVector)
   312  		locIsLeast, err := strconv.ParseBool(params[1])
   313  		require.Nil(t, err)
   314  		expectedRecvSecret, err := hex.DecodeString(params[2])
   315  		require.Nil(t, err)
   316  		expectedSendSecret, err := hex.DecodeString(params[3])
   317  		require.Nil(t, err)
   318  
   319  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   320  		require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
   321  		require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
   322  	}
   323  }
   324  
   325  type privKeyWithNilPubKey struct {
   326  	orig crypto.PrivKey
   327  }
   328  
   329  func (pk privKeyWithNilPubKey) Bytes() []byte                   { return pk.orig.Bytes() }
   330  func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
   331  func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey           { return nil }
   332  func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool  { return pk.orig.Equals(pk2) }
   333  
   334  func TestNilPubkey(t *testing.T) {
   335  	var fooConn, barConn = makeKVStoreConnPair()
   336  	var fooPrvKey = ed25519.GenPrivKey()
   337  	var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
   338  
   339  	go func() {
   340  		_, err := MakeSecretConnection(barConn, barPrvKey)
   341  		assert.NoError(t, err)
   342  	}()
   343  
   344  	assert.NotPanics(t, func() {
   345  		_, err := MakeSecretConnection(fooConn, fooPrvKey)
   346  		if assert.Error(t, err) {
   347  			assert.Equal(t, "expected ed25519 pubkey, got <nil>", err.Error())
   348  		}
   349  	})
   350  }
   351  
   352  func TestNonEd25519Pubkey(t *testing.T) {
   353  	var fooConn, barConn = makeKVStoreConnPair()
   354  	var fooPrvKey = ed25519.GenPrivKey()
   355  	var barPrvKey = secp256k1.GenPrivKey()
   356  
   357  	go func() {
   358  		_, err := MakeSecretConnection(barConn, barPrvKey)
   359  		assert.NoError(t, err)
   360  	}()
   361  
   362  	assert.NotPanics(t, func() {
   363  		_, err := MakeSecretConnection(fooConn, fooPrvKey)
   364  		if assert.Error(t, err) {
   365  			assert.Equal(t, "expected ed25519 pubkey, got secp256k1.PubKeySecp256k1", err.Error())
   366  		}
   367  	})
   368  }
   369  
   370  // Creates the data for a test vector file.
   371  // The file format is:
   372  // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
   373  func createGoldenTestVectors(t *testing.T) string {
   374  	data := ""
   375  	for i := 0; i < 32; i++ {
   376  		randSecretVector := tmrand.Bytes(32)
   377  		randSecret := new([32]byte)
   378  		copy((*randSecret)[:], randSecretVector)
   379  		data += hex.EncodeToString((*randSecret)[:]) + ","
   380  		locIsLeast := tmrand.Bool()
   381  		data += strconv.FormatBool(locIsLeast) + ","
   382  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   383  		data += hex.EncodeToString((*recvSecret)[:]) + ","
   384  		data += hex.EncodeToString((*sendSecret)[:]) + ","
   385  	}
   386  	return data
   387  }
   388  
   389  func BenchmarkWriteSecretConnection(b *testing.B) {
   390  	b.StopTimer()
   391  	b.ReportAllocs()
   392  	fooSecConn, barSecConn := makeSecretConnPair(b)
   393  	randomMsgSizes := []int{
   394  		dataMaxSize / 10,
   395  		dataMaxSize / 3,
   396  		dataMaxSize / 2,
   397  		dataMaxSize,
   398  		dataMaxSize * 3 / 2,
   399  		dataMaxSize * 2,
   400  		dataMaxSize * 7 / 2,
   401  	}
   402  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   403  	for _, size := range randomMsgSizes {
   404  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   405  	}
   406  	// Consume reads from bar's reader
   407  	go func() {
   408  		readBuffer := make([]byte, dataMaxSize)
   409  		for {
   410  			_, err := barSecConn.Read(readBuffer)
   411  			if err == io.EOF {
   412  				return
   413  			} else if err != nil {
   414  				b.Errorf("failed to read from barSecConn: %v", err)
   415  				return
   416  			}
   417  		}
   418  	}()
   419  
   420  	b.StartTimer()
   421  	for i := 0; i < b.N; i++ {
   422  		idx := tmrand.Intn(len(fooWriteBytes))
   423  		_, err := fooSecConn.Write(fooWriteBytes[idx])
   424  		if err != nil {
   425  			b.Errorf("failed to write to fooSecConn: %v", err)
   426  			return
   427  		}
   428  	}
   429  	b.StopTimer()
   430  
   431  	if err := fooSecConn.Close(); err != nil {
   432  		b.Error(err)
   433  	}
   434  	//barSecConn.Close() race condition
   435  }
   436  
   437  func BenchmarkReadSecretConnection(b *testing.B) {
   438  	b.StopTimer()
   439  	b.ReportAllocs()
   440  	fooSecConn, barSecConn := makeSecretConnPair(b)
   441  	randomMsgSizes := []int{
   442  		dataMaxSize / 10,
   443  		dataMaxSize / 3,
   444  		dataMaxSize / 2,
   445  		dataMaxSize,
   446  		dataMaxSize * 3 / 2,
   447  		dataMaxSize * 2,
   448  		dataMaxSize * 7 / 2,
   449  	}
   450  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   451  	for _, size := range randomMsgSizes {
   452  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   453  	}
   454  	go func() {
   455  		for i := 0; i < b.N; i++ {
   456  			idx := tmrand.Intn(len(fooWriteBytes))
   457  			_, err := fooSecConn.Write(fooWriteBytes[idx])
   458  			if err != nil {
   459  				b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
   460  				return
   461  			}
   462  		}
   463  	}()
   464  
   465  	b.StartTimer()
   466  	for i := 0; i < b.N; i++ {
   467  		readBuffer := make([]byte, dataMaxSize)
   468  		_, err := barSecConn.Read(readBuffer)
   469  
   470  		if err == io.EOF {
   471  			return
   472  		} else if err != nil {
   473  			b.Fatalf("Failed to read from barSecConn: %v", err)
   474  		}
   475  	}
   476  	b.StopTimer()
   477  }