github.com/status-im/status-go@v1.1.0/services/wallet/transfer/query.go (about)

     1  package transfer
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"math/big"
     8  
     9  	"github.com/ethereum/go-ethereum/common"
    10  	"github.com/ethereum/go-ethereum/core/types"
    11  	"github.com/status-im/status-go/services/wallet/bigint"
    12  	w_common "github.com/status-im/status-go/services/wallet/common"
    13  )
    14  
    15  const baseTransfersQuery = "SELECT hash, type, blk_hash, blk_number, timestamp, address, tx, sender, receipt, log, network_id, base_gas_fee, COALESCE(multi_transaction_id, 0) %s FROM transfers"
    16  const preloadedTransfersQuery = "SELECT hash, type, address, log, token_id, amount_padded128hex FROM transfers"
    17  
    18  type transfersQuery struct {
    19  	buf        *bytes.Buffer
    20  	args       []interface{}
    21  	whereAdded bool
    22  	subQuery   bool
    23  }
    24  
    25  func newTransfersQuery() *transfersQuery {
    26  	newQuery := newEmptyQuery()
    27  	transfersQueryString := fmt.Sprintf(baseTransfersQuery, "")
    28  	newQuery.buf.WriteString(transfersQueryString)
    29  	return newQuery
    30  }
    31  
    32  func newTransfersQueryForPreloadedTransactions() *transfersQuery {
    33  	newQuery := newEmptyQuery()
    34  	newQuery.buf.WriteString(preloadedTransfersQuery)
    35  	return newQuery
    36  }
    37  
    38  func newSubQuery() *transfersQuery {
    39  	newQuery := newEmptyQuery()
    40  	newQuery.subQuery = true
    41  	return newQuery
    42  }
    43  
    44  func newEmptyQuery() *transfersQuery {
    45  	buf := bytes.NewBuffer(nil)
    46  	return &transfersQuery{buf: buf}
    47  }
    48  
    49  func (q *transfersQuery) addWhereSeparator(separator SeparatorType) {
    50  	if !q.whereAdded {
    51  		if !q.subQuery {
    52  			q.buf.WriteString(" WHERE")
    53  		}
    54  		q.whereAdded = true
    55  	} else if separator == OrSeparator {
    56  		q.buf.WriteString(" OR")
    57  	} else if separator == AndSeparator {
    58  		q.buf.WriteString(" AND")
    59  	} else if separator != NoSeparator {
    60  		panic("Unknown separator. Need to handle current SeparatorType value")
    61  	}
    62  }
    63  
    64  type SeparatorType int
    65  
    66  // Beware: please update addWhereSeparator if changing this enum
    67  const (
    68  	NoSeparator SeparatorType = iota + 1
    69  	OrSeparator
    70  	AndSeparator
    71  )
    72  
    73  // addSubQuery adds where clause formed as: WHERE/<separator> (<subQuery>)
    74  func (q *transfersQuery) addSubQuery(subQuery *transfersQuery, separator SeparatorType) *transfersQuery {
    75  	q.addWhereSeparator(separator)
    76  	q.buf.WriteString(" (")
    77  	q.buf.Write(subQuery.buf.Bytes())
    78  	q.buf.WriteString(")")
    79  	q.args = append(q.args, subQuery.args...)
    80  	return q
    81  }
    82  
    83  func (q *transfersQuery) FilterStart(start *big.Int) *transfersQuery {
    84  	if start != nil {
    85  		q.addWhereSeparator(AndSeparator)
    86  		q.buf.WriteString(" blk_number >= ?")
    87  		q.args = append(q.args, (*bigint.SQLBigInt)(start))
    88  	}
    89  	return q
    90  }
    91  
    92  func (q *transfersQuery) FilterEnd(end *big.Int) *transfersQuery {
    93  	if end != nil {
    94  		q.addWhereSeparator(AndSeparator)
    95  		q.buf.WriteString(" blk_number <= ?")
    96  		q.args = append(q.args, (*bigint.SQLBigInt)(end))
    97  	}
    98  	return q
    99  }
   100  
   101  func (q *transfersQuery) FilterLoaded(loaded int) *transfersQuery {
   102  	q.addWhereSeparator(AndSeparator)
   103  	q.buf.WriteString(" loaded = ? ")
   104  	q.args = append(q.args, loaded)
   105  
   106  	return q
   107  }
   108  
   109  func (q *transfersQuery) FilterNetwork(network uint64) *transfersQuery {
   110  	q.addWhereSeparator(AndSeparator)
   111  	q.buf.WriteString(" network_id = ?")
   112  	q.args = append(q.args, network)
   113  	return q
   114  }
   115  
   116  func (q *transfersQuery) FilterAddress(address common.Address) *transfersQuery {
   117  	q.addWhereSeparator(AndSeparator)
   118  	q.buf.WriteString(" address = ?")
   119  	q.args = append(q.args, address)
   120  	return q
   121  }
   122  
   123  func (q *transfersQuery) FilterTransactionID(hash common.Hash) *transfersQuery {
   124  	q.addWhereSeparator(AndSeparator)
   125  	q.buf.WriteString(" hash = ?")
   126  	q.args = append(q.args, hash)
   127  	return q
   128  }
   129  
   130  func (q *transfersQuery) FilterTransactionHash(hash common.Hash) *transfersQuery {
   131  	q.addWhereSeparator(AndSeparator)
   132  	q.buf.WriteString(" tx_hash = ?")
   133  	q.args = append(q.args, hash)
   134  	return q
   135  }
   136  
   137  func (q *transfersQuery) FilterBlockHash(blockHash common.Hash) *transfersQuery {
   138  	q.addWhereSeparator(AndSeparator)
   139  	q.buf.WriteString(" blk_hash = ?")
   140  	q.args = append(q.args, blockHash)
   141  	return q
   142  }
   143  
   144  func (q *transfersQuery) FilterBlockNumber(blockNumber *big.Int) *transfersQuery {
   145  	q.addWhereSeparator(AndSeparator)
   146  	q.buf.WriteString(" blk_number = ?")
   147  	q.args = append(q.args, (*bigint.SQLBigInt)(blockNumber))
   148  	return q
   149  }
   150  
   151  func ascendingString(ascending bool) string {
   152  	if ascending {
   153  		return "ASC"
   154  	}
   155  	return "DESC"
   156  }
   157  
   158  func (q *transfersQuery) SortByBlockNumberAndHash() *transfersQuery {
   159  	q.buf.WriteString(" ORDER BY blk_number DESC, hash ASC ")
   160  	return q
   161  }
   162  
   163  func (q *transfersQuery) SortByTimestamp(ascending bool) *transfersQuery {
   164  	q.buf.WriteString(fmt.Sprintf(" ORDER BY timestamp %s ", ascendingString(ascending)))
   165  	return q
   166  }
   167  
   168  func (q *transfersQuery) Limit(pageSize int64) *transfersQuery {
   169  	q.buf.WriteString(" LIMIT ?")
   170  	q.args = append(q.args, pageSize)
   171  	return q
   172  }
   173  
   174  func (q *transfersQuery) FilterType(txType w_common.Type) *transfersQuery {
   175  	q.addWhereSeparator(AndSeparator)
   176  	q.buf.WriteString(" type = ?")
   177  	q.args = append(q.args, txType)
   178  	return q
   179  }
   180  
   181  func (q *transfersQuery) FilterTokenAddress(address common.Address) *transfersQuery {
   182  	q.addWhereSeparator(AndSeparator)
   183  	q.buf.WriteString(" token_address = ?")
   184  	q.args = append(q.args, address)
   185  	return q
   186  }
   187  
   188  func (q *transfersQuery) FilterTokenID(tokenID *big.Int) *transfersQuery {
   189  	q.addWhereSeparator(AndSeparator)
   190  	q.buf.WriteString(" token_id = ?")
   191  	q.args = append(q.args, (*bigint.SQLBigIntBytes)(tokenID))
   192  	return q
   193  }
   194  
   195  func (q *transfersQuery) String() string {
   196  	return q.buf.String()
   197  }
   198  
   199  func (q *transfersQuery) Args() []interface{} {
   200  	return q.args
   201  }
   202  
   203  func (q *transfersQuery) TransferScan(rows *sql.Rows) (rst []Transfer, err error) {
   204  	for rows.Next() {
   205  		transfer := Transfer{
   206  			BlockNumber: &big.Int{},
   207  			Transaction: &types.Transaction{},
   208  			Receipt:     &types.Receipt{},
   209  			Log:         &types.Log{},
   210  		}
   211  		err = rows.Scan(
   212  			&transfer.ID, &transfer.Type, &transfer.BlockHash,
   213  			(*bigint.SQLBigInt)(transfer.BlockNumber), &transfer.Timestamp, &transfer.Address,
   214  			&JSONBlob{transfer.Transaction}, &transfer.From, &JSONBlob{transfer.Receipt}, &JSONBlob{transfer.Log}, &transfer.NetworkID, &transfer.BaseGasFees, &transfer.MultiTransactionID)
   215  		if err != nil {
   216  			return nil, err
   217  		}
   218  		rst = append(rst, transfer)
   219  	}
   220  
   221  	return rst, nil
   222  }
   223  
   224  func (q *transfersQuery) PreloadedTransactionScan(rows *sql.Rows) (rst []*PreloadedTransaction, err error) {
   225  	transfers := make([]Transfer, 0)
   226  	for rows.Next() {
   227  		transfer := Transfer{
   228  			Log: &types.Log{},
   229  		}
   230  		tokenValue := sql.NullString{}
   231  		tokenID := sql.RawBytes{}
   232  		err = rows.Scan(
   233  			&transfer.ID, &transfer.Type,
   234  			&transfer.Address,
   235  			&JSONBlob{transfer.Log},
   236  			&tokenID, &tokenValue)
   237  
   238  		if len(tokenID) > 0 {
   239  			transfer.TokenID = new(big.Int).SetBytes(tokenID)
   240  		}
   241  
   242  		if tokenValue.Valid {
   243  			var ok bool
   244  			transfer.TokenValue, ok = new(big.Int).SetString(tokenValue.String, 16)
   245  			if !ok {
   246  				panic("failed to parse token value")
   247  			}
   248  		}
   249  
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  		transfers = append(transfers, transfer)
   254  	}
   255  
   256  	rst = make([]*PreloadedTransaction, 0, len(transfers))
   257  
   258  	for _, transfer := range transfers {
   259  		preloadedTransaction := &PreloadedTransaction{
   260  			ID:      transfer.ID,
   261  			Type:    transfer.Type,
   262  			Address: transfer.Address,
   263  			Log:     transfer.Log,
   264  			TokenID: transfer.TokenID,
   265  			Value:   transfer.TokenValue,
   266  		}
   267  
   268  		rst = append(rst, preloadedTransaction)
   269  	}
   270  
   271  	return rst, nil
   272  }