github.com/Jeffail/benthos/v3@v3.65.0/lib/input/reader/async_preserver_test.go (about)

     1  package reader
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"reflect"
     7  	"sync"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/Jeffail/benthos/v3/internal/batch"
    12  	"github.com/Jeffail/benthos/v3/lib/message"
    13  	"github.com/Jeffail/benthos/v3/lib/response"
    14  	"github.com/Jeffail/benthos/v3/lib/types"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  type mockAsyncReader struct {
    20  	msgsToSnd []types.Message
    21  	ackRcvd   []error
    22  
    23  	connChan         chan error
    24  	readChan         chan error
    25  	ackChan          chan error
    26  	closeAsyncChan   chan struct{}
    27  	waitForCloseChan chan error
    28  }
    29  
    30  func newMockAsyncReader() *mockAsyncReader {
    31  	return &mockAsyncReader{
    32  		connChan:         make(chan error),
    33  		readChan:         make(chan error),
    34  		ackChan:          make(chan error),
    35  		closeAsyncChan:   make(chan struct{}),
    36  		waitForCloseChan: make(chan error),
    37  	}
    38  }
    39  
    40  func (r *mockAsyncReader) ConnectWithContext(ctx context.Context) error {
    41  	cerr, open := <-r.connChan
    42  	if !open {
    43  		return types.ErrNotConnected
    44  	}
    45  	return cerr
    46  }
    47  
    48  func (r *mockAsyncReader) ReadWithContext(ctx context.Context) (types.Message, AsyncAckFn, error) {
    49  	select {
    50  	case <-ctx.Done():
    51  		return nil, nil, types.ErrTimeout
    52  	case err, open := <-r.readChan:
    53  		if !open {
    54  			return nil, nil, types.ErrNotConnected
    55  		}
    56  		if err != nil {
    57  			return nil, nil, err
    58  		}
    59  	}
    60  	r.ackRcvd = append(r.ackRcvd, errors.New("ack not received"))
    61  	i := len(r.ackRcvd) - 1
    62  
    63  	var nextMsg types.Message = message.New(nil)
    64  	if len(r.msgsToSnd) > 0 {
    65  		nextMsg = r.msgsToSnd[0]
    66  		r.msgsToSnd = r.msgsToSnd[1:]
    67  	}
    68  
    69  	return nextMsg.DeepCopy(), func(ctx context.Context, res types.Response) error {
    70  		if res.SkipAck() {
    71  			return nil
    72  		}
    73  		r.ackRcvd[i] = res.Error()
    74  		return <-r.ackChan
    75  	}, nil
    76  }
    77  
    78  func (r *mockAsyncReader) CloseAsync() {
    79  	<-r.closeAsyncChan
    80  }
    81  
    82  func (r *mockAsyncReader) WaitForClose(time.Duration) error {
    83  	return <-r.waitForCloseChan
    84  }
    85  
    86  //------------------------------------------------------------------------------
    87  
    88  func TestAsyncPreserverClose(t *testing.T) {
    89  	t.Parallel()
    90  
    91  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
    92  	defer cancel()
    93  
    94  	readerImpl := newMockAsyncReader()
    95  	pres := NewAsyncPreserver(readerImpl)
    96  
    97  	exp := errors.New("foo error")
    98  
    99  	wg := sync.WaitGroup{}
   100  	wg.Add(1)
   101  
   102  	go func() {
   103  		if err := pres.ConnectWithContext(ctx); err != nil {
   104  			t.Error(err)
   105  		}
   106  		pres.CloseAsync()
   107  		if act := pres.WaitForClose(time.Second); act != exp {
   108  			t.Errorf("Wrong error returned: %v != %v", act, exp)
   109  		}
   110  		wg.Done()
   111  	}()
   112  
   113  	select {
   114  	case readerImpl.connChan <- nil:
   115  	case <-time.After(time.Second):
   116  		t.Error("Timed out")
   117  	}
   118  
   119  	select {
   120  	case readerImpl.closeAsyncChan <- struct{}{}:
   121  	case <-time.After(time.Second):
   122  		t.Error("Timed out")
   123  	}
   124  
   125  	select {
   126  	case readerImpl.waitForCloseChan <- exp:
   127  	case <-time.After(time.Second):
   128  		t.Error("Timed out")
   129  	}
   130  
   131  	wg.Wait()
   132  }
   133  
   134  func TestAsyncPreserverNackThenClose(t *testing.T) {
   135  	t.Parallel()
   136  
   137  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   138  	defer cancel()
   139  
   140  	readerImpl := newMockAsyncReader()
   141  	readerImpl.msgsToSnd = []types.Message{
   142  		message.New([][]byte{[]byte("hello world")}),
   143  	}
   144  	pres := NewAsyncPreserver(readerImpl)
   145  
   146  	wg := sync.WaitGroup{}
   147  	wg.Add(1)
   148  
   149  	go func() {
   150  		defer wg.Done()
   151  
   152  		select {
   153  		case readerImpl.connChan <- nil:
   154  		case <-ctx.Done():
   155  			t.Error("Timed out")
   156  		}
   157  
   158  		select {
   159  		case readerImpl.readChan <- nil:
   160  		case <-ctx.Done():
   161  			t.Error("Timed out")
   162  		}
   163  
   164  		select {
   165  		case readerImpl.readChan <- types.ErrTypeClosed:
   166  		case <-ctx.Done():
   167  			t.Error("Timed out")
   168  		}
   169  
   170  		select {
   171  		case readerImpl.ackChan <- nil:
   172  		case <-ctx.Done():
   173  			t.Error("Timed out")
   174  		}
   175  
   176  		select {
   177  		case readerImpl.readChan <- types.ErrTypeClosed:
   178  		case <-ctx.Done():
   179  			t.Error("Timed out")
   180  		}
   181  
   182  		select {
   183  		case readerImpl.closeAsyncChan <- struct{}{}:
   184  		case <-ctx.Done():
   185  			t.Error("Timed out")
   186  		}
   187  
   188  		select {
   189  		case readerImpl.waitForCloseChan <- nil:
   190  		case <-ctx.Done():
   191  			t.Error("Timed out")
   192  		}
   193  	}()
   194  
   195  	err := pres.ConnectWithContext(ctx)
   196  	assert.NoError(t, err)
   197  
   198  	_, ackFn1, err := pres.ReadWithContext(ctx)
   199  	assert.NoError(t, err)
   200  
   201  	go func() {
   202  		time.Sleep(time.Millisecond * 10)
   203  		assert.NoError(t, ackFn1(ctx, response.NewError(errors.New("rejected"))))
   204  	}()
   205  
   206  	_, _, err = pres.ReadWithContext(ctx)
   207  	assert.Equal(t, types.ErrTimeout, err)
   208  
   209  	_, ackFn2, err := pres.ReadWithContext(ctx)
   210  	assert.NoError(t, err)
   211  	assert.NoError(t, ackFn2(ctx, response.NewAck()))
   212  
   213  	_, _, err = pres.ReadWithContext(ctx)
   214  	assert.Equal(t, types.ErrTypeClosed, err)
   215  
   216  	pres.CloseAsync()
   217  	err = pres.WaitForClose(time.Second)
   218  	assert.NoError(t, err)
   219  
   220  	wg.Wait()
   221  }
   222  
   223  func TestAsyncPreserverCloseThenAck(t *testing.T) {
   224  	t.Parallel()
   225  
   226  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   227  	defer cancel()
   228  
   229  	readerImpl := newMockAsyncReader()
   230  	readerImpl.msgsToSnd = []types.Message{
   231  		message.New([][]byte{[]byte("hello world")}),
   232  	}
   233  	pres := NewAsyncPreserver(readerImpl)
   234  
   235  	wg := sync.WaitGroup{}
   236  	wg.Add(1)
   237  
   238  	go func() {
   239  		defer wg.Done()
   240  
   241  		select {
   242  		case readerImpl.connChan <- nil:
   243  		case <-ctx.Done():
   244  			t.Error("Timed out")
   245  		}
   246  
   247  		select {
   248  		case readerImpl.readChan <- nil:
   249  		case <-ctx.Done():
   250  			t.Error("Timed out")
   251  		}
   252  
   253  		select {
   254  		case readerImpl.readChan <- types.ErrTypeClosed:
   255  		case <-ctx.Done():
   256  			t.Error("Timed out")
   257  		}
   258  
   259  		select {
   260  		case readerImpl.ackChan <- nil:
   261  		case <-ctx.Done():
   262  			t.Error("Timed out")
   263  		}
   264  
   265  		select {
   266  		case readerImpl.closeAsyncChan <- struct{}{}:
   267  		case <-ctx.Done():
   268  			t.Error("Timed out")
   269  		}
   270  
   271  		select {
   272  		case readerImpl.waitForCloseChan <- nil:
   273  		case <-ctx.Done():
   274  			t.Error("Timed out")
   275  		}
   276  	}()
   277  
   278  	err := pres.ConnectWithContext(ctx)
   279  	assert.NoError(t, err)
   280  
   281  	_, ackFn1, err := pres.ReadWithContext(ctx)
   282  	assert.NoError(t, err)
   283  
   284  	go func() {
   285  		time.Sleep(time.Millisecond * 10)
   286  		assert.NoError(t, ackFn1(ctx, response.NewAck()))
   287  	}()
   288  
   289  	_, _, err = pres.ReadWithContext(ctx)
   290  	assert.Equal(t, types.ErrTypeClosed, err)
   291  
   292  	pres.CloseAsync()
   293  	err = pres.WaitForClose(time.Second)
   294  	assert.NoError(t, err)
   295  
   296  	wg.Wait()
   297  }
   298  
   299  func TestAsyncPreserverCloseThenNackThenAck(t *testing.T) {
   300  	t.Parallel()
   301  
   302  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   303  	defer cancel()
   304  
   305  	readerImpl := newMockAsyncReader()
   306  	readerImpl.msgsToSnd = []types.Message{
   307  		message.New([][]byte{[]byte("hello world")}),
   308  	}
   309  	pres := NewAsyncPreserver(readerImpl)
   310  
   311  	wg := sync.WaitGroup{}
   312  	wg.Add(1)
   313  
   314  	go func() {
   315  		defer wg.Done()
   316  
   317  		select {
   318  		case readerImpl.connChan <- nil:
   319  		case <-ctx.Done():
   320  			t.Error("Timed out")
   321  		}
   322  
   323  		select {
   324  		case readerImpl.readChan <- nil:
   325  		case <-ctx.Done():
   326  			t.Error("Timed out")
   327  		}
   328  
   329  		select {
   330  		case readerImpl.readChan <- types.ErrTypeClosed:
   331  		case <-ctx.Done():
   332  			t.Error("Timed out")
   333  		}
   334  
   335  		select {
   336  		case readerImpl.readChan <- types.ErrTypeClosed:
   337  		case <-ctx.Done():
   338  			t.Error("Timed out")
   339  		}
   340  
   341  		select {
   342  		case readerImpl.ackChan <- nil:
   343  		case <-ctx.Done():
   344  			t.Error("Timed out")
   345  		}
   346  
   347  		select {
   348  		case readerImpl.closeAsyncChan <- struct{}{}:
   349  		case <-ctx.Done():
   350  			t.Error("Timed out")
   351  		}
   352  
   353  		select {
   354  		case readerImpl.waitForCloseChan <- nil:
   355  		case <-ctx.Done():
   356  			t.Error("Timed out")
   357  		}
   358  	}()
   359  
   360  	err := pres.ConnectWithContext(ctx)
   361  	assert.NoError(t, err)
   362  
   363  	_, ackFn1, err := pres.ReadWithContext(ctx)
   364  	assert.NoError(t, err)
   365  
   366  	go func() {
   367  		time.Sleep(time.Millisecond * 100)
   368  		assert.NoError(t, ackFn1(ctx, response.NewError(errors.New("huh"))))
   369  	}()
   370  
   371  	_, _, err = pres.ReadWithContext(ctx)
   372  	assert.Equal(t, types.ErrTimeout, err)
   373  
   374  	_, ackFn2, err := pres.ReadWithContext(ctx)
   375  	require.NoError(t, err)
   376  
   377  	go func() {
   378  		time.Sleep(time.Millisecond * 100)
   379  		assert.NoError(t, ackFn2(ctx, response.NewAck()))
   380  	}()
   381  
   382  	_, _, err = pres.ReadWithContext(ctx)
   383  	assert.Equal(t, types.ErrTypeClosed, err)
   384  
   385  	pres.CloseAsync()
   386  	err = pres.WaitForClose(time.Second)
   387  	assert.NoError(t, err)
   388  
   389  	wg.Wait()
   390  }
   391  
   392  func TestAsyncPreserverCloseViaConnectThenAck(t *testing.T) {
   393  	t.Parallel()
   394  
   395  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   396  	defer cancel()
   397  
   398  	readerImpl := newMockAsyncReader()
   399  	readerImpl.msgsToSnd = []types.Message{
   400  		message.New([][]byte{[]byte("hello world")}),
   401  	}
   402  	pres := NewAsyncPreserver(readerImpl)
   403  
   404  	wg := sync.WaitGroup{}
   405  	wg.Add(1)
   406  
   407  	go func() {
   408  		defer wg.Done()
   409  
   410  		select {
   411  		case readerImpl.connChan <- nil:
   412  		case <-ctx.Done():
   413  			t.Error("Timed out")
   414  		}
   415  
   416  		select {
   417  		case readerImpl.readChan <- nil:
   418  		case <-ctx.Done():
   419  			t.Error("Timed out")
   420  		}
   421  
   422  		select {
   423  		case readerImpl.readChan <- types.ErrNotConnected:
   424  		case <-ctx.Done():
   425  			t.Error("Timed out")
   426  		}
   427  
   428  		select {
   429  		case readerImpl.connChan <- types.ErrTypeClosed:
   430  		case <-ctx.Done():
   431  			t.Error("Timed out")
   432  		}
   433  
   434  		select {
   435  		case readerImpl.ackChan <- nil:
   436  		case <-ctx.Done():
   437  			t.Error("Timed out")
   438  		}
   439  
   440  		select {
   441  		case readerImpl.closeAsyncChan <- struct{}{}:
   442  		case <-ctx.Done():
   443  			t.Error("Timed out")
   444  		}
   445  
   446  		select {
   447  		case readerImpl.waitForCloseChan <- nil:
   448  		case <-ctx.Done():
   449  			t.Error("Timed out")
   450  		}
   451  	}()
   452  
   453  	err := pres.ConnectWithContext(ctx)
   454  	assert.NoError(t, err)
   455  
   456  	_, ackFn1, err := pres.ReadWithContext(ctx)
   457  	assert.NoError(t, err)
   458  
   459  	_, _, err = pres.ReadWithContext(ctx)
   460  	assert.Equal(t, types.ErrNotConnected, err)
   461  
   462  	err = pres.ConnectWithContext(ctx)
   463  	assert.NoError(t, err)
   464  
   465  	go func() {
   466  		time.Sleep(time.Millisecond * 100)
   467  		assert.NoError(t, ackFn1(ctx, response.NewAck()))
   468  	}()
   469  
   470  	_, _, err = pres.ReadWithContext(ctx)
   471  	assert.Equal(t, types.ErrTypeClosed, err)
   472  
   473  	pres.CloseAsync()
   474  	err = pres.WaitForClose(time.Second)
   475  	assert.NoError(t, err)
   476  
   477  	wg.Wait()
   478  }
   479  
   480  func TestAsyncPreserverHappy(t *testing.T) {
   481  	t.Parallel()
   482  
   483  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   484  	defer cancel()
   485  
   486  	readerImpl := newMockAsyncReader()
   487  	pres := NewAsyncPreserver(readerImpl)
   488  
   489  	expParts := [][]byte{
   490  		[]byte("foo"),
   491  	}
   492  
   493  	go func() {
   494  		select {
   495  		case readerImpl.connChan <- nil:
   496  		case <-time.After(time.Second):
   497  			t.Error("Timed out")
   498  		}
   499  		for _, p := range expParts {
   500  			readerImpl.msgsToSnd = []types.Message{message.New([][]byte{p})}
   501  			select {
   502  			case readerImpl.readChan <- nil:
   503  			case <-time.After(time.Second):
   504  				t.Error("Timed out")
   505  			}
   506  		}
   507  	}()
   508  
   509  	if err := pres.ConnectWithContext(ctx); err != nil {
   510  		t.Error(err)
   511  	}
   512  
   513  	for _, exp := range expParts {
   514  		msg, _, err := pres.ReadWithContext(ctx)
   515  		if err != nil {
   516  			t.Fatal(err)
   517  		}
   518  		if act := msg.Get(0).Get(); !reflect.DeepEqual(act, exp) {
   519  			t.Errorf("Wrong message returned: %v != %v", act, exp)
   520  		}
   521  	}
   522  }
   523  
   524  func TestAsyncPreserverErrorProp(t *testing.T) {
   525  	t.Parallel()
   526  
   527  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   528  	defer cancel()
   529  
   530  	readerImpl := newMockAsyncReader()
   531  	pres := NewAsyncPreserver(readerImpl)
   532  
   533  	expErr := errors.New("foo")
   534  
   535  	go func() {
   536  		select {
   537  		case readerImpl.connChan <- expErr:
   538  		case <-time.After(time.Second):
   539  			t.Error("Timed out")
   540  		}
   541  		select {
   542  		case readerImpl.readChan <- expErr:
   543  		case <-time.After(time.Second):
   544  			t.Error("Timed out")
   545  		}
   546  		select {
   547  		case readerImpl.readChan <- nil:
   548  		case <-time.After(time.Second):
   549  			t.Error("Timed out")
   550  		}
   551  		select {
   552  		case readerImpl.ackChan <- expErr:
   553  		case <-time.After(time.Second):
   554  			t.Error("Timed out")
   555  		}
   556  	}()
   557  
   558  	if actErr := pres.ConnectWithContext(ctx); expErr != actErr {
   559  		t.Errorf("Wrong error returned: %v != %v", actErr, expErr)
   560  	}
   561  	if _, _, actErr := pres.ReadWithContext(ctx); expErr != actErr {
   562  		t.Errorf("Wrong error returned: %v != %v", actErr, expErr)
   563  	}
   564  	if _, aFn, actErr := pres.ReadWithContext(ctx); actErr != nil {
   565  		t.Fatal(actErr)
   566  	} else if actErr = aFn(ctx, response.NewAck()); expErr != actErr {
   567  		t.Errorf("Wrong error returned: %v != %v", actErr, expErr)
   568  	}
   569  }
   570  
   571  func TestAsyncPreserverErrorBackoff(t *testing.T) {
   572  	t.Parallel()
   573  
   574  	readerImpl := newMockAsyncReader()
   575  	pres := NewAsyncPreserver(readerImpl)
   576  
   577  	go func() {
   578  		select {
   579  		case readerImpl.connChan <- nil:
   580  		case <-time.After(time.Second):
   581  			t.Error("Timed out")
   582  		}
   583  		select {
   584  		case readerImpl.readChan <- nil:
   585  		case <-time.After(time.Second):
   586  			t.Error("Timed out")
   587  		}
   588  		select {
   589  		case readerImpl.closeAsyncChan <- struct{}{}:
   590  		case <-time.After(time.Second):
   591  			t.Error("Timed out")
   592  		}
   593  		select {
   594  		case readerImpl.waitForCloseChan <- nil:
   595  		case <-time.After(time.Second):
   596  			t.Error("Timed out")
   597  		}
   598  	}()
   599  
   600  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
   601  	defer cancel()
   602  
   603  	require.NoError(t, pres.ConnectWithContext(ctx))
   604  
   605  	i := 0
   606  	for {
   607  		_, aFn, actErr := pres.ReadWithContext(ctx)
   608  		if actErr != nil {
   609  			assert.EqualError(t, actErr, "context deadline exceeded")
   610  			break
   611  		}
   612  		require.NoError(t, aFn(ctx, response.NewError(errors.New("no thanks"))))
   613  		i++
   614  		if i == 10 {
   615  			t.Error("Expected backoff to prevent this")
   616  			break
   617  		}
   618  	}
   619  
   620  	pres.CloseAsync()
   621  	require.NoError(t, pres.WaitForClose(time.Second))
   622  }
   623  
   624  func TestAsyncPreserverBatchError(t *testing.T) {
   625  	t.Parallel()
   626  
   627  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   628  	defer cancel()
   629  
   630  	readerImpl := newMockAsyncReader()
   631  	pres := NewAsyncPreserver(readerImpl)
   632  
   633  	go func() {
   634  		select {
   635  		case readerImpl.connChan <- nil:
   636  		case <-time.After(time.Second):
   637  			t.Error("Timed out")
   638  		}
   639  		readerImpl.msgsToSnd = []types.Message{
   640  			message.New([][]byte{
   641  				[]byte("foo"),
   642  				[]byte("bar"),
   643  				[]byte("baz"),
   644  				[]byte("buz"),
   645  				[]byte("bev"),
   646  			})}
   647  		select {
   648  		case readerImpl.readChan <- nil:
   649  		case <-time.After(time.Second):
   650  			t.Error("Timed out")
   651  		}
   652  		select {
   653  		case readerImpl.ackChan <- errors.New("ack propagated"):
   654  		case <-time.After(time.Second):
   655  			t.Error("Timed out")
   656  		}
   657  	}()
   658  
   659  	require.NoError(t, pres.ConnectWithContext(ctx))
   660  
   661  	msg, ackFn, err := pres.ReadWithContext(ctx)
   662  	require.NoError(t, err)
   663  	assert.Equal(t, [][]byte{
   664  		[]byte("foo"),
   665  		[]byte("bar"),
   666  		[]byte("baz"),
   667  		[]byte("buz"),
   668  		[]byte("bev"),
   669  	}, message.GetAllBytes(msg))
   670  
   671  	bErr := batch.NewError(msg, errors.New("first"))
   672  	bErr.Failed(1, errors.New("second"))
   673  	bErr.Failed(3, errors.New("third"))
   674  
   675  	require.NoError(t, ackFn(ctx, response.NewError(bErr)))
   676  
   677  	msg, ackFn, err = pres.ReadWithContext(ctx)
   678  	require.NoError(t, err)
   679  	assert.Equal(t, [][]byte{
   680  		[]byte("bar"),
   681  		[]byte("buz"),
   682  	}, message.GetAllBytes(msg))
   683  
   684  	require.EqualError(t, ackFn(ctx, response.NewAck()), "ack propagated")
   685  }
   686  
   687  func TestAsyncPreserverBatchErrorUnordered(t *testing.T) {
   688  	t.Parallel()
   689  
   690  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   691  	defer cancel()
   692  
   693  	readerImpl := newMockAsyncReader()
   694  	pres := NewAsyncPreserver(readerImpl)
   695  
   696  	go func() {
   697  		select {
   698  		case readerImpl.connChan <- nil:
   699  		case <-time.After(time.Second):
   700  			t.Error("Timed out")
   701  		}
   702  		readerImpl.msgsToSnd = []types.Message{
   703  			message.New([][]byte{
   704  				[]byte("foo"),
   705  				[]byte("bar"),
   706  				[]byte("baz"),
   707  				[]byte("buz"),
   708  				[]byte("bev"),
   709  			})}
   710  		select {
   711  		case readerImpl.readChan <- nil:
   712  		case <-time.After(time.Second):
   713  			t.Error("Timed out")
   714  		}
   715  		select {
   716  		case readerImpl.ackChan <- errors.New("ack propagated"):
   717  		case <-time.After(time.Second):
   718  			t.Error("Timed out")
   719  		}
   720  	}()
   721  
   722  	require.NoError(t, pres.ConnectWithContext(ctx))
   723  
   724  	msg, ackFn, err := pres.ReadWithContext(ctx)
   725  	require.NoError(t, err)
   726  	assert.Equal(t, [][]byte{
   727  		[]byte("foo"),
   728  		[]byte("bar"),
   729  		[]byte("baz"),
   730  		[]byte("buz"),
   731  		[]byte("bev"),
   732  	}, message.GetAllBytes(msg))
   733  
   734  	bMsg := message.New(nil)
   735  	bMsg.Append(msg.Get(1))
   736  	bMsg.Append(msg.Get(3))
   737  	bMsg.Append(msg.Get(0))
   738  	bMsg.Append(msg.Get(4))
   739  	bMsg.Append(msg.Get(2))
   740  
   741  	bErr := batch.NewError(bMsg, errors.New("first"))
   742  	bErr.Failed(1, errors.New("second"))
   743  	bErr.Failed(2, errors.New("third"))
   744  
   745  	require.NoError(t, ackFn(ctx, response.NewError(bErr)))
   746  
   747  	msg, ackFn, err = pres.ReadWithContext(ctx)
   748  	require.NoError(t, err)
   749  	assert.Equal(t, [][]byte{
   750  		[]byte("buz"),
   751  		[]byte("foo"),
   752  	}, message.GetAllBytes(msg))
   753  
   754  	require.EqualError(t, ackFn(ctx, response.NewAck()), "ack propagated")
   755  }
   756  
   757  //------------------------------------------------------------------------------
   758  
   759  func TestAsyncPreserverBuffer(t *testing.T) {
   760  	t.Parallel()
   761  
   762  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   763  	defer cancel()
   764  
   765  	readerImpl := newMockAsyncReader()
   766  	pres := NewAsyncPreserver(readerImpl)
   767  
   768  	sendMsg := func(content string) {
   769  		readerImpl.msgsToSnd = []types.Message{message.New(
   770  			[][]byte{[]byte(content)},
   771  		)}
   772  		select {
   773  		case readerImpl.readChan <- nil:
   774  		case <-time.After(time.Second):
   775  			t.Error("Timed out")
   776  		}
   777  	}
   778  	sendAck := func() {
   779  		select {
   780  		case readerImpl.ackChan <- nil:
   781  		case <-time.After(time.Second):
   782  			t.Error("Timed out")
   783  		}
   784  	}
   785  
   786  	// Send message normally.
   787  	exp := "msg 1"
   788  	exp2 := "msg 2"
   789  	exp3 := "msg 3"
   790  
   791  	go sendMsg(exp)
   792  	msg, aFn, err := pres.ReadWithContext(ctx)
   793  	if err != nil {
   794  		t.Fatal(err)
   795  	}
   796  	if act := string(msg.Get(0).Get()); exp != act {
   797  		t.Errorf("Wrong message returned: %v != %v", act, exp)
   798  	}
   799  
   800  	// Prime second message.
   801  	go sendMsg(exp2)
   802  
   803  	// Fail previous message, expecting it to be resent.
   804  	_ = aFn(ctx, response.NewError(errors.New("failed")))
   805  	msg, aFn, err = pres.ReadWithContext(ctx)
   806  	if err != nil {
   807  		t.Fatal(err)
   808  	}
   809  	if act := string(msg.Get(0).Get()); exp != act {
   810  		t.Errorf("Wrong message returned: %v != %v", act, exp)
   811  	}
   812  
   813  	// Read the primed message.
   814  	var aFn2 AsyncAckFn
   815  	msg, aFn2, err = pres.ReadWithContext(ctx)
   816  	if err != nil {
   817  		t.Fatal(err)
   818  	}
   819  	if act := string(msg.Get(0).Get()); exp2 != act {
   820  		t.Errorf("Wrong message returned: %v != %v", act, exp2)
   821  	}
   822  
   823  	// Fail both messages, expecting them to be resent.
   824  	_ = aFn(ctx, response.NewError(errors.New("failed again")))
   825  	_ = aFn2(ctx, response.NewError(errors.New("failed again")))
   826  
   827  	// Read both messages.
   828  	msg, aFn, err = pres.ReadWithContext(ctx)
   829  	if err != nil {
   830  		t.Fatal(err)
   831  	}
   832  	if act := string(msg.Get(0).Get()); exp != act {
   833  		t.Errorf("Wrong message returned: %v != %v", act, exp)
   834  	}
   835  	msg, aFn2, err = pres.ReadWithContext(ctx)
   836  	if err != nil {
   837  		t.Fatal(err)
   838  	}
   839  	if act := string(msg.Get(0).Get()); exp2 != act {
   840  		t.Errorf("Wrong message returned: %v != %v", act, exp2)
   841  	}
   842  
   843  	// Prime a new message and also an acknowledgement.
   844  	go sendMsg(exp3)
   845  	go sendAck()
   846  	go sendAck()
   847  
   848  	// Ack all messages.
   849  	_ = aFn(ctx, response.NewAck())
   850  	_ = aFn2(ctx, response.NewAck())
   851  
   852  	msg, _, err = pres.ReadWithContext(ctx)
   853  	if err != nil {
   854  		t.Fatal(err)
   855  	}
   856  	if act := string(msg.Get(0).Get()); exp3 != act {
   857  		t.Errorf("Wrong message returned: %v != %v", act, exp3)
   858  	}
   859  }
   860  
   861  func TestAsyncPreserverBufferBatchedAcks(t *testing.T) {
   862  	t.Parallel()
   863  
   864  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   865  	defer cancel()
   866  
   867  	readerImpl := newMockAsyncReader()
   868  	pres := NewAsyncPreserver(readerImpl)
   869  
   870  	sendMsg := func(content string) {
   871  		readerImpl.msgsToSnd = []types.Message{message.New(
   872  			[][]byte{[]byte(content)},
   873  		)}
   874  		select {
   875  		case readerImpl.readChan <- nil:
   876  		case <-time.After(time.Second):
   877  			t.Error("Timed out")
   878  		}
   879  	}
   880  	sendAck := func() {
   881  		select {
   882  		case readerImpl.ackChan <- nil:
   883  		case <-time.After(time.Second):
   884  			t.Error("Timed out")
   885  		}
   886  	}
   887  
   888  	messages := []string{
   889  		"msg 1",
   890  		"msg 2",
   891  		"msg 3",
   892  	}
   893  
   894  	ackFns := []AsyncAckFn{}
   895  	for _, exp := range messages {
   896  		go sendMsg(exp)
   897  		msg, aFn, err := pres.ReadWithContext(ctx)
   898  		if err != nil {
   899  			t.Fatal(err)
   900  		}
   901  		ackFns = append(ackFns, aFn)
   902  		if act := string(msg.Get(0).Get()); exp != act {
   903  			t.Errorf("Wrong message returned: %v != %v", act, exp)
   904  		}
   905  	}
   906  
   907  	// Fail all messages, expecting them to be resent.
   908  	for _, aFn := range ackFns {
   909  		_ = aFn(ctx, response.NewError(errors.New("failed again")))
   910  	}
   911  	ackFns = []AsyncAckFn{}
   912  
   913  	for _, exp := range messages {
   914  		msg, aFn, err := pres.ReadWithContext(ctx)
   915  		if err != nil {
   916  			t.Fatal(err)
   917  		}
   918  		ackFns = append(ackFns, aFn)
   919  		if act := string(msg.Get(0).Get()); exp != act {
   920  			t.Errorf("Wrong message returned: %v != %v", act, exp)
   921  		}
   922  	}
   923  
   924  	// Ack all messages.
   925  	go func() {
   926  		for _, aFn := range ackFns {
   927  			_ = aFn(ctx, response.NewAck())
   928  		}
   929  	}()
   930  
   931  	for range ackFns {
   932  		sendAck()
   933  	}
   934  }