github.com/status-im/status-go@v1.1.0/services/wallet/router/router_test.go (about)

     1  package router
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"encoding/json"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/status-im/status-go/appdatabase"
    11  	"github.com/status-im/status-go/params"
    12  	"github.com/status-im/status-go/rpc"
    13  	"github.com/status-im/status-go/services/wallet/responses"
    14  	"github.com/status-im/status-go/services/wallet/router/pathprocessor"
    15  	"github.com/status-im/status-go/services/wallet/router/routes"
    16  	"github.com/status-im/status-go/signal"
    17  	"github.com/status-im/status-go/t/helpers"
    18  
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func amountOptionEqual(a, b amountOption) bool {
    24  	return a.amount.Cmp(b.amount) == 0 && a.locked == b.locked
    25  }
    26  
    27  func contains(slice []amountOption, val amountOption) bool {
    28  	for _, item := range slice {
    29  		if amountOptionEqual(item, val) {
    30  			return true
    31  		}
    32  	}
    33  	return false
    34  }
    35  
    36  func amountOptionsMapsEqual(map1, map2 map[uint64][]amountOption) bool {
    37  	if len(map1) != len(map2) {
    38  		return false
    39  	}
    40  
    41  	for key, slice1 := range map1 {
    42  		slice2, ok := map2[key]
    43  		if !ok || len(slice1) != len(slice2) {
    44  			return false
    45  		}
    46  
    47  		for _, val1 := range slice1 {
    48  			if !contains(slice2, val1) {
    49  				return false
    50  			}
    51  		}
    52  
    53  		for _, val2 := range slice2 {
    54  			if !contains(slice1, val2) {
    55  				return false
    56  			}
    57  		}
    58  	}
    59  
    60  	return true
    61  }
    62  
    63  func assertPathsEqual(t *testing.T, expected, actual routes.Route) {
    64  	assert.Equal(t, len(expected), len(actual))
    65  	if len(expected) == 0 {
    66  		return
    67  	}
    68  
    69  	for _, c := range actual {
    70  		found := false
    71  		for _, expC := range expected {
    72  			if c.ProcessorName == expC.ProcessorName &&
    73  				c.FromChain.ChainID == expC.FromChain.ChainID &&
    74  				c.ToChain.ChainID == expC.ToChain.ChainID &&
    75  				c.ApprovalRequired == expC.ApprovalRequired &&
    76  				(expC.AmountOut == nil || c.AmountOut.ToInt().Cmp(expC.AmountOut.ToInt()) == 0) {
    77  				found = true
    78  				break
    79  			}
    80  		}
    81  
    82  		assert.True(t, found)
    83  	}
    84  }
    85  
    86  func setupTestNetworkDB(t *testing.T) (*sql.DB, func()) {
    87  	db, cleanup, err := helpers.SetupTestSQLDB(appdatabase.DbInitializer{}, "wallet-router-tests")
    88  	require.NoError(t, err)
    89  	return db, func() { require.NoError(t, cleanup()) }
    90  }
    91  
    92  func setupRouter(t *testing.T) (*Router, func()) {
    93  	db, cleanTmpDb := setupTestNetworkDB(t)
    94  
    95  	client, _ := rpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, defaultNetworks, db, nil)
    96  
    97  	router := NewRouter(client, nil, nil, nil, nil, nil, nil, nil)
    98  
    99  	transfer := pathprocessor.NewTransferProcessor(nil, nil)
   100  	router.AddPathProcessor(transfer)
   101  
   102  	erc721Transfer := pathprocessor.NewERC721Processor(nil, nil)
   103  	router.AddPathProcessor(erc721Transfer)
   104  
   105  	erc1155Transfer := pathprocessor.NewERC1155Processor(nil, nil)
   106  	router.AddPathProcessor(erc1155Transfer)
   107  
   108  	hop := pathprocessor.NewHopBridgeProcessor(nil, nil, nil, nil)
   109  	router.AddPathProcessor(hop)
   110  
   111  	paraswap := pathprocessor.NewSwapParaswapProcessor(nil, nil, nil)
   112  	router.AddPathProcessor(paraswap)
   113  
   114  	ensRegister := pathprocessor.NewENSReleaseProcessor(nil, nil, nil)
   115  	router.AddPathProcessor(ensRegister)
   116  
   117  	ensRelease := pathprocessor.NewENSReleaseProcessor(nil, nil, nil)
   118  	router.AddPathProcessor(ensRelease)
   119  
   120  	ensPublicKey := pathprocessor.NewENSPublicKeyProcessor(nil, nil, nil)
   121  	router.AddPathProcessor(ensPublicKey)
   122  
   123  	buyStickers := pathprocessor.NewStickersBuyProcessor(nil, nil, nil)
   124  	router.AddPathProcessor(buyStickers)
   125  
   126  	return router, cleanTmpDb
   127  }
   128  
   129  type routerSuggestedRoutesEnvelope struct {
   130  	Type   string                          `json:"type"`
   131  	Routes responses.RouterSuggestedRoutes `json:"event"`
   132  }
   133  
   134  func setupSignalHandler(t *testing.T) (chan responses.RouterSuggestedRoutes, func()) {
   135  	suggestedRoutesCh := make(chan responses.RouterSuggestedRoutes)
   136  	signalHandler := signal.MobileSignalHandler(func(data []byte) {
   137  		var envelope signal.Envelope
   138  		err := json.Unmarshal(data, &envelope)
   139  		assert.NoError(t, err)
   140  		if envelope.Type == string(signal.SuggestedRoutes) {
   141  			var response routerSuggestedRoutesEnvelope
   142  			err := json.Unmarshal(data, &response)
   143  			assert.NoError(t, err)
   144  
   145  			suggestedRoutesCh <- response.Routes
   146  		}
   147  	})
   148  	signal.SetMobileSignalHandler(signalHandler)
   149  	t.Cleanup(signal.ResetMobileSignalHandler)
   150  
   151  	closeFn := func() {
   152  		close(suggestedRoutesCh)
   153  	}
   154  
   155  	return suggestedRoutesCh, closeFn
   156  }
   157  
   158  func TestRouter(t *testing.T) {
   159  	router, cleanTmpDb := setupRouter(t)
   160  	defer cleanTmpDb()
   161  
   162  	suggestedRoutesCh, closeSignalHandler := setupSignalHandler(t)
   163  	defer closeSignalHandler()
   164  
   165  	tests := getNormalTestParamsList()
   166  
   167  	// Test blocking endpoints
   168  	for _, tt := range tests {
   169  		t.Run(tt.name, func(t *testing.T) {
   170  			routes, err := router.SuggestedRoutes(context.Background(), tt.input)
   171  
   172  			if tt.expectedError != nil {
   173  				assert.Error(t, err)
   174  				assert.Equal(t, tt.expectedError.Error(), err.Error())
   175  				if routes == nil {
   176  					assert.Empty(t, tt.expectedCandidates)
   177  				} else {
   178  					assertPathsEqual(t, tt.expectedCandidates, routes.Candidates)
   179  				}
   180  			} else {
   181  				assert.NoError(t, err)
   182  				assertPathsEqual(t, tt.expectedCandidates, routes.Candidates)
   183  			}
   184  		})
   185  	}
   186  
   187  	// Test async endpoints
   188  	for _, tt := range tests {
   189  		router.SuggestedRoutesAsync(tt.input)
   190  
   191  		select {
   192  		case asyncRoutes := <-suggestedRoutesCh:
   193  			assert.Equal(t, tt.input.Uuid, asyncRoutes.Uuid)
   194  			assert.Equal(t, tt.expectedError, asyncRoutes.ErrorResponse)
   195  			assertPathsEqual(t, tt.expectedCandidates, asyncRoutes.Candidates)
   196  			break
   197  		case <-time.After(10 * time.Second):
   198  			t.FailNow()
   199  		}
   200  	}
   201  }
   202  
   203  func TestNoBalanceForTheBestRouteRouter(t *testing.T) {
   204  	router, cleanTmpDb := setupRouter(t)
   205  	defer cleanTmpDb()
   206  
   207  	suggestedRoutesCh, closeSignalHandler := setupSignalHandler(t)
   208  	defer closeSignalHandler()
   209  
   210  	tests := getNoBalanceTestParamsList()
   211  
   212  	// Test blocking endpoints
   213  	for _, tt := range tests {
   214  		t.Run(tt.name, func(t *testing.T) {
   215  
   216  			routes, err := router.SuggestedRoutes(context.Background(), tt.input)
   217  
   218  			if tt.expectedError != nil {
   219  				assert.Error(t, err)
   220  				assert.Equal(t, tt.expectedError.Error(), err.Error())
   221  				if tt.expectedError == ErrNoPositiveBalance {
   222  					assert.Nil(t, routes)
   223  				} else {
   224  					assert.NotNil(t, routes)
   225  					assertPathsEqual(t, tt.expectedCandidates, routes.Candidates)
   226  				}
   227  			} else {
   228  				assert.NoError(t, err)
   229  				assert.Equal(t, len(tt.expectedCandidates), len(routes.Candidates))
   230  				assert.Equal(t, len(tt.expectedBest), len(routes.Best))
   231  				assertPathsEqual(t, tt.expectedCandidates, routes.Candidates)
   232  				assertPathsEqual(t, tt.expectedBest, routes.Best)
   233  			}
   234  		})
   235  	}
   236  
   237  	// Test async endpoints
   238  	for _, tt := range tests {
   239  		t.Run(tt.name, func(t *testing.T) {
   240  
   241  			router.SuggestedRoutesAsync(tt.input)
   242  
   243  			select {
   244  			case asyncRoutes := <-suggestedRoutesCh:
   245  				assert.Equal(t, tt.input.Uuid, asyncRoutes.Uuid)
   246  				assert.Equal(t, tt.expectedError, asyncRoutes.ErrorResponse)
   247  				assertPathsEqual(t, tt.expectedCandidates, asyncRoutes.Candidates)
   248  				if tt.expectedError == nil {
   249  					assertPathsEqual(t, tt.expectedBest, asyncRoutes.Best)
   250  				}
   251  				break
   252  			case <-time.After(10 * time.Second):
   253  				t.FailNow()
   254  			}
   255  		})
   256  	}
   257  }
   258  
   259  func TestAmountOptions(t *testing.T) {
   260  	router, cleanTmpDb := setupRouter(t)
   261  	defer cleanTmpDb()
   262  
   263  	tests := getAmountOptionsTestParamsList()
   264  
   265  	for _, tt := range tests {
   266  		t.Run(tt.name, func(t *testing.T) {
   267  
   268  			selectedFromChains, _, err := router.getSelectedChains(tt.input)
   269  			assert.NoError(t, err)
   270  
   271  			router.SetTestBalanceMap(tt.input.TestParams.BalanceMap)
   272  			amountOptions, err := router.findOptionsForSendingAmount(tt.input, selectedFromChains)
   273  			assert.NoError(t, err)
   274  
   275  			assert.Equal(t, len(tt.expectedAmountOptions), len(amountOptions))
   276  			assert.True(t, amountOptionsMapsEqual(tt.expectedAmountOptions, amountOptions))
   277  		})
   278  	}
   279  }