code.vegaprotocol.io/vega@v0.79.0/core/positions/snapshot.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package positions
    17  
    18  import (
    19  	"context"
    20  	"sort"
    21  
    22  	"code.vegaprotocol.io/vega/core/types"
    23  	"code.vegaprotocol.io/vega/libs/num"
    24  	"code.vegaprotocol.io/vega/libs/proto"
    25  	"code.vegaprotocol.io/vega/libs/ptr"
    26  	"code.vegaprotocol.io/vega/logging"
    27  	snapshotpb "code.vegaprotocol.io/vega/protos/vega/snapshot/v1"
    28  
    29  	"golang.org/x/exp/maps"
    30  )
    31  
    32  type SnapshotEngine struct {
    33  	*Engine
    34  	pl      types.Payload
    35  	data    []byte
    36  	stopped bool
    37  }
    38  
    39  func NewSnapshotEngine(
    40  	log *logging.Logger, config Config, marketID string, broker Broker,
    41  ) *SnapshotEngine {
    42  	return &SnapshotEngine{
    43  		Engine:  New(log, config, marketID, broker),
    44  		pl:      types.Payload{},
    45  		stopped: false,
    46  	}
    47  }
    48  
    49  // StopSnapshots is called when the engines respective market no longer exists. We need to stop
    50  // taking snapshots and communicate to the snapshot engine to remove us as a provider.
    51  func (e *SnapshotEngine) StopSnapshots() {
    52  	e.log.Debug("market has been cleared, stopping snapshot production", logging.String("marketid", e.marketID))
    53  	e.stopped = true
    54  }
    55  
    56  func (e *SnapshotEngine) Namespace() types.SnapshotNamespace {
    57  	return types.PositionsSnapshot
    58  }
    59  
    60  func (e *SnapshotEngine) Keys() []string {
    61  	return []string{e.marketID}
    62  }
    63  
    64  func (e *SnapshotEngine) Stopped() bool {
    65  	return e.stopped
    66  }
    67  
    68  func (e *SnapshotEngine) GetState(k string) ([]byte, []types.StateProvider, error) {
    69  	if k != e.marketID {
    70  		return nil, nil, types.ErrSnapshotKeyDoesNotExist
    71  	}
    72  
    73  	state, err := e.serialise()
    74  	return state, nil, err
    75  }
    76  
    77  func (e *SnapshotEngine) LoadState(_ context.Context, payload *types.Payload) ([]types.StateProvider, error) {
    78  	if e.Namespace() != payload.Data.Namespace() {
    79  		return nil, types.ErrInvalidSnapshotNamespace
    80  	}
    81  
    82  	var err error
    83  	switch pl := payload.Data.(type) {
    84  	case *types.PayloadMarketPositions:
    85  		// Check the payload is for this market
    86  		if e.marketID != pl.MarketPositions.MarketID {
    87  			return nil, types.ErrUnknownSnapshotType
    88  		}
    89  		e.log.Debug("loading snapshot", logging.Int("positions", len(pl.MarketPositions.Positions)))
    90  		for _, p := range pl.MarketPositions.Positions {
    91  			pos := NewMarketPosition(p.PartyID)
    92  			pos.price = p.Price
    93  			pos.buy = p.Buy
    94  			pos.sell = p.Sell
    95  			pos.size = p.Size
    96  			pos.buySumProduct = p.BuySumProduct
    97  			pos.sellSumProduct = p.SellSumProduct
    98  			pos.distressed = p.Distressed
    99  			pos.averageEntryPrice = p.AverageEntryPrice
   100  			e.positionsCpy = append(e.positionsCpy, pos)
   101  			e.positions[p.PartyID] = pos
   102  			if p.Distressed {
   103  				e.distressedPos[p.PartyID] = struct{}{}
   104  			}
   105  
   106  			// This is for migration, on the first time we load from snapshot there won't be an average entry price
   107  			// so take the last price as the current average
   108  			if p.AverageEntryPrice == nil {
   109  				if pos.size != 0 && !pos.price.IsZero() {
   110  					pos.averageEntryPrice = pos.price.Clone()
   111  				} else {
   112  					pos.averageEntryPrice = num.UintZero()
   113  				}
   114  			}
   115  
   116  			// ensure these exists on the first snapshot after the upgrade
   117  			e.partiesHighestVolume[p.PartyID] = &openVolumeRecord{}
   118  		}
   119  
   120  		for _, v := range pl.MarketPositions.PartieRecords {
   121  			if v.LatestOpenInterest != nil && v.LowestOpenInterest != nil {
   122  				e.partiesHighestVolume[v.Party] = &openVolumeRecord{
   123  					Latest:  *v.LatestOpenInterest,
   124  					Highest: *v.LowestOpenInterest,
   125  				}
   126  			}
   127  
   128  			if v.TradedVolume != nil {
   129  				e.partiesTradedSize[v.Party] = *v.TradedVolume
   130  			}
   131  		}
   132  
   133  		e.data, err = proto.Marshal(payload.IntoProto())
   134  		return nil, err
   135  
   136  	default:
   137  		return nil, types.ErrUnknownSnapshotType
   138  	}
   139  }
   140  
   141  // serialise marshal the snapshot state, populating the data field
   142  // with updated values.
   143  func (e *SnapshotEngine) serialise() ([]byte, error) {
   144  	if e.stopped {
   145  		return nil, nil
   146  	}
   147  
   148  	e.log.Debug("serialising snapshot", logging.Int("positions", len(e.positionsCpy)))
   149  	positions := make([]*types.MarketPosition, 0, len(e.positionsCpy))
   150  
   151  	for _, evt := range e.positionsCpy {
   152  		party := evt.Party()
   153  		_, distressed := e.distressedPos[party]
   154  		pos := &types.MarketPosition{
   155  			PartyID:           party,
   156  			Price:             evt.Price(),
   157  			Buy:               evt.Buy(),
   158  			Sell:              evt.Sell(),
   159  			Size:              evt.Size(),
   160  			BuySumProduct:     evt.BuySumProduct(),
   161  			SellSumProduct:    evt.SellSumProduct(),
   162  			Distressed:        distressed,
   163  			AverageEntryPrice: evt.AverageEntryPrice(),
   164  		}
   165  		positions = append(positions, pos)
   166  	}
   167  
   168  	partiesRecordsMap := map[string]*snapshotpb.PartyPositionStats{}
   169  
   170  	// now iterate over both map as some could have been remove
   171  	// when closing positions or being closed out.
   172  	for party, poi := range e.partiesHighestVolume {
   173  		partiesRecordsMap[party] = &snapshotpb.PartyPositionStats{
   174  			Party:              party,
   175  			LowestOpenInterest: ptr.From(poi.Highest),
   176  			LatestOpenInterest: ptr.From(poi.Latest),
   177  		}
   178  	}
   179  
   180  	for party, tradedSize := range e.partiesTradedSize {
   181  		if pr, ok := partiesRecordsMap[party]; ok {
   182  			pr.TradedVolume = ptr.From(tradedSize)
   183  			continue
   184  		}
   185  
   186  		partiesRecordsMap[party] = &snapshotpb.PartyPositionStats{
   187  			Party:        party,
   188  			TradedVolume: ptr.From(tradedSize),
   189  		}
   190  	}
   191  
   192  	partiesRecord := maps.Values(partiesRecordsMap)
   193  	sort.Slice(partiesRecord, func(i, j int) bool {
   194  		return partiesRecord[i].Party < partiesRecord[j].Party
   195  	})
   196  
   197  	e.pl.Data = &types.PayloadMarketPositions{
   198  		MarketPositions: &types.MarketPositions{
   199  			MarketID:      e.marketID,
   200  			Positions:     positions,
   201  			PartieRecords: partiesRecord,
   202  		},
   203  	}
   204  
   205  	var err error
   206  	e.data, err = proto.Marshal(e.pl.IntoProto())
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  	return e.data, nil
   211  }