code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/stake_linking.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "context" 20 "errors" 21 "fmt" 22 23 "code.vegaprotocol.io/vega/datanode/entities" 24 "code.vegaprotocol.io/vega/datanode/metrics" 25 "code.vegaprotocol.io/vega/libs/num" 26 "code.vegaprotocol.io/vega/logging" 27 v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2" 28 29 "github.com/georgysavva/scany/pgxscan" 30 "github.com/shopspring/decimal" 31 ) 32 33 type StakeLinking struct { 34 *ConnectionSource 35 } 36 37 const ( 38 sqlStakeLinkingColumns = `id, stake_linking_type, ethereum_timestamp, party_id, amount, stake_linking_status, finalized_at, 39 foreign_tx_hash, foreign_block_height, foreign_block_time, log_index, ethereum_address, tx_hash, vega_time` 40 ) 41 42 var stakeLinkingOrdering = TableOrdering{ 43 ColumnOrdering{Name: "vega_time", Sorting: ASC}, 44 ColumnOrdering{Name: "id", Sorting: ASC}, 45 } 46 47 func NewStakeLinking(connectionSource *ConnectionSource) *StakeLinking { 48 return &StakeLinking{ 49 ConnectionSource: connectionSource, 50 } 51 } 52 53 func (s *StakeLinking) Upsert(ctx context.Context, stake *entities.StakeLinking) error { 54 defer metrics.StartSQLQuery("StakeLinking", "Upsert")() 55 query := fmt.Sprintf(`insert into stake_linking (%s) 56 values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) 57 on conflict (id, vega_time) do update 58 set 59 stake_linking_type=EXCLUDED.stake_linking_type, 60 ethereum_timestamp=EXCLUDED.ethereum_timestamp, 61 party_id=EXCLUDED.party_id, 62 amount=EXCLUDED.amount, 63 stake_linking_status=EXCLUDED.stake_linking_status, 64 finalized_at=EXCLUDED.finalized_at, 65 foreign_tx_hash=EXCLUDED.foreign_tx_hash, 66 foreign_block_height=EXCLUDED.foreign_block_height, 67 foreign_block_time=EXCLUDED.foreign_block_time, 68 log_index=EXCLUDED.log_index, 69 ethereum_address=EXCLUDED.ethereum_address, 70 tx_hash=EXCLUDED.tx_hash 71 `, sqlStakeLinkingColumns) 72 73 if _, err := s.Exec(ctx, query, stake.ID, stake.StakeLinkingType, stake.EthereumTimestamp, stake.PartyID, stake.Amount, 74 stake.StakeLinkingStatus, stake.FinalizedAt, stake.ForeignTxHash, stake.ForeignBlockHeight, stake.ForeignBlockTime, stake.LogIndex, 75 stake.EthereumAddress, stake.TxHash, stake.VegaTime); err != nil { 76 return err 77 } 78 79 return nil 80 } 81 82 func (s *StakeLinking) GetStake(ctx context.Context, partyID entities.PartyID, 83 p entities.Pagination, 84 ) (*num.Uint, []entities.StakeLinking, entities.PageInfo, error) { 85 switch pagination := p.(type) { 86 case entities.CursorPagination: 87 return s.getStakeWithCursorPagination(ctx, partyID, pagination) 88 default: 89 panic("unsupported pagination") 90 } 91 } 92 93 func (s *StakeLinking) getStakeWithCursorPagination(ctx context.Context, partyID entities.PartyID, pagination entities.CursorPagination) ( 94 *num.Uint, []entities.StakeLinking, entities.PageInfo, error, 95 ) { 96 var ( 97 links []entities.StakeLinking 98 pageInfo entities.PageInfo 99 err error 100 ) 101 // get the links from the database 102 query, bindVars := getStakeLinkingQuery(partyID) 103 104 query, bindVars, err = PaginateQuery[entities.StakeLinkingCursor](query, bindVars, stakeLinkingOrdering, pagination) 105 if err != nil { 106 return nil, nil, pageInfo, err 107 } 108 defer metrics.StartSQLQuery("StakeLinking", "GetStake")() 109 110 var bal *num.Uint 111 112 err = pgxscan.Select(ctx, s.ConnectionSource, &links, query, bindVars...) 113 if err != nil { 114 s.log.Errorf("could not retrieve links", logging.Error(err)) 115 return bal, nil, pageInfo, err 116 } 117 118 links, pageInfo = entities.PageEntities[*v2.StakeLinkingEdge](links, pagination) 119 120 bal, err = s.calculateBalance(ctx, partyID) 121 if err != nil { 122 s.log.Errorf("cannot calculate balance", logging.Error(err)) 123 return num.UintZero(), nil, pageInfo, err 124 } 125 return bal, links, pageInfo, nil 126 } 127 128 func getStakeLinkingQuery(partyID entities.PartyID) (string, []interface{}) { 129 var bindVars []interface{} 130 131 query := fmt.Sprintf(`select %s 132 from stake_linking_current 133 where party_id=%s`, sqlStakeLinkingColumns, nextBindVar(&bindVars, partyID)) 134 135 return query, bindVars 136 } 137 138 func (s *StakeLinking) calculateBalance(ctx context.Context, partyID entities.PartyID) (*num.Uint, error) { 139 bal := num.UintZero() 140 var bindVars []interface{} 141 142 query := fmt.Sprintf(`select coalesce(sum(CASE stake_linking_type 143 WHEN 'TYPE_LINK' THEN amount 144 WHEN 'TYPE_UNLINK' THEN -amount 145 ELSE 0 146 END), 0) 147 FROM stake_linking_current 148 WHERE party_id = %s 149 AND stake_linking_status = 'STATUS_ACCEPTED' 150 `, nextBindVar(&bindVars, partyID)) 151 152 var currentBalance decimal.Decimal 153 defer metrics.StartSQLQuery("StakeLinking", "calculateBalance")() 154 if err := pgxscan.Get(ctx, s.ConnectionSource, ¤tBalance, query, bindVars...); err != nil { 155 return bal, err 156 } 157 158 if currentBalance.LessThan(decimal.Zero) { 159 return bal, errors.New("unlinked amount is greater than linked amount, potential missed events") 160 } 161 162 var overflowed bool 163 if bal, overflowed = num.UintFromDecimal(currentBalance); overflowed { 164 return num.UintZero(), fmt.Errorf("current balance is invalid: %s", currentBalance.String()) 165 } 166 167 return bal, nil 168 }