code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/transfers.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"code.vegaprotocol.io/vega/datanode/entities"
    25  	"code.vegaprotocol.io/vega/datanode/metrics"
    26  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    27  	"code.vegaprotocol.io/vega/protos/vega"
    28  
    29  	"github.com/georgysavva/scany/pgxscan"
    30  	"github.com/jackc/pgx/v4"
    31  )
    32  
    33  var transfersOrdering = TableOrdering{
    34  	ColumnOrdering{Name: "vega_time", Sorting: ASC},
    35  	ColumnOrdering{Name: "id", Sorting: ASC},
    36  }
    37  
    38  type Transfers struct {
    39  	*ConnectionSource
    40  }
    41  
    42  type ListTransfersFilters struct {
    43  	FromEpoch       *uint64
    44  	ToEpoch         *uint64
    45  	Scope           *entities.TransferScope
    46  	Status          *entities.TransferStatus
    47  	GameID          *entities.GameID
    48  	FromAccountType *vega.AccountType
    49  	ToAccountType   *vega.AccountType
    50  }
    51  
    52  func NewTransfers(connectionSource *ConnectionSource) *Transfers {
    53  	return &Transfers{
    54  		ConnectionSource: connectionSource,
    55  	}
    56  }
    57  
    58  func (t *Transfers) Upsert(ctx context.Context, transfer *entities.Transfer) error {
    59  	defer metrics.StartSQLQuery("Transfers", "Upsert")()
    60  	query := `INSERT INTO transfers(
    61  				id,
    62  				tx_hash,
    63  				vega_time,
    64  				from_account_id,
    65  				to_account_id,
    66  				asset_id,
    67  				amount,
    68  				reference,
    69  				status,
    70  				transfer_type,
    71  				deliver_on,
    72  				start_epoch,
    73  				end_epoch,
    74  				factor,
    75  				dispatch_strategy,
    76  				reason,
    77  				game_id
    78  			)
    79  					VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
    80  					ON CONFLICT (id, vega_time) DO UPDATE
    81  					SET
    82  				from_account_id=excluded.from_account_id,
    83  				to_account_id=excluded.to_account_id,
    84  				asset_id=excluded.asset_id,
    85  				amount=excluded.amount,
    86  				reference=excluded.reference,
    87  				status=excluded.status,
    88  				transfer_type=excluded.transfer_type,
    89  				deliver_on=excluded.deliver_on,
    90  				start_epoch=excluded.start_epoch,
    91  				end_epoch=excluded.end_epoch,
    92  				factor=excluded.factor,
    93  				dispatch_strategy=excluded.dispatch_strategy,
    94  				reason=excluded.reason,
    95  				tx_hash=excluded.tx_hash,
    96  				game_id=excluded.game_id
    97  				;`
    98  
    99  	if _, err := t.Exec(ctx, query, transfer.ID, transfer.TxHash, transfer.VegaTime, transfer.FromAccountID, transfer.ToAccountID,
   100  		transfer.AssetID, transfer.Amount, transfer.Reference, transfer.Status, transfer.TransferType,
   101  		transfer.DeliverOn, transfer.StartEpoch, transfer.EndEpoch, transfer.Factor, transfer.DispatchStrategy, transfer.Reason, transfer.GameID); err != nil {
   102  		return fmt.Errorf("could not insert transfer into database: %w", err)
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  func (t *Transfers) UpsertFees(ctx context.Context, tf *entities.TransferFees) error {
   109  	defer metrics.StartSQLQuery("Transfers", "UpsertFees")()
   110  	query := `INSERT INTO  transfer_fees(
   111  				transfer_id,
   112  				amount,
   113  				epoch_seq,
   114  				vega_time,
   115  				discount_applied
   116  			) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (vega_time, transfer_id) DO NOTHING;` // conflicts may occur on checkpoint restore.
   117  	if _, err := t.Exec(ctx, query, tf.TransferID, tf.Amount, tf.EpochSeq, tf.VegaTime, tf.DiscountApplied); err != nil {
   118  		return err
   119  	}
   120  	return nil
   121  }
   122  
   123  func (t *Transfers) GetTransfersToOrFromParty(ctx context.Context, pagination entities.CursorPagination, filters ListTransfersFilters, partyID entities.PartyID) ([]entities.TransferDetails, entities.PageInfo, error) {
   124  	defer metrics.StartSQLQuery("Transfers", "GetTransfersToOrFromParty")()
   125  
   126  	where := []string{
   127  		"(from_account_id in (select id from accounts where accounts.party_id=$1) or to_account_id in (select id from accounts where accounts.party_id=$1))",
   128  	}
   129  
   130  	transfers, pageInfo, err := t.getCurrentTransfers(ctx, pagination, filters, where, []any{partyID})
   131  	if err != nil {
   132  		return nil, entities.PageInfo{}, fmt.Errorf("could not get transfers to or from party: %w", err)
   133  	}
   134  
   135  	details, err := t.getTransferDetails(ctx, transfers)
   136  	if err != nil {
   137  		return nil, entities.PageInfo{}, err
   138  	}
   139  
   140  	return details, pageInfo, nil
   141  }
   142  
   143  func (t *Transfers) GetTransfersFromParty(ctx context.Context, pagination entities.CursorPagination, filters ListTransfersFilters, partyID entities.PartyID) ([]entities.TransferDetails, entities.PageInfo, error) {
   144  	defer metrics.StartSQLQuery("Transfers", "GetTransfersFromParty")()
   145  
   146  	where := []string{
   147  		"from_account_id in (select id from accounts where accounts.party_id=$1)",
   148  	}
   149  
   150  	transfers, pageInfo, err := t.getCurrentTransfers(ctx, pagination, filters, where, []any{partyID})
   151  	if err != nil {
   152  		return nil, entities.PageInfo{}, fmt.Errorf("could not get transfers from party: %w", err)
   153  	}
   154  	details, err := t.getTransferDetails(ctx, transfers)
   155  	if err != nil {
   156  		return nil, entities.PageInfo{}, err
   157  	}
   158  
   159  	return details, pageInfo, nil
   160  }
   161  
   162  func (t *Transfers) GetTransfersToParty(ctx context.Context, pagination entities.CursorPagination, filters ListTransfersFilters, partyID entities.PartyID) ([]entities.TransferDetails, entities.PageInfo, error) {
   163  	defer metrics.StartSQLQuery("Transfers", "GetTransfersToParty")()
   164  
   165  	where := []string{
   166  		"to_account_id in (select id from accounts where accounts.party_id=$1)",
   167  	}
   168  
   169  	transfers, pageInfo, err := t.getCurrentTransfers(ctx, pagination, filters, where, []any{partyID})
   170  	if err != nil {
   171  		return nil, entities.PageInfo{}, fmt.Errorf("could not get transfers to party: %w", err)
   172  	}
   173  
   174  	details, err := t.getTransferDetails(ctx, transfers)
   175  	if err != nil {
   176  		return nil, entities.PageInfo{}, err
   177  	}
   178  
   179  	return details, pageInfo, nil
   180  }
   181  
   182  func (t *Transfers) GetAll(ctx context.Context, pagination entities.CursorPagination, filters ListTransfersFilters) ([]entities.TransferDetails, entities.PageInfo, error) {
   183  	defer metrics.StartSQLQuery("Transfers", "GetAll")()
   184  
   185  	transfers, pageInfo, err := t.getCurrentTransfers(ctx, pagination, filters, nil, nil)
   186  	if err != nil {
   187  		return nil, entities.PageInfo{}, err
   188  	}
   189  
   190  	details, err := t.getTransferDetails(ctx, transfers)
   191  	if err != nil {
   192  		return nil, entities.PageInfo{}, err
   193  	}
   194  	return details, pageInfo, nil
   195  }
   196  
   197  func (t *Transfers) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Transfer, error) {
   198  	defer metrics.StartSQLQuery("Transfers", "GetByTxHash")()
   199  
   200  	var transfers []entities.Transfer
   201  	query := "SELECT * FROM transfers WHERE tx_hash = $1 ORDER BY id"
   202  
   203  	if err := pgxscan.Select(ctx, t.ConnectionSource, &transfers, query, txHash); err != nil {
   204  		return nil, fmt.Errorf("could not get transfers by transaction hash: %w", err)
   205  	}
   206  	return transfers, nil
   207  }
   208  
   209  func (t *Transfers) GetByID(ctx context.Context, id string) (entities.TransferDetails, error) {
   210  	var tr entities.Transfer
   211  	query := `SELECT * FROM transfers_current WHERE id=$1`
   212  
   213  	if err := pgxscan.Get(ctx, t.ConnectionSource, &tr, query, entities.TransferID(id)); err != nil {
   214  		return entities.TransferDetails{}, t.wrapE(err)
   215  	}
   216  
   217  	details, err := t.getTransferDetails(ctx, []entities.Transfer{tr})
   218  	if err != nil || len(details) == 0 {
   219  		return entities.TransferDetails{}, err
   220  	}
   221  
   222  	return details[0], nil
   223  }
   224  
   225  func (t *Transfers) GetAllRewards(ctx context.Context, pagination entities.CursorPagination, filters ListTransfersFilters) ([]entities.TransferDetails, entities.PageInfo, error) {
   226  	defer metrics.StartSQLQuery("Transfers", "GetAllRewards")()
   227  
   228  	args := []any{entities.Recurring, entities.GovernanceRecurring}
   229  
   230  	transfers, pageInfo, err := t.getRecurringTransfers(ctx, pagination, filters, []string{}, args)
   231  	if err != nil {
   232  		return nil, entities.PageInfo{}, fmt.Errorf("could not get recurring transfers: %w", err)
   233  	}
   234  
   235  	details, err := t.getTransferDetails(ctx, transfers)
   236  	if err != nil {
   237  		return nil, entities.PageInfo{}, err
   238  	}
   239  
   240  	return details, pageInfo, nil
   241  }
   242  
   243  func (t *Transfers) GetRewardTransfersFromParty(ctx context.Context, pagination entities.CursorPagination, filters ListTransfersFilters, partyID entities.PartyID) ([]entities.TransferDetails, entities.PageInfo, error) {
   244  	defer metrics.StartSQLQuery("Transfers", "GetRewardTransfersFromParty")()
   245  
   246  	where := []string{
   247  		"from_account_id IN (SELECT id FROM accounts WHERE accounts.party_id = $3)",
   248  	}
   249  
   250  	args := []any{entities.Recurring, entities.GovernanceRecurring, partyID}
   251  
   252  	transfers, pageInfo, err := t.getRecurringTransfers(ctx, pagination, filters, where, args)
   253  	if err != nil {
   254  		return nil, entities.PageInfo{}, fmt.Errorf("could not get recurring transfers: %w", err)
   255  	}
   256  
   257  	details, err := t.getTransferDetails(ctx, transfers)
   258  	if err != nil {
   259  		return nil, entities.PageInfo{}, err
   260  	}
   261  
   262  	return details, pageInfo, nil
   263  }
   264  
   265  func (t *Transfers) UpsertFeesDiscount(ctx context.Context, tfd *entities.TransferFeesDiscount) error {
   266  	defer metrics.StartSQLQuery("Transfers", "UpsertFeesDiscount")()
   267  	query := `INSERT INTO transfer_fees_discount(
   268  				party_id,
   269  				asset_id,
   270  				amount,
   271  				epoch_seq,
   272  				vega_time
   273  			) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (vega_time, party_id, asset_id) DO NOTHING ;` // conflicts may occur on checkpoint restore.
   274  	if _, err := t.Exec(ctx, query, tfd.PartyID, tfd.AssetID, tfd.Amount, tfd.EpochSeq, tfd.VegaTime); err != nil {
   275  		return err
   276  	}
   277  	return nil
   278  }
   279  
   280  func (t *Transfers) GetCurrentTransferFeeDiscount(
   281  	ctx context.Context,
   282  	partyID entities.PartyID,
   283  	assetID entities.AssetID,
   284  ) (*entities.TransferFeesDiscount, error) {
   285  	defer metrics.StartSQLQuery("Transfers", "GetCurrentTransferFeeDiscount")()
   286  
   287  	var tfd entities.TransferFeesDiscount
   288  	query := `SELECT * FROM transfer_fees_discount
   289  		WHERE party_id = $1 AND asset_id = $2
   290  		ORDER BY vega_time DESC LIMIT 1`
   291  
   292  	if err := pgxscan.Get(ctx, t.ConnectionSource, &tfd, query, partyID, assetID); err != nil {
   293  		return &entities.TransferFeesDiscount{}, t.wrapE(err)
   294  	}
   295  
   296  	return &tfd, nil
   297  }
   298  
   299  func (t *Transfers) getCurrentTransfers(ctx context.Context, pagination entities.CursorPagination, filters ListTransfersFilters, where []string, args []any) ([]entities.Transfer, entities.PageInfo, error) {
   300  	whereStr, args := t.buildWhereClause(filters, where, args)
   301  	query := `WITH current_transfers as (
   302  	SELECT tc.*, af.type as from_account_type, at.type as to_account_type
   303  	FROM transfers_current tc
   304  	JOIN accounts af on tc. from_account_id = af.id
   305  	JOIN accounts at on tc.to_account_id = at.id
   306  )
   307  	SELECT id, tx_hash, vega_time, from_account_id, to_account_id, asset_id, amount, reference, status, transfer_type, deliver_on,
   308  		start_epoch, end_epoch, factor, dispatch_strategy, reason, game_id
   309  	FROM current_transfers
   310  ` + whereStr
   311  
   312  	return t.selectTransfers(ctx, pagination, query, args)
   313  }
   314  
   315  func (t *Transfers) getRecurringTransfers(ctx context.Context, pagination entities.CursorPagination, filters ListTransfersFilters, where []string, args []any) ([]entities.Transfer, entities.PageInfo, error) {
   316  	whereStr, args := t.buildWhereClause(filters, where, args)
   317  
   318  	query := `WITH recurring_transfers AS (
   319  		SELECT tc.*, af.type as from_account_type, at.type as to_account_type FROM transfers_current as tc
   320  		JOIN accounts as at on tc.to_account_id = at.id
   321  		JOIN accounts as af on tc.from_account_id = af.id
   322  		WHERE transfer_type IN ($1, $2)
   323  		AND at.type = 12 OR (jsonb_typeof(tc.dispatch_strategy) != 'null' AND dispatch_strategy->>'metric' <> '0')
   324  )
   325  SELECT id, tx_hash, vega_time, from_account_id, to_account_id, asset_id, amount, reference, status, transfer_type, deliver_on,
   326  	start_epoch, end_epoch, factor, dispatch_strategy, reason, game_id
   327  FROM recurring_transfers
   328  ` + whereStr
   329  
   330  	return t.selectTransfers(ctx, pagination, query, args)
   331  }
   332  
   333  func (t *Transfers) buildWhereClause(filters ListTransfersFilters, where []string, args []any) (string, []any) {
   334  	if filters.Scope != nil {
   335  		where = append(where, "jsonb_typeof(dispatch_strategy) != 'null'")
   336  		switch *filters.Scope {
   337  		case entities.TransferScopeIndividual:
   338  			where = append(where, "dispatch_strategy ? 'individual_scope'")
   339  		case entities.TransferScopeTeam:
   340  			where = append(where, "dispatch_strategy ? 'team_scope'")
   341  		}
   342  	}
   343  
   344  	if filters.Status != nil {
   345  		where = append(where, fmt.Sprintf("status = %s", nextBindVar(&args, *filters.Status)))
   346  	}
   347  
   348  	if filters.FromEpoch != nil {
   349  		where = append(where, fmt.Sprintf("(start_epoch >= %s or end_epoch >= %s)",
   350  			nextBindVar(&args, *filters.FromEpoch),
   351  			nextBindVar(&args, *filters.FromEpoch),
   352  		))
   353  	}
   354  
   355  	if filters.ToEpoch != nil {
   356  		where = append(where, fmt.Sprintf("(start_epoch <= %s or end_epoch <= %s)",
   357  			nextBindVar(&args, *filters.ToEpoch),
   358  			nextBindVar(&args, *filters.ToEpoch),
   359  		))
   360  	}
   361  
   362  	if filters.GameID != nil {
   363  		where = append(where, fmt.Sprintf("game_id = %s", nextBindVar(&args, *filters.GameID)))
   364  	}
   365  
   366  	if filters.FromAccountType != nil {
   367  		where = append(where, fmt.Sprintf("from_account_type = %s", nextBindVar(&args, *filters.FromAccountType)))
   368  	}
   369  
   370  	if filters.ToAccountType != nil {
   371  		where = append(where, fmt.Sprintf("to_account_type = %s", nextBindVar(&args, *filters.ToAccountType)))
   372  	}
   373  
   374  	whereStr := ""
   375  	if len(where) > 0 {
   376  		whereStr = "where " + strings.Join(where, " and ")
   377  	} else {
   378  		whereStr = "where 1=1" // required because there is a where clause in the subquery and without a where clause in the main query, the pagination will break
   379  	}
   380  	return whereStr, args
   381  }
   382  
   383  func (t *Transfers) selectTransfers(ctx context.Context, pagination entities.CursorPagination, query string, args []any) ([]entities.Transfer, entities.PageInfo, error) {
   384  	query, args, err := PaginateQuery[entities.TransferCursor](query, args, transfersOrdering, pagination)
   385  	if err != nil {
   386  		return nil, entities.PageInfo{}, err
   387  	}
   388  
   389  	var transfers []entities.Transfer
   390  	err = pgxscan.Select(ctx, t.ConnectionSource, &transfers, query, args...)
   391  	if err != nil {
   392  		return nil, entities.PageInfo{}, fmt.Errorf("could not get transfers: %w", err)
   393  	}
   394  
   395  	transfers, pageInfo := entities.PageEntities[*v2.TransferEdge](transfers, pagination)
   396  
   397  	return transfers, pageInfo, nil
   398  }
   399  
   400  func (t *Transfers) getTransferDetails(ctx context.Context, transfers []entities.Transfer) ([]entities.TransferDetails, error) {
   401  	details := make([]entities.TransferDetails, 0, len(transfers))
   402  	query := `SELECT * FROM transfer_fees WHERE transfer_id = $1`
   403  	for _, tr := range transfers {
   404  		detail := entities.TransferDetails{
   405  			Transfer: tr,
   406  		}
   407  		rows, err := t.Query(ctx, query, tr.ID)
   408  		if errors.Is(err, pgx.ErrNoRows) {
   409  			details = append(details, detail)
   410  			if rows != nil {
   411  				rows.Close()
   412  			}
   413  			continue
   414  		}
   415  		if err != nil {
   416  			return nil, t.wrapE(err)
   417  		}
   418  		if err := pgxscan.ScanAll(&detail.Fees, rows); err != nil {
   419  			return nil, t.wrapE(err)
   420  		}
   421  		rows.Close()
   422  		details = append(details, detail)
   423  	}
   424  	return details, nil
   425  }