code.vegaprotocol.io/vega@v0.79.0/core/broker/mocks/broker_drop_in_mock.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package mocks
    17  
    18  import (
    19  	"sync"
    20  
    21  	"code.vegaprotocol.io/vega/core/events"
    22  
    23  	"github.com/golang/mock/gomock"
    24  )
    25  
    26  // MockBroker - drop in mock that allows us to check the events themselves in unit tests (and as such ensure the state changes are correct)
    27  // We're only overriding the Send and SendBatch functions. The way in which this is done shouldn't be a problem, even when using DoAndReturn, but you never know...
    28  type MockBroker struct {
    29  	// embed the broker mock here... this is how we can end up with a drop-in replacement
    30  	*MockInterface
    31  
    32  	// settlement has a TestConcurrent test, which causes data race on this wrapped mock
    33  	mu *sync.Mutex
    34  	// all events in a map per type
    35  	// the last of each event type
    36  	// and last events for each event type by ID (e.g. latest order event given the order ID)
    37  	allEvts    map[events.Type][]events.Event
    38  	lastEvts   map[events.Type]events.Event
    39  	lastEvtsID map[events.Type]map[string]events.Event
    40  }
    41  
    42  func NewMockBroker(ctrl *gomock.Controller) *MockBroker {
    43  	mbi := NewMockInterface(ctrl)
    44  	return &MockBroker{
    45  		MockInterface: mbi,
    46  		mu:            &sync.Mutex{},
    47  		allEvts:       map[events.Type][]events.Event{},
    48  		lastEvts:      map[events.Type]events.Event{},
    49  		lastEvtsID:    map[events.Type]map[string]events.Event{},
    50  	}
    51  }
    52  
    53  // Send - first call Send on the underlying mock, then add the argument to the various maps.
    54  func (b *MockBroker) Send(event events.Event) {
    55  	// first call the regular mock
    56  	b.MockInterface.Send(event)
    57  	b.mu.Lock()
    58  	t := event.Type()
    59  	s, ok := b.allEvts[t]
    60  	if !ok {
    61  		s = []events.Event{}
    62  	}
    63  	s = append(s, event)
    64  	b.allEvts[t] = s
    65  	b.lastEvts[t] = event
    66  	if ok, id := isIDEvt(event); ok {
    67  		m, ok := b.lastEvtsID[t]
    68  		if !ok {
    69  			m = map[string]events.Event{}
    70  		}
    71  		m[id] = event
    72  		b.lastEvtsID[t] = m
    73  	}
    74  	b.mu.Unlock()
    75  }
    76  
    77  // SendBatch - same as Send: call mock first, then add arguments to the maps.
    78  func (b *MockBroker) SendBatch(evts []events.Event) {
    79  	b.MockInterface.SendBatch(evts)
    80  	if len(evts) == 0 {
    81  		return
    82  	}
    83  	b.mu.Lock()
    84  	first := evts[0]
    85  	t := first.Type()
    86  	s, ok := b.allEvts[t]
    87  	if !ok {
    88  		s = make([]events.Event, 0, cap(evts))
    89  	}
    90  	s = append(s, evts...)
    91  	b.allEvts[t] = s
    92  	last := evts[len(evts)-1]
    93  	// batched events must all be of the same type anyway
    94  	b.lastEvts[t] = last
    95  	if ok, id := isIDEvt(last); ok {
    96  		m, ok := b.lastEvtsID[t]
    97  		if !ok {
    98  			m = map[string]events.Event{}
    99  		}
   100  		m[id] = last
   101  		b.lastEvtsID[t] = m
   102  	}
   103  	b.mu.Unlock()
   104  }
   105  
   106  // GetAllByType returns all events of a given type the mock has received.
   107  func (b *MockBroker) GetAllByType(t events.Type) []events.Event {
   108  	b.mu.Lock()
   109  	allEvts := b.allEvts
   110  	b.mu.Unlock()
   111  	if s, ok := allEvts[t]; ok {
   112  		return s
   113  	}
   114  	return nil
   115  }
   116  
   117  // GetLastByType returns the most recent event for a given type. If SendBatch was called, this is the last event of the batch.
   118  func (b *MockBroker) GetLastByType(t events.Type) events.Event {
   119  	b.mu.Lock()
   120  	defer b.mu.Unlock()
   121  	return b.lastEvts[t]
   122  }
   123  
   124  // GetLastByTypeAndID returns the last event of a given type, for a specific identified (party, market, order, etc...)
   125  // list of implemented events - and ID's used:
   126  //   - Order (by order ID)
   127  //   - Account (by account ID)
   128  //   - Asset (by asset ID)
   129  //   - Auction (by market ID)
   130  //   - Deposit (party ID)
   131  //   - Proposal (proposal ID)
   132  //   - LP (by party ID)
   133  //   - MarginLevels (party ID)
   134  //   - MarketData (market ID)
   135  //   - PosRes (market ID)
   136  //   - RiskFactor (market ID)
   137  //   - SettleDistressed (party ID)
   138  //   - Vote (currently PartyID, might want to use proposalID, too?)
   139  //   - Withdrawal (PartyID)
   140  func (b *MockBroker) GetLastByTypeAndID(t events.Type, id string) events.Event {
   141  	b.mu.Lock()
   142  	m, ok := b.lastEvtsID[t]
   143  	b.mu.Unlock()
   144  	if !ok {
   145  		return nil
   146  	}
   147  	return m[id]
   148  }
   149  
   150  // @TODO loss socialization. Given that this is something that would impact several parties, there's most likely
   151  // no real point to filtering by ID.
   152  // Not implemented yet, but worth considering:
   153  //   - Trade
   154  //   - TransferResponse
   155  //
   156  // Implemented events:
   157  //   - Order (by order ID)
   158  //   - Account (by account ID)
   159  //   - Asset (by asset ID)
   160  //   - Auction (by market ID)
   161  //   - Deposit (party ID)
   162  //   - Proposal (proposal ID)
   163  //   - LP (by party ID)
   164  //   - MarginLevels (party ID)
   165  //   - MarketData (market ID)
   166  //   - PosRes (market ID)
   167  //   - RiskFactor (market ID)
   168  //   - SettleDistressed (party ID)
   169  //   - Vote (currently PartyID, might want to use proposalID, too?)
   170  //   - Withdrawal (PartyID)
   171  func isIDEvt(e events.Event) (bool, string) {
   172  	switch et := e.(type) {
   173  	case *events.Order:
   174  		return true, et.Order().Id
   175  	case events.Order:
   176  		return true, et.Order().Id
   177  	case *events.Acc:
   178  		return true, et.Account().Id
   179  	case events.Acc:
   180  		return true, et.Account().Id
   181  	case *events.Asset:
   182  		return true, et.Asset().Id
   183  	case events.Asset:
   184  		return true, et.Asset().Id
   185  	case *events.Auction:
   186  		return true, et.MarketID()
   187  	case events.Auction:
   188  		return true, et.MarketID()
   189  	case *events.Deposit:
   190  		return true, et.Deposit().PartyId
   191  	case events.Deposit:
   192  		return true, et.Deposit().PartyId
   193  	case *events.Proposal:
   194  		return true, et.ProposalID()
   195  	case events.Proposal:
   196  		return true, et.ProposalID()
   197  	case *events.LiquidityProvision:
   198  		return true, et.PartyID()
   199  	case events.LiquidityProvision:
   200  		return true, et.PartyID()
   201  	case *events.MarginLevels:
   202  		return true, et.PartyID()
   203  	case events.MarginLevels:
   204  		return true, et.PartyID()
   205  	case *events.MarketData:
   206  		return true, et.MarketID()
   207  	case events.MarketData:
   208  		return true, et.MarketID()
   209  	case *events.PosRes:
   210  		return true, et.MarketID()
   211  	case events.PosRes:
   212  		return true, et.MarketID()
   213  	case *events.RiskFactor:
   214  		return true, et.MarketID()
   215  	case events.RiskFactor:
   216  		return true, et.MarketID()
   217  	case *events.SettleDistressed:
   218  		return true, et.PartyID()
   219  	case events.SettleDistressed:
   220  		return true, et.PartyID()
   221  	case *events.Vote:
   222  		return true, et.PartyID()
   223  	case events.Vote:
   224  		return true, et.PartyID()
   225  	case *events.Withdrawal:
   226  		return true, et.PartyID()
   227  	case events.Withdrawal:
   228  		return true, et.PartyID()
   229  	}
   230  	return false, ""
   231  }