code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/balances.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/datanode/metrics"
    25  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    26  )
    27  
    28  type Balances struct {
    29  	*ConnectionSource
    30  	batcher MapBatcher[entities.AccountBalanceKey, entities.AccountBalance]
    31  }
    32  
    33  func NewBalances(connectionSource *ConnectionSource) *Balances {
    34  	b := &Balances{
    35  		ConnectionSource: connectionSource,
    36  		batcher: NewMapBatcher[entities.AccountBalanceKey, entities.AccountBalance](
    37  			"balances",
    38  			entities.BalanceColumns),
    39  	}
    40  	return b
    41  }
    42  
    43  func (bs *Balances) Flush(ctx context.Context) ([]entities.AccountBalance, error) {
    44  	defer metrics.StartSQLQuery("Balances", "Flush")()
    45  	return bs.batcher.Flush(ctx, bs.ConnectionSource)
    46  }
    47  
    48  // Add inserts a row to the balance table. If there's already a balance for this
    49  // (account, block time) update it to match with the one supplied.
    50  func (bs *Balances) Add(b entities.AccountBalance) error {
    51  	bs.batcher.Add(b)
    52  	return nil
    53  }
    54  
    55  func (bs *Balances) Query(ctx context.Context, filter entities.AccountFilter, dateRange entities.DateRange,
    56  	pagination entities.CursorPagination,
    57  ) (*[]entities.AggregatedBalance, entities.PageInfo, error) {
    58  	var pageInfo entities.PageInfo
    59  	accountsQ, args, err := filterAccountsQuery(filter, false)
    60  	if err != nil {
    61  		return nil, pageInfo, err
    62  	}
    63  
    64  	predicates := []string{}
    65  	if dateRange.Start != nil {
    66  		predicate := fmt.Sprintf("vega_time >= %s", nextBindVar(&args, *dateRange.Start))
    67  		predicates = append(predicates, predicate)
    68  	}
    69  
    70  	if dateRange.End != nil {
    71  		predicate := fmt.Sprintf("vega_time < %s", nextBindVar(&args, *dateRange.End))
    72  		predicates = append(predicates, predicate)
    73  	}
    74  
    75  	whereClause := ""
    76  	if len(predicates) > 0 {
    77  		whereClause = fmt.Sprintf("WHERE %s", strings.Join(predicates, " AND "))
    78  	}
    79  
    80  	query := fmt.Sprintf(`
    81      WITH a AS(%s)
    82      SELECT b.vega_time,
    83          a.asset_id,
    84          a.party_id,
    85          a.market_id,
    86          a.type,
    87          b.balance
    88      FROM balances b JOIN a ON b.account_id = a.id
    89  	%s`, accountsQ, whereClause)
    90  
    91  	ordering := TableOrdering{
    92  		ColumnOrdering{Name: "vega_time", Sorting: ASC},
    93  		ColumnOrdering{Name: "account_id", Sorting: ASC},
    94  	}
    95  
    96  	query, args, err = PaginateQuery[entities.AggregatedBalanceCursor](query, args, ordering, pagination)
    97  	if err != nil {
    98  		return nil, pageInfo, err
    99  	}
   100  
   101  	defer metrics.StartSQLQuery("Balances", "Query")()
   102  	rows, err := bs.ConnectionSource.Query(ctx, query, args...)
   103  	if err != nil {
   104  		return nil, pageInfo, fmt.Errorf("querying balances: %w", err)
   105  	}
   106  	defer rows.Close()
   107  
   108  	groupBy := []entities.AccountField{
   109  		entities.AccountFieldAssetID,
   110  		entities.AccountFieldPartyID,
   111  		entities.AccountFieldMarketID,
   112  		entities.AccountFieldType,
   113  	}
   114  
   115  	balances, err := entities.AggregatedBalanceScan(groupBy, rows)
   116  	if err != nil {
   117  		return nil, pageInfo, err
   118  	}
   119  
   120  	balances, pageInfo = entities.PageEntities[*v2.AggregatedBalanceEdge](balances, pagination)
   121  
   122  	return &balances, pageInfo, nil
   123  }