github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/p2p/conn/secret_connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/hex"
     6  	"errors"
     7  	"flag"
     8  	"fmt"
     9  	"io"
    10  	"log"
    11  	"net"
    12  	"os"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/gnolang/gno/tm2/pkg/async"
    23  	"github.com/gnolang/gno/tm2/pkg/crypto"
    24  	"github.com/gnolang/gno/tm2/pkg/crypto/ed25519"
    25  	"github.com/gnolang/gno/tm2/pkg/crypto/secp256k1"
    26  	osm "github.com/gnolang/gno/tm2/pkg/os"
    27  	"github.com/gnolang/gno/tm2/pkg/random"
    28  )
    29  
    30  type kvstoreConn struct {
    31  	*io.PipeReader
    32  	*io.PipeWriter
    33  }
    34  
    35  func (drw kvstoreConn) Close() (err error) {
    36  	err2 := drw.PipeWriter.CloseWithError(io.EOF)
    37  	err1 := drw.PipeReader.Close()
    38  	if err2 != nil {
    39  		return err
    40  	}
    41  	return err1
    42  }
    43  
    44  // Each returned ReadWriteCloser is akin to a net.Connection
    45  func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
    46  	barReader, fooWriter := io.Pipe()
    47  	fooReader, barWriter := io.Pipe()
    48  	return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
    49  }
    50  
    51  func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
    52  	tb.Helper()
    53  
    54  	fooConn, barConn := makeKVStoreConnPair()
    55  	fooPrvKey := ed25519.GenPrivKey()
    56  	fooPubKey := fooPrvKey.PubKey()
    57  	barPrvKey := ed25519.GenPrivKey()
    58  	barPubKey := barPrvKey.PubKey()
    59  
    60  	// Make connections from both sides in parallel.
    61  	trs, ok := async.Parallel(
    62  		func(_ int) (val interface{}, err error, abort bool) {
    63  			fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
    64  			if err != nil {
    65  				tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
    66  				return nil, err, true
    67  			}
    68  			remotePubBytes := fooSecConn.RemotePubKey()
    69  			if !remotePubBytes.Equals(barPubKey) {
    70  				err = fmt.Errorf("Unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
    71  					barPubKey, fooSecConn.RemotePubKey())
    72  				tb.Error(err)
    73  				return nil, err, false
    74  			}
    75  			return nil, nil, false
    76  		},
    77  		func(_ int) (val interface{}, err error, abort bool) {
    78  			barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
    79  			if barSecConn == nil {
    80  				tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
    81  				return nil, err, true
    82  			}
    83  			remotePubBytes := barSecConn.RemotePubKey()
    84  			if !remotePubBytes.Equals(fooPubKey) {
    85  				err = fmt.Errorf("Unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
    86  					fooPubKey, barSecConn.RemotePubKey())
    87  				tb.Error(err)
    88  				return nil, nil, false
    89  			}
    90  			return nil, nil, false
    91  		},
    92  	)
    93  
    94  	require.Nil(tb, trs.FirstError())
    95  	require.True(tb, ok, "Unexpected task abortion")
    96  
    97  	return fooSecConn, barSecConn
    98  }
    99  
   100  func TestSecretConnectionHandshake(t *testing.T) {
   101  	t.Parallel()
   102  
   103  	fooSecConn, barSecConn := makeSecretConnPair(t)
   104  	if err := fooSecConn.Close(); err != nil {
   105  		t.Error(err)
   106  	}
   107  	if err := barSecConn.Close(); err != nil {
   108  		t.Error(err)
   109  	}
   110  }
   111  
   112  // Test that shareEphPubKey rejects lower order public keys based on an
   113  // (incomplete) blacklist.
   114  func TestShareLowOrderPubkey(t *testing.T) {
   115  	t.Parallel()
   116  
   117  	fooConn, barConn := makeKVStoreConnPair()
   118  	defer fooConn.Close()
   119  	defer barConn.Close()
   120  	locEphPub, _ := genEphKeys()
   121  
   122  	// all blacklisted low order points:
   123  	for _, remLowOrderPubKey := range blacklist {
   124  		remLowOrderPubKey := remLowOrderPubKey
   125  		_, _ = async.Parallel(
   126  			func(_ int) (val interface{}, err error, abort bool) {
   127  				_, err = shareEphPubKey(fooConn, locEphPub)
   128  
   129  				require.Error(t, err)
   130  				require.Equal(t, err, ErrSmallOrderRemotePubKey)
   131  
   132  				return nil, nil, false
   133  			},
   134  			func(_ int) (val interface{}, err error, abort bool) {
   135  				readRemKey, err := shareEphPubKey(barConn, &remLowOrderPubKey)
   136  
   137  				require.NoError(t, err)
   138  				require.Equal(t, locEphPub, readRemKey)
   139  
   140  				return nil, nil, false
   141  			})
   142  	}
   143  }
   144  
   145  // Test that additionally that the Diffie-Hellman shared secret is non-zero.
   146  // The shared secret would be zero for lower order pub-keys (but tested against the blacklist only).
   147  func TestComputeDHFailsOnLowOrder(t *testing.T) {
   148  	t.Parallel()
   149  
   150  	_, locPrivKey := genEphKeys()
   151  	for _, remLowOrderPubKey := range blacklist {
   152  		remLowOrderPubKey := remLowOrderPubKey
   153  		shared, err := computeDHSecret(&remLowOrderPubKey, locPrivKey)
   154  		assert.Error(t, err)
   155  
   156  		assert.Equal(t, err, ErrSharedSecretIsZero)
   157  		assert.Empty(t, shared)
   158  	}
   159  }
   160  
   161  func TestConcurrentWrite(t *testing.T) {
   162  	t.Parallel()
   163  
   164  	fooSecConn, barSecConn := makeSecretConnPair(t)
   165  	fooWriteText := random.RandStr(dataMaxSize)
   166  
   167  	// write from two routines.
   168  	// should be safe from race according to net.Conn:
   169  	// https://golang.org/pkg/net/#Conn
   170  	n := 100
   171  	wg := new(sync.WaitGroup)
   172  	wg.Add(3)
   173  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
   174  	go writeLots(t, wg, fooSecConn, fooWriteText, n)
   175  
   176  	// Consume reads from bar's reader
   177  	readLots(t, wg, barSecConn, n*2)
   178  	wg.Wait()
   179  
   180  	if err := fooSecConn.Close(); err != nil {
   181  		t.Error(err)
   182  	}
   183  }
   184  
   185  func TestConcurrentRead(t *testing.T) {
   186  	t.Parallel()
   187  
   188  	fooSecConn, barSecConn := makeSecretConnPair(t)
   189  	fooWriteText := random.RandStr(dataMaxSize)
   190  	n := 100
   191  
   192  	// read from two routines.
   193  	// should be safe from race according to net.Conn:
   194  	// https://golang.org/pkg/net/#Conn
   195  	wg := new(sync.WaitGroup)
   196  	wg.Add(3)
   197  	go readLots(t, wg, fooSecConn, n/2)
   198  	go readLots(t, wg, fooSecConn, n/2)
   199  
   200  	// write to bar
   201  	writeLots(t, wg, barSecConn, fooWriteText, n)
   202  	wg.Wait()
   203  
   204  	if err := fooSecConn.Close(); err != nil {
   205  		t.Error(err)
   206  	}
   207  }
   208  
   209  func writeLots(t *testing.T, wg *sync.WaitGroup, conn net.Conn, txt string, n int) {
   210  	t.Helper()
   211  
   212  	defer wg.Done()
   213  	for i := 0; i < n; i++ {
   214  		_, err := conn.Write([]byte(txt))
   215  		if err != nil {
   216  			t.Errorf("Failed to write to fooSecConn: %v", err)
   217  			return
   218  		}
   219  	}
   220  }
   221  
   222  func readLots(t *testing.T, wg *sync.WaitGroup, conn net.Conn, n int) {
   223  	t.Helper()
   224  
   225  	readBuffer := make([]byte, dataMaxSize)
   226  	for i := 0; i < n; i++ {
   227  		_, err := conn.Read(readBuffer)
   228  		assert.NoError(t, err)
   229  	}
   230  	wg.Done()
   231  }
   232  
   233  func TestSecretConnectionReadWrite(t *testing.T) {
   234  	t.Parallel()
   235  
   236  	fooConn, barConn := makeKVStoreConnPair()
   237  	fooWrites, barWrites := []string{}, []string{}
   238  	fooReads, barReads := []string{}, []string{}
   239  
   240  	// Pre-generate the things to write (for foo & bar)
   241  	for i := 0; i < 100; i++ {
   242  		fooWrites = append(fooWrites, random.RandStr((random.RandInt()%(dataMaxSize*5))+1))
   243  		barWrites = append(barWrites, random.RandStr((random.RandInt()%(dataMaxSize*5))+1))
   244  	}
   245  
   246  	// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
   247  	genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
   248  		return func(_ int) (interface{}, error, bool) {
   249  			// Initiate cryptographic private key and secret connection through nodeConn.
   250  			nodePrvKey := ed25519.GenPrivKey()
   251  			nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
   252  			if err != nil {
   253  				t.Errorf("Failed to establish SecretConnection for node: %v", err)
   254  				return nil, err, true
   255  			}
   256  			// In parallel, handle some reads and writes.
   257  			trs, ok := async.Parallel(
   258  				func(_ int) (interface{}, error, bool) {
   259  					// Node writes:
   260  					for _, nodeWrite := range nodeWrites {
   261  						n, err := nodeSecretConn.Write([]byte(nodeWrite))
   262  						if err != nil {
   263  							t.Errorf("Failed to write to nodeSecretConn: %v", err)
   264  							return nil, err, true
   265  						}
   266  						if n != len(nodeWrite) {
   267  							err = fmt.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
   268  							t.Error(err)
   269  							return nil, err, true
   270  						}
   271  					}
   272  					if err := nodeConn.PipeWriter.Close(); err != nil {
   273  						t.Error(err)
   274  						return nil, err, true
   275  					}
   276  					return nil, nil, false
   277  				},
   278  				func(_ int) (interface{}, error, bool) {
   279  					// Node reads:
   280  					readBuffer := make([]byte, dataMaxSize)
   281  					for {
   282  						n, err := nodeSecretConn.Read(readBuffer)
   283  						if errors.Is(err, io.EOF) {
   284  							if err := nodeConn.PipeReader.Close(); err != nil {
   285  								t.Error(err)
   286  								return nil, err, true
   287  							}
   288  							return nil, nil, false
   289  						} else if err != nil {
   290  							t.Errorf("Failed to read from nodeSecretConn: %v", err)
   291  							return nil, err, true
   292  						}
   293  						*nodeReads = append(*nodeReads, string(readBuffer[:n]))
   294  					}
   295  				},
   296  			)
   297  			assert.True(t, ok, "Unexpected task abortion")
   298  
   299  			// If error:
   300  			if trs.FirstError() != nil {
   301  				return nil, trs.FirstError(), true
   302  			}
   303  
   304  			// Otherwise:
   305  			return nil, nil, false
   306  		}
   307  	}
   308  
   309  	// Run foo & bar in parallel
   310  	trs, ok := async.Parallel(
   311  		genNodeRunner("foo", fooConn, fooWrites, &fooReads),
   312  		genNodeRunner("bar", barConn, barWrites, &barReads),
   313  	)
   314  	require.Nil(t, trs.FirstError())
   315  	require.True(t, ok, "unexpected task abortion")
   316  
   317  	// A helper to ensure that the writes and reads match.
   318  	// Additionally, small writes (<= dataMaxSize) must be atomically read.
   319  	compareWritesReads := func(writes []string, reads []string) {
   320  		for {
   321  			// Pop next write & corresponding reads
   322  			read, write := "", writes[0]
   323  			readCount := 0
   324  			for _, readChunk := range reads {
   325  				read += readChunk
   326  				readCount++
   327  				if len(write) <= len(read) {
   328  					break
   329  				}
   330  				if len(write) <= dataMaxSize {
   331  					break // atomicity of small writes
   332  				}
   333  			}
   334  			// Compare
   335  			if write != read {
   336  				t.Errorf("Expected to read %X, got %X", write, read)
   337  			}
   338  			// Iterate
   339  			writes = writes[1:]
   340  			reads = reads[readCount:]
   341  			if len(writes) == 0 {
   342  				break
   343  			}
   344  		}
   345  	}
   346  
   347  	compareWritesReads(fooWrites, barReads)
   348  	compareWritesReads(barWrites, fooReads)
   349  }
   350  
   351  // Run go test -update from within this module
   352  // to update the golden test vector file
   353  var update = flag.Bool("update", false, "update .golden files")
   354  
   355  func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
   356  	t.Parallel()
   357  
   358  	goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
   359  	if *update {
   360  		t.Logf("Updating golden test vector file %s", goldenFilepath)
   361  		data := createGoldenTestVectors()
   362  		osm.WriteFile(goldenFilepath, []byte(data), 0o644)
   363  	}
   364  	f, err := os.Open(goldenFilepath)
   365  	if err != nil {
   366  		log.Fatal(err)
   367  	}
   368  	defer f.Close()
   369  	scanner := bufio.NewScanner(f)
   370  	for scanner.Scan() {
   371  		line := scanner.Text()
   372  		params := strings.Split(line, ",")
   373  		randSecretVector, err := hex.DecodeString(params[0])
   374  		require.Nil(t, err)
   375  		randSecret := new([32]byte)
   376  		copy((*randSecret)[:], randSecretVector)
   377  		locIsLeast, err := strconv.ParseBool(params[1])
   378  		require.Nil(t, err)
   379  		expectedRecvSecret, err := hex.DecodeString(params[2])
   380  		require.Nil(t, err)
   381  		expectedSendSecret, err := hex.DecodeString(params[3])
   382  		require.Nil(t, err)
   383  		expectedChallenge, err := hex.DecodeString(params[4])
   384  		require.Nil(t, err)
   385  
   386  		recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
   387  		require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
   388  		require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
   389  		require.Equal(t, expectedChallenge, (*challenge)[:], "challenges aren't equal")
   390  	}
   391  }
   392  
   393  type privKeyWithNilPubKey struct {
   394  	orig crypto.PrivKey
   395  }
   396  
   397  func (pk privKeyWithNilPubKey) Bytes() []byte                   { return pk.orig.Bytes() }
   398  func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
   399  func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey           { return nil }
   400  func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool  { return pk.orig.Equals(pk2) }
   401  
   402  func TestNilPubkey(t *testing.T) {
   403  	t.Parallel()
   404  
   405  	fooConn, barConn := makeKVStoreConnPair()
   406  	fooPrvKey := ed25519.GenPrivKey()
   407  	barPrvKey := privKeyWithNilPubKey{ed25519.GenPrivKey()}
   408  
   409  	go func() {
   410  		_, err := MakeSecretConnection(barConn, barPrvKey)
   411  		assert.NoError(t, err)
   412  	}()
   413  
   414  	assert.NotPanics(t, func() {
   415  		_, err := MakeSecretConnection(fooConn, fooPrvKey)
   416  		if assert.Error(t, err) {
   417  			assert.Equal(t, "expected ed25519 pubkey, got <nil>", err.Error())
   418  		}
   419  	})
   420  }
   421  
   422  func TestNonEd25519Pubkey(t *testing.T) {
   423  	t.Parallel()
   424  
   425  	fooConn, barConn := makeKVStoreConnPair()
   426  	fooPrvKey := ed25519.GenPrivKey()
   427  	barPrvKey := secp256k1.GenPrivKey()
   428  
   429  	go func() {
   430  		_, err := MakeSecretConnection(barConn, barPrvKey)
   431  		assert.NoError(t, err)
   432  	}()
   433  
   434  	assert.NotPanics(t, func() {
   435  		_, err := MakeSecretConnection(fooConn, fooPrvKey)
   436  		if assert.Error(t, err) {
   437  			assert.Equal(t, "expected ed25519 pubkey, got secp256k1.PubKeySecp256k1", err.Error())
   438  		}
   439  	})
   440  }
   441  
   442  // Creates the data for a test vector file.
   443  // The file format is:
   444  // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
   445  func createGoldenTestVectors() string {
   446  	data := ""
   447  	for i := 0; i < 32; i++ {
   448  		randSecretVector := random.RandBytes(32)
   449  		randSecret := new([32]byte)
   450  		copy((*randSecret)[:], randSecretVector)
   451  		data += hex.EncodeToString((*randSecret)[:]) + ","
   452  		locIsLeast := random.RandBool()
   453  		data += strconv.FormatBool(locIsLeast) + ","
   454  		recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
   455  		data += hex.EncodeToString((*recvSecret)[:]) + ","
   456  		data += hex.EncodeToString((*sendSecret)[:]) + ","
   457  		data += hex.EncodeToString((*challenge)[:]) + "\n"
   458  	}
   459  	return data
   460  }
   461  
   462  func BenchmarkWriteSecretConnection(b *testing.B) {
   463  	b.StopTimer()
   464  	b.ReportAllocs()
   465  	fooSecConn, barSecConn := makeSecretConnPair(b)
   466  	randomMsgSizes := []int{
   467  		dataMaxSize / 10,
   468  		dataMaxSize / 3,
   469  		dataMaxSize / 2,
   470  		dataMaxSize,
   471  		dataMaxSize * 3 / 2,
   472  		dataMaxSize * 2,
   473  		dataMaxSize * 7 / 2,
   474  	}
   475  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   476  	for _, size := range randomMsgSizes {
   477  		fooWriteBytes = append(fooWriteBytes, random.RandBytes(size))
   478  	}
   479  	// Consume reads from bar's reader
   480  	go func() {
   481  		readBuffer := make([]byte, dataMaxSize)
   482  		for {
   483  			_, err := barSecConn.Read(readBuffer)
   484  			if errors.Is(err, io.EOF) {
   485  				return
   486  			} else if err != nil {
   487  				b.Errorf("Failed to read from barSecConn: %v", err)
   488  				return
   489  			}
   490  		}
   491  	}()
   492  
   493  	b.StartTimer()
   494  	for i := 0; i < b.N; i++ {
   495  		idx := random.RandIntn(len(fooWriteBytes))
   496  		_, err := fooSecConn.Write(fooWriteBytes[idx])
   497  		if err != nil {
   498  			b.Errorf("Failed to write to fooSecConn: %v", err)
   499  			return
   500  		}
   501  	}
   502  	b.StopTimer()
   503  
   504  	if err := fooSecConn.Close(); err != nil {
   505  		b.Error(err)
   506  	}
   507  	// barSecConn.Close() race condition
   508  }
   509  
   510  func BenchmarkReadSecretConnection(b *testing.B) {
   511  	b.StopTimer()
   512  	b.ReportAllocs()
   513  	fooSecConn, barSecConn := makeSecretConnPair(b)
   514  	randomMsgSizes := []int{
   515  		dataMaxSize / 10,
   516  		dataMaxSize / 3,
   517  		dataMaxSize / 2,
   518  		dataMaxSize,
   519  		dataMaxSize * 3 / 2,
   520  		dataMaxSize * 2,
   521  		dataMaxSize * 7 / 2,
   522  	}
   523  	fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
   524  	for _, size := range randomMsgSizes {
   525  		fooWriteBytes = append(fooWriteBytes, random.RandBytes(size))
   526  	}
   527  	go func() {
   528  		for i := 0; i < b.N; i++ {
   529  			idx := random.RandIntn(len(fooWriteBytes))
   530  			_, err := fooSecConn.Write(fooWriteBytes[idx])
   531  			if err != nil {
   532  				b.Errorf("Failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
   533  				return
   534  			}
   535  		}
   536  	}()
   537  
   538  	b.StartTimer()
   539  	for i := 0; i < b.N; i++ {
   540  		readBuffer := make([]byte, dataMaxSize)
   541  		_, err := barSecConn.Read(readBuffer)
   542  
   543  		if errors.Is(err, io.EOF) {
   544  			return
   545  		} else if err != nil {
   546  			b.Fatalf("Failed to read from barSecConn: %v", err)
   547  		}
   548  	}
   549  	b.StopTimer()
   550  }