code.vegaprotocol.io/vega@v0.79.0/core/execution/amm/pool.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package amm
    17  
    18  import (
    19  	"errors"
    20  	"fmt"
    21  
    22  	"code.vegaprotocol.io/vega/core/idgeneration"
    23  	"code.vegaprotocol.io/vega/core/types"
    24  	"code.vegaprotocol.io/vega/libs/num"
    25  	"code.vegaprotocol.io/vega/libs/ptr"
    26  	"code.vegaprotocol.io/vega/logging"
    27  	snapshotpb "code.vegaprotocol.io/vega/protos/vega/snapshot/v1"
    28  )
    29  
    30  // ephemeralPosition keeps track of the pools position as if its generated orders had traded.
    31  type ephemeralPosition struct {
    32  	size int64
    33  }
    34  
    35  type curve struct {
    36  	l       num.Decimal // virtual liquidity
    37  	high    *num.Uint   // high price value, upper bound if upper curve, base price is lower curve
    38  	low     *num.Uint   // low price value, base price if upper curve, lower bound if lower curve
    39  	empty   bool        // if true the curve is of zero length and represents no liquidity on this side of the amm
    40  	isLower bool        // whether the curve is for the lower curve or the upper curve
    41  
    42  	// the theoretical position of the curve at its lower boundary
    43  	// note that this equals Vega's position at the boundary only in the lower curve, since Vega position == curve-position
    44  	// in the upper curve Vega's position == 0 => position of `pv`` in curve-position, Vega's position pv => 0 in curve-position
    45  	pv num.Decimal
    46  
    47  	lDivSqrtPu num.Decimal
    48  	sqrtHigh   num.Decimal
    49  }
    50  
    51  // positionAtPrice returns the position of the AMM if its fair-price were the given price. This
    52  // will be signed for long/short as usual.
    53  func (c *curve) positionAtPrice(sqrt sqrtFn, price *num.Uint) int64 {
    54  	pos := impliedPosition(sqrt(price), c.sqrtHigh, c.l)
    55  	if c.isLower {
    56  		return pos.IntPart()
    57  	}
    58  
    59  	// if we are in the upper curve the position of 0 in "curve-space" is -cu.pv in Vega position
    60  	// so we need to flip the interval
    61  	return -c.pv.Sub(pos).IntPart()
    62  }
    63  
    64  // singleVolumePrice returns the price that is 1 volume away from the given price in the given direction.
    65  // If the AMM's commitment is low this may be more than one-tick away from `p`.
    66  func (c *curve) singleVolumePrice(sqrt sqrtFn, p *num.Uint, side types.Side) *num.Uint {
    67  	if c.empty {
    68  		panic("should not be calculating single-volume step on empty curve")
    69  	}
    70  
    71  	// for best buy:  (L * sqrt(pu) / (L + sqrt(pu)))^2
    72  	// for best sell: (L * sqrt(pu) / (L - sqrt(pu)))^2
    73  	var denom num.Decimal
    74  	if side == types.SideBuy {
    75  		denom = c.l.Add(sqrt(p))
    76  	} else {
    77  		denom = c.l.Sub(sqrt(p))
    78  	}
    79  
    80  	np := c.l.Mul(sqrt(p)).Div(denom)
    81  	np = np.Mul(np)
    82  
    83  	if side == types.SideSell {
    84  		// have to make sure we round away `p`
    85  		np = np.Ceil()
    86  	}
    87  
    88  	adj, _ := num.UintFromDecimal(np)
    89  	return adj
    90  }
    91  
    92  // singleVolumeDelta returns the price interval between p and the price that represents 1 volume movement.
    93  func (c *curve) singleVolumeDelta(sqrt sqrtFn, p *num.Uint, side types.Side) *num.Uint {
    94  	adj := c.singleVolumePrice(sqrt, p, side)
    95  	delta, _ := num.UintZero().Delta(p, adj)
    96  	return delta
    97  }
    98  
    99  // check will return an error is the curve contains too many price-levels where there is 0 volume.
   100  func (c *curve) check(sqrt sqrtFn, oneTick *num.Uint, allowedEmptyLevels uint64) error {
   101  	if c.empty {
   102  		return nil
   103  	}
   104  
   105  	if c.pv.LessThan(num.DecimalOne()) {
   106  		return ErrCommitmentTooLow
   107  	}
   108  
   109  	// curve is valid if
   110  	// n * oneTick > pu - (L * sqrt(pu) / (L + sqrt(pu)))^2
   111  	adj := c.singleVolumePrice(sqrt, c.high, types.SideBuy)
   112  	delta := num.UintZero().Sub(c.high, adj)
   113  
   114  	// the plus one is because if allowable empty levels is 0, then the biggest delta allowed is 1
   115  	maxDelta := num.UintZero().Mul(oneTick, num.NewUint(allowedEmptyLevels+1))
   116  
   117  	// now this price delta must be less that the given maximum
   118  	if delta.GT(maxDelta) {
   119  		return ErrCommitmentTooLow
   120  	}
   121  	return nil
   122  }
   123  
   124  type Pool struct {
   125  	log         *logging.Logger
   126  	ID          string
   127  	AMMParty    string
   128  	Commitment  *num.Uint
   129  	ProposedFee num.Decimal
   130  	Parameters  *types.ConcentratedLiquidityParameters
   131  
   132  	asset                     string
   133  	market                    string
   134  	owner                     string
   135  	collateral                Collateral
   136  	position                  Position
   137  	priceFactor               num.Decimal
   138  	positionFactor            num.Decimal
   139  	SlippageTolerance         num.Decimal
   140  	MinimumPriceChangeTrigger num.Decimal
   141  
   142  	// current pool status
   143  	status types.AMMPoolStatus
   144  
   145  	// sqrt function to use.
   146  	sqrt sqrtFn
   147  
   148  	// the two curves joined at base-price used to determine price and volume in the pool
   149  	// lower is used when the pool is long.
   150  	lower *curve
   151  	upper *curve
   152  
   153  	// during the matching process across price levels we need to keep tracking of the pools potential positions
   154  	// as if those matching orders were to trade. This is so that when we generate more orders at the next price level
   155  	// for the same incoming order, the second round of generated orders are priced as if the first round had traded.
   156  	eph *ephemeralPosition
   157  
   158  	maxCalculationLevels *num.Uint // maximum number of price levels the AMM will be expanded into
   159  	oneTick              *num.Uint // one price tick
   160  
   161  	cache *poolCache
   162  }
   163  
   164  func NewPool(
   165  	log *logging.Logger,
   166  	id,
   167  	ammParty,
   168  	asset string,
   169  	submit *types.SubmitAMM,
   170  	sqrt sqrtFn,
   171  	collateral Collateral,
   172  	position Position,
   173  	rf *types.RiskFactor,
   174  	sf *types.ScalingFactors,
   175  	linearSlippage num.Decimal,
   176  	priceFactor num.Decimal,
   177  	positionFactor num.Decimal,
   178  	maxCalculationLevels *num.Uint,
   179  	allowedEmptyAMMLevels uint64,
   180  	slippageTolerance num.Decimal,
   181  	minimumPriceChangeTrigger num.Decimal,
   182  ) (*Pool, error) {
   183  	oneTick, _ := num.UintFromDecimal(priceFactor)
   184  	pool := &Pool{
   185  		log:                       log,
   186  		ID:                        id,
   187  		AMMParty:                  ammParty,
   188  		Commitment:                submit.CommitmentAmount,
   189  		ProposedFee:               submit.ProposedFee,
   190  		Parameters:                submit.Parameters,
   191  		market:                    submit.MarketID,
   192  		owner:                     submit.Party,
   193  		asset:                     asset,
   194  		sqrt:                      sqrt,
   195  		collateral:                collateral,
   196  		position:                  position,
   197  		priceFactor:               priceFactor,
   198  		positionFactor:            positionFactor,
   199  		oneTick:                   num.Max(num.UintOne(), oneTick),
   200  		status:                    types.AMMPoolStatusActive,
   201  		maxCalculationLevels:      maxCalculationLevels,
   202  		cache:                     NewPoolCache(),
   203  		SlippageTolerance:         slippageTolerance,
   204  		MinimumPriceChangeTrigger: minimumPriceChangeTrigger,
   205  	}
   206  
   207  	if submit.Parameters.DataSourceID != nil {
   208  		pool.status = types.AMMPoolStatusPending
   209  		pool.lower = emptyCurve(num.UintZero(), true)
   210  		pool.upper = emptyCurve(num.UintZero(), false)
   211  		return pool, nil
   212  	}
   213  
   214  	err := pool.setCurves(rf, sf, linearSlippage, allowedEmptyAMMLevels)
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  	return pool, nil
   219  }
   220  
   221  func NewPoolFromProto(
   222  	log *logging.Logger,
   223  	sqrt sqrtFn,
   224  	collateral Collateral,
   225  	position Position,
   226  	state *snapshotpb.PoolMapEntry_Pool,
   227  	party string,
   228  	priceFactor num.Decimal,
   229  	positionFactor num.Decimal,
   230  ) (*Pool, error) {
   231  	oneTick, _ := num.UintFromDecimal(priceFactor)
   232  
   233  	var lowerLeverage, upperLeverage *num.Decimal
   234  	if state.Parameters.LeverageAtLowerBound != nil {
   235  		l, err := num.DecimalFromString(*state.Parameters.LeverageAtLowerBound)
   236  		if err != nil {
   237  			return nil, err
   238  		}
   239  		lowerLeverage = &l
   240  	}
   241  	if state.Parameters.LeverageAtUpperBound != nil {
   242  		l, err := num.DecimalFromString(*state.Parameters.LeverageAtUpperBound)
   243  		if err != nil {
   244  			return nil, err
   245  		}
   246  		upperLeverage = &l
   247  	}
   248  
   249  	base, overflow := num.UintFromString(state.Parameters.Base, 10)
   250  	if overflow {
   251  		return nil, fmt.Errorf("failed to convert string to Uint: %s", state.Parameters.Base)
   252  	}
   253  
   254  	var lower, upper *num.Uint
   255  	if state.Parameters.LowerBound != nil {
   256  		lower, overflow = num.UintFromString(*state.Parameters.LowerBound, 10)
   257  		if overflow {
   258  			return nil, fmt.Errorf("failed to convert string to Uint: %s", *state.Parameters.LowerBound)
   259  		}
   260  	}
   261  
   262  	if state.Parameters.UpperBound != nil {
   263  		upper, overflow = num.UintFromString(*state.Parameters.UpperBound, 10)
   264  		if overflow {
   265  			return nil, fmt.Errorf("failed to convert string to Uint: %s", *state.Parameters.UpperBound)
   266  		}
   267  	}
   268  
   269  	upperCu, err := NewCurveFromProto(state.Upper)
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	lowerCu, err := NewCurveFromProto(state.Lower)
   275  	lowerCu.isLower = true
   276  	if err != nil {
   277  		return nil, err
   278  	}
   279  
   280  	proposedFee, err := num.DecimalFromString(state.ProposedFee)
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  
   285  	var slippageTolerance num.Decimal
   286  	if state.SlippageTolerance != "" {
   287  		slippageTolerance, err = num.DecimalFromString(state.SlippageTolerance)
   288  		if err != nil {
   289  			return nil, err
   290  		}
   291  	}
   292  
   293  	minimumPriceChangeTrigger := num.DecimalZero()
   294  	if state.MinimumPriceChangeTrigger != "" {
   295  		minimumPriceChangeTrigger, err = num.DecimalFromString(state.MinimumPriceChangeTrigger)
   296  		if err != nil {
   297  			return nil, err
   298  		}
   299  	}
   300  
   301  	return &Pool{
   302  		log:         log,
   303  		ID:          state.Id,
   304  		AMMParty:    state.AmmPartyId,
   305  		Commitment:  num.MustUintFromString(state.Commitment, 10),
   306  		ProposedFee: proposedFee,
   307  		Parameters: &types.ConcentratedLiquidityParameters{
   308  			Base:                 base,
   309  			LowerBound:           lower,
   310  			UpperBound:           upper,
   311  			LeverageAtLowerBound: lowerLeverage,
   312  			LeverageAtUpperBound: upperLeverage,
   313  			DataSourceID:         state.Parameters.DataSourceId,
   314  		},
   315  		owner:                     party,
   316  		market:                    state.Market,
   317  		asset:                     state.Asset,
   318  		sqrt:                      sqrt,
   319  		collateral:                collateral,
   320  		position:                  position,
   321  		lower:                     lowerCu,
   322  		upper:                     upperCu,
   323  		priceFactor:               priceFactor,
   324  		positionFactor:            positionFactor,
   325  		oneTick:                   num.Max(num.UintOne(), oneTick),
   326  		status:                    state.Status,
   327  		cache:                     NewPoolCache(),
   328  		SlippageTolerance:         slippageTolerance,
   329  		MinimumPriceChangeTrigger: minimumPriceChangeTrigger,
   330  	}, nil
   331  }
   332  
   333  func NewCurveFromProto(c *snapshotpb.PoolMapEntry_Curve) (*curve, error) {
   334  	l, err := num.DecimalFromString(c.L)
   335  	if err != nil {
   336  		return nil, err
   337  	}
   338  
   339  	pv, err := num.DecimalFromString(c.Pv)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  
   344  	high, overflow := num.UintFromString(c.High, 10)
   345  	if overflow {
   346  		return nil, fmt.Errorf("failed to convert string to Uint: %s", c.High)
   347  	}
   348  
   349  	low, overflow := num.UintFromString(c.Low, 10)
   350  	if overflow {
   351  		return nil, fmt.Errorf("failed to convert string to Uint: %s", c.Low)
   352  	}
   353  
   354  	var sqrtHigh, lDivSqrtPu num.Decimal
   355  	if !c.Empty {
   356  		sqrtHigh = num.UintOne().Sqrt(high)
   357  		lDivSqrtPu = l.Div(sqrtHigh)
   358  	}
   359  
   360  	return &curve{
   361  		l:          l,
   362  		high:       high,
   363  		low:        low,
   364  		empty:      c.Empty,
   365  		pv:         pv,
   366  		sqrtHigh:   sqrtHigh,
   367  		lDivSqrtPu: lDivSqrtPu,
   368  	}, nil
   369  }
   370  
   371  func (p *Pool) IntoProto() *snapshotpb.PoolMapEntry_Pool {
   372  	return &snapshotpb.PoolMapEntry_Pool{
   373  		Id:          p.ID,
   374  		AmmPartyId:  p.AMMParty,
   375  		Commitment:  p.Commitment.String(),
   376  		ProposedFee: p.ProposedFee.String(),
   377  		Parameters:  p.Parameters.ToProtoEvent(),
   378  		Market:      p.market,
   379  		Asset:       p.asset,
   380  		Lower: &snapshotpb.PoolMapEntry_Curve{
   381  			L:     p.lower.l.String(),
   382  			High:  p.lower.high.String(),
   383  			Low:   p.lower.low.String(),
   384  			Empty: p.lower.empty,
   385  			Pv:    p.lower.pv.String(),
   386  		},
   387  		Upper: &snapshotpb.PoolMapEntry_Curve{
   388  			L:     p.upper.l.String(),
   389  			High:  p.upper.high.String(),
   390  			Low:   p.upper.low.String(),
   391  			Empty: p.upper.empty,
   392  			Pv:    p.upper.pv.String(),
   393  		},
   394  		Status:                    p.status,
   395  		SlippageTolerance:         p.SlippageTolerance.String(),
   396  		MinimumPriceChangeTrigger: p.MinimumPriceChangeTrigger.String(),
   397  	}
   398  }
   399  
   400  // checkPosition will return false if its position exists outside of the curve boundaries and so the AMM
   401  // is invalid.
   402  func (p *Pool) checkPosition() bool {
   403  	pos := p.getPosition()
   404  
   405  	if pos > p.lower.pv.IntPart() {
   406  		return false
   407  	}
   408  
   409  	if -pos > p.upper.pv.IntPart() {
   410  		return false
   411  	}
   412  
   413  	return true
   414  }
   415  
   416  // Update returns a copy of the give pool but with its curves and parameters update as specified by `amend`.
   417  func (p *Pool) Update(
   418  	amend *types.AmendAMM,
   419  	rf *types.RiskFactor,
   420  	sf *types.ScalingFactors,
   421  	linearSlippage num.Decimal,
   422  	allowedEmptyAMMLevels uint64,
   423  ) (*Pool, error) {
   424  	commitment := p.Commitment.Clone()
   425  	if amend.CommitmentAmount != nil {
   426  		commitment = amend.CommitmentAmount
   427  	}
   428  
   429  	proposedFee := p.ProposedFee
   430  	if amend.ProposedFee.IsPositive() {
   431  		proposedFee = amend.ProposedFee
   432  	}
   433  
   434  	// parameters cannot only be updated all at once or not at all
   435  	parameters := p.Parameters.Clone()
   436  	if amend.Parameters != nil {
   437  		parameters = amend.Parameters
   438  	}
   439  
   440  	// if an AMM is amended so that it cannot be long (i.e it has no lower curve) but the existing AMM
   441  	// is already long then we cannot make the change since its fair-price will be undefined.
   442  	if parameters.LowerBound == nil && p.getPosition() > 0 {
   443  		return nil, errors.New("cannot remove lower bound when AMM is long")
   444  	}
   445  
   446  	if parameters.UpperBound == nil && p.getPosition() < 0 {
   447  		return nil, errors.New("cannot remove upper bound when AMM is short")
   448  	}
   449  
   450  	updated := &Pool{
   451  		log:                       p.log,
   452  		ID:                        p.ID,
   453  		AMMParty:                  p.AMMParty,
   454  		Commitment:                commitment,
   455  		ProposedFee:               proposedFee,
   456  		Parameters:                parameters,
   457  		asset:                     p.asset,
   458  		market:                    p.market,
   459  		owner:                     p.owner,
   460  		collateral:                p.collateral,
   461  		position:                  p.position,
   462  		priceFactor:               p.priceFactor,
   463  		positionFactor:            p.positionFactor,
   464  		status:                    types.AMMPoolStatusActive,
   465  		sqrt:                      p.sqrt,
   466  		oneTick:                   p.oneTick,
   467  		maxCalculationLevels:      p.maxCalculationLevels,
   468  		cache:                     NewPoolCache(),
   469  		SlippageTolerance:         amend.SlippageTolerance,
   470  		MinimumPriceChangeTrigger: amend.MinimumPriceChangeTrigger,
   471  	}
   472  
   473  	// data source has changed, if the old base price is within bounds we'll keep it until the update comes in
   474  	// otherwise we'll kick it into pending
   475  	if ptr.UnBox(parameters.DataSourceID) != ptr.UnBox(p.Parameters.DataSourceID) {
   476  		base := p.lower.high
   477  		outside := p.IsPending()
   478  
   479  		if parameters.UpperBound != nil {
   480  			bound, _ := num.UintFromDecimal(parameters.UpperBound.ToDecimal().Mul(p.priceFactor))
   481  			outside = outside || base.GTE(bound)
   482  		}
   483  
   484  		if parameters.LowerBound != nil {
   485  			bound, _ := num.UintFromDecimal(parameters.LowerBound.ToDecimal().Mul(p.priceFactor))
   486  			outside = outside || base.LTE(bound)
   487  		}
   488  
   489  		if outside {
   490  			updated.status = types.AMMPoolStatusPending
   491  			updated.lower = emptyCurve(num.UintZero(), true)
   492  			updated.upper = emptyCurve(num.UintZero(), false)
   493  			return updated, nil
   494  		}
   495  
   496  		// inherit the old base price
   497  		parameters.Base = p.Parameters.Base.Clone()
   498  	}
   499  
   500  	if err := updated.setCurves(rf, sf, linearSlippage, allowedEmptyAMMLevels); err != nil {
   501  		return nil, err
   502  	}
   503  
   504  	if !updated.checkPosition() {
   505  		return nil, errors.New("AMM's current position is outside of amended bounds - reduce position first")
   506  	}
   507  
   508  	return updated, nil
   509  }
   510  
   511  // emptyCurve creates the curve details that represent no liquidity.
   512  func emptyCurve(
   513  	base *num.Uint,
   514  	isLower bool,
   515  ) *curve {
   516  	return &curve{
   517  		l:       num.DecimalZero(),
   518  		pv:      num.DecimalZero(),
   519  		low:     base.Clone(),
   520  		high:    base.Clone(),
   521  		empty:   true,
   522  		isLower: isLower,
   523  	}
   524  }
   525  
   526  // generateCurve creates the curve details and calculates its virtual liquidity.
   527  func generateCurve(
   528  	sqrt sqrtFn,
   529  	commitment,
   530  	low, high *num.Uint,
   531  	riskFactor,
   532  	marginFactor,
   533  	linearSlippage num.Decimal,
   534  	leverageAtBound *num.Decimal,
   535  	positionFactor num.Decimal,
   536  	isLower bool,
   537  ) *curve {
   538  	// rf = 1 / ( mf * ( risk-factor + slippage ) )
   539  	rf := num.DecimalOne().Div(marginFactor.Mul(riskFactor.Add(linearSlippage)))
   540  	if leverageAtBound != nil {
   541  		// rf = min(rf, leverage)
   542  		rf = num.MinD(rf, *leverageAtBound)
   543  	}
   544  
   545  	// we now need to calculate the virtual-liquidity L of the curve from the
   546  	// input parameters: leverage (rf), lower bound price (pl), upper bound price (pu)
   547  	// we first calculate the unit-virtual-liquidity:
   548  	// Lu = sqrt(pu) * sqrt(pl) / sqrt(pu) - sqrt(pl)
   549  
   550  	// sqrt(high) * sqrt(low)
   551  	term1 := sqrt(high).Mul(sqrt(low))
   552  
   553  	// sqrt(high) - sqrt(low)
   554  	term2 := sqrt(high).Sub(sqrt(low))
   555  	lu := term1.Div(term2)
   556  
   557  	// now we calculate average-entry price if we were to trade the entire curve
   558  	// pa := lu * pu * (1 - (lu / lu + pu))
   559  
   560  	// (1 - (lu / lu + pu))
   561  	denom := num.DecimalOne().Sub(lu.Div(lu.Add(sqrt(high))))
   562  
   563  	// lu * pu / denom
   564  	pa := denom.Mul(lu).Mul(sqrt(high))
   565  
   566  	// and now we calculate the theoretical position `pv` which is the total tradeable volume of the curve.
   567  	var pv num.Decimal
   568  	if isLower {
   569  		// pv := rf * cc / ( pl(1 - rf) + rf * pa )
   570  
   571  		// pl * (1 - rf)
   572  		denom := low.ToDecimal().Mul(num.DecimalOne().Sub(rf))
   573  
   574  		// ( pl(1 - rf) + rf * pa )
   575  		denom = denom.Add(pa.Mul(rf))
   576  
   577  		// pv := rf * cc / ( pl(1 - rf) + rf * pa )
   578  		pv = commitment.ToDecimal().Mul(rf).Div(denom)
   579  	} else {
   580  		// pv := rf * cc / ( pu(1 + rf) - rf * pa )
   581  
   582  		// pu * (1 + rf)
   583  		denom := high.ToDecimal().Mul(num.DecimalOne().Add(rf))
   584  
   585  		// ( pu(1 + rf) - rf * pa )
   586  		denom = denom.Sub(pa.Mul(rf))
   587  
   588  		// pv := rf * cc / ( pu(1 + rf) - rf * pa )
   589  		pv = commitment.ToDecimal().Mul(rf).Div(denom).Abs()
   590  	}
   591  
   592  	// now we scale theoretical position by position factor so that is it feeds through into all subsequent equations
   593  	pv = pv.Mul(positionFactor)
   594  	l := pv.Mul(lu)
   595  
   596  	sqrtHigh := sqrt(high)
   597  	lDivSqrtPu := l.Div(sqrtHigh)
   598  
   599  	// and finally calculate L = pv * Lu
   600  	return &curve{
   601  		l:          l,
   602  		low:        low,
   603  		high:       high,
   604  		pv:         pv,
   605  		isLower:    isLower,
   606  		lDivSqrtPu: lDivSqrtPu,
   607  		sqrtHigh:   sqrtHigh,
   608  	}
   609  }
   610  
   611  func (p *Pool) setCurves(
   612  	rfs *types.RiskFactor,
   613  	sfs *types.ScalingFactors,
   614  	linearSlippage num.Decimal,
   615  	allowedEmptyAMMLevels uint64,
   616  ) error {
   617  	// convert the bounds into asset precision
   618  	base, _ := num.UintFromDecimal(p.Parameters.Base.ToDecimal().Mul(p.priceFactor))
   619  	p.lower = emptyCurve(base, true)
   620  	p.upper = emptyCurve(base, false)
   621  
   622  	if p.Parameters.LowerBound != nil {
   623  		lowerBound, _ := num.UintFromDecimal(p.Parameters.LowerBound.ToDecimal().Mul(p.priceFactor))
   624  		p.lower = generateCurve(
   625  			p.sqrt,
   626  			p.Commitment.Clone(),
   627  			lowerBound,
   628  			base,
   629  			rfs.Long,
   630  			sfs.InitialMargin,
   631  			linearSlippage,
   632  			p.Parameters.LeverageAtLowerBound,
   633  			p.positionFactor,
   634  			true,
   635  		)
   636  
   637  		if err := p.lower.check(p.sqrt, p.oneTick.Clone(), allowedEmptyAMMLevels); err != nil {
   638  			return err
   639  		}
   640  	}
   641  
   642  	if p.Parameters.UpperBound != nil {
   643  		upperBound, _ := num.UintFromDecimal(p.Parameters.UpperBound.ToDecimal().Mul(p.priceFactor))
   644  		p.upper = generateCurve(
   645  			p.sqrt,
   646  			p.Commitment.Clone(),
   647  			base.Clone(),
   648  			upperBound,
   649  			rfs.Short,
   650  			sfs.InitialMargin,
   651  			linearSlippage,
   652  			p.Parameters.LeverageAtUpperBound,
   653  			p.positionFactor,
   654  			false,
   655  		)
   656  
   657  		// lets find an interval that represents one volume, it might be a sparse curve
   658  		if err := p.upper.check(p.sqrt, p.oneTick.Clone(), allowedEmptyAMMLevels); err != nil {
   659  			return err
   660  		}
   661  	}
   662  
   663  	return nil
   664  }
   665  
   666  // impliedPosition returns the position of the pool if its fair-price were the given price. `l` is
   667  // the virtual liquidity of the pool, and `sqrtPrice` and `sqrtHigh` are, the square-roots of the
   668  // price to calculate the position for, and higher boundary of the curve.
   669  func impliedPosition(sqrtPrice, sqrtHigh num.Decimal, l num.Decimal) num.Decimal {
   670  	// L * (sqrt(high) - sqrt(price))
   671  	numer := sqrtHigh.Sub(sqrtPrice).Mul(l)
   672  
   673  	// sqrt(high) * sqrt(price)
   674  	denom := sqrtHigh.Mul(sqrtPrice)
   675  
   676  	// L * (sqrt(high) - sqrt(price)) / sqrt(high) * sqrt(price)
   677  	return numer.Div(denom)
   678  }
   679  
   680  // PriceForVolume returns the price the AMM is willing to trade at to match with the given volume of an incoming order.
   681  func (p *Pool) PriceForVolume(volume uint64, side types.Side) *num.Uint {
   682  	return p.priceForVolumeAtPosition(
   683  		volume,
   684  		side,
   685  		p.getPosition(),
   686  		p.FairPrice(),
   687  	)
   688  }
   689  
   690  // priceForVolumeAtPosition returns the price the AMM is willing to trade at to match with the given volume if its position and fair-price
   691  // are as given.
   692  func (p *Pool) priceForVolumeAtPosition(volume uint64, side types.Side, pos int64, fp *num.Uint) *num.Uint {
   693  	if volume == 0 {
   694  		panic("cannot calculate price for zero volume trade")
   695  	}
   696  
   697  	x, y := p.virtualBalances(pos, fp, side)
   698  
   699  	// dy = x*y / (x - dx) - y
   700  	// where y and x are the balances on either side of the pool, and dx is the change in volume
   701  	// then the trade price is dy/dx
   702  	dx := num.DecimalFromInt64(int64(volume))
   703  	if side == types.SideSell {
   704  		// if incoming order is a sell, the AMM is buying so reducing cash balance so dx is negative
   705  		dx = dx.Neg()
   706  	}
   707  
   708  	dy := x.Mul(y).Div(x.Sub(dx)).Sub(y)
   709  
   710  	// dy / dx
   711  	price, overflow := num.UintFromDecimal(dy.Div(dx).Abs())
   712  	if overflow {
   713  		panic("calculated negative price")
   714  	}
   715  	return price
   716  }
   717  
   718  // TradableVolumeInRange returns the volume the pool is willing to provide between the two given price levels for side of a given order
   719  // that is trading with the pool. If `nil` is provided for either price then we take the full volume in that direction.
   720  func (p *Pool) TradableVolumeInRange(side types.Side, price1 *num.Uint, price2 *num.Uint) uint64 {
   721  	if !p.canTrade(side) {
   722  		return 0
   723  	}
   724  
   725  	pos := p.getPosition()
   726  	st, nd := price1, price2
   727  
   728  	if price1 == nil {
   729  		st = p.lower.low
   730  	}
   731  
   732  	if price2 == nil {
   733  		nd = p.upper.high
   734  	}
   735  
   736  	if st.EQ(nd) {
   737  		return 0
   738  	}
   739  
   740  	if st.GT(nd) {
   741  		st, nd = nd, st
   742  	}
   743  
   744  	// map the given st/nd prices into positions, then the difference is the volume
   745  	asPosition := func(price *num.Uint) int64 {
   746  		switch {
   747  		case price.GT(p.lower.high):
   748  			// in upper curve
   749  			if !p.upper.empty {
   750  				return p.upper.positionAtPrice(p.sqrt, num.Min(p.upper.high, price))
   751  			}
   752  		case price.LT(p.lower.high):
   753  			// in lower curve
   754  			if !p.lower.empty {
   755  				return p.lower.positionAtPrice(p.sqrt, num.Max(p.lower.low, price))
   756  			}
   757  		}
   758  		return 0
   759  	}
   760  
   761  	stP := asPosition(st)
   762  	ndP := asPosition(nd)
   763  
   764  	if side == types.SideSell {
   765  		// want all buy volume so everything below fair price, where the AMM is long
   766  		if pos > stP {
   767  			return 0
   768  		}
   769  		ndP = num.MaxV(pos, ndP)
   770  	}
   771  
   772  	if side == types.SideBuy {
   773  		// want all sell volume so everything above fair price, where the AMM is short
   774  		if pos < ndP {
   775  			return 0
   776  		}
   777  		stP = num.MinV(pos, stP)
   778  	}
   779  
   780  	if !p.closing() {
   781  		return uint64(stP - ndP)
   782  	}
   783  
   784  	if pos > 0 {
   785  		// if closing and long, we have no volume at short prices, so cap range to > 0
   786  		stP = num.MaxV(0, stP)
   787  		ndP = num.MaxV(0, ndP)
   788  	}
   789  
   790  	if pos < 0 {
   791  		// if closing and short, we have no volume at long prices, so cap range to < 0
   792  		stP = num.MinV(0, stP)
   793  		ndP = num.MinV(0, ndP)
   794  	}
   795  	return num.MinV(uint64(stP-ndP), uint64(num.AbsV(pos)))
   796  }
   797  
   798  // TrableVolumeForPrice returns the volume available between the AMM's fair-price and the given
   799  // price and side of an incoming order. It is a special case of TradableVolumeInRange with
   800  // the benefit of accurately using the AMM's position instead of having to calculate the hop
   801  // from fair-price -> position.
   802  func (p *Pool) TradableVolumeForPrice(side types.Side, price *num.Uint) uint64 {
   803  	if side == types.SideSell {
   804  		return p.TradableVolumeInRange(side, price, nil)
   805  	}
   806  	return p.TradableVolumeInRange(side, nil, price)
   807  }
   808  
   809  // getBalance returns the total balance of the pool i.e it's general account + it's margin account.
   810  func (p *Pool) getBalance() *num.Uint {
   811  	general, err := p.collateral.GetPartyGeneralAccount(p.AMMParty, p.asset)
   812  	if err != nil {
   813  		panic("general account not created")
   814  	}
   815  
   816  	margin, err := p.collateral.GetPartyMarginAccount(p.market, p.AMMParty, p.asset)
   817  	if err != nil {
   818  		panic("margin account not created")
   819  	}
   820  
   821  	return num.UintZero().AddSum(general.Balance, margin.Balance)
   822  }
   823  
   824  // setEphemeralPosition is called when we are starting the matching process against this pool
   825  // so that we can track its position and average-entry as it goes through the matching process.
   826  func (p *Pool) setEphemeralPosition() {
   827  	if p.eph != nil {
   828  		return
   829  	}
   830  	p.eph = &ephemeralPosition{
   831  		size: 0,
   832  	}
   833  
   834  	if pos := p.position.GetPositionsByParty(p.AMMParty); len(pos) != 0 {
   835  		p.eph.size = pos[0].Size()
   836  	}
   837  }
   838  
   839  // updateEphemeralPosition sets the pools transient position given a generated order.
   840  func (p *Pool) updateEphemeralPosition(order *types.Order) {
   841  	if order.Side == types.SideSell {
   842  		p.eph.size -= int64(order.Size)
   843  		return
   844  	}
   845  	p.eph.size += int64(order.Size)
   846  }
   847  
   848  // clearEphemeralPosition signifies that the matching process has finished
   849  // and the pool can continue to read it's position from the positions engine.
   850  func (p *Pool) clearEphemeralPosition() {
   851  	p.eph = nil
   852  }
   853  
   854  // getPosition gets the pools current position an average-entry price.
   855  func (p *Pool) getPosition() int64 {
   856  	if p.eph != nil {
   857  		return p.eph.size
   858  	}
   859  
   860  	if pos := p.position.GetPositionsByParty(p.AMMParty); len(pos) != 0 {
   861  		return pos[0].Size()
   862  	}
   863  	return 0
   864  }
   865  
   866  // fairPrice returns the fair price of the pool given its current position.
   867  
   868  // sqrt(pf) = sqrt(pu) / (1 + pv * sqrt(pu) * 1/L )
   869  // where pv is the virtual-position
   870  // pv = pos,  when the pool is long
   871  // pv = pos + Pv, when pool is short
   872  //
   873  // this transformation is needed since for each curve its virtual position is 0 at the lower bound which maps to the Vega position when the pool is
   874  // long, but when the pool is short Vega position == 0 at the upper bounds and -ve at the lower.
   875  func (p *Pool) FairPrice() *num.Uint {
   876  	pos := p.getPosition()
   877  	if pos == 0 {
   878  		// if no position fair price is base price
   879  		return p.lower.high.Clone()
   880  	}
   881  
   882  	if fp, ok := p.cache.getFairPrice(pos); ok {
   883  		return fp.Clone()
   884  	}
   885  
   886  	cu := p.lower
   887  	pv := num.DecimalFromInt64(pos)
   888  	if pos < 0 {
   889  		cu = p.upper
   890  		// pos + pv
   891  		pv = cu.pv.Add(pv)
   892  	}
   893  
   894  	if cu.empty {
   895  		p.log.Panic("should not be calculating fair-price on empty-curve side",
   896  			logging.Bool("lower", cu.isLower),
   897  			logging.Int64("pos", pos),
   898  			logging.String("amm-party", p.AMMParty),
   899  		)
   900  	}
   901  
   902  	// pv * sqrt(pu) * (1/L) + 1
   903  	denom := pv.Mul(cu.sqrtHigh).Div(cu.l).Add(num.DecimalOne())
   904  
   905  	// sqrt(fp) = sqrt(pu) / denom
   906  	sqrtPf := p.sqrt(cu.high).Div(denom)
   907  
   908  	// fair-price = sqrt(fp) * sqrt(fp)
   909  	fp := sqrtPf.Mul(sqrtPf)
   910  
   911  	// we want to round such that the price is further away from the base. This is so that once
   912  	// a pool's position is at its boundary we do not report volume that doesn't exist. For example
   913  	// say a pool's upper boundary is 1000 and for it to be at that boundary its position needs to
   914  	// be 10.5. The closest we can get is 10 but then we'd report a fair-price of 999.78. If
   915  	// we use 999 we'd be implying volume between 999 and 1000 which we don't want to trade.
   916  	if pos < 0 {
   917  		fp = fp.Ceil()
   918  	}
   919  
   920  	fairPrice, _ := num.UintFromDecimal(fp)
   921  
   922  	p.cache.setFairPrice(pos, fairPrice.Clone())
   923  
   924  	return fairPrice
   925  }
   926  
   927  // virtualBalancesShort returns the pools x, y balances when the pool has a negative position
   928  //
   929  // x = P + Pv + L / sqrt(pu)
   930  // y = L * sqrt(fair-price).
   931  func (p *Pool) virtualBalancesShort(pos int64, fp *num.Uint) (num.Decimal, num.Decimal) {
   932  	cu := p.upper
   933  	if cu.empty {
   934  		panic("should not be calculating balances on empty-curve side")
   935  	}
   936  
   937  	// lets start with x
   938  
   939  	// P
   940  	term1x := num.DecimalFromInt64(pos)
   941  
   942  	// Pv
   943  	term2x := cu.pv
   944  
   945  	// L / sqrt(pu)
   946  	term3x := cu.lDivSqrtPu
   947  
   948  	// x = P + (cc * rf / pu) + (L / sqrt(pu))
   949  	x := term2x.Add(term3x).Add(term1x)
   950  
   951  	// now lets get y
   952  
   953  	// y = L * sqrt(fair-price)
   954  	y := cu.l.Mul(p.sqrt(fp))
   955  	return x, y
   956  }
   957  
   958  // virtualBalancesLong returns the pools x, y balances when the pool has a positive position
   959  //
   960  // x = P + (L / sqrt(pu))
   961  // y = L * sqrt(fair-price).
   962  func (p *Pool) virtualBalancesLong(pos int64, fp *num.Uint) (num.Decimal, num.Decimal) {
   963  	cu := p.lower
   964  	if cu.empty {
   965  		panic("should not be calculating balances on empty-curve side")
   966  	}
   967  
   968  	// lets start with x
   969  
   970  	// P
   971  	term1x := num.DecimalFromInt64(pos)
   972  
   973  	// L / sqrt(pu)
   974  	term2x := cu.lDivSqrtPu
   975  
   976  	// x = P + (L / sqrt(pu))
   977  	x := term1x.Add(term2x)
   978  
   979  	// now lets move to y
   980  
   981  	// y = L * sqrt(fair-price)
   982  	y := cu.l.Mul(p.sqrt(fp))
   983  	return x, y
   984  }
   985  
   986  // virtualBalances returns the pools x, y values where x is the balance in contracts and y is the balance in asset.
   987  func (p *Pool) virtualBalances(pos int64, fp *num.Uint, side types.Side) (num.Decimal, num.Decimal) {
   988  	switch {
   989  	case pos < 0, pos == 0 && side == types.SideBuy:
   990  		// zero position but incoming is buy which will make pool short
   991  		return p.virtualBalancesShort(pos, fp)
   992  	case pos > 0, pos == 0 && side == types.SideSell:
   993  		// zero position but incoming is sell which will make pool long
   994  		return p.virtualBalancesLong(pos, fp)
   995  	default:
   996  		panic("should not reach here")
   997  	}
   998  }
   999  
  1000  // BestPrice returns the AMM's quote price on the given side. If the AMM's position is fully at a boundary
  1001  // then there is no quote price on that side and false is returned.
  1002  func (p *Pool) BestPrice(side types.Side) (*num.Uint, bool) {
  1003  	if p.IsPending() {
  1004  		return nil, false
  1005  	}
  1006  
  1007  	pos := p.getPosition()
  1008  	fairPrice := p.FairPrice()
  1009  
  1010  	switch side {
  1011  	case types.SideSell:
  1012  		cu := p.lower
  1013  		if pos <= 0 {
  1014  			cu = p.upper
  1015  			// we're short, and want the sell quote price, if we're at the boundary there is not volume left
  1016  			if p.closing() || num.AbsV(pos) >= cu.pv.IntPart() {
  1017  				return nil, false
  1018  			}
  1019  		}
  1020  
  1021  		np := cu.singleVolumePrice(p.sqrt, fairPrice, side)
  1022  		return num.Min(p.upper.high, num.Max(np, fairPrice.AddSum(p.oneTick))), true
  1023  	case types.SideBuy:
  1024  		cu := p.upper
  1025  		if pos >= 0 {
  1026  			cu = p.lower
  1027  			// we're long, and want the buy quote price, if we're at the boundary there is not volume left
  1028  			if p.closing() || pos >= cu.pv.IntPart() {
  1029  				return nil, false
  1030  			}
  1031  		}
  1032  
  1033  		np := cu.singleVolumePrice(p.sqrt, fairPrice, side)
  1034  		return num.Max(p.lower.low, num.Min(np, num.UintZero().Sub(fairPrice, p.oneTick))), true
  1035  	default:
  1036  		panic("should never reach here")
  1037  	}
  1038  }
  1039  
  1040  // BestPriceAndVolume returns the AMM's best price on a given side and the volume available to trade.
  1041  func (p *Pool) BestPriceAndVolume(side types.Side) (*num.Uint, uint64) {
  1042  	// check cache
  1043  	pos := p.getPosition()
  1044  
  1045  	if p, v, ok := p.cache.getBestPrice(pos, side, p.status); ok {
  1046  		return p, v
  1047  	}
  1048  
  1049  	price, ok := p.BestPrice(side)
  1050  	if !ok {
  1051  		return price, 0
  1052  	}
  1053  
  1054  	// now calculate the volume
  1055  	fp := p.FairPrice()
  1056  	if side == types.SideBuy {
  1057  		priceTick := num.Max(p.lower.low, num.UintZero().Sub(fp, p.oneTick))
  1058  
  1059  		if !price.GTE(priceTick) {
  1060  			p.cache.setBestPrice(pos, side, p.status, price, 1)
  1061  			return price, 1 // its low volume so 1 by construction
  1062  		}
  1063  
  1064  		volume := p.TradableVolumeForPrice(types.SideSell, priceTick)
  1065  		p.cache.setBestPrice(pos, side, p.status, priceTick, volume)
  1066  		return priceTick, volume
  1067  	}
  1068  
  1069  	priceTick := num.Min(p.upper.high, num.UintZero().Add(fp, p.oneTick))
  1070  	if !price.LTE(priceTick) {
  1071  		p.cache.setBestPrice(pos, side, p.status, price, 1)
  1072  		return price, 1 // its low volume so 1 by construction
  1073  	}
  1074  
  1075  	volume := p.TradableVolumeForPrice(types.SideBuy, priceTick)
  1076  	p.cache.setBestPrice(pos, side, p.status, priceTick, volume)
  1077  	return priceTick, volume
  1078  }
  1079  
  1080  func (p *Pool) LiquidityFee() num.Decimal {
  1081  	return p.ProposedFee
  1082  }
  1083  
  1084  func (p *Pool) CommitmentAmount() *num.Uint {
  1085  	return p.Commitment.Clone()
  1086  }
  1087  
  1088  func (p *Pool) Owner() string {
  1089  	return p.owner
  1090  }
  1091  
  1092  func (p *Pool) closing() bool {
  1093  	return p.status == types.AMMPoolStatusReduceOnly
  1094  }
  1095  
  1096  func (p *Pool) IsPending() bool {
  1097  	return p.status == types.AMMPoolStatusPending
  1098  }
  1099  
  1100  func (p *Pool) canTrade(side types.Side) bool {
  1101  	if p.IsPending() {
  1102  		return false
  1103  	}
  1104  
  1105  	if !p.closing() {
  1106  		return true
  1107  	}
  1108  
  1109  	pos := p.getPosition()
  1110  	// pool is long incoming order is a buy and will make it shorter, its ok
  1111  	if pos > 0 && side == types.SideBuy {
  1112  		return true
  1113  	}
  1114  	if pos < 0 && side == types.SideSell {
  1115  		return true
  1116  	}
  1117  	return false
  1118  }
  1119  
  1120  func (p *Pool) makeOrder(volume uint64, price *num.Uint, side types.Side, idgen *idgeneration.IDGenerator) *types.Order {
  1121  	order := &types.Order{
  1122  		MarketID:         p.market,
  1123  		Party:            p.AMMParty,
  1124  		Size:             volume,
  1125  		Remaining:        volume,
  1126  		Price:            price,
  1127  		Side:             side,
  1128  		TimeInForce:      types.OrderTimeInForceGTC,
  1129  		Type:             types.OrderTypeLimit,
  1130  		Status:           types.OrderStatusFilled,
  1131  		Reference:        "vamm-" + p.AMMParty,
  1132  		GeneratedOffbook: true,
  1133  	}
  1134  	order.OriginalPrice, _ = num.UintFromDecimal(order.Price.ToDecimal().Div(p.priceFactor))
  1135  
  1136  	if idgen != nil {
  1137  		order.ID = idgen.NextID()
  1138  	}
  1139  	return order
  1140  }