github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/access/subscription/streamer_test.go (about)

     1  package subscription_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/google/uuid"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/mock"
    12  
    13  	"github.com/onflow/flow-go/engine"
    14  	streammock "github.com/onflow/flow-go/engine/access/state_stream/mock"
    15  	"github.com/onflow/flow-go/engine/access/subscription"
    16  	"github.com/onflow/flow-go/utils/unittest"
    17  )
    18  
    19  type testData struct {
    20  	data string
    21  	err  error
    22  }
    23  
    24  var testErr = fmt.Errorf("test error")
    25  
    26  func TestStream(t *testing.T) {
    27  	t.Parallel()
    28  
    29  	ctx := context.Background()
    30  	timeout := subscription.DefaultSendTimeout
    31  
    32  	sub := streammock.NewStreamable(t)
    33  	sub.On("ID").Return(uuid.NewString())
    34  
    35  	tests := []testData{}
    36  	for i := 0; i < 4; i++ {
    37  		tests = append(tests, testData{fmt.Sprintf("test%d", i), nil})
    38  	}
    39  	tests = append(tests, testData{"", testErr})
    40  
    41  	broadcaster := engine.NewBroadcaster()
    42  	streamer := subscription.NewStreamer(unittest.Logger(), broadcaster, timeout, subscription.DefaultResponseLimit, sub)
    43  
    44  	for _, d := range tests {
    45  		sub.On("Next", mock.Anything).Return(d.data, d.err).Once()
    46  		if d.err == nil {
    47  			sub.On("Send", mock.Anything, d.data, timeout).Return(nil).Once()
    48  		} else {
    49  			mocked := sub.On("Fail", mock.Anything).Return().Once()
    50  			mocked.RunFn = func(args mock.Arguments) {
    51  				assert.ErrorIs(t, args.Get(0).(error), d.err)
    52  			}
    53  		}
    54  	}
    55  
    56  	broadcaster.Publish()
    57  
    58  	unittest.RequireReturnsBefore(t, func() {
    59  		streamer.Stream(ctx)
    60  	}, 100*time.Millisecond, "streamer.Stream() should return quickly")
    61  }
    62  
    63  func TestStreamRatelimited(t *testing.T) {
    64  	t.Parallel()
    65  
    66  	ctx := context.Background()
    67  	timeout := subscription.DefaultSendTimeout
    68  	duration := 100 * time.Millisecond
    69  
    70  	for _, limit := range []float64{0.2, 3, 20, 500} {
    71  		t.Run(fmt.Sprintf("responses are limited - %.1f rps", limit), func(t *testing.T) {
    72  			sub := streammock.NewStreamable(t)
    73  			sub.On("ID").Return(uuid.NewString())
    74  
    75  			broadcaster := engine.NewBroadcaster()
    76  			streamer := subscription.NewStreamer(unittest.Logger(), broadcaster, timeout, limit, sub)
    77  
    78  			var nextCalls, sendCalls int
    79  			sub.On("Next", mock.Anything).Return("data", nil).Run(func(args mock.Arguments) {
    80  				nextCalls++
    81  			})
    82  			sub.On("Send", mock.Anything, "data", timeout).Return(nil).Run(func(args mock.Arguments) {
    83  				sendCalls++
    84  			})
    85  
    86  			broadcaster.Publish()
    87  
    88  			unittest.RequireNeverReturnBefore(t, func() {
    89  				streamer.Stream(ctx)
    90  			}, duration, "streamer.Stream() should never stop")
    91  
    92  			// check the number of calls and make sure they are sane.
    93  			// ratelimit uses a token bucket algorithm which adds 1 token every 1/r seconds. This
    94  			// comes to roughly 10% of r within 100ms.
    95  			//
    96  			// Add a large buffer since the algorithm only guarantees the rate over longer time
    97  			// ranges. Since this test covers various orders of magnitude, we can still validate it
    98  			// is working as expected.
    99  			target := int(limit * float64(duration) / float64(time.Second))
   100  			if target == 0 {
   101  				target = 1
   102  			}
   103  
   104  			assert.LessOrEqual(t, nextCalls, target*3)
   105  			assert.LessOrEqual(t, sendCalls, target*3)
   106  		})
   107  	}
   108  }
   109  
   110  // TestLongStreamRatelimited tests that the streamer is uses the correct rate limit over a longer
   111  // period of time
   112  func TestLongStreamRatelimited(t *testing.T) {
   113  	t.Parallel()
   114  
   115  	unittest.SkipUnless(t, unittest.TEST_LONG_RUNNING, "skipping long stream rate limit test")
   116  
   117  	ctx := context.Background()
   118  	timeout := subscription.DefaultSendTimeout
   119  
   120  	limit := 5.0
   121  	duration := 30 * time.Second
   122  
   123  	sub := streammock.NewStreamable(t)
   124  	sub.On("ID").Return(uuid.NewString())
   125  
   126  	broadcaster := engine.NewBroadcaster()
   127  	streamer := subscription.NewStreamer(unittest.Logger(), broadcaster, timeout, limit, sub)
   128  
   129  	var nextCalls, sendCalls int
   130  	sub.On("Next", mock.Anything).Return("data", nil).Run(func(args mock.Arguments) {
   131  		nextCalls++
   132  	})
   133  	sub.On("Send", mock.Anything, "data", timeout).Return(nil).Run(func(args mock.Arguments) {
   134  		sendCalls++
   135  	})
   136  
   137  	broadcaster.Publish()
   138  
   139  	unittest.RequireNeverReturnBefore(t, func() {
   140  		streamer.Stream(ctx)
   141  	}, duration, "streamer.Stream() should never stop")
   142  
   143  	// check the number of calls and make sure they are sane.
   144  	// over a longer time, the rate limit should be more accurate
   145  	target := int(limit) * int(duration/time.Second)
   146  	diff := 5 // 5 ~= 3% of 150 expected
   147  
   148  	assert.LessOrEqual(t, nextCalls, target+diff)
   149  	assert.GreaterOrEqual(t, nextCalls, target-diff)
   150  
   151  	assert.LessOrEqual(t, sendCalls, target+diff)
   152  	assert.GreaterOrEqual(t, sendCalls, target-diff)
   153  }