code.vegaprotocol.io/vega@v0.79.0/datanode/entities/aggregated_balance.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package entities
    17  
    18  import (
    19  	"encoding/json"
    20  	"fmt"
    21  	"time"
    22  
    23  	"code.vegaprotocol.io/vega/core/types"
    24  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    25  
    26  	"github.com/shopspring/decimal"
    27  )
    28  
    29  // AggregatedBalance represents the the summed balance of a bunch of accounts at a given
    30  // time. VegaTime and Balance will always be set. The others will be nil unless when
    31  // querying, you requested grouping by one of the corresponding fields.
    32  type AggregatedBalance struct {
    33  	VegaTime  time.Time
    34  	Balance   decimal.Decimal
    35  	AccountID *AccountID
    36  	PartyID   *PartyID
    37  	AssetID   *AssetID
    38  	MarketID  *MarketID
    39  	Type      *types.AccountType
    40  }
    41  
    42  // NewAggregatedBalanceFromValues returns a new AggregatedBalance from a list of values as returned
    43  // from pgx.rows.values().
    44  // - vegaTime is assumed to be first
    45  // - then any extra fields listed in 'fields' in order (usually as as result of grouping)
    46  // - then finally the balance itself.
    47  func AggregatedBalanceScan(fields []AccountField, rows interface {
    48  	Next() bool
    49  	Values() ([]any, error)
    50  },
    51  ) ([]AggregatedBalance, error) {
    52  	// Iterate through the result set
    53  	balances := []AggregatedBalance{}
    54  	for rows.Next() {
    55  		var ok bool
    56  		bal := AggregatedBalance{}
    57  		values, err := rows.Values()
    58  		if err != nil {
    59  			return nil, err
    60  		}
    61  
    62  		bal.VegaTime, ok = values[0].(time.Time)
    63  		if !ok {
    64  			return nil, fmt.Errorf("unable to cast to time.Time: %v", values[0])
    65  		}
    66  
    67  		for i, field := range fields {
    68  			if field == AccountFieldType {
    69  				intAccountType, ok := values[i+1].(int32)
    70  				if !ok {
    71  					return nil, fmt.Errorf("unable to cast to integer account type: %v", values[i])
    72  				}
    73  				accountType := types.AccountType(intAccountType)
    74  				bal.Type = &accountType
    75  				continue
    76  			}
    77  
    78  			idBytes, ok := values[i+1].([]byte)
    79  			if !ok {
    80  				return nil, fmt.Errorf("unable to cast to []byte: %v", values[i])
    81  			}
    82  
    83  			switch field {
    84  			case AccountFieldID:
    85  				var id AccountID
    86  				id.SetBytes(idBytes)
    87  				bal.AccountID = &id
    88  			case AccountFieldPartyID:
    89  				var id PartyID
    90  				id.SetBytes(idBytes)
    91  				bal.PartyID = &id
    92  			case AccountFieldAssetID:
    93  				var id AssetID
    94  				id.SetBytes(idBytes)
    95  				bal.AssetID = &id
    96  			case AccountFieldMarketID:
    97  				var id MarketID
    98  				id.SetBytes(idBytes)
    99  				bal.MarketID = &id
   100  			default:
   101  				return nil, fmt.Errorf("invalid field: %v", field)
   102  			}
   103  		}
   104  		lastValue := values[len(values)-1]
   105  
   106  		if bal.Balance, ok = lastValue.(decimal.Decimal); !ok {
   107  			return nil, fmt.Errorf("unable to cast to decimal %v", lastValue)
   108  		}
   109  
   110  		balances = append(balances, bal)
   111  	}
   112  
   113  	return balances, nil
   114  }
   115  
   116  func (balance *AggregatedBalance) ToProto() *v2.AggregatedBalance {
   117  	var partyID, assetID, marketID *string
   118  
   119  	if balance.PartyID != nil {
   120  		pid := balance.PartyID.String()
   121  		partyID = &pid
   122  	}
   123  
   124  	if balance.AssetID != nil {
   125  		aid := balance.AssetID.String()
   126  		assetID = &aid
   127  	}
   128  
   129  	if balance.MarketID != nil {
   130  		mid := balance.MarketID.String()
   131  		marketID = &mid
   132  	}
   133  
   134  	return &v2.AggregatedBalance{
   135  		Timestamp:   balance.VegaTime.UnixNano(),
   136  		Balance:     balance.Balance.String(),
   137  		PartyId:     partyID,
   138  		AssetId:     assetID,
   139  		MarketId:    marketID,
   140  		AccountType: balance.Type,
   141  	}
   142  }
   143  
   144  func (balance AggregatedBalance) Cursor() *Cursor {
   145  	return NewCursor(AggregatedBalanceCursor{
   146  		VegaTime:  balance.VegaTime,
   147  		AccountID: balance.AccountID,
   148  		PartyID:   balance.PartyID,
   149  		AssetID:   balance.AssetID,
   150  		MarketID:  balance.MarketID,
   151  		Type:      balance.Type,
   152  	}.String())
   153  }
   154  
   155  func (balance AggregatedBalance) ToProtoEdge(_ ...any) (*v2.AggregatedBalanceEdge, error) {
   156  	return &v2.AggregatedBalanceEdge{
   157  		Node:   balance.ToProto(),
   158  		Cursor: balance.Cursor().Encode(),
   159  	}, nil
   160  }
   161  
   162  type AggregatedBalanceCursor struct {
   163  	VegaTime  time.Time `json:"vega_time"`
   164  	AccountID *AccountID
   165  	PartyID   *PartyID
   166  	AssetID   *AssetID
   167  	MarketID  *MarketID
   168  	Type      *types.AccountType
   169  }
   170  
   171  func (c AggregatedBalanceCursor) String() string {
   172  	bs, err := json.Marshal(c)
   173  	if err != nil {
   174  		panic(fmt.Errorf("could not marshal aggregate balance cursor: %w", err))
   175  	}
   176  	return string(bs)
   177  }
   178  
   179  func (c *AggregatedBalanceCursor) Parse(cursorString string) error {
   180  	if cursorString == "" {
   181  		return nil
   182  	}
   183  	return json.Unmarshal([]byte(cursorString), c)
   184  }