github.com/ethereum-optimism/optimism@v1.7.2/op-node/rollup/derive/span_batch_txs_test.go (about)

     1  package derive
     2  
     3  import (
     4  	"bytes"
     5  	"math/big"
     6  	"math/rand"
     7  	"testing"
     8  
     9  	"github.com/holiman/uint256"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/ethereum/go-ethereum/core/types"
    13  
    14  	"github.com/ethereum-optimism/optimism/op-service/testutils"
    15  )
    16  
    17  type txTypeTest struct {
    18  	name   string
    19  	mkTx   func(rng *rand.Rand, signer types.Signer) *types.Transaction
    20  	signer types.Signer
    21  }
    22  
    23  func TestSpanBatchTxsContractCreationBits(t *testing.T) {
    24  	rng := rand.New(rand.NewSource(0x1234567))
    25  	chainID := big.NewInt(rng.Int63n(1000))
    26  
    27  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
    28  	contractCreationBits := rawSpanBatch.txs.contractCreationBits
    29  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
    30  
    31  	var sbt spanBatchTxs
    32  	sbt.contractCreationBits = contractCreationBits
    33  	sbt.totalBlockTxCount = totalBlockTxCount
    34  
    35  	var buf bytes.Buffer
    36  	err := sbt.encodeContractCreationBits(&buf)
    37  	require.NoError(t, err)
    38  
    39  	// contractCreationBit field is fixed length: single bit
    40  	contractCreationBitBufferLen := totalBlockTxCount / 8
    41  	if totalBlockTxCount%8 != 0 {
    42  		contractCreationBitBufferLen++
    43  	}
    44  	require.Equal(t, buf.Len(), int(contractCreationBitBufferLen))
    45  
    46  	result := buf.Bytes()
    47  	sbt.contractCreationBits = nil
    48  
    49  	r := bytes.NewReader(result)
    50  	err = sbt.decodeContractCreationBits(r)
    51  	require.NoError(t, err)
    52  
    53  	require.Equal(t, contractCreationBits, sbt.contractCreationBits)
    54  }
    55  
    56  func TestSpanBatchTxsContractCreationCount(t *testing.T) {
    57  	rng := rand.New(rand.NewSource(0x1337))
    58  	chainID := big.NewInt(rng.Int63n(1000))
    59  
    60  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
    61  
    62  	contractCreationBits := rawSpanBatch.txs.contractCreationBits
    63  	contractCreationCount, err := rawSpanBatch.txs.contractCreationCount()
    64  	require.NoError(t, err)
    65  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
    66  
    67  	var sbt spanBatchTxs
    68  	sbt.contractCreationBits = contractCreationBits
    69  	sbt.totalBlockTxCount = totalBlockTxCount
    70  
    71  	var buf bytes.Buffer
    72  	err = sbt.encodeContractCreationBits(&buf)
    73  	require.NoError(t, err)
    74  
    75  	result := buf.Bytes()
    76  	sbt.contractCreationBits = nil
    77  
    78  	r := bytes.NewReader(result)
    79  	err = sbt.decodeContractCreationBits(r)
    80  	require.NoError(t, err)
    81  
    82  	contractCreationCount2, err := sbt.contractCreationCount()
    83  	require.NoError(t, err)
    84  
    85  	require.Equal(t, contractCreationCount, contractCreationCount2)
    86  }
    87  
    88  func TestSpanBatchTxsYParityBits(t *testing.T) {
    89  	rng := rand.New(rand.NewSource(0x7331))
    90  	chainID := big.NewInt(rng.Int63n(1000))
    91  
    92  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
    93  	yParityBits := rawSpanBatch.txs.yParityBits
    94  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
    95  
    96  	var sbt spanBatchTxs
    97  	sbt.yParityBits = yParityBits
    98  	sbt.totalBlockTxCount = totalBlockTxCount
    99  
   100  	var buf bytes.Buffer
   101  	err := sbt.encodeYParityBits(&buf)
   102  	require.NoError(t, err)
   103  
   104  	// yParityBit field is fixed length: single bit
   105  	yParityBitBufferLen := totalBlockTxCount / 8
   106  	if totalBlockTxCount%8 != 0 {
   107  		yParityBitBufferLen++
   108  	}
   109  	require.Equal(t, buf.Len(), int(yParityBitBufferLen))
   110  
   111  	result := buf.Bytes()
   112  	sbt.yParityBits = nil
   113  
   114  	r := bytes.NewReader(result)
   115  	err = sbt.decodeYParityBits(r)
   116  	require.NoError(t, err)
   117  
   118  	require.Equal(t, yParityBits, sbt.yParityBits)
   119  }
   120  
   121  func TestSpanBatchTxsProtectedBits(t *testing.T) {
   122  	rng := rand.New(rand.NewSource(0x7331))
   123  	chainID := big.NewInt(rng.Int63n(1000))
   124  
   125  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
   126  	protectedBits := rawSpanBatch.txs.protectedBits
   127  	txTypes := rawSpanBatch.txs.txTypes
   128  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
   129  	totalLegacyTxCount := rawSpanBatch.txs.totalLegacyTxCount
   130  
   131  	var sbt spanBatchTxs
   132  	sbt.protectedBits = protectedBits
   133  	sbt.totalBlockTxCount = totalBlockTxCount
   134  	sbt.txTypes = txTypes
   135  	sbt.totalLegacyTxCount = totalLegacyTxCount
   136  
   137  	var buf bytes.Buffer
   138  	err := sbt.encodeProtectedBits(&buf)
   139  	require.NoError(t, err)
   140  
   141  	// protectedBit field is fixed length: single bit
   142  	protectedBitBufferLen := totalLegacyTxCount / 8
   143  	require.NoError(t, err)
   144  	if totalLegacyTxCount%8 != 0 {
   145  		protectedBitBufferLen++
   146  	}
   147  	require.Equal(t, buf.Len(), int(protectedBitBufferLen))
   148  
   149  	result := buf.Bytes()
   150  	sbt.protectedBits = nil
   151  
   152  	r := bytes.NewReader(result)
   153  	err = sbt.decodeProtectedBits(r)
   154  	require.NoError(t, err)
   155  
   156  	require.Equal(t, protectedBits, sbt.protectedBits)
   157  }
   158  
   159  func TestSpanBatchTxsTxSigs(t *testing.T) {
   160  	rng := rand.New(rand.NewSource(0x73311337))
   161  	chainID := big.NewInt(rng.Int63n(1000))
   162  
   163  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
   164  	txSigs := rawSpanBatch.txs.txSigs
   165  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
   166  
   167  	var sbt spanBatchTxs
   168  	sbt.totalBlockTxCount = totalBlockTxCount
   169  	sbt.txSigs = txSigs
   170  
   171  	var buf bytes.Buffer
   172  	err := sbt.encodeTxSigsRS(&buf)
   173  	require.NoError(t, err)
   174  
   175  	// txSig field is fixed length: 32 byte + 32 byte = 64 byte
   176  	require.Equal(t, buf.Len(), 64*int(totalBlockTxCount))
   177  
   178  	result := buf.Bytes()
   179  	sbt.txSigs = nil
   180  
   181  	r := bytes.NewReader(result)
   182  	err = sbt.decodeTxSigsRS(r)
   183  	require.NoError(t, err)
   184  
   185  	// v field is not set
   186  	for i := 0; i < int(totalBlockTxCount); i++ {
   187  		require.Equal(t, txSigs[i].r, sbt.txSigs[i].r)
   188  		require.Equal(t, txSigs[i].s, sbt.txSigs[i].s)
   189  	}
   190  }
   191  
   192  func TestSpanBatchTxsTxNonces(t *testing.T) {
   193  	rng := rand.New(rand.NewSource(0x123456))
   194  	chainID := big.NewInt(rng.Int63n(1000))
   195  
   196  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
   197  	txNonces := rawSpanBatch.txs.txNonces
   198  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
   199  
   200  	var sbt spanBatchTxs
   201  	sbt.totalBlockTxCount = totalBlockTxCount
   202  	sbt.txNonces = txNonces
   203  
   204  	var buf bytes.Buffer
   205  	err := sbt.encodeTxNonces(&buf)
   206  	require.NoError(t, err)
   207  
   208  	result := buf.Bytes()
   209  	sbt.txNonces = nil
   210  
   211  	r := bytes.NewReader(result)
   212  	err = sbt.decodeTxNonces(r)
   213  	require.NoError(t, err)
   214  
   215  	require.Equal(t, txNonces, sbt.txNonces)
   216  }
   217  
   218  func TestSpanBatchTxsTxGases(t *testing.T) {
   219  	rng := rand.New(rand.NewSource(0x12345))
   220  	chainID := big.NewInt(rng.Int63n(1000))
   221  
   222  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
   223  	txGases := rawSpanBatch.txs.txGases
   224  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
   225  
   226  	var sbt spanBatchTxs
   227  	sbt.totalBlockTxCount = totalBlockTxCount
   228  	sbt.txGases = txGases
   229  
   230  	var buf bytes.Buffer
   231  	err := sbt.encodeTxGases(&buf)
   232  	require.NoError(t, err)
   233  
   234  	result := buf.Bytes()
   235  	sbt.txGases = nil
   236  
   237  	r := bytes.NewReader(result)
   238  	err = sbt.decodeTxGases(r)
   239  	require.NoError(t, err)
   240  
   241  	require.Equal(t, txGases, sbt.txGases)
   242  }
   243  
   244  func TestSpanBatchTxsTxTos(t *testing.T) {
   245  	rng := rand.New(rand.NewSource(0x54321))
   246  	chainID := big.NewInt(rng.Int63n(1000))
   247  
   248  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
   249  	txTos := rawSpanBatch.txs.txTos
   250  	contractCreationBits := rawSpanBatch.txs.contractCreationBits
   251  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
   252  
   253  	var sbt spanBatchTxs
   254  	sbt.txTos = txTos
   255  	// creation bits and block tx count must be se to decode tos
   256  	sbt.contractCreationBits = contractCreationBits
   257  	sbt.totalBlockTxCount = totalBlockTxCount
   258  
   259  	var buf bytes.Buffer
   260  	err := sbt.encodeTxTos(&buf)
   261  	require.NoError(t, err)
   262  
   263  	// to field is fixed length: 20 bytes
   264  	require.Equal(t, buf.Len(), 20*len(txTos))
   265  
   266  	result := buf.Bytes()
   267  	sbt.txTos = nil
   268  
   269  	r := bytes.NewReader(result)
   270  	err = sbt.decodeTxTos(r)
   271  	require.NoError(t, err)
   272  
   273  	require.Equal(t, txTos, sbt.txTos)
   274  }
   275  
   276  func TestSpanBatchTxsTxDatas(t *testing.T) {
   277  	rng := rand.New(rand.NewSource(0x1234))
   278  	chainID := big.NewInt(rng.Int63n(1000))
   279  
   280  	rawSpanBatch := RandomRawSpanBatch(rng, chainID)
   281  	txDatas := rawSpanBatch.txs.txDatas
   282  	txTypes := rawSpanBatch.txs.txTypes
   283  	totalBlockTxCount := rawSpanBatch.txs.totalBlockTxCount
   284  
   285  	var sbt spanBatchTxs
   286  	sbt.totalBlockTxCount = totalBlockTxCount
   287  
   288  	sbt.txDatas = txDatas
   289  
   290  	var buf bytes.Buffer
   291  	err := sbt.encodeTxDatas(&buf)
   292  	require.NoError(t, err)
   293  
   294  	result := buf.Bytes()
   295  	sbt.txDatas = nil
   296  	sbt.txTypes = nil
   297  
   298  	r := bytes.NewReader(result)
   299  	err = sbt.decodeTxDatas(r)
   300  	require.NoError(t, err)
   301  
   302  	require.Equal(t, txDatas, sbt.txDatas)
   303  	require.Equal(t, txTypes, sbt.txTypes)
   304  }
   305  
   306  func TestSpanBatchTxsRecoverV(t *testing.T) {
   307  	rng := rand.New(rand.NewSource(0x123))
   308  
   309  	chainID := big.NewInt(rng.Int63n(1000))
   310  	londonSigner := types.NewLondonSigner(chainID)
   311  	totalblockTxCount := 20 + rng.Intn(100)
   312  
   313  	cases := []txTypeTest{
   314  		{"unprotected legacy tx", testutils.RandomLegacyTx, types.HomesteadSigner{}},
   315  		{"legacy tx", testutils.RandomLegacyTx, londonSigner},
   316  		{"access list tx", testutils.RandomAccessListTx, londonSigner},
   317  		{"dynamic fee tx", testutils.RandomDynamicFeeTx, londonSigner},
   318  	}
   319  
   320  	for _, testCase := range cases {
   321  		t.Run(testCase.name, func(t *testing.T) {
   322  			var spanBatchTxs spanBatchTxs
   323  			var txTypes []int
   324  			var txSigs []spanBatchSignature
   325  			var originalVs []uint64
   326  			yParityBits := new(big.Int)
   327  			protectedBits := new(big.Int)
   328  			totalLegacyTxCount := 0
   329  			for idx := 0; idx < totalblockTxCount; idx++ {
   330  				tx := testCase.mkTx(rng, testCase.signer)
   331  				txType := tx.Type()
   332  				txTypes = append(txTypes, int(txType))
   333  				var txSig spanBatchSignature
   334  				v, r, s := tx.RawSignatureValues()
   335  				if txType == types.LegacyTxType {
   336  					protectedBit := uint(0)
   337  					if tx.Protected() {
   338  						protectedBit = uint(1)
   339  					}
   340  					protectedBits.SetBit(protectedBits, int(totalLegacyTxCount), protectedBit)
   341  					totalLegacyTxCount++
   342  				}
   343  				// Do not fill in txSig.V
   344  				txSig.r, _ = uint256.FromBig(r)
   345  				txSig.s, _ = uint256.FromBig(s)
   346  				txSigs = append(txSigs, txSig)
   347  				originalVs = append(originalVs, v.Uint64())
   348  				yParityBit, err := convertVToYParity(v.Uint64(), int(tx.Type()))
   349  				require.NoError(t, err)
   350  				yParityBits.SetBit(yParityBits, idx, yParityBit)
   351  			}
   352  
   353  			spanBatchTxs.yParityBits = yParityBits
   354  			spanBatchTxs.txSigs = txSigs
   355  			spanBatchTxs.txTypes = txTypes
   356  			spanBatchTxs.protectedBits = protectedBits
   357  			// recover txSig.v
   358  			err := spanBatchTxs.recoverV(chainID)
   359  			require.NoError(t, err)
   360  
   361  			var recoveredVs []uint64
   362  			for _, txSig := range spanBatchTxs.txSigs {
   363  				recoveredVs = append(recoveredVs, txSig.v)
   364  			}
   365  			require.Equal(t, originalVs, recoveredVs, "recovered v mismatch")
   366  		})
   367  	}
   368  }
   369  
   370  func TestSpanBatchTxsRoundTrip(t *testing.T) {
   371  	rng := rand.New(rand.NewSource(0x73311337))
   372  	chainID := big.NewInt(rng.Int63n(1000))
   373  
   374  	for i := 0; i < 4; i++ {
   375  		rawSpanBatch := RandomRawSpanBatch(rng, chainID)
   376  		sbt := rawSpanBatch.txs
   377  		totalBlockTxCount := sbt.totalBlockTxCount
   378  
   379  		var buf bytes.Buffer
   380  		err := sbt.encode(&buf)
   381  		require.NoError(t, err)
   382  
   383  		result := buf.Bytes()
   384  		r := bytes.NewReader(result)
   385  
   386  		var sbt2 spanBatchTxs
   387  		sbt2.totalBlockTxCount = totalBlockTxCount
   388  		err = sbt2.decode(r)
   389  		require.NoError(t, err)
   390  
   391  		err = sbt2.recoverV(chainID)
   392  		require.NoError(t, err)
   393  
   394  		require.Equal(t, sbt, &sbt2)
   395  	}
   396  }
   397  
   398  func TestSpanBatchTxsRoundTripFullTxs(t *testing.T) {
   399  	rng := rand.New(rand.NewSource(0x13377331))
   400  	chainID := big.NewInt(rng.Int63n(1000))
   401  	londonSigner := types.NewLondonSigner(chainID)
   402  
   403  	cases := []txTypeTest{
   404  		{"unprotected legacy tx", testutils.RandomLegacyTx, types.HomesteadSigner{}},
   405  		{"legacy tx", testutils.RandomLegacyTx, londonSigner},
   406  		{"access list tx", testutils.RandomAccessListTx, londonSigner},
   407  		{"dynamic fee tx", testutils.RandomDynamicFeeTx, londonSigner},
   408  	}
   409  
   410  	for _, testCase := range cases {
   411  		t.Run(testCase.name, func(t *testing.T) {
   412  			for i := 0; i < 4; i++ {
   413  				totalblockTxCounts := uint64(1 + rng.Int()&0xFF)
   414  				var txs [][]byte
   415  				for i := 0; i < int(totalblockTxCounts); i++ {
   416  					tx := testCase.mkTx(rng, testCase.signer)
   417  					rawTx, err := tx.MarshalBinary()
   418  					require.NoError(t, err)
   419  					txs = append(txs, rawTx)
   420  				}
   421  				sbt, err := newSpanBatchTxs(txs, chainID)
   422  				require.NoError(t, err)
   423  
   424  				txs2, err := sbt.fullTxs(chainID)
   425  				require.NoError(t, err)
   426  
   427  				require.Equal(t, txs, txs2)
   428  			}
   429  		})
   430  	}
   431  }
   432  
   433  func TestSpanBatchTxsRecoverVInvalidTxType(t *testing.T) {
   434  	rng := rand.New(rand.NewSource(0x321))
   435  	chainID := big.NewInt(rng.Int63n(1000))
   436  
   437  	var sbt spanBatchTxs
   438  
   439  	sbt.txTypes = []int{types.DepositTxType}
   440  	sbt.txSigs = []spanBatchSignature{{v: 0, r: nil, s: nil}}
   441  	sbt.yParityBits = new(big.Int)
   442  	sbt.protectedBits = new(big.Int)
   443  
   444  	err := sbt.recoverV(chainID)
   445  	require.ErrorContains(t, err, "invalid tx type")
   446  }
   447  
   448  func TestSpanBatchTxsFullTxNotEnoughTxTos(t *testing.T) {
   449  	rng := rand.New(rand.NewSource(0x13572468))
   450  	chainID := big.NewInt(rng.Int63n(1000))
   451  	londonSigner := types.NewLondonSigner(chainID)
   452  
   453  	cases := []txTypeTest{
   454  		{"unprotected legacy tx", testutils.RandomLegacyTx, types.HomesteadSigner{}},
   455  		{"legacy tx", testutils.RandomLegacyTx, londonSigner},
   456  		{"access list tx", testutils.RandomAccessListTx, londonSigner},
   457  		{"dynamic fee tx", testutils.RandomDynamicFeeTx, londonSigner},
   458  	}
   459  
   460  	for _, testCase := range cases {
   461  		t.Run(testCase.name, func(t *testing.T) {
   462  			totalblockTxCounts := uint64(1 + rng.Int()&0xFF)
   463  			var txs [][]byte
   464  			for i := 0; i < int(totalblockTxCounts); i++ {
   465  				tx := testCase.mkTx(rng, testCase.signer)
   466  				rawTx, err := tx.MarshalBinary()
   467  				require.NoError(t, err)
   468  				txs = append(txs, rawTx)
   469  			}
   470  			sbt, err := newSpanBatchTxs(txs, chainID)
   471  			require.NoError(t, err)
   472  
   473  			// drop single to field
   474  			sbt.txTos = sbt.txTos[:len(sbt.txTos)-2]
   475  
   476  			_, err = sbt.fullTxs(chainID)
   477  			require.EqualError(t, err, "tx to not enough")
   478  		})
   479  	}
   480  }
   481  
   482  func TestSpanBatchTxsMaxContractCreationBitsLength(t *testing.T) {
   483  	var sbt spanBatchTxs
   484  	sbt.totalBlockTxCount = 0xFFFFFFFFFFFFFFFF
   485  
   486  	r := bytes.NewReader([]byte{})
   487  	err := sbt.decodeContractCreationBits(r)
   488  	require.ErrorIs(t, err, ErrTooBigSpanBatchSize)
   489  }
   490  
   491  func TestSpanBatchTxsMaxYParityBitsLength(t *testing.T) {
   492  	var sb RawSpanBatch
   493  	sb.blockCount = 0xFFFFFFFFFFFFFFFF
   494  
   495  	r := bytes.NewReader([]byte{})
   496  	err := sb.decodeOriginBits(r)
   497  	require.ErrorIs(t, err, ErrTooBigSpanBatchSize)
   498  }
   499  
   500  func TestSpanBatchTxsMaxProtectedBitsLength(t *testing.T) {
   501  	var sb RawSpanBatch
   502  	sb.txs = &spanBatchTxs{}
   503  	sb.txs.totalLegacyTxCount = 0xFFFFFFFFFFFFFFFF
   504  
   505  	r := bytes.NewReader([]byte{})
   506  	err := sb.txs.decodeProtectedBits(r)
   507  	require.ErrorIs(t, err, ErrTooBigSpanBatchSize)
   508  }