github.com/m3db/m3@v1.5.0/src/dbnode/client/host_queue_fetch_batch_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package client
    22  
    23  import (
    24  	"fmt"
    25  	"sync"
    26  	"testing"
    27  
    28  	"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
    29  
    30  	"github.com/golang/mock/gomock"
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/uber/tchannel-go/thrift"
    33  )
    34  
    35  func TestHostQueueFetchBatches(t *testing.T) {
    36  	namespace := "testNs"
    37  	ids := []string{"foo", "bar", "baz", "qux"}
    38  	result := &rpc.FetchBatchRawResult_{}
    39  	for range ids {
    40  		result.Elements = append(result.Elements, &rpc.FetchRawResult_{Segments: []*rpc.Segments{}})
    41  	}
    42  	var expected []hostQueueResult
    43  	for i := range ids {
    44  		expected = append(expected, hostQueueResult{result.Elements[i].Segments, nil})
    45  	}
    46  	testHostQueueFetchBatches(t, namespace, ids, result, expected, nil, func(results []hostQueueResult) {
    47  		assert.Equal(t, expected, results)
    48  	})
    49  }
    50  
    51  func TestHostQueueFetchBatchesV2MultiNS(t *testing.T) {
    52  	ids := []string{"foo", "bar", "baz", "qux"}
    53  	result := &rpc.FetchBatchRawResult_{}
    54  	for range ids {
    55  		result.Elements = append(result.Elements, &rpc.FetchRawResult_{Segments: []*rpc.Segments{}})
    56  	}
    57  	var expected []hostQueueResult
    58  	for i := range ids {
    59  		expected = append(expected, hostQueueResult{result.Elements[i].Segments, nil})
    60  	}
    61  	opts := newHostQueueTestOptions().SetUseV2BatchAPIs(true)
    62  	ctrl := gomock.NewController(t)
    63  	defer ctrl.Finish()
    64  
    65  	mockConnPool := NewMockconnectionPool(ctrl)
    66  
    67  	queue := newTestHostQueue(opts)
    68  	queue.connPool = mockConnPool
    69  
    70  	// Open.
    71  	mockConnPool.EXPECT().Open()
    72  	queue.Open()
    73  	assert.Equal(t, statusOpen, queue.status)
    74  
    75  	// Prepare callback for fetches.
    76  	var (
    77  		results []hostQueueResult
    78  		wg      sync.WaitGroup
    79  	)
    80  	callback := func(r interface{}, err error) {
    81  		results = append(results, hostQueueResult{r, err})
    82  		wg.Done()
    83  	}
    84  
    85  	fetchBatches := []*fetchBatchOp{}
    86  	for i, id := range ids {
    87  		fetchBatch := &fetchBatchOp{
    88  			request: rpc.FetchBatchRawRequest{
    89  				NameSpace: []byte(fmt.Sprintf("ns-%d", i)),
    90  			},
    91  			requestV2Elements: []rpc.FetchBatchRawV2RequestElement{
    92  				{
    93  					ID:         []byte(id),
    94  					RangeStart: int64(i),
    95  					RangeEnd:   int64(i + 1),
    96  				},
    97  			},
    98  		}
    99  		fetchBatches = append(fetchBatches, fetchBatch)
   100  		fetchBatch.completionFns = append(fetchBatch.completionFns, callback)
   101  	}
   102  	wg.Add(len(ids))
   103  
   104  	// Prepare mocks for flush
   105  	mockClient := rpc.NewMockTChanNode(ctrl)
   106  
   107  	verifyFetchBatchRawV2 := func(ctx thrift.Context, req *rpc.FetchBatchRawV2Request) {
   108  		assert.Equal(t, len(ids), len(req.NameSpaces))
   109  		for i, ns := range req.NameSpaces {
   110  			assert.Equal(t, []byte(fmt.Sprintf("ns-%d", i)), ns)
   111  		}
   112  		assert.Equal(t, len(ids), len(req.Elements))
   113  		for i, elem := range req.Elements {
   114  			assert.Equal(t, int64(i), elem.NameSpace)
   115  			assert.Equal(t, int64(i), elem.RangeStart)
   116  			assert.Equal(t, int64(i+1), elem.RangeEnd)
   117  			assert.Equal(t, []byte(ids[i]), elem.ID)
   118  		}
   119  	}
   120  
   121  	mockClient.EXPECT().
   122  		FetchBatchRawV2(gomock.Any(), gomock.Any()).
   123  		Do(verifyFetchBatchRawV2).
   124  		Return(result, nil)
   125  
   126  	mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil)
   127  
   128  	for _, fetchBatch := range fetchBatches {
   129  		assert.NoError(t, queue.Enqueue(fetchBatch))
   130  	}
   131  
   132  	// Wait for fetch to complete.
   133  	wg.Wait()
   134  
   135  	assert.Equal(t, len(ids), len(results))
   136  
   137  	// Close.
   138  	var closeWg sync.WaitGroup
   139  	closeWg.Add(1)
   140  	mockConnPool.EXPECT().Close().Do(func() {
   141  		closeWg.Done()
   142  	})
   143  	queue.Close()
   144  	closeWg.Wait()
   145  }
   146  
   147  func TestHostQueueFetchBatchesErrorOnNextClientUnavailable(t *testing.T) {
   148  	namespace := "testNs"
   149  	ids := []string{"foo", "bar", "baz", "qux"}
   150  	expectedErr := fmt.Errorf("an error")
   151  	var expected []hostQueueResult
   152  	for range ids {
   153  		expected = append(expected, hostQueueResult{nil, expectedErr})
   154  	}
   155  	opts := &testHostQueueFetchBatchesOptions{
   156  		nextClientErr: expectedErr,
   157  	}
   158  	testHostQueueFetchBatches(t, namespace, ids, nil, expected, opts, func(results []hostQueueResult) {
   159  		assert.Equal(t, expected, results)
   160  	})
   161  }
   162  
   163  func TestHostQueueFetchBatchesErrorOnFetchRawBatchError(t *testing.T) {
   164  	namespace := "testNs"
   165  	ids := []string{"foo", "bar", "baz", "qux"}
   166  	expectedErr := fmt.Errorf("an error")
   167  	var expected []hostQueueResult
   168  	for range ids {
   169  		expected = append(expected, hostQueueResult{nil, expectedErr})
   170  	}
   171  	opts := &testHostQueueFetchBatchesOptions{
   172  		fetchRawBatchErr: expectedErr,
   173  	}
   174  	testHostQueueFetchBatches(t, namespace, ids, nil, expected, opts, func(results []hostQueueResult) {
   175  		assert.Equal(t, expected, results)
   176  	})
   177  }
   178  
   179  func TestHostQueueFetchBatchesErrorOnFetchNoResponse(t *testing.T) {
   180  	namespace := "testNs"
   181  	ids := []string{"foo", "bar", "baz", "qux"}
   182  	result := &rpc.FetchBatchRawResult_{}
   183  	for range ids[:len(ids)-1] {
   184  		result.Elements = append(result.Elements, &rpc.FetchRawResult_{Segments: []*rpc.Segments{}})
   185  	}
   186  	var expected []hostQueueResult
   187  	for i := range ids[:len(ids)-1] {
   188  		expected = append(expected, hostQueueResult{result.Elements[i].Segments, nil})
   189  	}
   190  
   191  	testHostQueueFetchBatches(t, namespace, ids, result, expected, nil, func(results []hostQueueResult) {
   192  		assert.Equal(t, expected, results[:len(results)-1])
   193  		lastResult := results[len(results)-1]
   194  		assert.Nil(t, lastResult.result)
   195  		assert.IsType(t, errQueueFetchNoResponse(""), lastResult.err)
   196  	})
   197  }
   198  
   199  func TestHostQueueFetchBatchesErrorOnResultError(t *testing.T) {
   200  	namespace := "testNs"
   201  	ids := []string{"foo", "bar", "baz", "qux"}
   202  	anError := &rpc.Error{Type: rpc.ErrorType_INTERNAL_ERROR, Message: "an error"}
   203  	result := &rpc.FetchBatchRawResult_{}
   204  	for range ids[:len(ids)-1] {
   205  		result.Elements = append(result.Elements, &rpc.FetchRawResult_{Segments: []*rpc.Segments{}})
   206  	}
   207  	result.Elements = append(result.Elements, &rpc.FetchRawResult_{Err: anError})
   208  	var expected []hostQueueResult
   209  	for i := range ids[:len(ids)-1] {
   210  		expected = append(expected, hostQueueResult{result.Elements[i].Segments, nil})
   211  	}
   212  	testHostQueueFetchBatches(t, namespace, ids, result, expected, nil, func(results []hostQueueResult) {
   213  		assert.Equal(t, expected, results[:len(results)-1])
   214  		rpcErr, ok := results[len(results)-1].err.(*rpc.Error)
   215  		assert.True(t, ok)
   216  		assert.Equal(t, anError.Type, rpcErr.Type)
   217  		assert.Equal(t, anError.Message, rpcErr.Message)
   218  	})
   219  }
   220  
   221  type testHostQueueFetchBatchesOptions struct {
   222  	nextClientErr    error
   223  	fetchRawBatchErr error
   224  }
   225  
   226  func testHostQueueFetchBatches(
   227  	t *testing.T,
   228  	namespace string,
   229  	ids []string,
   230  	result *rpc.FetchBatchRawResult_,
   231  	expected []hostQueueResult,
   232  	testOpts *testHostQueueFetchBatchesOptions,
   233  	assertion func(results []hostQueueResult),
   234  ) {
   235  	for _, opts := range []Options{
   236  		newHostQueueTestOptions().SetUseV2BatchAPIs(false),
   237  		newHostQueueTestOptions().SetUseV2BatchAPIs(true),
   238  	} {
   239  		t.Run(fmt.Sprintf("useV2: %v", opts.UseV2BatchAPIs()), func(t *testing.T) {
   240  			ctrl := gomock.NewController(t)
   241  			defer ctrl.Finish()
   242  
   243  			mockConnPool := NewMockconnectionPool(ctrl)
   244  
   245  			queue := newTestHostQueue(opts)
   246  			queue.connPool = mockConnPool
   247  
   248  			// Open
   249  			mockConnPool.EXPECT().Open()
   250  			queue.Open()
   251  			assert.Equal(t, statusOpen, queue.status)
   252  
   253  			// Prepare callback for fetches
   254  			var (
   255  				results []hostQueueResult
   256  				wg      sync.WaitGroup
   257  			)
   258  			callback := func(r interface{}, err error) {
   259  				results = append(results, hostQueueResult{r, err})
   260  				wg.Done()
   261  			}
   262  
   263  			rawIDs := make([][]byte, len(ids))
   264  
   265  			for i, id := range ids {
   266  				rawIDs[i] = []byte(id)
   267  			}
   268  
   269  			var fetchBatch *fetchBatchOp
   270  			if opts.UseV2BatchAPIs() {
   271  				fetchBatch = &fetchBatchOp{
   272  					request: rpc.FetchBatchRawRequest{
   273  						NameSpace: []byte(namespace),
   274  					},
   275  				}
   276  			} else {
   277  				fetchBatch = &fetchBatchOp{
   278  					request: rpc.FetchBatchRawRequest{
   279  						RangeStart: 0,
   280  						RangeEnd:   1,
   281  						NameSpace:  []byte(namespace),
   282  						Ids:        rawIDs,
   283  					},
   284  				}
   285  			}
   286  
   287  			for _, id := range ids {
   288  				if opts.UseV2BatchAPIs() {
   289  					fetchBatch.requestV2Elements = append(fetchBatch.requestV2Elements, rpc.FetchBatchRawV2RequestElement{
   290  						ID:         []byte(id),
   291  						RangeStart: 0,
   292  						RangeEnd:   1,
   293  					})
   294  				}
   295  				fetchBatch.completionFns = append(fetchBatch.completionFns, callback)
   296  			}
   297  			wg.Add(len(ids))
   298  
   299  			// Prepare mocks for flush
   300  			mockClient := rpc.NewMockTChanNode(ctrl)
   301  
   302  			verifyFetchBatchRawV2 := func(ctx thrift.Context, req *rpc.FetchBatchRawV2Request) {
   303  				assert.Equal(t, 1, len(req.NameSpaces))
   304  				assert.Equal(t, len(ids), len(req.Elements))
   305  				for i, elem := range req.Elements {
   306  					assert.Equal(t, int64(0), elem.NameSpace)
   307  					assert.Equal(t, int64(0), elem.RangeStart)
   308  					assert.Equal(t, int64(1), elem.RangeEnd)
   309  					assert.Equal(t, []byte(ids[i]), elem.ID)
   310  				}
   311  			}
   312  			if testOpts != nil && testOpts.nextClientErr != nil {
   313  				mockConnPool.EXPECT().NextClient().Return(nil, nil, testOpts.nextClientErr)
   314  			} else if testOpts != nil && testOpts.fetchRawBatchErr != nil {
   315  				if opts.UseV2BatchAPIs() {
   316  					mockClient.EXPECT().
   317  						FetchBatchRawV2(gomock.Any(), gomock.Any()).
   318  						Do(verifyFetchBatchRawV2).
   319  						Return(nil, testOpts.fetchRawBatchErr)
   320  				} else {
   321  					fetchBatchRaw := func(ctx thrift.Context, req *rpc.FetchBatchRawRequest) {
   322  						assert.Equal(t, &fetchBatch.request, req)
   323  					}
   324  					mockClient.EXPECT().
   325  						FetchBatchRaw(gomock.Any(), gomock.Any()).
   326  						Do(fetchBatchRaw).
   327  						Return(nil, testOpts.fetchRawBatchErr)
   328  				}
   329  				mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil)
   330  			} else {
   331  				if opts.UseV2BatchAPIs() {
   332  					mockClient.EXPECT().
   333  						FetchBatchRawV2(gomock.Any(), gomock.Any()).
   334  						Do(verifyFetchBatchRawV2).
   335  						Return(result, nil)
   336  				} else {
   337  					fetchBatchRaw := func(ctx thrift.Context, req *rpc.FetchBatchRawRequest) {
   338  						assert.Equal(t, &fetchBatch.request, req)
   339  					}
   340  					mockClient.EXPECT().
   341  						FetchBatchRaw(gomock.Any(), gomock.Any()).
   342  						Do(fetchBatchRaw).
   343  						Return(result, nil)
   344  				}
   345  
   346  				mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil)
   347  			}
   348  
   349  			// Fetch
   350  			assert.NoError(t, queue.Enqueue(fetchBatch))
   351  
   352  			// Wait for fetch to complete
   353  			wg.Wait()
   354  
   355  			// Assert results match expected
   356  			assertion(results)
   357  
   358  			// Close
   359  			var closeWg sync.WaitGroup
   360  			closeWg.Add(1)
   361  			mockConnPool.EXPECT().Close().Do(func() {
   362  				closeWg.Done()
   363  			})
   364  			queue.Close()
   365  			closeWg.Wait()
   366  		})
   367  	}
   368  }