code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/amm_pool.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/datanode/metrics"
    25  	"code.vegaprotocol.io/vega/libs/ptr"
    26  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    27  
    28  	"github.com/georgysavva/scany/pgxscan"
    29  )
    30  
    31  type AMMPools struct {
    32  	*ConnectionSource
    33  }
    34  
    35  var (
    36  	ammPoolsOrdering = TableOrdering{
    37  		ColumnOrdering{Name: "created_at", Sorting: ASC},
    38  		ColumnOrdering{Name: "party_id", Sorting: DESC},
    39  		ColumnOrdering{Name: "amm_party_id", Sorting: DESC},
    40  		ColumnOrdering{Name: "market_id", Sorting: DESC},
    41  		ColumnOrdering{Name: "id", Sorting: DESC},
    42  	}
    43  
    44  	activeStates = []entities.AMMStatus{entities.AMMStatusActive, entities.AMMStatusReduceOnly}
    45  )
    46  
    47  func NewAMMPools(connectionSource *ConnectionSource) *AMMPools {
    48  	return &AMMPools{
    49  		ConnectionSource: connectionSource,
    50  	}
    51  }
    52  
    53  func (p *AMMPools) Upsert(ctx context.Context, pool entities.AMMPool) error {
    54  	defer metrics.StartSQLQuery("AMMs", "UpsertAMM")
    55  	if _, err := p.ConnectionSource.Exec(ctx, `
    56  insert into amms(party_id, market_id, id, amm_party_id,
    57  commitment, status, status_reason, 	parameters_base,
    58  parameters_lower_bound, parameters_upper_bound,
    59  parameters_leverage_at_lower_bound, parameters_leverage_at_upper_bound,
    60  created_at, last_updated, proposed_fee,
    61  lower_virtual_liquidity, lower_theoretical_position,
    62  upper_virtual_liquidity, upper_theoretical_position, data_source_id,
    63  minimum_price_change_trigger)values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
    64  on conflict (party_id, market_id, id, amm_party_id) do update set
    65  	commitment=excluded.commitment,
    66  	status=excluded.status,
    67  	status_reason=excluded.status_reason,
    68  	parameters_base=excluded.parameters_base,
    69  	parameters_lower_bound=excluded.parameters_lower_bound,
    70  	parameters_upper_bound=excluded.parameters_upper_bound,
    71  	parameters_leverage_at_lower_bound=excluded.parameters_leverage_at_lower_bound,
    72  	parameters_leverage_at_upper_bound=excluded.parameters_leverage_at_upper_bound,
    73  	last_updated=excluded.last_updated,
    74  	proposed_fee=excluded.proposed_fee,
    75  	lower_virtual_liquidity=excluded.lower_virtual_liquidity,
    76  	lower_theoretical_position=excluded.lower_theoretical_position,
    77  	upper_virtual_liquidity=excluded.upper_virtual_liquidity,
    78  	upper_theoretical_position=excluded.upper_theoretical_position,
    79  	data_source_id=excluded.data_source_id,
    80  	minimum_price_change_trigger=excluded.minimum_price_change_trigger;`,
    81  		pool.PartyID,
    82  		pool.MarketID,
    83  		pool.ID,
    84  		pool.AmmPartyID,
    85  		pool.Commitment,
    86  		pool.Status,
    87  		pool.StatusReason,
    88  		pool.ParametersBase,
    89  		pool.ParametersLowerBound,
    90  		pool.ParametersUpperBound,
    91  		pool.ParametersLeverageAtLowerBound,
    92  		pool.ParametersLeverageAtUpperBound,
    93  		pool.CreatedAt,
    94  		pool.LastUpdated,
    95  		pool.ProposedFee,
    96  		pool.LowerVirtualLiquidity,
    97  		pool.LowerTheoreticalPosition,
    98  		pool.UpperVirtualLiquidity,
    99  		pool.UpperTheoreticalPosition,
   100  		pool.DataSourceID,
   101  		pool.MinimumPriceChangeTrigger,
   102  	); err != nil {
   103  		return fmt.Errorf("could not upsert AMM Pool: %w", err)
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  func listByFields(ctx context.Context, connection Connection, fields map[string]entities.AMMFilterType, liveOnly bool, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   110  	var (
   111  		pools       []entities.AMMPool
   112  		pageInfo    entities.PageInfo
   113  		whereClause string
   114  	)
   115  	where := make([]string, 0, len(fields))
   116  	args := make([]any, 0, len(fields))
   117  	for field, val := range fields {
   118  		var clause string
   119  		clause, args = val.Where(&field, nextBindVar, args...)
   120  		where = append(where, clause)
   121  	}
   122  
   123  	if liveOnly {
   124  		where = append(where, liveOnlyClause(&args))
   125  	}
   126  
   127  	whereClause = strings.Join(where, " AND ")
   128  	query := fmt.Sprintf(`SELECT * FROM amms WHERE %s`, whereClause)
   129  	query, args, err := PaginateQuery[entities.AMMPoolCursor](query, args, ammPoolsOrdering, pagination)
   130  	if err != nil {
   131  		return nil, pageInfo, err
   132  	}
   133  
   134  	if err := pgxscan.Select(ctx, connection, &pools, query, args...); err != nil {
   135  		return nil, pageInfo, fmt.Errorf("could not list AMM Pools: %w", err)
   136  	}
   137  
   138  	pools, pageInfo = entities.PageEntities(pools, pagination)
   139  	return pools, pageInfo, nil
   140  }
   141  
   142  func listBy[T entities.AMMPoolsFilter](ctx context.Context, connection Connection, fieldName string, filter T, liveOnly bool, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   143  	var (
   144  		pools       []entities.AMMPool
   145  		pageInfo    entities.PageInfo
   146  		args        []interface{}
   147  		whereClause string
   148  	)
   149  	whereClause, args = filter.Where(&fieldName, nextBindVar, args...)
   150  
   151  	if liveOnly {
   152  		whereClause += " AND " + liveOnlyClause(&args)
   153  	}
   154  
   155  	query := fmt.Sprintf(`SELECT * FROM amms WHERE %s`, whereClause)
   156  	query, args, err := PaginateQuery[entities.AMMPoolCursor](query, args, ammPoolsOrdering, pagination)
   157  	if err != nil {
   158  		return nil, pageInfo, err
   159  	}
   160  	if err := pgxscan.Select(ctx, connection, &pools, query, args...); err != nil {
   161  		return nil, pageInfo, fmt.Errorf("could not list AMM Pools: %w", err)
   162  	}
   163  
   164  	pools, pageInfo = entities.PageEntities[*v2.AMMEdge](pools, pagination)
   165  	return pools, pageInfo, nil
   166  }
   167  
   168  func (p *AMMPools) ListByMarket(ctx context.Context, marketID entities.MarketID, liveOnly bool, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   169  	defer metrics.StartSQLQuery("AMMs", "ListByMarket")
   170  	return listBy(ctx, p.ConnectionSource, "market_id", &marketID, liveOnly, pagination)
   171  }
   172  
   173  func (p *AMMPools) ListByParty(ctx context.Context, partyID entities.PartyID, liveOnly bool, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   174  	defer metrics.StartSQLQuery("AMMs", "ListByParty")
   175  
   176  	return listBy(ctx, p.ConnectionSource, "party_id", &partyID, liveOnly, pagination)
   177  }
   178  
   179  func (p *AMMPools) GetSubKeysForParties(ctx context.Context, partyIDs []string, marketIDs []string) ([]string, error) {
   180  	if len(partyIDs) == 0 {
   181  		return nil, nil
   182  	}
   183  	parties := strings.Builder{}
   184  	args := make([]any, 0, len(partyIDs)+len(marketIDs))
   185  	query := `SELECT amm_party_id FROM amms WHERE `
   186  	for i, party := range partyIDs {
   187  		if i > 0 {
   188  			parties.WriteString(",")
   189  		}
   190  		parties.WriteString(nextBindVar(&args, ptr.From(entities.PartyID(party))))
   191  	}
   192  	query = fmt.Sprintf(`%s party_id IN (%s)`, query, parties.String())
   193  	if len(marketIDs) > 0 {
   194  		markets := strings.Builder{}
   195  		for i, mID := range marketIDs {
   196  			if i > 0 {
   197  				markets.WriteString(",")
   198  			}
   199  
   200  			markets.WriteString(nextBindVar(&args, ptr.From(entities.MarketID(mID))))
   201  		}
   202  		query = fmt.Sprintf("%s AND market_id IN (%s)", query, markets.String())
   203  	}
   204  
   205  	subKeys := []entities.PartyID{}
   206  	if err := pgxscan.Select(ctx, p.ConnectionSource, &subKeys, query, args...); err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	res := make([]string, 0, len(subKeys))
   211  	for _, k := range subKeys {
   212  		res = append(res, k.String())
   213  	}
   214  	return res, nil
   215  }
   216  
   217  func (p *AMMPools) ListByPool(ctx context.Context, poolID entities.AMMPoolID, liveOnly bool, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   218  	defer metrics.StartSQLQuery("AMMs", "ListByPool")
   219  	return listBy(ctx, p.ConnectionSource, "id", &poolID, liveOnly, pagination)
   220  }
   221  
   222  func (p *AMMPools) ListBySubAccount(ctx context.Context, ammPartyID entities.PartyID, liveOnly bool, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   223  	defer metrics.StartSQLQuery("AMMs", "ListByAMMParty")
   224  	return listBy(ctx, p.ConnectionSource, "amm_party_id", &ammPartyID, liveOnly, pagination)
   225  }
   226  
   227  func (p *AMMPools) ListByStatus(ctx context.Context, status entities.AMMStatus, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   228  	defer metrics.StartSQLQuery("AMMs", "ListByStatus")
   229  	return listBy(ctx, p.ConnectionSource, "status", &status, false, pagination)
   230  }
   231  
   232  func (p *AMMPools) ListByPartyMarketStatus(ctx context.Context, party *entities.PartyID, market *entities.MarketID, status *entities.AMMStatus, liveOnly bool, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   233  	defer metrics.StartSQLQuery("AMMs", "ListByPartyMarketStatus")
   234  	fields := make(map[string]entities.AMMFilterType, 3)
   235  	if party != nil {
   236  		fields["party_id"] = party
   237  	}
   238  	if market != nil {
   239  		fields["market_id"] = market
   240  	}
   241  	if status != nil {
   242  		fields["status"] = status
   243  	}
   244  	return listByFields(ctx, p.ConnectionSource, fields, liveOnly, pagination)
   245  }
   246  
   247  func (p *AMMPools) ListAll(ctx context.Context, liveOnly bool, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
   248  	defer metrics.StartSQLQuery("AMMs", "ListAll")
   249  	var (
   250  		pools    []entities.AMMPool
   251  		pageInfo entities.PageInfo
   252  		args     []interface{}
   253  	)
   254  	query := `SELECT * FROM amms`
   255  
   256  	if liveOnly {
   257  		query += ` WHERE ` + liveOnlyClause(&args)
   258  	}
   259  
   260  	query, args, err := PaginateQuery[entities.AMMPoolCursor](query, args, ammPoolsOrdering, pagination)
   261  	if err != nil {
   262  		return nil, pageInfo, err
   263  	}
   264  
   265  	if err := pgxscan.Select(ctx, p.ConnectionSource, &pools, query, args...); err != nil {
   266  		return nil, pageInfo, fmt.Errorf("could not list AMMs: %w", err)
   267  	}
   268  
   269  	pools, pageInfo = entities.PageEntities[*v2.AMMEdge](pools, pagination)
   270  	return pools, pageInfo, nil
   271  }
   272  
   273  func (p *AMMPools) ListActive(ctx context.Context) ([]entities.AMMPool, error) {
   274  	defer metrics.StartSQLQuery("AMMs", "ListAll")
   275  	var (
   276  		pools []entities.AMMPool
   277  		args  []interface{}
   278  	)
   279  
   280  	query := fmt.Sprintf(`SELECT * from amms WHERE %s`, liveOnlyClause(&args))
   281  	if err := pgxscan.Select(ctx, p.ConnectionSource, &pools, query, args...); err != nil {
   282  		return nil, fmt.Errorf("could not list active AMMs: %w", err)
   283  	}
   284  
   285  	return pools, nil
   286  }
   287  
   288  func liveOnlyClause(args *[]interface{}) string {
   289  	states := strings.Builder{}
   290  	for i, status := range activeStates {
   291  		if i > 0 {
   292  			states.WriteString(",")
   293  		}
   294  		states.WriteString(nextBindVar(args, status))
   295  	}
   296  	return fmt.Sprintf("status IN (%s)", states.String())
   297  }