github.com/evdatsion/aphelion-dpos-bft@v0.32.1/p2p/conn/secret_connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/hex"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"net"
    11  	"os"
    12  	"path/filepath"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"github.com/evdatsion/aphelion-dpos-bft/crypto/ed25519"
    21  	cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common"
    22  )
    23  
    24  type kvstoreConn struct {
    25  	*io.PipeReader
    26  	*io.PipeWriter
    27  }
    28  
    29  func (drw kvstoreConn) Close() (err error) {
    30  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    31  	err1 := drw.PipeReader.Close()
    32  	if err2 != nil {
    33  		return err
    34  	}
    35  	return err1
    36  }
    37  
    38  // Each returned ReadWriteCloser is akin to a net.Connection
    39  func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
    40  	barReader, fooWriter := io.Pipe()
    41  	fooReader, barWriter := io.Pipe()
    42  	return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
    43  }
    44  
    45  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
    46  
    47  	var fooConn, barConn = makeKVStoreConnPair()
    48  	var fooPrvKey = ed25519.GenPrivKey()
    49  	var fooPubKey = fooPrvKey.PubKey()
    50  	var barPrvKey = ed25519.GenPrivKey()
    51  	var barPubKey = barPrvKey.PubKey()
    52  
    53  	// Make connections from both sides in parallel.
    54  	var trs, ok = cmn.Parallel(
    55  		func(_ int) (val interface{}, err error, abort bool) {
    56  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
    57  			if err != nil {
    58  				tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
    59  				return nil, err, true
    60  			}
    61  			remotePubBytes := fooSecConn.RemotePubKey()
    62  			if !remotePubBytes.Equals(barPubKey) {
    63  				err = fmt.Errorf("Unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
    64  					barPubKey, fooSecConn.RemotePubKey())
    65  				tb.Error(err)
    66  				return nil, err, false
    67  			}
    68  			return nil, nil, false
    69  		},
    70  		func(_ int) (val interface{}, err error, abort bool) {
    71  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
    72  			if barSecConn == nil {
    73  				tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
    74  				return nil, err, true
    75  			}
    76  			remotePubBytes := barSecConn.RemotePubKey()
    77  			if !remotePubBytes.Equals(fooPubKey) {
    78  				err = fmt.Errorf("Unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
    79  					fooPubKey, barSecConn.RemotePubKey())
    80  				tb.Error(err)
    81  				return nil, nil, false
    82  			}
    83  			return nil, nil, false
    84  		},
    85  	)
    86  
    87  	require.Nil(tb, trs.FirstError())
    88  	require.True(tb, ok, "Unexpected task abortion")
    89  
    90  	return
    91  }
    92  
    93  func TestSecretConnectionHandshake(t *testing.T) {
    94  	fooSecConn, barSecConn := makeSecretConnPair(t)
    95  	if err := fooSecConn.Close(); err != nil {
    96  		t.Error(err)
    97  	}
    98  	if err := barSecConn.Close(); err != nil {
    99  		t.Error(err)
   100  	}
   101  }
   102  
   103  // Test that shareEphPubKey rejects lower order public keys based on an
   104  // (incomplete) blacklist.
   105  func TestShareLowOrderPubkey(t *testing.T) {
   106  	var fooConn, barConn = makeKVStoreConnPair()
   107  	defer fooConn.Close()
   108  	defer barConn.Close()
   109  	locEphPub, _ := genEphKeys()
   110  
   111  	// all blacklisted low order points:
   112  	for _, remLowOrderPubKey := range blacklist {
   113  		_, _ = cmn.Parallel(
   114  			func(_ int) (val interface{}, err error, abort bool) {
   115  				_, err = shareEphPubKey(fooConn, locEphPub)
   116  
   117  				require.Error(t, err)
   118  				require.Equal(t, err, ErrSmallOrderRemotePubKey)
   119  
   120  				return nil, nil, false
   121  			},
   122  			func(_ int) (val interface{}, err error, abort bool) {
   123  				readRemKey, err := shareEphPubKey(barConn, &remLowOrderPubKey)
   124  
   125  				require.NoError(t, err)
   126  				require.Equal(t, locEphPub, readRemKey)
   127  
   128  				return nil, nil, false
   129  			})
   130  	}
   131  }
   132  
   133  // Test that additionally that the Diffie-Hellman shared secret is non-zero.
   134  // The shared secret would be zero for lower order pub-keys (but tested against the blacklist only).
   135  func TestComputeDHFailsOnLowOrder(t *testing.T) {
   136  	_, locPrivKey := genEphKeys()
   137  	for _, remLowOrderPubKey := range blacklist {
   138  		shared, err := computeDHSecret(&remLowOrderPubKey, locPrivKey)
   139  		assert.Error(t, err)
   140  
   141  		assert.Equal(t, err, ErrSharedSecretIsZero)
   142  		assert.Empty(t, shared)
   143  	}
   144  }
   145  
   146  func TestConcurrentWrite(t *testing.T) {
   147  	fooSecConn, barSecConn := makeSecretConnPair(t)
   148  	fooWriteText := cmn.RandStr(dataMaxSize)
   149  
   150  	// write from two routines.
   151  	// should be safe from race according to net.Conn:
   152  	// https://golang.org/pkg/net/#Conn
   153  	n := 100
   154  	wg := new(sync.WaitGroup)
   155  	wg.Add(3)
   156  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
   157  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
   158  
   159  	// Consume reads from bar's reader
   160  	readLots(t, wg, barSecConn, n*2)
   161  	wg.Wait()
   162  
   163  	if err := fooSecConn.Close(); err != nil {
   164  		t.Error(err)
   165  	}
   166  }
   167  
   168  func TestConcurrentRead(t *testing.T) {
   169  	fooSecConn, barSecConn := makeSecretConnPair(t)
   170  	fooWriteText := cmn.RandStr(dataMaxSize)
   171  	n := 100
   172  
   173  	// read from two routines.
   174  	// should be safe from race according to net.Conn:
   175  	// https://golang.org/pkg/net/#Conn
   176  	wg := new(sync.WaitGroup)
   177  	wg.Add(3)
   178  	go readLots(t, wg, fooSecConn, n/2)
   179  	go readLots(t, wg, fooSecConn, n/2)
   180  
   181  	// write to bar
   182  	writeLots(t, wg, barSecConn, fooWriteText, n)
   183  	wg.Wait()
   184  
   185  	if err := fooSecConn.Close(); err != nil {
   186  		t.Error(err)
   187  	}
   188  }
   189  
   190  func writeLots(t *testing.T, wg *sync.WaitGroup, conn net.Conn, txt string, n int) {
   191  	defer wg.Done()
   192  	for i := 0; i < n; i++ {
   193  		_, err := conn.Write([]byte(txt))
   194  		if err != nil {
   195  			t.Fatalf("Failed to write to fooSecConn: %v", err)
   196  		}
   197  	}
   198  }
   199  
   200  func readLots(t *testing.T, wg *sync.WaitGroup, conn net.Conn, n int) {
   201  	readBuffer := make([]byte, dataMaxSize)
   202  	for i := 0; i < n; i++ {
   203  		_, err := conn.Read(readBuffer)
   204  		assert.NoError(t, err)
   205  	}
   206  	wg.Done()
   207  }
   208  
   209  func TestSecretConnectionReadWrite(t *testing.T) {
   210  	fooConn, barConn := makeKVStoreConnPair()
   211  	fooWrites, barWrites := []string{}, []string{}
   212  	fooReads, barReads := []string{}, []string{}
   213  
   214  	// Pre-generate the things to write (for foo & bar)
   215  	for i := 0; i < 100; i++ {
   216  		fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
   217  		barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
   218  	}
   219  
   220  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
   221  	genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) cmn.Task {
   222  		return func(_ int) (interface{}, error, bool) {
   223  			// Initiate cryptographic private key and secret connection trhough nodeConn.
   224  			nodePrvKey := ed25519.GenPrivKey()
   225  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
   226  			if err != nil {
   227  				t.Errorf("Failed to establish SecretConnection for node: %v", err)
   228  				return nil, err, true
   229  			}
   230  			// In parallel, handle some reads and writes.
   231  			var trs, ok = cmn.Parallel(
   232  				func(_ int) (interface{}, error, bool) {
   233  					// Node writes:
   234  					for _, nodeWrite := range nodeWrites {
   235  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   236  						if err != nil {
   237  							t.Errorf("Failed to write to nodeSecretConn: %v", err)
   238  							return nil, err, true
   239  						}
   240  						if n != len(nodeWrite) {
   241  							err = fmt.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   242  							t.Error(err)
   243  							return nil, err, true
   244  						}
   245  					}
   246  					if err := nodeConn.PipeWriter.Close(); err != nil {
   247  						t.Error(err)
   248  						return nil, err, true
   249  					}
   250  					return nil, nil, false
   251  				},
   252  				func(_ int) (interface{}, error, bool) {
   253  					// Node reads:
   254  					readBuffer := make([]byte, dataMaxSize)
   255  					for {
   256  						n, err := nodeSecretConn.Read(readBuffer)
   257  						if err == io.EOF {
   258  							if err := nodeConn.PipeReader.Close(); err != nil {
   259  								t.Error(err)
   260  								return nil, err, true
   261  							}
   262  							return nil, nil, false
   263  						} else if err != nil {
   264  							t.Errorf("Failed to read from nodeSecretConn: %v", err)
   265  							return nil, err, true
   266  						}
   267  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   268  					}
   269  				},
   270  			)
   271  			assert.True(t, ok, "Unexpected task abortion")
   272  
   273  			// If error:
   274  			if trs.FirstError() != nil {
   275  				return nil, trs.FirstError(), true
   276  			}
   277  
   278  			// Otherwise:
   279  			return nil, nil, false
   280  		}
   281  	}
   282  
   283  	// Run foo & bar in parallel
   284  	var trs, ok = cmn.Parallel(
   285  		genNodeRunner("foo", fooConn, fooWrites, &fooReads),
   286  		genNodeRunner("bar", barConn, barWrites, &barReads),
   287  	)
   288  	require.Nil(t, trs.FirstError())
   289  	require.True(t, ok, "unexpected task abortion")
   290  
   291  	// A helper to ensure that the writes and reads match.
   292  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   293  	compareWritesReads := func(writes []string, reads []string) {
   294  		for {
   295  			// Pop next write & corresponding reads
   296  			var read, write string = "", writes[0]
   297  			var readCount = 0
   298  			for _, readChunk := range reads {
   299  				read += readChunk
   300  				readCount++
   301  				if len(write) <= len(read) {
   302  					break
   303  				}
   304  				if len(write) <= dataMaxSize {
   305  					break // atomicity of small writes
   306  				}
   307  			}
   308  			// Compare
   309  			if write != read {
   310  				t.Errorf("Expected to read %X, got %X", write, read)
   311  			}
   312  			// Iterate
   313  			writes = writes[1:]
   314  			reads = reads[readCount:]
   315  			if len(writes) == 0 {
   316  				break
   317  			}
   318  		}
   319  	}
   320  
   321  	compareWritesReads(fooWrites, barReads)
   322  	compareWritesReads(barWrites, fooReads)
   323  
   324  }
   325  
   326  // Run go test -update from within this module
   327  // to update the golden test vector file
   328  var update = flag.Bool("update", false, "update .golden files")
   329  
   330  func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
   331  	goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
   332  	if *update {
   333  		t.Logf("Updating golden test vector file %s", goldenFilepath)
   334  		data := createGoldenTestVectors(t)
   335  		cmn.WriteFile(goldenFilepath, []byte(data), 0644)
   336  	}
   337  	f, err := os.Open(goldenFilepath)
   338  	if err != nil {
   339  		log.Fatal(err)
   340  	}
   341  	defer f.Close()
   342  	scanner := bufio.NewScanner(f)
   343  	for scanner.Scan() {
   344  		line := scanner.Text()
   345  		params := strings.Split(line, ",")
   346  		randSecretVector, err := hex.DecodeString(params[0])
   347  		require.Nil(t, err)
   348  		randSecret := new([32]byte)
   349  		copy((*randSecret)[:], randSecretVector)
   350  		locIsLeast, err := strconv.ParseBool(params[1])
   351  		require.Nil(t, err)
   352  		expectedRecvSecret, err := hex.DecodeString(params[2])
   353  		require.Nil(t, err)
   354  		expectedSendSecret, err := hex.DecodeString(params[3])
   355  		require.Nil(t, err)
   356  		expectedChallenge, err := hex.DecodeString(params[4])
   357  		require.Nil(t, err)
   358  
   359  		recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
   360  		require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
   361  		require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
   362  		require.Equal(t, expectedChallenge, (*challenge)[:], "challenges aren't equal")
   363  	}
   364  }
   365  
   366  // Creates the data for a test vector file.
   367  // The file format is:
   368  // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
   369  func createGoldenTestVectors(t *testing.T) string {
   370  	data := ""
   371  	for i := 0; i < 32; i++ {
   372  		randSecretVector := cmn.RandBytes(32)
   373  		randSecret := new([32]byte)
   374  		copy((*randSecret)[:], randSecretVector)
   375  		data += hex.EncodeToString((*randSecret)[:]) + ","
   376  		locIsLeast := cmn.RandBool()
   377  		data += strconv.FormatBool(locIsLeast) + ","
   378  		recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
   379  		data += hex.EncodeToString((*recvSecret)[:]) + ","
   380  		data += hex.EncodeToString((*sendSecret)[:]) + ","
   381  		data += hex.EncodeToString((*challenge)[:]) + "\n"
   382  	}
   383  	return data
   384  }
   385  
   386  func BenchmarkSecretConnection(b *testing.B) {
   387  	b.StopTimer()
   388  	fooSecConn, barSecConn := makeSecretConnPair(b)
   389  	fooWriteText := cmn.RandStr(dataMaxSize)
   390  	// Consume reads from bar's reader
   391  	go func() {
   392  		readBuffer := make([]byte, dataMaxSize)
   393  		for {
   394  			_, err := barSecConn.Read(readBuffer)
   395  			if err == io.EOF {
   396  				return
   397  			} else if err != nil {
   398  				b.Fatalf("Failed to read from barSecConn: %v", err)
   399  			}
   400  		}
   401  	}()
   402  
   403  	b.StartTimer()
   404  	for i := 0; i < b.N; i++ {
   405  		_, err := fooSecConn.Write([]byte(fooWriteText))
   406  		if err != nil {
   407  			b.Fatalf("Failed to write to fooSecConn: %v", err)
   408  		}
   409  	}
   410  	b.StopTimer()
   411  
   412  	if err := fooSecConn.Close(); err != nil {
   413  		b.Error(err)
   414  	}
   415  	//barSecConn.Close() race condition
   416  }