code.vegaprotocol.io/vega@v0.79.0/datanode/service/position.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package service 17 18 import ( 19 "context" 20 21 "code.vegaprotocol.io/vega/datanode/entities" 22 "code.vegaprotocol.io/vega/datanode/utils" 23 "code.vegaprotocol.io/vega/logging" 24 25 lru "github.com/hashicorp/golang-lru" 26 "golang.org/x/exp/slices" 27 ) 28 29 type PositionStore interface { 30 Flush(ctx context.Context) ([]entities.Position, error) 31 Add(ctx context.Context, p entities.Position) error 32 GetByMarketAndParty(ctx context.Context, marketID string, partyID string) (entities.Position, error) 33 GetByMarketAndParties(ctx context.Context, marketIDRaw string, partyIDsRaw []string) ([]entities.Position, error) 34 GetByMarket(ctx context.Context, marketID string) ([]entities.Position, error) 35 GetByParty(ctx context.Context, partyID string) ([]entities.Position, error) 36 GetByPartyConnection(ctx context.Context, partyID []string, marketID []string, pagination entities.CursorPagination) ([]entities.Position, entities.PageInfo, error) 37 GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Position, error) 38 GetAll(ctx context.Context) ([]entities.Position, error) 39 } 40 41 type positionCacheKey struct { 42 MarketID entities.MarketID 43 PartyID entities.PartyID 44 } 45 type Position struct { 46 store PositionStore 47 observer utils.Observer[entities.Position] 48 cache *lru.Cache 49 } 50 51 func NewPosition(store PositionStore, log *logging.Logger) *Position { 52 cache, err := lru.New(10000) 53 if err != nil { 54 panic(err) 55 } 56 return &Position{ 57 store: store, 58 observer: utils.NewObserver[entities.Position]("positions", log, 0, 0), 59 cache: cache, 60 } 61 } 62 63 func (p *Position) Flush(ctx context.Context) error { 64 flushed, err := p.store.Flush(ctx) 65 if err != nil { 66 return err 67 } 68 p.observer.Notify(flushed) 69 return nil 70 } 71 72 func (p *Position) Add(ctx context.Context, pos entities.Position) error { 73 key := positionCacheKey{pos.MarketID, pos.PartyID} 74 p.cache.Add(key, pos) 75 return p.store.Add(ctx, pos) 76 } 77 78 func (p *Position) GetByMarketAndParties(ctx context.Context, marketID string, partyIDs []string) ([]entities.Position, error) { 79 missedParties := make([]string, 0, len(partyIDs)) 80 ret := make([]entities.Position, 0, len(partyIDs)) 81 key := positionCacheKey{ 82 MarketID: entities.MarketID(marketID), 83 } 84 for _, partyID := range partyIDs { 85 key.PartyID = entities.PartyID(partyID) 86 if v, ok := p.cache.Get(key); ok { 87 switch val := v.(type) { 88 case entities.Position: 89 ret = append(ret, val) 90 default: 91 // this includes errors from cache, ignore them and try again? 92 missedParties = append(missedParties, partyID) 93 } 94 } else { 95 missedParties = append(missedParties, partyID) 96 } 97 } 98 // everything was cached, we're done 99 if len(missedParties) == 0 { 100 return ret, nil 101 } 102 storePos, err := p.store.GetByMarketAndParties(ctx, marketID, missedParties) 103 // append the positions from store to those from cache 104 ret = append(ret, storePos...) 105 if err == nil { 106 // we had cache misses, and got them from store, so add them to cache 107 for _, sp := range storePos { 108 key.PartyID = sp.PartyID 109 p.cache.Add(key, sp) 110 } 111 } 112 return ret, err 113 } 114 115 func (p *Position) GetByMarketAndParty(ctx context.Context, marketID string, partyID string) (entities.Position, error) { 116 key := positionCacheKey{entities.MarketID(marketID), entities.PartyID(partyID)} 117 value, ok := p.cache.Get(key) 118 if ok { 119 // make sure the value in cache is a position entity, ignore errors 120 if v, ok := value.(entities.Position); ok { 121 return v, nil 122 } 123 } 124 // either cache miss, or an error was cached, either way fall back to store and update cache 125 pos, err := p.store.GetByMarketAndParty( 126 ctx, marketID, partyID) 127 // let's not cache errors here 128 if err == nil { 129 p.cache.Add(key, pos) 130 } 131 132 return pos, err 133 } 134 135 func (p *Position) GetByMarket(ctx context.Context, marketID string) ([]entities.Position, error) { 136 return p.store.GetByMarket(ctx, marketID) 137 } 138 139 func (p *Position) GetByParty(ctx context.Context, partyID entities.PartyID) ([]entities.Position, error) { 140 return p.store.GetByParty(ctx, partyID.String()) 141 } 142 143 func (p *Position) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Position, error) { 144 return p.store.GetByTxHash(ctx, txHash) 145 } 146 147 func (p *Position) GetByPartyConnection(ctx context.Context, partyIDs []entities.PartyID, marketIDs []entities.MarketID, pagination entities.CursorPagination) ([]entities.Position, entities.PageInfo, error) { 148 ps := make([]string, len(partyIDs)) 149 for i, p := range partyIDs { 150 ps[i] = p.String() 151 } 152 153 ms := make([]string, len(marketIDs)) 154 for i, m := range marketIDs { 155 ms[i] = m.String() 156 } 157 return p.store.GetByPartyConnection(ctx, ps, ms, pagination) 158 } 159 160 func (p *Position) GetAll(ctx context.Context) ([]entities.Position, error) { 161 return p.store.GetAll(ctx) 162 } 163 164 func (p *Position) Observe(ctx context.Context, retries int, partyID, marketID string) (<-chan []entities.Position, uint64) { 165 ch, ref := p.observer.Observe(ctx, 166 retries, 167 func(pos entities.Position) bool { 168 return (len(marketID) == 0 || marketID == pos.MarketID.String()) && 169 (len(partyID) == 0 || partyID == pos.PartyID.String()) 170 }) 171 return ch, ref 172 } 173 174 func (p *Position) ObserveMany(ctx context.Context, retries int, marketID string, parties ...string) (<-chan []entities.Position, uint64) { 175 ch, ref := p.observer.Observe(ctx, 176 retries, 177 func(pos entities.Position) bool { 178 return (len(marketID) == 0 || marketID == pos.MarketID.String()) && 179 (len(parties) == 0 || slices.Contains(parties, pos.PartyID.String())) 180 }) 181 return ch, ref 182 }