code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/stake_linking.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  	"fmt"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/datanode/metrics"
    25  	"code.vegaprotocol.io/vega/libs/num"
    26  	"code.vegaprotocol.io/vega/logging"
    27  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    28  
    29  	"github.com/georgysavva/scany/pgxscan"
    30  	"github.com/shopspring/decimal"
    31  )
    32  
    33  type StakeLinking struct {
    34  	*ConnectionSource
    35  }
    36  
    37  const (
    38  	sqlStakeLinkingColumns = `id, stake_linking_type, ethereum_timestamp, party_id, amount, stake_linking_status, finalized_at,
    39  foreign_tx_hash, foreign_block_height, foreign_block_time, log_index, ethereum_address, tx_hash, vega_time`
    40  )
    41  
    42  var stakeLinkingOrdering = TableOrdering{
    43  	ColumnOrdering{Name: "vega_time", Sorting: ASC},
    44  	ColumnOrdering{Name: "id", Sorting: ASC},
    45  }
    46  
    47  func NewStakeLinking(connectionSource *ConnectionSource) *StakeLinking {
    48  	return &StakeLinking{
    49  		ConnectionSource: connectionSource,
    50  	}
    51  }
    52  
    53  func (s *StakeLinking) Upsert(ctx context.Context, stake *entities.StakeLinking) error {
    54  	defer metrics.StartSQLQuery("StakeLinking", "Upsert")()
    55  	query := fmt.Sprintf(`insert into stake_linking (%s)
    56  values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
    57  on conflict (id, vega_time) do update
    58  set
    59  	stake_linking_type=EXCLUDED.stake_linking_type,
    60  	ethereum_timestamp=EXCLUDED.ethereum_timestamp,
    61  	party_id=EXCLUDED.party_id,
    62  	amount=EXCLUDED.amount,
    63  	stake_linking_status=EXCLUDED.stake_linking_status,
    64  	finalized_at=EXCLUDED.finalized_at,
    65  	foreign_tx_hash=EXCLUDED.foreign_tx_hash,
    66  	foreign_block_height=EXCLUDED.foreign_block_height,
    67  	foreign_block_time=EXCLUDED.foreign_block_time,
    68  	log_index=EXCLUDED.log_index,
    69  	ethereum_address=EXCLUDED.ethereum_address,
    70  	tx_hash=EXCLUDED.tx_hash
    71  	`, sqlStakeLinkingColumns)
    72  
    73  	if _, err := s.Exec(ctx, query, stake.ID, stake.StakeLinkingType, stake.EthereumTimestamp, stake.PartyID, stake.Amount,
    74  		stake.StakeLinkingStatus, stake.FinalizedAt, stake.ForeignTxHash, stake.ForeignBlockHeight, stake.ForeignBlockTime, stake.LogIndex,
    75  		stake.EthereumAddress, stake.TxHash, stake.VegaTime); err != nil {
    76  		return err
    77  	}
    78  
    79  	return nil
    80  }
    81  
    82  func (s *StakeLinking) GetStake(ctx context.Context, partyID entities.PartyID,
    83  	p entities.Pagination,
    84  ) (*num.Uint, []entities.StakeLinking, entities.PageInfo, error) {
    85  	switch pagination := p.(type) {
    86  	case entities.CursorPagination:
    87  		return s.getStakeWithCursorPagination(ctx, partyID, pagination)
    88  	default:
    89  		panic("unsupported pagination")
    90  	}
    91  }
    92  
    93  func (s *StakeLinking) getStakeWithCursorPagination(ctx context.Context, partyID entities.PartyID, pagination entities.CursorPagination) (
    94  	*num.Uint, []entities.StakeLinking, entities.PageInfo, error,
    95  ) {
    96  	var (
    97  		links    []entities.StakeLinking
    98  		pageInfo entities.PageInfo
    99  		err      error
   100  	)
   101  	// get the links from the database
   102  	query, bindVars := getStakeLinkingQuery(partyID)
   103  
   104  	query, bindVars, err = PaginateQuery[entities.StakeLinkingCursor](query, bindVars, stakeLinkingOrdering, pagination)
   105  	if err != nil {
   106  		return nil, nil, pageInfo, err
   107  	}
   108  	defer metrics.StartSQLQuery("StakeLinking", "GetStake")()
   109  
   110  	var bal *num.Uint
   111  
   112  	err = pgxscan.Select(ctx, s.ConnectionSource, &links, query, bindVars...)
   113  	if err != nil {
   114  		s.log.Errorf("could not retrieve links", logging.Error(err))
   115  		return bal, nil, pageInfo, err
   116  	}
   117  
   118  	links, pageInfo = entities.PageEntities[*v2.StakeLinkingEdge](links, pagination)
   119  
   120  	bal, err = s.calculateBalance(ctx, partyID)
   121  	if err != nil {
   122  		s.log.Errorf("cannot calculate balance", logging.Error(err))
   123  		return num.UintZero(), nil, pageInfo, err
   124  	}
   125  	return bal, links, pageInfo, nil
   126  }
   127  
   128  func getStakeLinkingQuery(partyID entities.PartyID) (string, []interface{}) {
   129  	var bindVars []interface{}
   130  
   131  	query := fmt.Sprintf(`select %s
   132  from stake_linking_current
   133  where party_id=%s`, sqlStakeLinkingColumns, nextBindVar(&bindVars, partyID))
   134  
   135  	return query, bindVars
   136  }
   137  
   138  func (s *StakeLinking) calculateBalance(ctx context.Context, partyID entities.PartyID) (*num.Uint, error) {
   139  	bal := num.UintZero()
   140  	var bindVars []interface{}
   141  
   142  	query := fmt.Sprintf(`select coalesce(sum(CASE stake_linking_type
   143      WHEN 'TYPE_LINK' THEN amount
   144      WHEN 'TYPE_UNLINK' THEN -amount
   145      ELSE 0
   146      END), 0)
   147      FROM stake_linking_current
   148  WHERE party_id = %s
   149    AND stake_linking_status = 'STATUS_ACCEPTED'
   150  `, nextBindVar(&bindVars, partyID))
   151  
   152  	var currentBalance decimal.Decimal
   153  	defer metrics.StartSQLQuery("StakeLinking", "calculateBalance")()
   154  	if err := pgxscan.Get(ctx, s.ConnectionSource, &currentBalance, query, bindVars...); err != nil {
   155  		return bal, err
   156  	}
   157  
   158  	if currentBalance.LessThan(decimal.Zero) {
   159  		return bal, errors.New("unlinked amount is greater than linked amount, potential missed events")
   160  	}
   161  
   162  	var overflowed bool
   163  	if bal, overflowed = num.UintFromDecimal(currentBalance); overflowed {
   164  		return num.UintZero(), fmt.Errorf("current balance is invalid: %s", currentBalance.String())
   165  	}
   166  
   167  	return bal, nil
   168  }