github.com/status-im/status-go@v1.1.0/services/connector/commands/switch_ethereum_chain.go (about)

     1  package commands
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"slices"
     7  	"strconv"
     8  
     9  	"github.com/status-im/status-go/services/connector/chainutils"
    10  	persistence "github.com/status-im/status-go/services/connector/database"
    11  	walletCommon "github.com/status-im/status-go/services/wallet/common"
    12  	"github.com/status-im/status-go/signal"
    13  )
    14  
    15  // errors
    16  var (
    17  	ErrNoActiveNetworks     = errors.New("no active networks")
    18  	ErrUnsupportedNetwork   = errors.New("unsupported network")
    19  	ErrNoChainIDParamsFound = errors.New("no chain id in params found")
    20  )
    21  
    22  type SwitchEthereumChainCommand struct {
    23  	NetworkManager NetworkManagerInterface
    24  	Db             *sql.DB
    25  }
    26  
    27  func hexStringToUint64(s string) (uint64, error) {
    28  	if len(s) > 2 && s[:2] == "0x" {
    29  		value, err := strconv.ParseUint(s[2:], 16, 64)
    30  		if err != nil {
    31  			return 0, err
    32  		}
    33  		return value, nil
    34  	}
    35  	return 0, ErrUnsupportedNetwork
    36  }
    37  
    38  func (r *RPCRequest) getChainID() (uint64, error) {
    39  	if r.Params == nil || len(r.Params) == 0 {
    40  		return 0, ErrEmptyRPCParams
    41  	}
    42  
    43  	chainIds := r.Params[0].(map[string]interface{})
    44  
    45  	for _, chainId := range chainIds {
    46  		return hexStringToUint64(chainId.(string))
    47  	}
    48  
    49  	return 0, nil
    50  }
    51  
    52  func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) {
    53  	return chainutils.GetSupportedChainIDs(c.NetworkManager)
    54  }
    55  
    56  func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (interface{}, error) {
    57  	err := request.Validate()
    58  	if err != nil {
    59  		return "", err
    60  	}
    61  
    62  	requestedChainID, err := request.getChainID()
    63  	if err != nil {
    64  		return "", err
    65  	}
    66  
    67  	chainIDs, err := c.getSupportedChainIDs()
    68  	if err != nil {
    69  		return "", err
    70  	}
    71  
    72  	if !slices.Contains(chainIDs, requestedChainID) {
    73  		return "", ErrUnsupportedNetwork
    74  	}
    75  
    76  	dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL)
    77  	if err != nil {
    78  		return "", err
    79  	}
    80  
    81  	if dApp == nil {
    82  		return "", ErrDAppIsNotPermittedByUser
    83  	}
    84  
    85  	dApp.ChainID = requestedChainID
    86  
    87  	err = persistence.UpsertDApp(c.Db, dApp)
    88  	if err != nil {
    89  		return "", err
    90  	}
    91  
    92  	chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String())
    93  	if err != nil {
    94  		return "", err
    95  	}
    96  
    97  	signal.SendConnectorDAppChainIdSwitched(signal.ConnectorDAppChainIdSwitchedSignal{
    98  		URL:     request.URL,
    99  		ChainId: chainId,
   100  	})
   101  
   102  	return chainId, nil
   103  }