github.com/status-im/status-go@v1.1.0/rpc/network/network.go (about)

     1  package network
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  
     8  	"github.com/status-im/status-go/multiaccounts/accounts"
     9  	"github.com/status-im/status-go/params"
    10  )
    11  
    12  var SepoliaChainIDs = []uint64{11155111, 421614, 11155420}
    13  
    14  var GoerliChainIDs = []uint64{5, 421613, 420}
    15  
    16  type CombinedNetwork struct {
    17  	Prod *params.Network
    18  	Test *params.Network
    19  }
    20  
    21  const baseQuery = "SELECT chain_id, chain_name, rpc_url, original_rpc_url, fallback_url, original_fallback_url, block_explorer_url, icon_url, native_currency_name, native_currency_symbol, native_currency_decimals, is_test, layer, enabled, chain_color, short_name, related_chain_id FROM networks"
    22  
    23  func newNetworksQuery() *networksQuery {
    24  	buf := bytes.NewBuffer(nil)
    25  	buf.WriteString(baseQuery)
    26  	return &networksQuery{buf: buf}
    27  }
    28  
    29  type networksQuery struct {
    30  	buf   *bytes.Buffer
    31  	args  []interface{}
    32  	added bool
    33  }
    34  
    35  func (nq *networksQuery) andOrWhere() {
    36  	if nq.added {
    37  		nq.buf.WriteString(" AND")
    38  	} else {
    39  		nq.buf.WriteString(" WHERE")
    40  	}
    41  }
    42  
    43  func (nq *networksQuery) filterEnabled(enabled bool) *networksQuery {
    44  	nq.andOrWhere()
    45  	nq.added = true
    46  	nq.buf.WriteString(" enabled = ?")
    47  	nq.args = append(nq.args, enabled)
    48  	return nq
    49  }
    50  
    51  func (nq *networksQuery) filterChainID(chainID uint64) *networksQuery {
    52  	nq.andOrWhere()
    53  	nq.added = true
    54  	nq.buf.WriteString(" chain_id = ?")
    55  	nq.args = append(nq.args, chainID)
    56  	return nq
    57  }
    58  
    59  func (nq *networksQuery) exec(db *sql.DB) ([]*params.Network, error) {
    60  	rows, err := db.Query(nq.buf.String(), nq.args...)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	var res []*params.Network
    65  	defer rows.Close()
    66  	for rows.Next() {
    67  		network := params.Network{}
    68  		err := rows.Scan(
    69  			&network.ChainID, &network.ChainName, &network.RPCURL, &network.OriginalRPCURL, &network.FallbackURL, &network.OriginalFallbackURL,
    70  			&network.BlockExplorerURL, &network.IconURL, &network.NativeCurrencyName, &network.NativeCurrencySymbol,
    71  			&network.NativeCurrencyDecimals, &network.IsTest, &network.Layer, &network.Enabled, &network.ChainColor, &network.ShortName,
    72  			&network.RelatedChainID,
    73  		)
    74  		if err != nil {
    75  			return nil, err
    76  		}
    77  
    78  		res = append(res, &network)
    79  	}
    80  
    81  	return res, err
    82  }
    83  
    84  type ManagerInterface interface {
    85  	Get(onlyEnabled bool) ([]*params.Network, error)
    86  	GetAll() ([]*params.Network, error)
    87  	Find(chainID uint64) *params.Network
    88  	GetConfiguredNetworks() []params.Network
    89  	GetTestNetworksEnabled() (bool, error)
    90  }
    91  
    92  type Manager struct {
    93  	db                 *sql.DB
    94  	configuredNetworks []params.Network
    95  	accountsDB         *accounts.Database
    96  }
    97  
    98  func NewManager(db *sql.DB) *Manager {
    99  	accountsDB, err := accounts.NewDB(db)
   100  	if err != nil {
   101  		return nil
   102  	}
   103  	return &Manager{
   104  		db:         db,
   105  		accountsDB: accountsDB,
   106  	}
   107  }
   108  
   109  func find(chainID uint64, networks []params.Network) int {
   110  	for i := range networks {
   111  		if networks[i].ChainID == chainID {
   112  			return i
   113  		}
   114  	}
   115  	return -1
   116  }
   117  
   118  func (nm *Manager) Init(networks []params.Network) error {
   119  	if networks == nil {
   120  		return nil
   121  	}
   122  	nm.configuredNetworks = networks
   123  
   124  	var errors string
   125  	currentNetworks, _ := nm.Get(false)
   126  
   127  	// Delete networks which are not supported any more
   128  	for i := range currentNetworks {
   129  		if find(currentNetworks[i].ChainID, networks) == -1 {
   130  			err := nm.Delete(currentNetworks[i].ChainID)
   131  			if err != nil {
   132  				errors += fmt.Sprintf("error deleting network with ChainID: %d, %s", currentNetworks[i].ChainID, err.Error())
   133  			}
   134  		}
   135  	}
   136  
   137  	// Add new networks and update related chain id for the old ones
   138  	for i := range networks {
   139  		found := false
   140  		networks[i].OriginalRPCURL = networks[i].RPCURL
   141  		networks[i].OriginalFallbackURL = networks[i].FallbackURL
   142  
   143  		for j := range currentNetworks {
   144  			if currentNetworks[j].ChainID == networks[i].ChainID {
   145  				found = true
   146  				if currentNetworks[j].RelatedChainID != networks[i].RelatedChainID {
   147  					// Update fallback_url if it's different
   148  					err := nm.UpdateRelatedChainID(currentNetworks[j].ChainID, networks[i].RelatedChainID)
   149  					if err != nil {
   150  						errors += fmt.Sprintf("error updating network fallback_url for ChainID: %d, %s", currentNetworks[j].ChainID, err.Error())
   151  					}
   152  				}
   153  
   154  				if networks[i].OriginalRPCURL != currentNetworks[j].OriginalRPCURL && currentNetworks[j].RPCURL == currentNetworks[j].OriginalRPCURL {
   155  					err := nm.updateRPCURL(networks[i].ChainID, networks[i].OriginalRPCURL)
   156  					if err != nil {
   157  						errors += fmt.Sprintf("error updating rpc url for ChainID: %d, %s", currentNetworks[j].ChainID, err.Error())
   158  					}
   159  				}
   160  
   161  				if networks[i].OriginalFallbackURL != currentNetworks[j].OriginalFallbackURL && currentNetworks[j].FallbackURL == currentNetworks[j].OriginalFallbackURL {
   162  					err := nm.updateFallbackURL(networks[i].ChainID, networks[i].OriginalFallbackURL)
   163  					if err != nil {
   164  						errors += fmt.Sprintf("error updating rpc url for ChainID: %d, %s", currentNetworks[j].ChainID, err.Error())
   165  					}
   166  				}
   167  
   168  				err := nm.updateOriginalURLs(networks[i].ChainID, networks[i].OriginalRPCURL, networks[i].OriginalFallbackURL)
   169  				if err != nil {
   170  					errors += fmt.Sprintf("error updating network original url for ChainID: %d, %s", currentNetworks[j].ChainID, err.Error())
   171  				}
   172  
   173  				break
   174  			}
   175  		}
   176  
   177  		if !found {
   178  			// Insert new network
   179  			err := nm.Upsert(&networks[i])
   180  			if err != nil {
   181  				errors += fmt.Sprintf("error inserting network with ChainID: %d, %s", networks[i].ChainID, err.Error())
   182  			}
   183  		}
   184  	}
   185  
   186  	if len(errors) > 0 {
   187  		return fmt.Errorf(errors)
   188  	}
   189  
   190  	return nil
   191  }
   192  
   193  func (nm *Manager) Upsert(network *params.Network) error {
   194  	_, err := nm.db.Exec(
   195  		"INSERT OR REPLACE INTO networks (chain_id, chain_name, rpc_url, original_rpc_url, fallback_url, original_fallback_url, block_explorer_url, icon_url, native_currency_name, native_currency_symbol, native_currency_decimals, is_test, layer, enabled, chain_color, short_name, related_chain_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
   196  		network.ChainID, network.ChainName, network.RPCURL, network.OriginalRPCURL, network.FallbackURL, network.OriginalFallbackURL, network.BlockExplorerURL, network.IconURL,
   197  		network.NativeCurrencyName, network.NativeCurrencySymbol, network.NativeCurrencyDecimals,
   198  		network.IsTest, network.Layer, network.Enabled, network.ChainColor, network.ShortName,
   199  		network.RelatedChainID,
   200  	)
   201  	return err
   202  }
   203  
   204  func (nm *Manager) Delete(chainID uint64) error {
   205  	_, err := nm.db.Exec("DELETE FROM networks WHERE chain_id = ?", chainID)
   206  	return err
   207  }
   208  
   209  func (nm *Manager) UpdateRelatedChainID(chainID uint64, relatedChainID uint64) error {
   210  	_, err := nm.db.Exec(`UPDATE networks SET related_chain_id = ? WHERE chain_id = ?`, relatedChainID, chainID)
   211  	return err
   212  }
   213  
   214  func (nm *Manager) updateRPCURL(chainID uint64, rpcURL string) error {
   215  	_, err := nm.db.Exec(`UPDATE networks SET rpc_url = ? WHERE chain_id = ?`, rpcURL, chainID)
   216  	return err
   217  }
   218  
   219  func (nm *Manager) updateFallbackURL(chainID uint64, fallbackURL string) error {
   220  	_, err := nm.db.Exec(`UPDATE networks SET fallback_url = ? WHERE chain_id = ?`, fallbackURL, chainID)
   221  	return err
   222  }
   223  
   224  func (nm *Manager) updateOriginalURLs(chainID uint64, originalRPCURL, OriginalFallbackURL string) error {
   225  	_, err := nm.db.Exec(`UPDATE networks SET original_rpc_url = ?, original_fallback_url = ?  WHERE chain_id = ?`, originalRPCURL, OriginalFallbackURL, chainID)
   226  	return err
   227  }
   228  
   229  func (nm *Manager) Find(chainID uint64) *params.Network {
   230  	networks, err := newNetworksQuery().filterChainID(chainID).exec(nm.db)
   231  	if len(networks) != 1 || err != nil {
   232  		return nil
   233  	}
   234  	setDefaultRPCURL(networks, nm.configuredNetworks)
   235  	return networks[0]
   236  }
   237  
   238  func (nm *Manager) GetAll() ([]*params.Network, error) {
   239  	query := newNetworksQuery()
   240  	networks, err := query.exec(nm.db)
   241  	setDefaultRPCURL(networks, nm.configuredNetworks)
   242  	return networks, err
   243  }
   244  
   245  func (nm *Manager) Get(onlyEnabled bool) ([]*params.Network, error) {
   246  	isGoerliEnabled, err := nm.accountsDB.GetIsGoerliEnabled()
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  
   251  	query := newNetworksQuery()
   252  	if onlyEnabled {
   253  		query.filterEnabled(true)
   254  	}
   255  
   256  	networks, err := query.exec(nm.db)
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  
   261  	var results []*params.Network
   262  	for _, network := range networks {
   263  		if isGoerliEnabled {
   264  			found := false
   265  			for _, chainID := range SepoliaChainIDs {
   266  				if network.ChainID == chainID {
   267  					found = true
   268  					break
   269  				}
   270  			}
   271  			if found {
   272  				continue
   273  			}
   274  		}
   275  
   276  		if !isGoerliEnabled {
   277  			found := false
   278  
   279  			for _, chainID := range GoerliChainIDs {
   280  				if network.ChainID == chainID {
   281  					found = true
   282  					break
   283  				}
   284  			}
   285  			if found {
   286  				continue
   287  			}
   288  		}
   289  
   290  		configuredNetwork, err := findNetwork(nm.configuredNetworks, network.ChainID)
   291  		if err != nil {
   292  			addDefaultRPCURL(network, configuredNetwork)
   293  		}
   294  
   295  		results = append(results, network)
   296  	}
   297  
   298  	return results, nil
   299  }
   300  
   301  func (nm *Manager) GetCombinedNetworks() ([]*CombinedNetwork, error) {
   302  	networks, err := nm.Get(false)
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  	var combinedNetworks []*CombinedNetwork
   307  	for _, network := range networks {
   308  		found := false
   309  		for _, n := range combinedNetworks {
   310  			if (n.Test != nil && (network.ChainID == n.Test.RelatedChainID || n.Test.ChainID == network.RelatedChainID)) || (n.Prod != nil && (network.ChainID == n.Prod.RelatedChainID || n.Prod.ChainID == network.RelatedChainID)) {
   311  				found = true
   312  				if network.IsTest {
   313  					n.Test = network
   314  					break
   315  				} else {
   316  					n.Prod = network
   317  					break
   318  				}
   319  			}
   320  		}
   321  
   322  		if found {
   323  			continue
   324  		}
   325  
   326  		newCombined := &CombinedNetwork{}
   327  		if network.IsTest {
   328  			newCombined.Test = network
   329  		} else {
   330  			newCombined.Prod = network
   331  		}
   332  		combinedNetworks = append(combinedNetworks, newCombined)
   333  	}
   334  
   335  	return combinedNetworks, nil
   336  }
   337  
   338  func (nm *Manager) GetConfiguredNetworks() []params.Network {
   339  	return nm.configuredNetworks
   340  }
   341  
   342  func (nm *Manager) GetTestNetworksEnabled() (result bool, err error) {
   343  	return nm.accountsDB.GetTestNetworksEnabled()
   344  }
   345  
   346  // Returns all networks for active mode (test/prod) and in case of test mode,
   347  // returns either Goerli or Sepolia networks based on the value of isGoerliEnabled
   348  func (nm *Manager) GetActiveNetworks() ([]*params.Network, error) {
   349  	areTestNetworksEnabled, err := nm.GetTestNetworksEnabled()
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  
   354  	networks, err := nm.Get(false)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  	availableNetworks := make([]*params.Network, 0)
   359  	for _, network := range networks {
   360  		if network.IsTest != areTestNetworksEnabled {
   361  			continue
   362  		}
   363  		availableNetworks = append(availableNetworks, network)
   364  	}
   365  
   366  	return availableNetworks, nil
   367  }
   368  
   369  func findNetwork(networks []params.Network, chainID uint64) (params.Network, error) {
   370  	for _, network := range networks {
   371  		if network.ChainID == chainID {
   372  			return network, nil
   373  		}
   374  	}
   375  	return params.Network{}, fmt.Errorf("network not found")
   376  }
   377  
   378  func addDefaultRPCURL(target *params.Network, source params.Network) {
   379  	target.DefaultRPCURL = source.DefaultRPCURL
   380  	target.DefaultFallbackURL = source.DefaultFallbackURL
   381  }
   382  
   383  func setDefaultRPCURL(target []*params.Network, source []params.Network) {
   384  	for i := range target {
   385  		for j := range source {
   386  			if target[i].ChainID == source[j].ChainID {
   387  				addDefaultRPCURL(target[i], source[j])
   388  				break
   389  			}
   390  		}
   391  	}
   392  }