github.com/Jeffail/benthos/v3@v3.65.0/lib/processor/rate_limit_test.go (about)

     1  package processor
     2  
     3  import (
     4  	"errors"
     5  	"reflect"
     6  	"sync/atomic"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/Jeffail/benthos/v3/lib/log"
    11  	"github.com/Jeffail/benthos/v3/lib/message"
    12  	"github.com/Jeffail/benthos/v3/lib/metrics"
    13  	"github.com/Jeffail/benthos/v3/lib/types"
    14  )
    15  
    16  type fakeRateLimit struct {
    17  	resFn func() (time.Duration, error)
    18  }
    19  
    20  func (f fakeRateLimit) Access() (time.Duration, error) {
    21  	return f.resFn()
    22  }
    23  func (f fakeRateLimit) CloseAsync()                      {}
    24  func (f fakeRateLimit) WaitForClose(time.Duration) error { return nil }
    25  
    26  func TestRateLimitBasic(t *testing.T) {
    27  	var hits int32
    28  	rlFn := func() (time.Duration, error) {
    29  		atomic.AddInt32(&hits, 1)
    30  		return 0, nil
    31  	}
    32  
    33  	mgr := &fakeMgr{
    34  		ratelimits: map[string]types.RateLimit{
    35  			"foo": fakeRateLimit{resFn: rlFn},
    36  		},
    37  	}
    38  
    39  	conf := NewConfig()
    40  	conf.RateLimit.Resource = "foo"
    41  	proc, err := NewRateLimit(conf, mgr, log.Noop(), metrics.Noop())
    42  	if err != nil {
    43  		t.Fatal(err)
    44  	}
    45  
    46  	input := message.New([][]byte{
    47  		[]byte(`{"key":"1","value":"foo 1"}`),
    48  		[]byte(`{"key":"2","value":"foo 2"}`),
    49  		[]byte(`{"key":"1","value":"foo 3"}`),
    50  	})
    51  
    52  	output, res := proc.ProcessMessage(input)
    53  	if res != nil {
    54  		t.Fatal(res.Error())
    55  	}
    56  
    57  	if len(output) != 1 {
    58  		t.Fatalf("Wrong count of result messages: %v", len(output))
    59  	}
    60  
    61  	if exp, act := message.GetAllBytes(input), message.GetAllBytes(output[0]); !reflect.DeepEqual(exp, act) {
    62  		t.Errorf("Wrong result messages: %s != %s", act, exp)
    63  	}
    64  
    65  	if exp, act := int32(3), atomic.LoadInt32(&hits); exp != act {
    66  		t.Errorf("Wrong count of rate limit hits: %v != %v", act, exp)
    67  	}
    68  }
    69  
    70  func TestRateLimitClosed(t *testing.T) {
    71  	var hits int32
    72  	rlFn := func() (time.Duration, error) {
    73  		if i := atomic.AddInt32(&hits, 1); i == 2 {
    74  			return 0, types.ErrTypeClosed
    75  		}
    76  		return 0, nil
    77  	}
    78  
    79  	mgr := &fakeMgr{
    80  		ratelimits: map[string]types.RateLimit{
    81  			"foo": fakeRateLimit{resFn: rlFn},
    82  		},
    83  	}
    84  
    85  	conf := NewConfig()
    86  	conf.RateLimit.Resource = "foo"
    87  	proc, err := NewRateLimit(conf, mgr, log.Noop(), metrics.Noop())
    88  	if err != nil {
    89  		t.Fatal(err)
    90  	}
    91  
    92  	input := message.New([][]byte{
    93  		[]byte(`{"key":"1","value":"foo 1"}`),
    94  		[]byte(`{"key":"2","value":"foo 2"}`),
    95  		[]byte(`{"key":"1","value":"foo 3"}`),
    96  	})
    97  
    98  	output, res := proc.ProcessMessage(input)
    99  	if res != nil {
   100  		t.Fatal(res.Error())
   101  	}
   102  
   103  	if len(output) != 1 {
   104  		t.Fatalf("Wrong count of result messages: %v", len(output))
   105  	}
   106  
   107  	if exp, act := message.GetAllBytes(input), message.GetAllBytes(output[0]); !reflect.DeepEqual(exp, act) {
   108  		t.Errorf("Wrong result messages: %s != %s", act, exp)
   109  	}
   110  
   111  	if exp, act := int32(2), atomic.LoadInt32(&hits); exp != act {
   112  		t.Errorf("Wrong count of rate limit hits: %v != %v", act, exp)
   113  	}
   114  }
   115  
   116  func TestRateLimitErroredOut(t *testing.T) {
   117  	rlFn := func() (time.Duration, error) {
   118  		return 0, errors.New("omg foo")
   119  	}
   120  
   121  	mgr := &fakeMgr{
   122  		ratelimits: map[string]types.RateLimit{
   123  			"foo": fakeRateLimit{resFn: rlFn},
   124  		},
   125  	}
   126  
   127  	conf := NewConfig()
   128  	conf.RateLimit.Resource = "foo"
   129  	proc, err := NewRateLimit(conf, mgr, log.Noop(), metrics.Noop())
   130  	if err != nil {
   131  		t.Fatal(err)
   132  	}
   133  
   134  	input := message.New([][]byte{
   135  		[]byte(`{"key":"1","value":"foo 1"}`),
   136  		[]byte(`{"key":"2","value":"foo 2"}`),
   137  		[]byte(`{"key":"1","value":"foo 3"}`),
   138  	})
   139  
   140  	closedChan := make(chan struct{})
   141  	go func() {
   142  		output, res := proc.ProcessMessage(input)
   143  		if res != nil {
   144  			t.Error(res.Error())
   145  		}
   146  
   147  		if len(output) != 1 {
   148  			t.Errorf("Wrong count of result messages: %v", len(output))
   149  		}
   150  
   151  		if exp, act := message.GetAllBytes(input), message.GetAllBytes(output[0]); !reflect.DeepEqual(exp, act) {
   152  			t.Errorf("Wrong result messages: %s != %s", act, exp)
   153  		}
   154  		close(closedChan)
   155  	}()
   156  
   157  	proc.CloseAsync()
   158  	select {
   159  	case <-closedChan:
   160  	case <-time.After(time.Second):
   161  		t.Error("Timed out")
   162  	}
   163  }
   164  
   165  func TestRateLimitBlocked(t *testing.T) {
   166  	rlFn := func() (time.Duration, error) {
   167  		return time.Second * 10, nil
   168  	}
   169  
   170  	mgr := &fakeMgr{
   171  		ratelimits: map[string]types.RateLimit{
   172  			"foo": fakeRateLimit{resFn: rlFn},
   173  		},
   174  	}
   175  
   176  	conf := NewConfig()
   177  	conf.RateLimit.Resource = "foo"
   178  	proc, err := NewRateLimit(conf, mgr, log.Noop(), metrics.Noop())
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  
   183  	input := message.New([][]byte{
   184  		[]byte(`{"key":"1","value":"foo 1"}`),
   185  		[]byte(`{"key":"2","value":"foo 2"}`),
   186  		[]byte(`{"key":"1","value":"foo 3"}`),
   187  	})
   188  
   189  	closedChan := make(chan struct{})
   190  	go func() {
   191  		output, res := proc.ProcessMessage(input)
   192  		if res != nil {
   193  			t.Error(res.Error())
   194  		}
   195  
   196  		if len(output) != 1 {
   197  			t.Errorf("Wrong count of result messages: %v", len(output))
   198  		}
   199  
   200  		if exp, act := message.GetAllBytes(input), message.GetAllBytes(output[0]); !reflect.DeepEqual(exp, act) {
   201  			t.Errorf("Wrong result messages: %s != %s", act, exp)
   202  		}
   203  		close(closedChan)
   204  	}()
   205  
   206  	proc.CloseAsync()
   207  	select {
   208  	case <-closedChan:
   209  	case <-time.After(time.Second):
   210  		t.Error("Timed out")
   211  	}
   212  }