code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/accounts.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"crypto/sha256"
    21  	"encoding/hex"
    22  	"fmt"
    23  	"sync"
    24  
    25  	"code.vegaprotocol.io/vega/datanode/entities"
    26  	"code.vegaprotocol.io/vega/datanode/metrics"
    27  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    28  
    29  	"github.com/georgysavva/scany/pgxscan"
    30  	"github.com/jackc/pgx/v4"
    31  )
    32  
    33  var accountOrdering = TableOrdering{
    34  	ColumnOrdering{Name: "account_id", Sorting: ASC},
    35  }
    36  
    37  type Accounts struct {
    38  	*ConnectionSource
    39  	idToAccount map[entities.AccountID]entities.Account
    40  	cacheLock   sync.RWMutex
    41  }
    42  
    43  func NewAccounts(connectionSource *ConnectionSource) *Accounts {
    44  	a := &Accounts{
    45  		ConnectionSource: connectionSource,
    46  		idToAccount:      make(map[entities.AccountID]entities.Account),
    47  	}
    48  	return a
    49  }
    50  
    51  // Add inserts a row and updates supplied struct with autogenerated ID.
    52  func (as *Accounts) Add(ctx context.Context, a *entities.Account) error {
    53  	defer metrics.StartSQLQuery("Accounts", "Add")()
    54  
    55  	err := as.QueryRow(ctx,
    56  		`INSERT INTO accounts(id, party_id, asset_id, market_id, type, tx_hash, vega_time)
    57  		 VALUES ($1, $2, $3, $4, $5, $6, $7)
    58  		 RETURNING id`,
    59  		DeterministicIDFromAccount(a),
    60  		a.PartyID,
    61  		a.AssetID,
    62  		a.MarketID,
    63  		a.Type,
    64  		a.TxHash,
    65  		a.VegaTime).Scan(&a.ID)
    66  	return err
    67  }
    68  
    69  func (as *Accounts) GetByRawID(ctx context.Context, accountID string) (entities.Account, error) {
    70  	return as.GetByID(ctx, entities.AccountID(accountID))
    71  }
    72  
    73  func (as *Accounts) GetByID(ctx context.Context, accountID entities.AccountID) (entities.Account, error) {
    74  	if account, ok := as.getAccountFromCache(accountID); ok {
    75  		return account, nil
    76  	}
    77  
    78  	as.cacheLock.Lock()
    79  	defer as.cacheLock.Unlock()
    80  
    81  	// It's possible that in-between releasing the read lock and obtaining the write lock that the account has been
    82  	// added to cache, so we need to check here and return the cached account if that's the case.
    83  	if account, ok := as.idToAccount[accountID]; ok {
    84  		return account, nil
    85  	}
    86  
    87  	a := entities.Account{}
    88  	defer metrics.StartSQLQuery("Accounts", "GetByID")()
    89  
    90  	if err := pgxscan.Get(ctx, as.ConnectionSource, &a,
    91  		`SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time
    92  		 FROM accounts WHERE id=$1`,
    93  		accountID,
    94  	); err != nil {
    95  		return a, as.wrapE(err)
    96  	}
    97  
    98  	as.idToAccount[accountID] = a
    99  	return a, nil
   100  }
   101  
   102  func (as *Accounts) GetAll(ctx context.Context) ([]entities.Account, error) {
   103  	accounts := []entities.Account{}
   104  	defer metrics.StartSQLQuery("Accounts", "GetAll")()
   105  	err := pgxscan.Select(ctx, as.ConnectionSource, &accounts, `
   106  		SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time
   107  		FROM accounts`)
   108  	return accounts, err
   109  }
   110  
   111  func (as *Accounts) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Account, error) {
   112  	accounts := []entities.Account{}
   113  	defer metrics.StartSQLQuery("Accounts", "GetByTxHash")()
   114  
   115  	err := pgxscan.Select(
   116  		ctx,
   117  		as.ConnectionSource,
   118  		&accounts,
   119  		`SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time FROM accounts WHERE tx_hash=$1`,
   120  		txHash,
   121  	)
   122  	return accounts, err
   123  }
   124  
   125  // Obtain will either fetch or create an account in the database.
   126  // If an account with matching party/asset/market/type does not exist in the database, create one.
   127  // If an account already exists, fetch that one.
   128  // In either case, update the entities.Account object passed with an ID from the database.
   129  func (as *Accounts) Obtain(ctx context.Context, a *entities.Account) error {
   130  	accountID := DeterministicIDFromAccount(a)
   131  	if account, ok := as.getAccountFromCache(accountID); ok {
   132  		a.ID = account.ID
   133  		a.VegaTime = account.VegaTime
   134  		a.TxHash = account.TxHash
   135  		return nil
   136  	}
   137  
   138  	as.cacheLock.Lock()
   139  	defer as.cacheLock.Unlock()
   140  
   141  	// It's possible that in-between releasing the cache read lock and obtaining the cache write lock that the account has been
   142  	// added to the cache, so we need to check here and return the cached account if that's the case.
   143  	if account, ok := as.idToAccount[accountID]; ok {
   144  		a.ID = account.ID
   145  		a.VegaTime = account.VegaTime
   146  		a.TxHash = account.TxHash
   147  		return nil
   148  	}
   149  
   150  	insertQuery := `INSERT INTO accounts(id, party_id, asset_id, market_id, type, tx_hash, vega_time)
   151                             VALUES ($1, $2, $3, $4, $5, $6, $7)
   152                             ON CONFLICT (party_id, asset_id, market_id, type) DO NOTHING`
   153  
   154  	selectQuery := `SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time
   155  	                FROM accounts
   156  	                WHERE party_id=$1 AND asset_id=$2 AND market_id=$3 AND type=$4`
   157  
   158  	batch := pgx.Batch{}
   159  
   160  	batch.Queue(insertQuery, accountID, a.PartyID, a.AssetID, a.MarketID, a.Type, a.TxHash, a.VegaTime)
   161  	batch.Queue(selectQuery, a.PartyID, a.AssetID, a.MarketID, a.Type)
   162  	defer metrics.StartSQLQuery("Accounts", "Obtain")()
   163  	results := as.SendBatch(ctx, &batch)
   164  	defer results.Close()
   165  
   166  	if _, err := results.Exec(); err != nil {
   167  		return fmt.Errorf("inserting account: %w", err)
   168  	}
   169  
   170  	rows, err := results.Query()
   171  	if err != nil {
   172  		return fmt.Errorf("querying accounts: %w", err)
   173  	}
   174  
   175  	if err = pgxscan.ScanOne(a, rows); err != nil {
   176  		return fmt.Errorf("scanning account: %w", err)
   177  	}
   178  
   179  	as.idToAccount[accountID] = *a
   180  	return nil
   181  }
   182  
   183  func (as *Accounts) getAccountFromCache(id entities.AccountID) (entities.Account, bool) {
   184  	as.cacheLock.RLock()
   185  	defer as.cacheLock.RUnlock()
   186  
   187  	if account, ok := as.idToAccount[id]; ok {
   188  		return account, true
   189  	}
   190  	return entities.Account{}, false
   191  }
   192  
   193  func DeterministicIDFromAccount(a *entities.Account) entities.AccountID {
   194  	idAsBytes := sha256.Sum256([]byte(a.AssetID.String() + a.PartyID.String() + a.MarketID.String() + a.Type.String()))
   195  	accountID := hex.EncodeToString(idAsBytes[:])
   196  	return entities.AccountID(accountID)
   197  }
   198  
   199  func (as *Accounts) Query(ctx context.Context, filter entities.AccountFilter) ([]entities.Account, error) {
   200  	query, args, err := filterAccountsQuery(filter, true)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	accs := []entities.Account{}
   205  
   206  	defer metrics.StartSQLQuery("Accounts", "Query")()
   207  	rows, err := as.ConnectionSource.Query(ctx, query, args...)
   208  	if err != nil {
   209  		return accs, fmt.Errorf("querying accounts: %w", err)
   210  	}
   211  	defer rows.Close()
   212  
   213  	if err = pgxscan.ScanAll(&accs, rows); err != nil {
   214  		return accs, fmt.Errorf("scanning account: %w", err)
   215  	}
   216  
   217  	return accs, nil
   218  }
   219  
   220  func (as *Accounts) QueryBalances(ctx context.Context,
   221  	filter entities.AccountFilter,
   222  	pagination entities.CursorPagination,
   223  ) ([]entities.AccountBalance, entities.PageInfo, error) {
   224  	query, args, err := filterAccountBalancesQuery(filter)
   225  	if err != nil {
   226  		return nil, entities.PageInfo{}, fmt.Errorf("querying account balances: %w", err)
   227  	}
   228  
   229  	query, args, err = PaginateQuery[entities.AccountCursor](query, args, accountOrdering, pagination)
   230  	if err != nil {
   231  		return nil, entities.PageInfo{}, fmt.Errorf("querying account balances: %w", err)
   232  	}
   233  
   234  	defer metrics.StartSQLQuery("Accounts", "QueryBalances")()
   235  
   236  	accountBalances := make([]entities.AccountBalance, 0)
   237  	rows, err := as.ConnectionSource.Query(ctx, query, args...)
   238  	if err != nil {
   239  		return accountBalances, entities.PageInfo{}, fmt.Errorf("querying account balances: %w", err)
   240  	}
   241  	defer rows.Close()
   242  
   243  	if err = pgxscan.ScanAll(&accountBalances, rows); err != nil {
   244  		return accountBalances, entities.PageInfo{}, fmt.Errorf("parsing account balances: %w", err)
   245  	}
   246  
   247  	pagedAccountBalances, pageInfo := entities.PageEntities[*v2.AccountEdge](accountBalances, pagination)
   248  	return pagedAccountBalances, pageInfo, nil
   249  }
   250  
   251  func (as *Accounts) GetBalancesByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.AccountBalance, error) {
   252  	balances := []entities.AccountBalance{}
   253  	defer metrics.StartSQLQuery("Accounts", "GetBalancesByTxHash")()
   254  
   255  	err := pgxscan.Select(
   256  		ctx,
   257  		as.ConnectionSource,
   258  		&balances,
   259  		fmt.Sprintf("%s WHERE balances.tx_hash=$1", accountBalancesQuery()),
   260  		txHash,
   261  	)
   262  	return balances, err
   263  }