github.com/onflow/flow-go@v0.33.17/engine/access/state_stream/backend/streamer_test.go (about)

     1  package backend_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/google/uuid"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/mock"
    12  
    13  	"github.com/onflow/flow-go/engine"
    14  	"github.com/onflow/flow-go/engine/access/state_stream"
    15  	"github.com/onflow/flow-go/engine/access/state_stream/backend"
    16  	streammock "github.com/onflow/flow-go/engine/access/state_stream/mock"
    17  	"github.com/onflow/flow-go/utils/unittest"
    18  )
    19  
    20  type testData struct {
    21  	data string
    22  	err  error
    23  }
    24  
    25  var testErr = fmt.Errorf("test error")
    26  
    27  func TestStream(t *testing.T) {
    28  	t.Parallel()
    29  
    30  	ctx := context.Background()
    31  	timeout := state_stream.DefaultSendTimeout
    32  
    33  	sub := streammock.NewStreamable(t)
    34  	sub.On("ID").Return(uuid.NewString())
    35  
    36  	tests := []testData{}
    37  	for i := 0; i < 4; i++ {
    38  		tests = append(tests, testData{fmt.Sprintf("test%d", i), nil})
    39  	}
    40  	tests = append(tests, testData{"", testErr})
    41  
    42  	broadcaster := engine.NewBroadcaster()
    43  	streamer := backend.NewStreamer(unittest.Logger(), broadcaster, timeout, state_stream.DefaultResponseLimit, sub)
    44  
    45  	for _, d := range tests {
    46  		sub.On("Next", mock.Anything).Return(d.data, d.err).Once()
    47  		if d.err == nil {
    48  			sub.On("Send", mock.Anything, d.data, timeout).Return(nil).Once()
    49  		} else {
    50  			mocked := sub.On("Fail", mock.Anything).Return().Once()
    51  			mocked.RunFn = func(args mock.Arguments) {
    52  				assert.ErrorIs(t, args.Get(0).(error), d.err)
    53  			}
    54  		}
    55  	}
    56  
    57  	broadcaster.Publish()
    58  
    59  	unittest.RequireReturnsBefore(t, func() {
    60  		streamer.Stream(ctx)
    61  	}, 100*time.Millisecond, "streamer.Stream() should return quickly")
    62  }
    63  
    64  func TestStreamRatelimited(t *testing.T) {
    65  	t.Parallel()
    66  
    67  	ctx := context.Background()
    68  	timeout := state_stream.DefaultSendTimeout
    69  	duration := 100 * time.Millisecond
    70  
    71  	for _, limit := range []float64{0.2, 3, 20, 500} {
    72  		t.Run(fmt.Sprintf("responses are limited - %.1f rps", limit), func(t *testing.T) {
    73  			sub := streammock.NewStreamable(t)
    74  			sub.On("ID").Return(uuid.NewString())
    75  
    76  			broadcaster := engine.NewBroadcaster()
    77  			streamer := backend.NewStreamer(unittest.Logger(), broadcaster, timeout, limit, sub)
    78  
    79  			var nextCalls, sendCalls int
    80  			sub.On("Next", mock.Anything).Return("data", nil).Run(func(args mock.Arguments) {
    81  				nextCalls++
    82  			})
    83  			sub.On("Send", mock.Anything, "data", timeout).Return(nil).Run(func(args mock.Arguments) {
    84  				sendCalls++
    85  			})
    86  
    87  			broadcaster.Publish()
    88  
    89  			unittest.RequireNeverReturnBefore(t, func() {
    90  				streamer.Stream(ctx)
    91  			}, duration, "streamer.Stream() should never stop")
    92  
    93  			// check the number of calls and make sure they are sane.
    94  			// ratelimit uses a token bucket algorithm which adds 1 token every 1/r seconds. This
    95  			// comes to roughly 10% of r within 100ms.
    96  			//
    97  			// Add a large buffer since the algorithm only guarantees the rate over longer time
    98  			// ranges. Since this test covers various orders of magnitude, we can still validate it
    99  			// is working as expected.
   100  			target := int(limit * float64(duration) / float64(time.Second))
   101  			if target == 0 {
   102  				target = 1
   103  			}
   104  
   105  			assert.LessOrEqual(t, nextCalls, target*3)
   106  			assert.LessOrEqual(t, sendCalls, target*3)
   107  		})
   108  	}
   109  }
   110  
   111  // TestLongStreamRatelimited tests that the streamer is uses the correct rate limit over a longer
   112  // period of time
   113  func TestLongStreamRatelimited(t *testing.T) {
   114  	t.Parallel()
   115  
   116  	unittest.SkipUnless(t, unittest.TEST_LONG_RUNNING, "skipping long stream rate limit test")
   117  
   118  	ctx := context.Background()
   119  	timeout := state_stream.DefaultSendTimeout
   120  
   121  	limit := 5.0
   122  	duration := 30 * time.Second
   123  
   124  	sub := streammock.NewStreamable(t)
   125  	sub.On("ID").Return(uuid.NewString())
   126  
   127  	broadcaster := engine.NewBroadcaster()
   128  	streamer := backend.NewStreamer(unittest.Logger(), broadcaster, timeout, limit, sub)
   129  
   130  	var nextCalls, sendCalls int
   131  	sub.On("Next", mock.Anything).Return("data", nil).Run(func(args mock.Arguments) {
   132  		nextCalls++
   133  	})
   134  	sub.On("Send", mock.Anything, "data", timeout).Return(nil).Run(func(args mock.Arguments) {
   135  		sendCalls++
   136  	})
   137  
   138  	broadcaster.Publish()
   139  
   140  	unittest.RequireNeverReturnBefore(t, func() {
   141  		streamer.Stream(ctx)
   142  	}, duration, "streamer.Stream() should never stop")
   143  
   144  	// check the number of calls and make sure they are sane.
   145  	// over a longer time, the rate limit should be more accurate
   146  	target := int(limit) * int(duration/time.Second)
   147  	diff := 5 // 5 ~= 3% of 150 expected
   148  
   149  	assert.LessOrEqual(t, nextCalls, target+diff)
   150  	assert.GreaterOrEqual(t, nextCalls, target-diff)
   151  
   152  	assert.LessOrEqual(t, sendCalls, target+diff)
   153  	assert.GreaterOrEqual(t, sendCalls, target-diff)
   154  }