code.vegaprotocol.io/vega@v0.79.0/datanode/service/market_depth_amm.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package service
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  	"time"
    22  
    23  	"code.vegaprotocol.io/vega/core/types"
    24  	"code.vegaprotocol.io/vega/datanode/entities"
    25  	vgcrypto "code.vegaprotocol.io/vega/libs/crypto"
    26  	"code.vegaprotocol.io/vega/libs/num"
    27  	"code.vegaprotocol.io/vega/logging"
    28  )
    29  
    30  var (
    31  	ErrNoAMMVolumeReference = errors.New("cannot find reference price to estimate AMM volume")
    32  	hundred                 = num.DecimalFromInt64(100)
    33  )
    34  
    35  // a version of entities.AMMPool that is less flat.
    36  type ammDefn struct {
    37  	partyID  string
    38  	lower    *curve
    39  	upper    *curve
    40  	position num.Decimal // signed position in Vega-space
    41  }
    42  
    43  type curve struct {
    44  	low       *num.Uint
    45  	high      *num.Uint
    46  	assetLow  *num.Uint
    47  	assetHigh *num.Uint
    48  	sqrtHigh  num.Decimal
    49  	sqrtLow   num.Decimal
    50  	pv        num.Decimal
    51  	l         num.Decimal
    52  	isLower   bool
    53  }
    54  
    55  type level struct {
    56  	price      *num.Uint
    57  	assetPrice *num.Uint
    58  	assetSqrt  num.Decimal
    59  	estimated  bool
    60  }
    61  
    62  func newLevel(price *num.Uint, estimated bool, priceFactor num.Decimal) *level {
    63  	assetPrice, _ := num.UintFromDecimal(price.ToDecimal().Mul(priceFactor))
    64  	return &level{
    65  		price:      price.Clone(),
    66  		assetPrice: assetPrice,
    67  		assetSqrt:  num.UintZero().Sqrt(assetPrice),
    68  		estimated:  estimated,
    69  	}
    70  }
    71  
    72  func (cu *curve) impliedPosition(sqrtPrice, sqrtHigh num.Decimal) num.Decimal {
    73  	// L * (sqrt(high) - sqrt(price))
    74  	numer := sqrtHigh.Sub(sqrtPrice).Mul(cu.l)
    75  
    76  	// sqrt(high) * sqrt(price)
    77  	denom := sqrtHigh.Mul(sqrtPrice)
    78  
    79  	// L * (sqrt(high) - sqrt(price)) / sqrt(high) * sqrt(price)
    80  	res := numer.Div(denom)
    81  
    82  	if cu.isLower {
    83  		return res
    84  	}
    85  
    86  	// if we are in the upper curve the position of 0 in "curve-space" is -cu.pv in Vega position
    87  	// so we need to flip the interval
    88  	return cu.pv.Sub(res).Neg()
    89  }
    90  
    91  func (m *MarketDepth) getActiveAMMs(ctx context.Context) map[string][]entities.AMMPool {
    92  	ammByMarket := map[string][]entities.AMMPool{}
    93  	amms, err := m.ammStore.ListActive(ctx)
    94  	if err != nil {
    95  		m.log.Warn("unable to query AMM's for market-depth",
    96  			logging.Error(err),
    97  		)
    98  	}
    99  
   100  	for _, amm := range amms {
   101  		marketID := string(amm.MarketID)
   102  		if _, ok := ammByMarket[marketID]; !ok {
   103  			ammByMarket[marketID] = []entities.AMMPool{}
   104  		}
   105  
   106  		ammByMarket[marketID] = append(ammByMarket[marketID], amm)
   107  	}
   108  	return ammByMarket
   109  }
   110  
   111  func (m *MarketDepth) getCalculationBounds(cache *ammCache, reference num.Decimal, priceFactor num.Decimal) []*level {
   112  	if levels, ok := cache.levels[reference.String()]; ok {
   113  		return levels
   114  	}
   115  
   116  	lowestBound := cache.lowestBound
   117  	highestBound := cache.highestBound
   118  
   119  	// first lets calculate the region we will expand accurately, this will be some percentage either side of the reference price
   120  	factor := num.DecimalFromFloat(m.cfg.AmmFullExpansionPercentage).Div(hundred)
   121  
   122  	// if someone has set the expansion to be more than 100% lets make sure it doesn't overflow
   123  	factor = num.MinD(factor, num.DecimalOne())
   124  
   125  	referenceU, _ := num.UintFromDecimal(reference)
   126  	accHigh, _ := num.UintFromDecimal(reference.Mul(num.DecimalOne().Add(factor)))
   127  	accLow, _ := num.UintFromDecimal(reference.Mul(num.DecimalOne().Sub(factor)))
   128  
   129  	// always want some volume so if for some reason the bounds were set too low so we calculated a sub-tick expansion make it at least one
   130  	if accHigh.EQ(referenceU) {
   131  		accHigh.Add(referenceU, num.UintOne())
   132  		accLow.Sub(referenceU, num.UintOne())
   133  	}
   134  
   135  	// this is the percentage of the reference price to take in estimated steps
   136  	stepFactor := num.DecimalFromFloat(m.cfg.AmmEstimatedStepPercentage).Div(hundred)
   137  
   138  	// this is how many of those steps to take
   139  	maxEstimatedSteps := m.cfg.AmmMaxEstimatedSteps
   140  
   141  	// and so this is the size of the estimated step
   142  	eStep, _ := num.UintFromDecimal(reference.Mul(stepFactor))
   143  
   144  	eRange := num.UintZero().Mul(eStep, num.NewUint(maxEstimatedSteps))
   145  	estLow := num.UintZero().Sub(accLow, num.Min(accLow, eRange))
   146  	estHigh := num.UintZero().Add(accHigh, eRange)
   147  
   148  	// cap steps to the lowest/highest boundaries of all AMMs
   149  	lowD, _ := num.UintFromDecimal(lowestBound)
   150  	if accLow.LTE(lowD) {
   151  		accLow = lowD.Clone()
   152  		estLow = lowD.Clone()
   153  	}
   154  
   155  	highD, _ := num.UintFromDecimal(highestBound)
   156  	if accHigh.GTE(highD) {
   157  		accHigh = highD.Clone()
   158  		estHigh = highD.Clone()
   159  	}
   160  
   161  	// need to find the first n such that
   162  	// accLow - (n * eStep) < lowD
   163  	// accLow
   164  	if estLow.LT(lowD) {
   165  		delta, _ := num.UintZero().Delta(accLow, lowD)
   166  		delta.Div(delta, eStep)
   167  		estLow = num.UintZero().Sub(accLow, delta.Mul(delta, eStep))
   168  	}
   169  
   170  	if estHigh.GT(highD) {
   171  		delta, _ := num.UintZero().Delta(accHigh, highD)
   172  		delta.Div(delta, eStep)
   173  		estHigh = num.UintZero().Add(accHigh, delta.Mul(delta, eStep))
   174  	}
   175  
   176  	levels := []*level{}
   177  
   178  	// we now have our four prices [estLow, accLow, accHigh, estHigh] where from
   179  	// estLow -> accLow   : we will take big price steps
   180  	// accLow -> accHigh  : we will take price steps of one-tick
   181  	// accHigh -> estHigh : we will take big price steps
   182  	price := estLow.Clone()
   183  
   184  	// larger steps from estLow -> accHigh
   185  	for price.LT(accLow) {
   186  		levels = append(levels, newLevel(price, true, priceFactor))
   187  		price = num.UintZero().Add(price, eStep)
   188  	}
   189  
   190  	// now smaller steps from accLow -> accHigh
   191  	for price.LTE(accHigh) {
   192  		levels = append(levels, newLevel(price, false, priceFactor))
   193  		price = num.UintZero().Add(price, num.UintOne())
   194  	}
   195  
   196  	// now back to large steps for accHigh -> estHigh
   197  	for price.LTE(estHigh) {
   198  		levels = append(levels, newLevel(price, true, priceFactor))
   199  		price = num.UintZero().Add(price, eStep)
   200  	}
   201  
   202  	cache.levels = map[string][]*level{
   203  		reference.String(): levels,
   204  	}
   205  
   206  	return levels
   207  }
   208  
   209  func (m *MarketDepth) getReference(ctx context.Context, marketID string) (num.Decimal, error) {
   210  	marketData, err := m.marketData.GetMarketDataByID(ctx, marketID)
   211  	if err != nil {
   212  		m.log.Warn("unable to get market-data for market",
   213  			logging.String("market-id", marketID),
   214  			logging.Error(err),
   215  		)
   216  		return num.DecimalZero(), err
   217  	}
   218  
   219  	reference := marketData.MidPrice
   220  	if !marketData.IndicativePrice.IsZero() {
   221  		reference = marketData.IndicativePrice
   222  	}
   223  
   224  	if reference.IsZero() {
   225  		m.log.Warn("cannot calculate market-depth for AMM, no reference point available",
   226  			logging.String("mid-price", marketData.MidPrice.String()),
   227  			logging.String("indicative-price", marketData.IndicativePrice.String()),
   228  		)
   229  		return num.DecimalZero(), ErrNoAMMVolumeReference
   230  	}
   231  
   232  	return reference, nil
   233  }
   234  
   235  func (m *MarketDepth) expandByLevels(pool entities.AMMPool, levels []*level, priceFactor num.Decimal) ([]*types.Order, []bool, error) {
   236  	if len(levels) == 0 {
   237  		return nil, nil, nil
   238  	}
   239  
   240  	// get position
   241  	pos, err := m.getAMMPosition(pool.MarketID.String(), pool.AmmPartyID.String())
   242  	if err != nil {
   243  		return nil, nil, err
   244  	}
   245  
   246  	ammDefn := definitionFromEntity(pool, pos, priceFactor)
   247  
   248  	estimated := []bool{}
   249  	orders := []*types.Order{}
   250  
   251  	level1 := levels[0]
   252  	extraVolume := int64(0)
   253  	for i := range levels {
   254  		if i == len(levels)-1 {
   255  			break
   256  		}
   257  
   258  		// level1 := levels[i]
   259  		level2 := levels[i+1]
   260  
   261  		// check if the interval is fully outside of the AMM range
   262  		if ammDefn.lower.low.GTE(level2.price) {
   263  			level1 = level2
   264  			continue
   265  		}
   266  		if ammDefn.upper.high.LTE(level1.price) {
   267  			break
   268  		}
   269  
   270  		// snap to AMM boundaries
   271  		if level1.price.LT(ammDefn.lower.low) {
   272  			level1 = &level{
   273  				price:      ammDefn.lower.low,
   274  				assetPrice: ammDefn.lower.assetLow,
   275  				assetSqrt:  ammDefn.lower.sqrtLow,
   276  				estimated:  level1.estimated,
   277  			}
   278  		}
   279  
   280  		if level2.price.GT(ammDefn.upper.high) {
   281  			level2 = &level{
   282  				price:      ammDefn.upper.high,
   283  				assetPrice: ammDefn.upper.assetHigh,
   284  				assetSqrt:  ammDefn.upper.sqrtHigh,
   285  				estimated:  level2.estimated,
   286  			}
   287  		}
   288  
   289  		// pick curve which curve we are in
   290  		cu := ammDefn.lower
   291  		if level1.price.GTE(ammDefn.lower.high) {
   292  			cu = ammDefn.upper
   293  		}
   294  
   295  		// let calculate the volume between these two
   296  		v1 := cu.impliedPosition(level1.assetSqrt, cu.sqrtHigh)
   297  		v2 := cu.impliedPosition(level2.assetSqrt, cu.sqrtHigh)
   298  
   299  		retPrice := level1.price
   300  		side := types.SideBuy
   301  
   302  		if v2.LessThan(ammDefn.position) {
   303  			side = types.SideSell
   304  			retPrice = level2.price
   305  
   306  			// if we've stepped over the pool's position we need to split the volume and add it to the outer levels
   307  			if v1.GreaterThan(ammDefn.position) {
   308  				volume := v1.Sub(ammDefn.position).Abs().IntPart()
   309  
   310  				// we want to add the volume to the previous order, because thats the price in marketDP when rounded away
   311  				// from the fair-price
   312  				if len(orders) != 0 {
   313  					o := orders[len(orders)-1]
   314  					o.Size += uint64(volume)
   315  					o.Remaining += uint64(volume)
   316  				}
   317  
   318  				// we need to add this volume to the price level we step to next
   319  				extraVolume = ammDefn.position.Sub(v2).Abs().IntPart()
   320  				level1 = level2
   321  				continue
   322  			}
   323  		}
   324  		// calculate the volume
   325  		volume := v1.Sub(v2).Abs().IntPart()
   326  
   327  		// if the volume is less than zero AMM must be sparse and so we want to keep adding it up until we have at least 1 volume
   328  		// so we'll continue and not shuffle along level1
   329  		if volume == 0 {
   330  			continue
   331  		}
   332  
   333  		// this is extra volume from when we stepped over the AMM's fair-price
   334  		if extraVolume != 0 {
   335  			volume += extraVolume
   336  			extraVolume = 0
   337  		}
   338  
   339  		orders = append(
   340  			orders,
   341  			m.makeOrder(retPrice, ammDefn.partyID, uint64(volume), side),
   342  		)
   343  		estimated = append(estimated, level1.estimated || level2.estimated)
   344  
   345  		// shuffle
   346  		level1 = level2
   347  	}
   348  	return orders, estimated, nil
   349  }
   350  
   351  func (m *MarketDepth) InitialiseAMMs(ctx context.Context) {
   352  	active := m.getActiveAMMs(ctx)
   353  	if len(active) == 0 {
   354  		return
   355  	}
   356  
   357  	// expand all these AMM's from the midpoint
   358  	for marketID, amms := range active {
   359  		md := m.getDepth(marketID)
   360  
   361  		cache, err := m.getAMMCache(marketID)
   362  		if err != nil {
   363  			m.log.Panic("unable to expand AMM's for market",
   364  				logging.Error(err),
   365  				logging.String("market-id", marketID),
   366  			)
   367  		}
   368  
   369  		priceFactor := cache.priceFactor
   370  
   371  		// add it to our active list, we want to do this even if we fail to get a reference
   372  		for _, a := range amms {
   373  			cache.addAMM(a)
   374  		}
   375  
   376  		reference, err := m.getReference(ctx, marketID)
   377  		if err != nil {
   378  			continue
   379  		}
   380  
   381  		levels := m.getCalculationBounds(cache, reference, priceFactor)
   382  
   383  		for _, amm := range amms {
   384  			orders, estimated, err := m.expandByLevels(amm, levels, priceFactor)
   385  			if err != nil {
   386  				continue
   387  			}
   388  
   389  			if len(orders) == 0 {
   390  				continue
   391  			}
   392  
   393  			// save them in the cache
   394  			cache.ammOrders[amm.AmmPartyID.String()] = orders
   395  
   396  			for i := range orders {
   397  				md.AddAMMOrder(orders[i], estimated[i])
   398  				if estimated[i] {
   399  					cache.estimatedOrder[orders[i].ID] = struct{}{}
   400  				}
   401  			}
   402  		}
   403  	}
   404  }
   405  
   406  func (m *MarketDepth) ExpandAMM(ctx context.Context, pool entities.AMMPool, priceFactor num.Decimal) ([]*types.Order, []bool, error) {
   407  	reference, err := m.getReference(ctx, pool.MarketID.String())
   408  	if err == ErrNoAMMVolumeReference {
   409  		// if we can't get a reference to expand from then the market must be fresh and we will just use the pool's base
   410  		reference = pool.ParametersBase
   411  	} else if err != nil {
   412  		return nil, nil, err
   413  	}
   414  
   415  	cache, err := m.getAMMCache(string(pool.MarketID))
   416  	if err != nil {
   417  		return nil, nil, err
   418  	}
   419  
   420  	levels := m.getCalculationBounds(cache, reference, priceFactor)
   421  
   422  	return m.expandByLevels(pool, levels, priceFactor)
   423  }
   424  
   425  func (m *MarketDepth) makeOrder(price *num.Uint, partyID string, volume uint64, side types.Side) *types.Order {
   426  	return &types.Order{
   427  		ID:               vgcrypto.RandomHash(),
   428  		Party:            partyID,
   429  		Price:            price,
   430  		Status:           entities.OrderStatusActive,
   431  		Type:             entities.OrderTypeLimit,
   432  		TimeInForce:      entities.OrderTimeInForceGTC,
   433  		Size:             volume,
   434  		Remaining:        volume,
   435  		GeneratedOffbook: true,
   436  		Side:             side,
   437  	}
   438  }
   439  
   440  // refreshAMM is used when an AMM has either traded or its definition has changed.
   441  func (m *MarketDepth) refreshAMM(pool entities.AMMPool, depth *entities.MarketDepth) {
   442  	marketID := pool.MarketID.String()
   443  	ammParty := pool.AmmPartyID.String()
   444  
   445  	// get all the AMM details from the cache
   446  	cache, err := m.getAMMCache(marketID)
   447  	if err != nil {
   448  		m.log.Warn("unable to refresh AMM expansion",
   449  			logging.Error(err),
   450  			logging.String("market-id", marketID),
   451  		)
   452  	}
   453  
   454  	// remove any expanded orders the AMM already has in the depth
   455  	existing := cache.ammOrders[ammParty]
   456  	for _, o := range existing {
   457  		o.Status = entities.OrderStatusCancelled
   458  
   459  		_, estimated := cache.estimatedOrder[o.ID]
   460  		delete(cache.estimatedOrder, o.ID)
   461  
   462  		depth.AddOrderUpdate(o, estimated)
   463  	}
   464  
   465  	if pool.Status == entities.AMMStatusCancelled || pool.Status == entities.AMMStatusStopped {
   466  		cache.removeAMM(ammParty)
   467  		return
   468  	}
   469  
   470  	cache.addAMM(pool)
   471  
   472  	// expand it again into new orders and push them into the market depth
   473  	orders, estimated, _ := m.ExpandAMM(context.Background(), pool, cache.priceFactor)
   474  	for i := range orders {
   475  		depth.AddOrderUpdate(orders[i], estimated[i])
   476  		if estimated[i] {
   477  			cache.estimatedOrder[orders[i].ID] = struct{}{}
   478  		}
   479  	}
   480  	cache.ammOrders[ammParty] = orders
   481  }
   482  
   483  // refreshAMM is used when an AMM has either traded or its definition has changed.
   484  func (m *MarketDepth) OnAMMUpdate(pool entities.AMMPool, vegaTime time.Time, seqNum uint64) {
   485  	m.mu.Lock()
   486  	defer m.mu.Unlock()
   487  
   488  	if !m.sequential(vegaTime, seqNum) {
   489  		return
   490  	}
   491  
   492  	depth := m.getDepth(pool.MarketID.String())
   493  	depth.SequenceNumber = m.sequenceNumber
   494  
   495  	m.refreshAMM(pool, depth)
   496  }
   497  
   498  func (m *MarketDepth) onAMMTraded(ammParty, marketID string) {
   499  	cache, err := m.getAMMCache(marketID)
   500  	if err != nil {
   501  		m.log.Warn("unable to refresh AMM expansion",
   502  			logging.Error(err),
   503  			logging.String("market-id", marketID),
   504  		)
   505  	}
   506  
   507  	pool, ok := cache.activeAMMs[ammParty]
   508  	if !ok {
   509  		m.log.Panic("market-depth out of sync -- received trade event for AMM that doesn't exist")
   510  	}
   511  
   512  	depth := m.getDepth(pool.MarketID.String())
   513  	depth.SequenceNumber = m.sequenceNumber
   514  	m.refreshAMM(pool, depth)
   515  }
   516  
   517  func (m *MarketDepth) isAMMOrder(order *types.Order) bool {
   518  	c, ok := m.ammCache[order.MarketID]
   519  	if !ok {
   520  		return false
   521  	}
   522  
   523  	_, ok = c.activeAMMs[order.Party]
   524  	return ok
   525  }
   526  
   527  func (m *MarketDepth) getAMMCache(marketID string) (*ammCache, error) {
   528  	if cache, ok := m.ammCache[marketID]; ok {
   529  		return cache, nil
   530  	}
   531  
   532  	// first time we've seen this market lets get the price factor
   533  	market, err := m.markets.GetByID(context.Background(), marketID)
   534  	if err != nil {
   535  		return nil, err
   536  	}
   537  
   538  	assetID, err := market.ToProto().GetAsset()
   539  	if err != nil {
   540  		return nil, err
   541  	}
   542  
   543  	asset, err := m.assetStore.GetByID(context.Background(), assetID)
   544  	if err != nil {
   545  		return nil, err
   546  	}
   547  
   548  	priceFactor := num.DecimalOne()
   549  	if exp := asset.Decimals - market.DecimalPlaces; exp != 0 {
   550  		priceFactor = num.DecimalFromInt64(10).Pow(num.DecimalFromInt64(int64(exp)))
   551  	}
   552  
   553  	cache := &ammCache{
   554  		priceFactor:    priceFactor,
   555  		ammOrders:      map[string][]*types.Order{},
   556  		activeAMMs:     map[string]entities.AMMPool{},
   557  		estimatedOrder: map[string]struct{}{},
   558  		levels:         map[string][]*level{},
   559  	}
   560  	m.ammCache[marketID] = cache
   561  
   562  	return cache, nil
   563  }
   564  
   565  func (m *MarketDepth) getAMMPosition(marketID, partyID string) (int64, error) {
   566  	p, err := m.positions.GetByMarketAndParty(context.Background(), marketID, partyID)
   567  	if err == nil {
   568  		return p.OpenVolume, nil
   569  	}
   570  
   571  	if err == entities.ErrNotFound {
   572  		return 0, nil
   573  	}
   574  
   575  	return 0, err
   576  }
   577  
   578  func definitionFromEntity(ent entities.AMMPool, position int64, priceFactor num.Decimal) *ammDefn {
   579  	base, _ := num.UintFromDecimal(ent.ParametersBase)
   580  	low := base.Clone()
   581  	high := base.Clone()
   582  
   583  	if ent.ParametersLowerBound != nil {
   584  		low, _ = num.UintFromDecimal(*ent.ParametersLowerBound)
   585  	}
   586  
   587  	if ent.ParametersUpperBound != nil {
   588  		high, _ = num.UintFromDecimal(*ent.ParametersUpperBound)
   589  	}
   590  
   591  	assetHigh, _ := num.UintFromDecimal(high.ToDecimal().Mul(priceFactor))
   592  	assetBase, _ := num.UintFromDecimal(base.ToDecimal().Mul(priceFactor))
   593  	assetLow, _ := num.UintFromDecimal(low.ToDecimal().Mul(priceFactor))
   594  
   595  	return &ammDefn{
   596  		position: num.DecimalFromInt64(position),
   597  		lower: &curve{
   598  			low:       low,
   599  			high:      base,
   600  			assetLow:  assetLow,
   601  			assetHigh: assetBase,
   602  			sqrtLow:   num.UintOne().Sqrt(assetLow),
   603  			sqrtHigh:  num.UintOne().Sqrt(assetBase),
   604  			isLower:   true,
   605  			l:         ent.LowerVirtualLiquidity,
   606  			pv:        ent.LowerTheoreticalPosition,
   607  		},
   608  		upper: &curve{
   609  			low:       base,
   610  			high:      high,
   611  			assetLow:  assetBase,
   612  			assetHigh: assetHigh,
   613  			sqrtLow:   num.UintOne().Sqrt(assetBase),
   614  			sqrtHigh:  num.UintOne().Sqrt(assetHigh),
   615  			l:         ent.UpperVirtualLiquidity,
   616  			pv:        ent.UpperTheoreticalPosition,
   617  		},
   618  	}
   619  }