go.uber.org/yarpc@v1.72.1/internal/outboundmiddleware/chain_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package outboundmiddleware
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"errors"
    27  	"io/ioutil"
    28  	"testing"
    29  
    30  	"github.com/golang/mock/gomock"
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  	"go.uber.org/yarpc/api/middleware"
    34  	"go.uber.org/yarpc/api/middleware/middlewaretest"
    35  	"go.uber.org/yarpc/api/transport"
    36  	"go.uber.org/yarpc/api/transport/transporttest"
    37  	"go.uber.org/yarpc/api/x/introspection"
    38  	"go.uber.org/yarpc/internal/testtime"
    39  )
    40  
    41  type countOutboundMiddleware struct{ Count int }
    42  
    43  func (c *countOutboundMiddleware) Call(
    44  	ctx context.Context, req *transport.Request, o transport.UnaryOutbound) (*transport.Response, error) {
    45  	c.Count++
    46  	return o.Call(ctx, req)
    47  }
    48  
    49  func (c *countOutboundMiddleware) CallOneway(ctx context.Context, req *transport.Request, o transport.OnewayOutbound) (transport.Ack, error) {
    50  	c.Count++
    51  	return o.CallOneway(ctx, req)
    52  }
    53  
    54  func (c *countOutboundMiddleware) CallStream(ctx context.Context, req *transport.StreamRequest, o transport.StreamOutbound) (*transport.ClientStream, error) {
    55  	c.Count++
    56  	return o.CallStream(ctx, req)
    57  }
    58  
    59  var retryUnaryOutbound middleware.UnaryOutboundFunc = func(
    60  	ctx context.Context, req *transport.Request, o transport.UnaryOutbound) (*transport.Response, error) {
    61  	res, err := o.Call(ctx, req)
    62  	if err != nil {
    63  		res, err = o.Call(ctx, req)
    64  	}
    65  	return res, err
    66  }
    67  
    68  func TestUnaryChain(t *testing.T) {
    69  	before := &countOutboundMiddleware{}
    70  	after := &countOutboundMiddleware{}
    71  
    72  	tests := []struct {
    73  		desc string
    74  		mw   middleware.UnaryOutbound
    75  	}{
    76  		{"flat chain", UnaryChain(before, retryUnaryOutbound, nil, after)},
    77  		{"nested chain", UnaryChain(before, UnaryChain(retryUnaryOutbound, after, nil))},
    78  	}
    79  
    80  	for _, tt := range tests {
    81  		t.Run(tt.desc, func(t *testing.T) {
    82  			before.Count, after.Count = 0, 0
    83  			mockCtrl := gomock.NewController(t)
    84  			defer mockCtrl.Finish()
    85  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
    86  			defer cancel()
    87  
    88  			req := &transport.Request{
    89  				Caller:    "somecaller",
    90  				Service:   "someservice",
    91  				Encoding:  transport.Encoding("raw"),
    92  				Procedure: "hello",
    93  				Body:      bytes.NewReader([]byte{1, 2, 3}),
    94  			}
    95  			res := &transport.Response{
    96  				Body: ioutil.NopCloser(bytes.NewReader([]byte{4, 5, 6})),
    97  			}
    98  			o := transporttest.NewMockUnaryOutbound(mockCtrl)
    99  			o.EXPECT().Call(ctx, req).After(
   100  				o.EXPECT().Call(ctx, req).Return(nil, errors.New("great sadness")),
   101  			).Return(res, nil)
   102  
   103  			gotRes, err := middleware.ApplyUnaryOutbound(o, tt.mw).Call(ctx, req)
   104  
   105  			assert.NoError(t, err, "expected success")
   106  			assert.Equal(t, 1, before.Count, "expected outer middleware to be called once")
   107  			assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice")
   108  			assert.Equal(t, res, gotRes, "expected response to match")
   109  		})
   110  	}
   111  }
   112  
   113  var retryOnewayOutbound middleware.OnewayOutboundFunc = func(
   114  	ctx context.Context, req *transport.Request, o transport.OnewayOutbound) (transport.Ack, error) {
   115  	res, err := o.CallOneway(ctx, req)
   116  	if err != nil {
   117  		res, err = o.CallOneway(ctx, req)
   118  	}
   119  	return res, err
   120  }
   121  
   122  func TestOnewayChain(t *testing.T) {
   123  	before := &countOutboundMiddleware{}
   124  	after := &countOutboundMiddleware{}
   125  
   126  	tests := []struct {
   127  		desc string
   128  		mw   middleware.OnewayOutbound
   129  	}{
   130  		{"flat chain", OnewayChain(before, retryOnewayOutbound, nil, after)},
   131  		{"flat chain", OnewayChain(before, OnewayChain(retryOnewayOutbound, after, nil))},
   132  	}
   133  
   134  	for _, tt := range tests {
   135  		t.Run(tt.desc, func(t *testing.T) {
   136  			mockCtrl := gomock.NewController(t)
   137  			defer mockCtrl.Finish()
   138  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   139  			defer cancel()
   140  
   141  			var res transport.Ack
   142  			req := &transport.Request{
   143  				Caller:    "somecaller",
   144  				Service:   "someservice",
   145  				Encoding:  transport.Encoding("raw"),
   146  				Procedure: "hello",
   147  				Body:      bytes.NewReader([]byte{1, 2, 3}),
   148  			}
   149  			o := transporttest.NewMockOnewayOutbound(mockCtrl)
   150  			before.Count, after.Count = 0, 0
   151  			o.EXPECT().CallOneway(ctx, req).After(
   152  				o.EXPECT().CallOneway(ctx, req).Return(nil, errors.New("great sadness")),
   153  			).Return(res, nil)
   154  
   155  			gotRes, err := middleware.ApplyOnewayOutbound(o, tt.mw).CallOneway(ctx, req)
   156  
   157  			assert.NoError(t, err, "expected success")
   158  			assert.Equal(t, 1, before.Count, "expected outer middleware to be called once")
   159  			assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice")
   160  			assert.Equal(t, res, gotRes, "expected response to match")
   161  		})
   162  	}
   163  }
   164  
   165  func TestEmptyChain(t *testing.T) {
   166  	errMsg := "expected nop Outbound"
   167  
   168  	t.Run("unary", func(t *testing.T) {
   169  		require.Equal(t, middleware.NopUnaryOutbound, UnaryChain(), errMsg)
   170  	})
   171  
   172  	t.Run("oneway", func(t *testing.T) {
   173  		require.Equal(t, middleware.NopOnewayOutbound, OnewayChain(), errMsg)
   174  	})
   175  }
   176  
   177  func TestSingleOutboundChain(t *testing.T) {
   178  	ctrl := gomock.NewController(t)
   179  
   180  	t.Run("unary", func(t *testing.T) {
   181  		out := middlewaretest.NewMockUnaryOutbound(ctrl)
   182  		require.Equal(t, out, UnaryChain(out))
   183  	})
   184  
   185  	t.Run("oneway", func(t *testing.T) {
   186  		out := middlewaretest.NewMockOnewayOutbound(ctrl)
   187  		require.Equal(t, out, OnewayChain(out))
   188  	})
   189  }
   190  
   191  func TestUnaryChainExec(t *testing.T) {
   192  	ctrl := gomock.NewController(t)
   193  	out := transporttest.NewMockUnaryOutbound(ctrl)
   194  
   195  	chain := &unaryChainExec{Final: out}
   196  
   197  	// start
   198  	out.EXPECT().Start().Return(nil)
   199  	assert.NoError(t, chain.Start(), "could not start outbound")
   200  
   201  	// transports
   202  	out.EXPECT().Transports()
   203  	chain.Transports()
   204  
   205  	// is running
   206  	out.EXPECT().IsRunning().Return(true)
   207  	assert.True(t, chain.IsRunning(), "expected outbound to be running")
   208  
   209  	// stop
   210  	out.EXPECT().Stop().Return(nil)
   211  	assert.NoError(t, chain.Stop(), "unexpected error stopping outbound")
   212  }
   213  
   214  func TestOnewayChainExec(t *testing.T) {
   215  	ctrl := gomock.NewController(t)
   216  	out := transporttest.NewMockOnewayOutbound(ctrl)
   217  
   218  	chain := &onewayChainExec{Final: out}
   219  
   220  	// start
   221  	out.EXPECT().Start().Return(nil)
   222  	assert.NoError(t, chain.Start(), "could not start outbound")
   223  
   224  	// transports
   225  	out.EXPECT().Transports()
   226  	chain.Transports()
   227  
   228  	// is running
   229  	out.EXPECT().IsRunning().Return(true)
   230  	assert.True(t, chain.IsRunning(), "expected outbound to be running")
   231  
   232  	// stop
   233  	out.EXPECT().Stop().Return(nil)
   234  	assert.NoError(t, chain.Stop(), "unexpected error stopping outbound")
   235  }
   236  
   237  func TestIntrospect(t *testing.T) {
   238  	ctrl := gomock.NewController(t)
   239  	expectStatus := introspection.OutboundStatusNotSupported
   240  	errMsg := "expected not supported status"
   241  
   242  	t.Run("unary", func(t *testing.T) {
   243  		out := transporttest.NewMockUnaryOutbound(ctrl)
   244  		chain := &unaryChainExec{Final: out}
   245  		assert.Equal(t, expectStatus, chain.Introspect(), errMsg)
   246  	})
   247  
   248  	t.Run("oneway", func(t *testing.T) {
   249  		out := transporttest.NewMockOnewayOutbound(ctrl)
   250  		chain := &onewayChainExec{Final: out}
   251  		assert.Equal(t, expectStatus, chain.Introspect(), errMsg)
   252  	})
   253  }
   254  
   255  var retryStreamOutbound middleware.StreamOutboundFunc = func(
   256  	ctx context.Context, req *transport.StreamRequest, o transport.StreamOutbound) (*transport.ClientStream, error) {
   257  	res, err := o.CallStream(ctx, req)
   258  	if err != nil {
   259  		res, err = o.CallStream(ctx, req)
   260  	}
   261  	return res, err
   262  }
   263  
   264  func TestStreamChain(t *testing.T) {
   265  	before := &countOutboundMiddleware{}
   266  	after := &countOutboundMiddleware{}
   267  
   268  	tests := []struct {
   269  		desc string
   270  		mw   middleware.StreamOutbound
   271  	}{
   272  		{"flat chain", StreamChain(before, retryStreamOutbound, nil, after)},
   273  		{"nested chain", StreamChain(before, StreamChain(retryStreamOutbound, after, nil))},
   274  		{"single chain", StreamChain(StreamChain(before), retryStreamOutbound, StreamChain(after), StreamChain())},
   275  	}
   276  
   277  	for _, tt := range tests {
   278  		t.Run(tt.desc, func(t *testing.T) {
   279  			mockCtrl := gomock.NewController(t)
   280  			defer mockCtrl.Finish()
   281  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   282  			defer cancel()
   283  
   284  			var res *transport.ClientStream
   285  			req := &transport.StreamRequest{
   286  				Meta: &transport.RequestMeta{
   287  					Caller:    "somecaller",
   288  					Service:   "someservice",
   289  					Encoding:  transport.Encoding("raw"),
   290  					Procedure: "hello",
   291  				},
   292  			}
   293  			o := transporttest.NewMockStreamOutbound(mockCtrl)
   294  
   295  			before.Count, after.Count = 0, 0
   296  			o.EXPECT().CallStream(ctx, req).After(
   297  				o.EXPECT().CallStream(ctx, req).Return(nil, errors.New("great sadness")),
   298  			).Return(res, nil)
   299  
   300  			mw := middleware.ApplyStreamOutbound(o, tt.mw)
   301  			gotRes, err := mw.CallStream(ctx, req)
   302  
   303  			assert.NoError(t, err, "expected success")
   304  			assert.Equal(t, 1, before.Count, "expected outer middleware to be called once")
   305  			assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice")
   306  			assert.Equal(t, res, gotRes, "expected response to match")
   307  		})
   308  	}
   309  }
   310  
   311  func TestStreamChainExecFuncs(t *testing.T) {
   312  	mockCtrl := gomock.NewController(t)
   313  	defer mockCtrl.Finish()
   314  
   315  	o := transporttest.NewMockStreamOutbound(mockCtrl)
   316  	o.EXPECT().Stop()
   317  	o.EXPECT().Start()
   318  	o.EXPECT().Transports()
   319  	o.EXPECT().IsRunning().Return(true)
   320  
   321  	mw := streamChainExec{Final: o}
   322  
   323  	assert.Nil(t, mw.Start())
   324  	assert.True(t, mw.IsRunning())
   325  	assert.Nil(t, mw.Stop())
   326  	assert.Len(t, mw.Transports(), 0)
   327  }