github.com/status-im/status-go@v1.1.0/transactions/testhelpers.go (about)

     1  package transactions
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/big"
     7  	"testing"
     8  
     9  	eth "github.com/ethereum/go-ethereum/common"
    10  	"github.com/ethereum/go-ethereum/core/types"
    11  	"github.com/ethereum/go-ethereum/rpc"
    12  	"github.com/status-im/status-go/rpc/chain"
    13  	mock_client "github.com/status-im/status-go/rpc/chain/mock/client"
    14  	"github.com/status-im/status-go/services/wallet/bigint"
    15  	"github.com/status-im/status-go/services/wallet/common"
    16  
    17  	"github.com/stretchr/testify/mock"
    18  	"github.com/stretchr/testify/require"
    19  )
    20  
    21  type MockETHClient struct {
    22  	mock.Mock
    23  }
    24  
    25  var _ chain.BatchCallClient = (*MockETHClient)(nil)
    26  
    27  func (m *MockETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
    28  	args := m.Called(ctx, b)
    29  	return args.Error(0)
    30  }
    31  
    32  type MockChainClient struct {
    33  	mock.Mock
    34  	mock_client.MockClientInterface
    35  
    36  	Clients map[common.ChainID]*MockETHClient
    37  }
    38  
    39  var _ chain.ClientInterface = (*MockChainClient)(nil)
    40  
    41  func NewMockChainClient() *MockChainClient {
    42  	return &MockChainClient{
    43  		Clients: make(map[common.ChainID]*MockETHClient),
    44  	}
    45  }
    46  
    47  func (m *MockChainClient) SetAvailableClients(chainIDs []common.ChainID) *MockChainClient {
    48  	for _, chainID := range chainIDs {
    49  		if _, ok := m.Clients[chainID]; !ok {
    50  			m.Clients[chainID] = new(MockETHClient)
    51  		}
    52  	}
    53  	return m
    54  }
    55  
    56  func (m *MockChainClient) AbstractEthClient(chainID common.ChainID) (chain.BatchCallClient, error) {
    57  	if _, ok := m.Clients[chainID]; !ok {
    58  		panic(fmt.Sprintf("no mock client for chainID %d", chainID))
    59  	}
    60  	return m.Clients[chainID], nil
    61  }
    62  
    63  func GenerateTestPendingTransactions(start int, count int) []PendingTransaction {
    64  	if count > 127 {
    65  		panic("can't generate more than 127 distinct transactions")
    66  	}
    67  
    68  	txs := make([]PendingTransaction, count)
    69  	for i := start; i < count; i++ {
    70  		txs[i] = PendingTransaction{
    71  			Hash:           eth.HexToHash(fmt.Sprintf("0x1%d", i)),
    72  			From:           eth.HexToAddress(fmt.Sprintf("0x2%d", i)),
    73  			To:             eth.HexToAddress(fmt.Sprintf("0x3%d", i)),
    74  			Type:           RegisterENS,
    75  			AdditionalData: "someuser.stateofus.eth",
    76  			Value:          bigint.BigInt{Int: big.NewInt(int64(i))},
    77  			GasLimit:       bigint.BigInt{Int: big.NewInt(21000)},
    78  			GasPrice:       bigint.BigInt{Int: big.NewInt(int64(i))},
    79  			ChainID:        777,
    80  			Status:         new(TxStatus),
    81  			AutoDelete:     new(bool),
    82  			Symbol:         "ETH",
    83  			Timestamp:      uint64(i),
    84  		}
    85  		*txs[i].Status = Pending  // set to pending by default
    86  		*txs[i].AutoDelete = true // set to true by default
    87  	}
    88  	return txs
    89  }
    90  
    91  // groupSliceInMap groups a slice of S into a map[K][]N using the getKeyValue function to extract the key and new value for each entry
    92  func groupSliceInMap[S any, K comparable, N any](s []S, getKeyValue func(entry S, i int) (K, N)) map[K][]N {
    93  	m := make(map[K][]N)
    94  	for i, x := range s {
    95  		k, v := getKeyValue(x, i)
    96  		m[k] = append(m[k], v)
    97  	}
    98  	return m
    99  }
   100  
   101  func keysInMap[K comparable, V any](m map[K]V) (res []K) {
   102  	if len(m) > 0 {
   103  		res = make([]K, 0, len(m))
   104  	}
   105  
   106  	for k := range m {
   107  		res = append(res, k)
   108  	}
   109  	return
   110  }
   111  
   112  type TestTxSummary struct {
   113  	failStatus  bool
   114  	DontConfirm bool
   115  	// Timestamp will be used to mock the Timestamp if greater than 0
   116  	Timestamp int
   117  }
   118  
   119  type summaryTxPair struct {
   120  	summary  TestTxSummary
   121  	tx       PendingTransaction
   122  	answered bool
   123  }
   124  
   125  func MockTestTransactions(t *testing.T, chainClient *MockChainClient, testTxs []TestTxSummary) []PendingTransaction {
   126  	genTxs := GenerateTestPendingTransactions(0, len(testTxs))
   127  	for i, tx := range testTxs {
   128  		if tx.Timestamp > 0 {
   129  			genTxs[i].Timestamp = uint64(tx.Timestamp)
   130  		}
   131  	}
   132  
   133  	grouped := groupSliceInMap(genTxs, func(tx PendingTransaction, i int) (common.ChainID, summaryTxPair) {
   134  		return tx.ChainID, summaryTxPair{
   135  			summary: testTxs[i],
   136  			tx:      tx,
   137  		}
   138  	})
   139  
   140  	chains := keysInMap(grouped)
   141  	chainClient.SetAvailableClients(chains)
   142  
   143  	for chainID, chainSummaries := range grouped {
   144  		// Mock the one call to getTransactionReceipt
   145  		// It is expected that pending transactions manager will call out of order, therefore match based on hash
   146  		cl := chainClient.Clients[chainID]
   147  		call := cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool {
   148  			if len(b) > len(chainSummaries) {
   149  				return false
   150  			}
   151  			for i := range b {
   152  				for _, localSummary := range chainSummaries {
   153  					// to satisfy gosec: C601 checks
   154  					sum := localSummary
   155  					tx := &sum.tx
   156  					if sum.answered {
   157  						continue
   158  					}
   159  					require.Equal(t, GetTransactionReceiptRPCName, b[i].Method)
   160  					if tx.Hash == b[i].Args[0].(eth.Hash) {
   161  						sum.answered = true
   162  						return true
   163  					}
   164  				}
   165  			}
   166  			return false
   167  		})).Return(nil)
   168  
   169  		call.Run(func(args mock.Arguments) {
   170  			elems := args.Get(1).([]rpc.BatchElem)
   171  			for i := range elems {
   172  				receiptWrapper, ok := elems[i].Result.(*nullableReceipt)
   173  				require.True(t, ok)
   174  				require.NotNil(t, receiptWrapper)
   175  				// Simulate parsing of eth_getTransactionReceipt response
   176  				for _, localSum := range chainSummaries {
   177  					// to satisfy gosec: C601 checks
   178  					sum := localSum
   179  					tx := &sum.tx
   180  					if tx.Hash == elems[i].Args[0].(eth.Hash) {
   181  						if !sum.summary.DontConfirm {
   182  							status := types.ReceiptStatusSuccessful
   183  							if sum.summary.failStatus {
   184  								status = types.ReceiptStatusFailed
   185  							}
   186  
   187  							receiptWrapper.Receipt = &types.Receipt{
   188  								BlockNumber: new(big.Int).SetUint64(1),
   189  								Status:      status,
   190  							}
   191  						}
   192  					}
   193  				}
   194  			}
   195  		})
   196  	}
   197  	return genTxs
   198  }