github.com/line/ostracon@v1.0.10-0.20230328032236-7f20145f065d/p2p/conn/secret_connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/hex"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"os"
    11  	"path/filepath"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  
    20  	"github.com/line/ostracon/crypto"
    21  	"github.com/line/ostracon/crypto/ed25519"
    22  	"github.com/line/ostracon/crypto/sr25519"
    23  	"github.com/line/ostracon/libs/async"
    24  	tmos "github.com/line/ostracon/libs/os"
    25  	tmrand "github.com/line/ostracon/libs/rand"
    26  )
    27  
    28  // Run go test -update from within this module
    29  // to update the golden test vector file
    30  var update = flag.Bool("update", false, "update .golden files")
    31  
    32  type kvstoreConn struct {
    33  	*io.PipeReader
    34  	*io.PipeWriter
    35  }
    36  
    37  func (drw kvstoreConn) Close() (err error) {
    38  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    39  	err1 := drw.PipeReader.Close()
    40  	if err2 != nil {
    41  		return err
    42  	}
    43  	return err1
    44  }
    45  
    46  type privKeyWithNilPubKey struct {
    47  	orig crypto.PrivKey
    48  }
    49  
    50  func (pk privKeyWithNilPubKey) Bytes() []byte                             { return pk.orig.Bytes() }
    51  func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error)           { return pk.orig.Sign(msg) }
    52  func (pk privKeyWithNilPubKey) VRFProve(msg []byte) (crypto.Proof, error) { return nil, nil }
    53  func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey                     { return nil }
    54  func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool            { return pk.orig.Equals(pk2) }
    55  func (pk privKeyWithNilPubKey) Type() string                              { return "privKeyWithNilPubKey" }
    56  
    57  func TestSecretConnectionHandshake(t *testing.T) {
    58  	fooSecConn, barSecConn := makeSecretConnPair(t)
    59  	if err := fooSecConn.Close(); err != nil {
    60  		t.Error(err)
    61  	}
    62  	if err := barSecConn.Close(); err != nil {
    63  		t.Error(err)
    64  	}
    65  }
    66  
    67  func TestConcurrentWrite(t *testing.T) {
    68  	fooSecConn, barSecConn := makeSecretConnPair(t)
    69  	fooWriteText := tmrand.Str(dataMaxSize)
    70  
    71  	// write from two routines.
    72  	// should be safe from race according to net.Conn:
    73  	// https://golang.org/pkg/net/#Conn
    74  	n := 100
    75  	wg := new(sync.WaitGroup)
    76  	wg.Add(3)
    77  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
    78  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
    79  
    80  	// Consume reads from bar's reader
    81  	readLots(t, wg, barSecConn, n*2)
    82  	wg.Wait()
    83  
    84  	if err := fooSecConn.Close(); err != nil {
    85  		t.Error(err)
    86  	}
    87  }
    88  
    89  func TestConcurrentRead(t *testing.T) {
    90  	fooSecConn, barSecConn := makeSecretConnPair(t)
    91  	fooWriteText := tmrand.Str(dataMaxSize)
    92  	n := 100
    93  
    94  	// read from two routines.
    95  	// should be safe from race according to net.Conn:
    96  	// https://golang.org/pkg/net/#Conn
    97  	wg := new(sync.WaitGroup)
    98  	wg.Add(3)
    99  	go readLots(t, wg, fooSecConn, n/2)
   100  	go readLots(t, wg, fooSecConn, n/2)
   101  
   102  	// write to bar
   103  	writeLots(t, wg, barSecConn, fooWriteText, n)
   104  	wg.Wait()
   105  
   106  	if err := fooSecConn.Close(); err != nil {
   107  		t.Error(err)
   108  	}
   109  }
   110  
   111  func TestSecretConnectionReadWrite(t *testing.T) {
   112  	fooConn, barConn := makeKVStoreConnPair()
   113  	fooWrites, barWrites := []string{}, []string{}
   114  	fooReads, barReads := []string{}, []string{}
   115  
   116  	// Pre-generate the things to write (for foo & bar)
   117  	for i := 0; i < 100; i++ {
   118  		fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   119  		barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
   120  	}
   121  
   122  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
   123  	genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
   124  		return func(_ int) (interface{}, bool, error) {
   125  			// Initiate cryptographic private key and secret connection trhough nodeConn.
   126  			nodePrvKey := ed25519.GenPrivKey()
   127  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
   128  			if err != nil {
   129  				t.Errorf("failed to establish SecretConnection for node: %v", err)
   130  				return nil, true, err
   131  			}
   132  			// In parallel, handle some reads and writes.
   133  			var trs, ok = async.Parallel(
   134  				func(_ int) (interface{}, bool, error) {
   135  					// Node writes:
   136  					for _, nodeWrite := range nodeWrites {
   137  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   138  						if err != nil {
   139  							t.Errorf("failed to write to nodeSecretConn: %v", err)
   140  							return nil, true, err
   141  						}
   142  						if n != len(nodeWrite) {
   143  							err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   144  							t.Error(err)
   145  							return nil, true, err
   146  						}
   147  					}
   148  					if err := nodeConn.PipeWriter.Close(); err != nil {
   149  						t.Error(err)
   150  						return nil, true, err
   151  					}
   152  					return nil, false, nil
   153  				},
   154  				func(_ int) (interface{}, bool, error) {
   155  					// Node reads:
   156  					readBuffer := make([]byte, dataMaxSize)
   157  					for {
   158  						n, err := nodeSecretConn.Read(readBuffer)
   159  						if err == io.EOF {
   160  							if err := nodeConn.PipeReader.Close(); err != nil {
   161  								t.Error(err)
   162  								return nil, true, err
   163  							}
   164  							return nil, false, nil
   165  						} else if err != nil {
   166  							t.Errorf("failed to read from nodeSecretConn: %v", err)
   167  							return nil, true, err
   168  						}
   169  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   170  					}
   171  				},
   172  			)
   173  			assert.True(t, ok, "Unexpected task abortion")
   174  
   175  			// If error:
   176  			if trs.FirstError() != nil {
   177  				return nil, true, trs.FirstError()
   178  			}
   179  
   180  			// Otherwise:
   181  			return nil, false, nil
   182  		}
   183  	}
   184  
   185  	// Run foo & bar in parallel
   186  	var trs, ok = async.Parallel(
   187  		genNodeRunner("foo", fooConn, fooWrites, &fooReads),
   188  		genNodeRunner("bar", barConn, barWrites, &barReads),
   189  	)
   190  	require.Nil(t, trs.FirstError())
   191  	require.True(t, ok, "unexpected task abortion")
   192  
   193  	// A helper to ensure that the writes and reads match.
   194  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   195  	compareWritesReads := func(writes []string, reads []string) {
   196  		for {
   197  			// Pop next write & corresponding reads
   198  			var read = ""
   199  			var write = writes[0]
   200  			var readCount = 0
   201  			for _, readChunk := range reads {
   202  				read += readChunk
   203  				readCount++
   204  				if len(write) <= len(read) {
   205  					break
   206  				}
   207  				if len(write) <= dataMaxSize {
   208  					break // atomicity of small writes
   209  				}
   210  			}
   211  			// Compare
   212  			if write != read {
   213  				t.Errorf("expected to read %X, got %X", write, read)
   214  			}
   215  			// Iterate
   216  			writes = writes[1:]
   217  			reads = reads[readCount:]
   218  			if len(writes) == 0 {
   219  				break
   220  			}
   221  		}
   222  	}
   223  
   224  	compareWritesReads(fooWrites, barReads)
   225  	compareWritesReads(barWrites, fooReads)
   226  }
   227  
   228  func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
   229  	goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
   230  	if *update {
   231  		t.Logf("Updating golden test vector file %s", goldenFilepath)
   232  		data := createGoldenTestVectors(t)
   233  		err := tmos.WriteFile(goldenFilepath, []byte(data), 0644)
   234  		require.NoError(t, err)
   235  	}
   236  	f, err := os.Open(goldenFilepath)
   237  	if err != nil {
   238  		log.Fatal(err)
   239  	}
   240  	defer f.Close()
   241  	scanner := bufio.NewScanner(f)
   242  	for scanner.Scan() {
   243  		line := scanner.Text()
   244  		params := strings.Split(line, ",")
   245  		randSecretVector, err := hex.DecodeString(params[0])
   246  		require.Nil(t, err)
   247  		randSecret := new([32]byte)
   248  		copy((*randSecret)[:], randSecretVector)
   249  		locIsLeast, err := strconv.ParseBool(params[1])
   250  		require.Nil(t, err)
   251  		expectedRecvSecret, err := hex.DecodeString(params[2])
   252  		require.Nil(t, err)
   253  		expectedSendSecret, err := hex.DecodeString(params[3])
   254  		require.Nil(t, err)
   255  
   256  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   257  		require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
   258  		require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
   259  	}
   260  }
   261  
   262  func TestNilPubkey(t *testing.T) {
   263  	var fooConn, barConn = makeKVStoreConnPair()
   264  	defer fooConn.Close()
   265  	defer barConn.Close()
   266  	var fooPrvKey = ed25519.GenPrivKey()
   267  	var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
   268  
   269  	go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
   270  
   271  	_, err := MakeSecretConnection(barConn, barPrvKey)
   272  	require.Error(t, err)
   273  	assert.Equal(t, "toproto: key type <nil> is not supported", err.Error())
   274  }
   275  
   276  func TestNonEd25519Pubkey(t *testing.T) {
   277  	var fooConn, barConn = makeKVStoreConnPair()
   278  	defer fooConn.Close()
   279  	defer barConn.Close()
   280  	var fooPrvKey = ed25519.GenPrivKey()
   281  	var barPrvKey = sr25519.GenPrivKey()
   282  
   283  	go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
   284  
   285  	_, err := MakeSecretConnection(barConn, barPrvKey)
   286  	require.Error(t, err)
   287  	assert.Contains(t, err.Error(), "is not supported")
   288  }
   289  
   290  func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
   291  	defer wg.Done()
   292  	for i := 0; i < n; i++ {
   293  		_, err := conn.Write([]byte(txt))
   294  		if err != nil {
   295  			t.Errorf("failed to write to fooSecConn: %v", err)
   296  			return
   297  		}
   298  	}
   299  }
   300  
   301  func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
   302  	readBuffer := make([]byte, dataMaxSize)
   303  	for i := 0; i < n; i++ {
   304  		_, err := conn.Read(readBuffer)
   305  		assert.NoError(t, err)
   306  	}
   307  	wg.Done()
   308  }
   309  
   310  // Creates the data for a test vector file.
   311  // The file format is:
   312  // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
   313  func createGoldenTestVectors(t *testing.T) string {
   314  	data := ""
   315  	for i := 0; i < 32; i++ {
   316  		randSecretVector := tmrand.Bytes(32)
   317  		randSecret := new([32]byte)
   318  		copy((*randSecret)[:], randSecretVector)
   319  		data += hex.EncodeToString((*randSecret)[:]) + ","
   320  		locIsLeast := tmrand.Bool()
   321  		data += strconv.FormatBool(locIsLeast) + ","
   322  		recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
   323  		data += hex.EncodeToString((*recvSecret)[:]) + ","
   324  		data += hex.EncodeToString((*sendSecret)[:]) + ","
   325  	}
   326  	return data
   327  }
   328  
   329  // Each returned ReadWriteCloser is akin to a net.Connection
   330  func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
   331  	barReader, fooWriter := io.Pipe()
   332  	fooReader, barWriter := io.Pipe()
   333  	return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
   334  }
   335  
   336  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
   337  	var (
   338  		fooConn, barConn = makeKVStoreConnPair()
   339  		fooPrvKey        = ed25519.GenPrivKey()
   340  		fooPubKey        = fooPrvKey.PubKey()
   341  		barPrvKey        = ed25519.GenPrivKey()
   342  		barPubKey        = barPrvKey.PubKey()
   343  	)
   344  
   345  	// Make connections from both sides in parallel.
   346  	var trs, ok = async.Parallel(
   347  		func(_ int) (val interface{}, abort bool, err error) {
   348  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
   349  			if err != nil {
   350  				tb.Errorf("failed to establish SecretConnection for foo: %v", err)
   351  				return nil, true, err
   352  			}
   353  			remotePubBytes := fooSecConn.RemotePubKey()
   354  			if !remotePubBytes.Equals(barPubKey) {
   355  				err = fmt.Errorf("unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
   356  					barPubKey, fooSecConn.RemotePubKey())
   357  				tb.Error(err)
   358  				return nil, true, err
   359  			}
   360  			return nil, false, nil
   361  		},
   362  		func(_ int) (val interface{}, abort bool, err error) {
   363  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
   364  			if barSecConn == nil {
   365  				tb.Errorf("failed to establish SecretConnection for bar: %v", err)
   366  				return nil, true, err
   367  			}
   368  			remotePubBytes := barSecConn.RemotePubKey()
   369  			if !remotePubBytes.Equals(fooPubKey) {
   370  				err = fmt.Errorf("unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
   371  					fooPubKey, barSecConn.RemotePubKey())
   372  				tb.Error(err)
   373  				return nil, true, err
   374  			}
   375  			return nil, false, nil
   376  		},
   377  	)
   378  
   379  	require.Nil(tb, trs.FirstError())
   380  	require.True(tb, ok, "Unexpected task abortion")
   381  
   382  	return fooSecConn, barSecConn
   383  }
   384  
   385  // Benchmarks
   386  
   387  func BenchmarkWriteSecretConnection(b *testing.B) {
   388  	b.StopTimer()
   389  	b.ReportAllocs()
   390  	fooSecConn, barSecConn := makeSecretConnPair(b)
   391  	randomMsgSizes := []int{
   392  		dataMaxSize / 10,
   393  		dataMaxSize / 3,
   394  		dataMaxSize / 2,
   395  		dataMaxSize,
   396  		dataMaxSize * 3 / 2,
   397  		dataMaxSize * 2,
   398  		dataMaxSize * 7 / 2,
   399  	}
   400  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   401  	for _, size := range randomMsgSizes {
   402  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   403  	}
   404  	// Consume reads from bar's reader
   405  	go func() {
   406  		readBuffer := make([]byte, dataMaxSize)
   407  		for {
   408  			_, err := barSecConn.Read(readBuffer)
   409  			if err == io.EOF {
   410  				return
   411  			} else if err != nil {
   412  				b.Errorf("failed to read from barSecConn: %v", err)
   413  				return
   414  			}
   415  		}
   416  	}()
   417  
   418  	b.StartTimer()
   419  	for i := 0; i < b.N; i++ {
   420  		idx := tmrand.Intn(len(fooWriteBytes))
   421  		_, err := fooSecConn.Write(fooWriteBytes[idx])
   422  		if err != nil {
   423  			b.Errorf("failed to write to fooSecConn: %v", err)
   424  			return
   425  		}
   426  	}
   427  	b.StopTimer()
   428  
   429  	if err := fooSecConn.Close(); err != nil {
   430  		b.Error(err)
   431  	}
   432  	// barSecConn.Close() race condition
   433  }
   434  
   435  func BenchmarkReadSecretConnection(b *testing.B) {
   436  	b.StopTimer()
   437  	b.ReportAllocs()
   438  	fooSecConn, barSecConn := makeSecretConnPair(b)
   439  	randomMsgSizes := []int{
   440  		dataMaxSize / 10,
   441  		dataMaxSize / 3,
   442  		dataMaxSize / 2,
   443  		dataMaxSize,
   444  		dataMaxSize * 3 / 2,
   445  		dataMaxSize * 2,
   446  		dataMaxSize * 7 / 2,
   447  	}
   448  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   449  	for _, size := range randomMsgSizes {
   450  		fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
   451  	}
   452  	go func() {
   453  		for i := 0; i < b.N; i++ {
   454  			idx := tmrand.Intn(len(fooWriteBytes))
   455  			_, err := fooSecConn.Write(fooWriteBytes[idx])
   456  			if err != nil {
   457  				b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
   458  				return
   459  			}
   460  		}
   461  	}()
   462  
   463  	b.StartTimer()
   464  	for i := 0; i < b.N; i++ {
   465  		readBuffer := make([]byte, dataMaxSize)
   466  		_, err := barSecConn.Read(readBuffer)
   467  
   468  		if err == io.EOF {
   469  			return
   470  		} else if err != nil {
   471  			b.Fatalf("Failed to read from barSecConn: %v", err)
   472  		}
   473  	}
   474  	b.StopTimer()
   475  }