code.vegaprotocol.io/vega@v0.79.0/libs/subscribers/stream_subscriber_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  //go:build !race
    17  // +build !race
    18  
    19  package subscribers_test
    20  
    21  import (
    22  	"context"
    23  	"sync"
    24  	"testing"
    25  
    26  	"code.vegaprotocol.io/vega/core/events"
    27  	dtypes "code.vegaprotocol.io/vega/core/types"
    28  	"code.vegaprotocol.io/vega/libs/subscribers"
    29  	types "code.vegaprotocol.io/vega/protos/vega"
    30  	eventspb "code.vegaprotocol.io/vega/protos/vega/events/v1"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  )
    34  
    35  type tstStreamSub struct {
    36  	*subscribers.StreamSub
    37  	ctx   context.Context
    38  	cfunc context.CancelFunc
    39  }
    40  
    41  type accEvt interface {
    42  	events.Event
    43  	Account() types.Account
    44  }
    45  
    46  func getTestStreamSub(types []events.Type, bufSize int, filters ...subscribers.EventFilter) *tstStreamSub {
    47  	ctx, cfunc := context.WithCancel(context.Background())
    48  	return &tstStreamSub{
    49  		StreamSub: subscribers.NewStreamSub(ctx, types, bufSize, filters...),
    50  		ctx:       ctx,
    51  		cfunc:     cfunc,
    52  	}
    53  }
    54  
    55  func accMarketIDFilter(mID string) subscribers.EventFilter {
    56  	return func(e events.Event) bool {
    57  		ae, ok := e.(accEvt)
    58  		if !ok {
    59  			return false
    60  		}
    61  		if ae.Account().MarketId != mID {
    62  			return false
    63  		}
    64  		return true
    65  	}
    66  }
    67  
    68  func TestUnfilteredSubscription(t *testing.T) {
    69  	t.Run("Stream subscriber without filters, no events", testUnfilteredNoEvents)
    70  	t.Run("Stream subscriber without filters - with events", testUnfilteredWithEventsPush)
    71  }
    72  
    73  func TestFilteredSubscription(t *testing.T) {
    74  	t.Run("Stream subscriber with filter - no valid events", testFilteredNoValidEvents)
    75  	t.Run("Stream subscriber with filter - some valid events", testFilteredSomeValidEvents)
    76  }
    77  
    78  func TestSubscriberTypes(t *testing.T) {
    79  	t.Run("Stream subscriber for all event types", testFilterAll)
    80  }
    81  
    82  func TestSubscriberBuffered(t *testing.T) {
    83  	t.Run("Batched stream subscriber", testBatchedStreamSubscriber)
    84  }
    85  
    86  func TestMidChannelDone(t *testing.T) {
    87  	t.Run("Stream subscriber stops mid event stream", testCloseChannelWrite)
    88  }
    89  
    90  func testUnfilteredNoEvents(t *testing.T) {
    91  	sub := getTestStreamSub([]events.Type{events.AccountEvent}, 0)
    92  	wg := sync.WaitGroup{}
    93  	wg.Add(1)
    94  	var data []*eventspb.BusEvent
    95  	go func() {
    96  		data = sub.GetData(context.Background())
    97  		wg.Done()
    98  	}()
    99  	sub.cfunc() // cancel ctx
   100  	wg.Wait()
   101  	// we expect to see no events
   102  	assert.Equal(t, 0, len(data))
   103  }
   104  
   105  func testUnfilteredWithEventsPush(t *testing.T) {
   106  	sub := getTestStreamSub([]events.Type{events.AccountEvent}, 0)
   107  	defer sub.cfunc()
   108  	set := []events.Event{
   109  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   110  			ID: "acc-1",
   111  		}),
   112  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   113  			ID: "acc-2",
   114  		}),
   115  	}
   116  
   117  	data := []*eventspb.BusEvent{}
   118  	done := make(chan struct{})
   119  	getData := func() {
   120  		done <- struct{}{}
   121  		data = sub.GetData(context.Background())
   122  		done <- struct{}{}
   123  	}
   124  
   125  	go getData()
   126  
   127  	<-done
   128  	sub.Push(set...)
   129  	<-done
   130  	// we expect to see no events
   131  	assert.Equal(t, len(set), len(data))
   132  	last := events.NewAccountEvent(sub.ctx, dtypes.Account{
   133  		ID: "acc-3",
   134  	})
   135  
   136  	go getData()
   137  
   138  	<-done
   139  	sub.Push(last)
   140  	<-done
   141  	assert.Equal(t, 1, len(data))
   142  	rt, err := events.ProtoToInternal(data[0].Type)
   143  	assert.NoError(t, err)
   144  	assert.Equal(t, 1, len(rt))
   145  	assert.Equal(t, events.AccountEvent, rt[0])
   146  	acc := data[0].GetAccount()
   147  	assert.NotNil(t, acc)
   148  	assert.Equal(t, last.Account().Id, acc.Id)
   149  }
   150  
   151  func testFilteredNoValidEvents(t *testing.T) {
   152  	sub := getTestStreamSub([]events.Type{events.AccountEvent}, 0, accMarketIDFilter("valid"))
   153  	set := []events.Event{
   154  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   155  			ID:       "acc-1",
   156  			MarketID: "invalid",
   157  		}),
   158  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   159  			ID:       "acc-2",
   160  			MarketID: "also-invalid",
   161  		}),
   162  	}
   163  	sub.Push(set...)
   164  	wg := sync.WaitGroup{}
   165  	wg.Add(1)
   166  	var data []*eventspb.BusEvent
   167  	go func() {
   168  		data = sub.GetData(context.Background())
   169  		wg.Done()
   170  	}()
   171  	sub.cfunc()
   172  	wg.Wait()
   173  	// we expect to see no events
   174  	assert.Equal(t, 0, len(data))
   175  }
   176  
   177  func testFilteredSomeValidEvents(t *testing.T) {
   178  	sub := getTestStreamSub([]events.Type{events.AccountEvent}, 0, accMarketIDFilter("valid"))
   179  	defer sub.cfunc()
   180  	set := []events.Event{
   181  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   182  			ID:       "acc-1",
   183  			MarketID: "invalid",
   184  		}),
   185  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   186  			ID:       "acc-2",
   187  			MarketID: "valid",
   188  		}),
   189  	}
   190  
   191  	data := []*eventspb.BusEvent{}
   192  	done := make(chan struct{})
   193  	getData := func() {
   194  		done <- struct{}{}
   195  		data = sub.GetData(context.Background())
   196  		done <- struct{}{}
   197  	}
   198  	go getData()
   199  
   200  	<-done
   201  	sub.Push(set...)
   202  	<-done
   203  	// we expect to see no events
   204  	assert.Equal(t, 1, len(data))
   205  }
   206  
   207  func testFilterAll(t *testing.T) {
   208  	sub := getTestStreamSub([]events.Type{events.All}, 0)
   209  	assert.Nil(t, sub.Types())
   210  }
   211  
   212  func testBatchedStreamSubscriber(t *testing.T) {
   213  	mID := "market-id"
   214  	sub := getTestStreamSub([]events.Type{events.All}, 5)
   215  	defer sub.cfunc()
   216  	sent, rec := make(chan struct{}), make(chan struct{})
   217  	set1 := []events.Event{
   218  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   219  			ID:       "acc1",
   220  			MarketID: mID,
   221  		}),
   222  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   223  			ID:       "acc2",
   224  			MarketID: mID,
   225  		}),
   226  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   227  			ID:       "acc50",
   228  			MarketID: "other-market",
   229  		}),
   230  	}
   231  	sendRoutine := func(ch chan struct{}, sub *tstStreamSub, set []events.Event) {
   232  		sub.C() <- set
   233  		close(ch)
   234  	}
   235  
   236  	var data []*eventspb.BusEvent
   237  	go func() {
   238  		rec <- struct{}{}
   239  		data = sub.GetData(context.Background())
   240  		close(rec)
   241  	}()
   242  	<-rec
   243  
   244  	go sendRoutine(sent, sub, set1)
   245  	// ensure all events were sent
   246  	<-sent
   247  	// now start receiving, this should not receive any events:
   248  	// let's send a new batch, this ought to fill the buffer
   249  	sent = make(chan struct{})
   250  	go sendRoutine(sent, sub, set1)
   251  	<-rec
   252  	// buffer max reached, data sent
   253  	assert.Equal(t, 5, len(data))
   254  	// a total of 6 events were now sent to the subscriber, changing the buffer size ought to return 1 event
   255  	<-sent
   256  	data = sub.UpdateBatchSize(sub.ctx, len(set1)) // set batch size to match test-data set
   257  	assert.Equal(t, 1, len(data))                  // we should have drained the buffer
   258  	sent = make(chan struct{})
   259  	go sendRoutine(sent, sub, set1)
   260  	<-sent
   261  	// we don't need the rec channel, the buffer is 3, and we sent 3 events
   262  	data = sub.GetData(context.Background())
   263  	assert.Equal(t, 3, len(data))
   264  	// just in case -> this is with the rec channel, it ought to produce the exact same result
   265  	sent = make(chan struct{})
   266  	go sendRoutine(sent, sub, set1)
   267  	<-sent
   268  	rec = make(chan struct{})
   269  	// buffer is 3, we sent 3 events, GetData ought to return
   270  	go func() {
   271  		data = sub.GetData(context.Background())
   272  		close(rec)
   273  	}()
   274  	<-rec
   275  	assert.Equal(t, 3, len(data))
   276  }
   277  
   278  // this test aims to replicate the crash when trying to write to a closed channel.
   279  func testCloseChannelWrite(t *testing.T) {
   280  	mID := "tstMarket"
   281  	sub := getTestStreamSub([]events.Type{events.AccountEvent}, 0, accMarketIDFilter(mID))
   282  	set := []events.Event{
   283  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   284  			ID:       "acc1",
   285  			MarketID: mID,
   286  		}),
   287  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   288  			ID:       "acc2",
   289  			MarketID: mID,
   290  		}),
   291  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   292  			ID:       "acc50",
   293  			MarketID: "other-market",
   294  		}),
   295  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   296  			ID:       "acc3",
   297  			MarketID: mID,
   298  		}),
   299  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   300  			ID:       "acc4",
   301  			MarketID: mID,
   302  		}),
   303  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   304  			ID:       "acc51",
   305  			MarketID: "other-market",
   306  		}),
   307  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   308  			ID:       "acc5",
   309  			MarketID: "other-market",
   310  		}),
   311  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   312  			ID:       "acc6",
   313  			MarketID: mID,
   314  		}),
   315  		events.NewAccountEvent(sub.ctx, dtypes.Account{
   316  			ID:       "acc7",
   317  			MarketID: mID,
   318  		}),
   319  	}
   320  	started := make(chan struct{})
   321  	wg := sync.WaitGroup{}
   322  	wg.Add(1)
   323  	go func() {
   324  		first := false
   325  		defer wg.Done()
   326  		// keep iterating until the context was closed, ensuring
   327  		// the context is cancelled mid-send
   328  		for {
   329  			select {
   330  			case <-sub.Closed():
   331  				return
   332  			case <-sub.Skip():
   333  				return
   334  			case sub.C() <- set:
   335  				// case ch <- e:
   336  				if !first {
   337  					first = true
   338  					close(started)
   339  				}
   340  			}
   341  		}
   342  	}()
   343  	<-started
   344  	// wait for sub to be confirmed closed down
   345  	data := sub.GetData(sub.ctx)
   346  	sub.cfunc()
   347  	wg.Wait()
   348  	// we received at least the first event, which is valid (filtered)
   349  	// so this slice ought not to be empty
   350  	assert.NotEmpty(t, data)
   351  }