code.vegaprotocol.io/vega@v0.79.0/datanode/service/position.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package service
    17  
    18  import (
    19  	"context"
    20  
    21  	"code.vegaprotocol.io/vega/datanode/entities"
    22  	"code.vegaprotocol.io/vega/datanode/utils"
    23  	"code.vegaprotocol.io/vega/logging"
    24  
    25  	lru "github.com/hashicorp/golang-lru"
    26  	"golang.org/x/exp/slices"
    27  )
    28  
    29  type PositionStore interface {
    30  	Flush(ctx context.Context) ([]entities.Position, error)
    31  	Add(ctx context.Context, p entities.Position) error
    32  	GetByMarketAndParty(ctx context.Context, marketID string, partyID string) (entities.Position, error)
    33  	GetByMarketAndParties(ctx context.Context, marketIDRaw string, partyIDsRaw []string) ([]entities.Position, error)
    34  	GetByMarket(ctx context.Context, marketID string) ([]entities.Position, error)
    35  	GetByParty(ctx context.Context, partyID string) ([]entities.Position, error)
    36  	GetByPartyConnection(ctx context.Context, partyID []string, marketID []string, pagination entities.CursorPagination) ([]entities.Position, entities.PageInfo, error)
    37  	GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Position, error)
    38  	GetAll(ctx context.Context) ([]entities.Position, error)
    39  }
    40  
    41  type positionCacheKey struct {
    42  	MarketID entities.MarketID
    43  	PartyID  entities.PartyID
    44  }
    45  type Position struct {
    46  	store    PositionStore
    47  	observer utils.Observer[entities.Position]
    48  	cache    *lru.Cache
    49  }
    50  
    51  func NewPosition(store PositionStore, log *logging.Logger) *Position {
    52  	cache, err := lru.New(10000)
    53  	if err != nil {
    54  		panic(err)
    55  	}
    56  	return &Position{
    57  		store:    store,
    58  		observer: utils.NewObserver[entities.Position]("positions", log, 0, 0),
    59  		cache:    cache,
    60  	}
    61  }
    62  
    63  func (p *Position) Flush(ctx context.Context) error {
    64  	flushed, err := p.store.Flush(ctx)
    65  	if err != nil {
    66  		return err
    67  	}
    68  	p.observer.Notify(flushed)
    69  	return nil
    70  }
    71  
    72  func (p *Position) Add(ctx context.Context, pos entities.Position) error {
    73  	key := positionCacheKey{pos.MarketID, pos.PartyID}
    74  	p.cache.Add(key, pos)
    75  	return p.store.Add(ctx, pos)
    76  }
    77  
    78  func (p *Position) GetByMarketAndParties(ctx context.Context, marketID string, partyIDs []string) ([]entities.Position, error) {
    79  	missedParties := make([]string, 0, len(partyIDs))
    80  	ret := make([]entities.Position, 0, len(partyIDs))
    81  	key := positionCacheKey{
    82  		MarketID: entities.MarketID(marketID),
    83  	}
    84  	for _, partyID := range partyIDs {
    85  		key.PartyID = entities.PartyID(partyID)
    86  		if v, ok := p.cache.Get(key); ok {
    87  			switch val := v.(type) {
    88  			case entities.Position:
    89  				ret = append(ret, val)
    90  			default:
    91  				// this includes errors from cache, ignore them and try again?
    92  				missedParties = append(missedParties, partyID)
    93  			}
    94  		} else {
    95  			missedParties = append(missedParties, partyID)
    96  		}
    97  	}
    98  	// everything was cached, we're done
    99  	if len(missedParties) == 0 {
   100  		return ret, nil
   101  	}
   102  	storePos, err := p.store.GetByMarketAndParties(ctx, marketID, missedParties)
   103  	// append the positions from store to those from cache
   104  	ret = append(ret, storePos...)
   105  	if err == nil {
   106  		// we had cache misses, and got them from store, so add them to cache
   107  		for _, sp := range storePos {
   108  			key.PartyID = sp.PartyID
   109  			p.cache.Add(key, sp)
   110  		}
   111  	}
   112  	return ret, err
   113  }
   114  
   115  func (p *Position) GetByMarketAndParty(ctx context.Context, marketID string, partyID string) (entities.Position, error) {
   116  	key := positionCacheKey{entities.MarketID(marketID), entities.PartyID(partyID)}
   117  	value, ok := p.cache.Get(key)
   118  	if ok {
   119  		// make sure the value in cache is a position entity, ignore errors
   120  		if v, ok := value.(entities.Position); ok {
   121  			return v, nil
   122  		}
   123  	}
   124  	// either cache miss, or an error was cached, either way fall back to store and update cache
   125  	pos, err := p.store.GetByMarketAndParty(
   126  		ctx, marketID, partyID)
   127  	// let's not cache errors here
   128  	if err == nil {
   129  		p.cache.Add(key, pos)
   130  	}
   131  
   132  	return pos, err
   133  }
   134  
   135  func (p *Position) GetByMarket(ctx context.Context, marketID string) ([]entities.Position, error) {
   136  	return p.store.GetByMarket(ctx, marketID)
   137  }
   138  
   139  func (p *Position) GetByParty(ctx context.Context, partyID entities.PartyID) ([]entities.Position, error) {
   140  	return p.store.GetByParty(ctx, partyID.String())
   141  }
   142  
   143  func (p *Position) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Position, error) {
   144  	return p.store.GetByTxHash(ctx, txHash)
   145  }
   146  
   147  func (p *Position) GetByPartyConnection(ctx context.Context, partyIDs []entities.PartyID, marketIDs []entities.MarketID, pagination entities.CursorPagination) ([]entities.Position, entities.PageInfo, error) {
   148  	ps := make([]string, len(partyIDs))
   149  	for i, p := range partyIDs {
   150  		ps[i] = p.String()
   151  	}
   152  
   153  	ms := make([]string, len(marketIDs))
   154  	for i, m := range marketIDs {
   155  		ms[i] = m.String()
   156  	}
   157  	return p.store.GetByPartyConnection(ctx, ps, ms, pagination)
   158  }
   159  
   160  func (p *Position) GetAll(ctx context.Context) ([]entities.Position, error) {
   161  	return p.store.GetAll(ctx)
   162  }
   163  
   164  func (p *Position) Observe(ctx context.Context, retries int, partyID, marketID string) (<-chan []entities.Position, uint64) {
   165  	ch, ref := p.observer.Observe(ctx,
   166  		retries,
   167  		func(pos entities.Position) bool {
   168  			return (len(marketID) == 0 || marketID == pos.MarketID.String()) &&
   169  				(len(partyID) == 0 || partyID == pos.PartyID.String())
   170  		})
   171  	return ch, ref
   172  }
   173  
   174  func (p *Position) ObserveMany(ctx context.Context, retries int, marketID string, parties ...string) (<-chan []entities.Position, uint64) {
   175  	ch, ref := p.observer.Observe(ctx,
   176  		retries,
   177  		func(pos entities.Position) bool {
   178  			return (len(marketID) == 0 || marketID == pos.MarketID.String()) &&
   179  				(len(parties) == 0 || slices.Contains(parties, pos.PartyID.String()))
   180  		})
   181  	return ch, ref
   182  }