go.uber.org/yarpc@v1.72.1/middleware_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package yarpc
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"errors"
    27  	"io/ioutil"
    28  	"testing"
    29  
    30  	"github.com/golang/mock/gomock"
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  	"go.uber.org/yarpc/api/middleware"
    34  	"go.uber.org/yarpc/api/transport"
    35  	"go.uber.org/yarpc/api/transport/transporttest"
    36  	"go.uber.org/yarpc/internal/testtime"
    37  )
    38  
    39  var (
    40  	retryUnaryInbound middleware.UnaryInboundFunc = func(
    41  		ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error {
    42  		if err := h.Handle(ctx, req, resw); err != nil {
    43  			return h.Handle(ctx, req, resw)
    44  		}
    45  		return nil
    46  	}
    47  
    48  	retryOnewayInbound middleware.OnewayInboundFunc = func(
    49  		ctx context.Context, req *transport.Request, h transport.OnewayHandler) error {
    50  		if err := h.HandleOneway(ctx, req); err != nil {
    51  			return h.HandleOneway(ctx, req)
    52  		}
    53  		return nil
    54  	}
    55  
    56  	retryUnaryOutbound middleware.UnaryOutboundFunc = func(
    57  		ctx context.Context, req *transport.Request, o transport.UnaryOutbound) (*transport.Response, error) {
    58  		res, err := o.Call(ctx, req)
    59  		if err != nil {
    60  			res, err = o.Call(ctx, req)
    61  		}
    62  		return res, err
    63  	}
    64  
    65  	retryOnewayOutbound middleware.OnewayOutboundFunc = func(
    66  		ctx context.Context, req *transport.Request, o transport.OnewayOutbound) (transport.Ack, error) {
    67  		res, err := o.CallOneway(ctx, req)
    68  		if err != nil {
    69  			res, err = o.CallOneway(ctx, req)
    70  		}
    71  		return res, err
    72  	}
    73  )
    74  
    75  type countInboundMiddleware struct{ Count int }
    76  
    77  func (c *countInboundMiddleware) Handle(
    78  	ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error {
    79  	c.Count++
    80  	return h.Handle(ctx, req, resw)
    81  }
    82  
    83  func (c *countInboundMiddleware) HandleOneway(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error {
    84  	c.Count++
    85  	return h.HandleOneway(ctx, req)
    86  }
    87  
    88  type countOutboundMiddleware struct{ Count int }
    89  
    90  func (c *countOutboundMiddleware) Call(
    91  	ctx context.Context, req *transport.Request, o transport.UnaryOutbound) (*transport.Response, error) {
    92  	c.Count++
    93  	return o.Call(ctx, req)
    94  }
    95  
    96  func (c *countOutboundMiddleware) CallOneway(ctx context.Context, req *transport.Request, o transport.OnewayOutbound) (transport.Ack, error) {
    97  	c.Count++
    98  	return o.CallOneway(ctx, req)
    99  }
   100  
   101  func TestUnaryInboundMiddleware(t *testing.T) {
   102  	before := &countInboundMiddleware{}
   103  	after := &countInboundMiddleware{}
   104  
   105  	tests := []struct {
   106  		desc string
   107  		mw   middleware.UnaryInbound
   108  	}{
   109  		{"flat chain", UnaryInboundMiddleware(before, retryUnaryInbound, after, nil)},
   110  		{"nested chain", UnaryInboundMiddleware(before, UnaryInboundMiddleware(retryUnaryInbound, nil, after))},
   111  	}
   112  
   113  	for _, tt := range tests {
   114  		t.Run(tt.desc, func(t *testing.T) {
   115  			before.Count, after.Count = 0, 0
   116  			mockCtrl := gomock.NewController(t)
   117  			defer mockCtrl.Finish()
   118  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   119  			defer cancel()
   120  
   121  			req := &transport.Request{
   122  				Caller:    "somecaller",
   123  				Service:   "someservice",
   124  				Encoding:  transport.Encoding("raw"),
   125  				Procedure: "hello",
   126  				Body:      bytes.NewReader([]byte{1, 2, 3}),
   127  			}
   128  			resw := new(transporttest.FakeResponseWriter)
   129  			h := transporttest.NewMockUnaryHandler(mockCtrl)
   130  			h.EXPECT().Handle(ctx, req, resw).After(
   131  				h.EXPECT().Handle(ctx, req, resw).Return(errors.New("great sadness")),
   132  			).Return(nil)
   133  
   134  			err := middleware.ApplyUnaryInbound(h, tt.mw).Handle(ctx, req, resw)
   135  
   136  			assert.NoError(t, err, "expected success")
   137  			assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once")
   138  			assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice")
   139  		})
   140  	}
   141  }
   142  
   143  func TestOnewayInboundMiddleware(t *testing.T) {
   144  	before := &countInboundMiddleware{}
   145  	after := &countInboundMiddleware{}
   146  
   147  	tests := []struct {
   148  		desc string
   149  		mw   middleware.OnewayInbound
   150  	}{
   151  		{"flat chain", OnewayInboundMiddleware(before, retryOnewayInbound, after, nil)},
   152  		{"nested chain", OnewayInboundMiddleware(before, OnewayInboundMiddleware(retryOnewayInbound, nil, after))},
   153  	}
   154  
   155  	for _, tt := range tests {
   156  		t.Run(tt.desc, func(t *testing.T) {
   157  			before.Count, after.Count = 0, 0
   158  			mockCtrl := gomock.NewController(t)
   159  			defer mockCtrl.Finish()
   160  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   161  			defer cancel()
   162  
   163  			req := &transport.Request{
   164  				Caller:    "somecaller",
   165  				Service:   "someservice",
   166  				Encoding:  transport.Encoding("raw"),
   167  				Procedure: "hello",
   168  				Body:      bytes.NewReader([]byte{1, 2, 3}),
   169  			}
   170  			h := transporttest.NewMockOnewayHandler(mockCtrl)
   171  			h.EXPECT().HandleOneway(ctx, req).After(
   172  				h.EXPECT().HandleOneway(ctx, req).Return(errors.New("great sadness")),
   173  			).Return(nil)
   174  
   175  			err := middleware.ApplyOnewayInbound(h, tt.mw).HandleOneway(ctx, req)
   176  
   177  			assert.NoError(t, err, "expected success")
   178  			assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once")
   179  			assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice")
   180  		})
   181  	}
   182  }
   183  
   184  func TestUnaryOutboundMiddleware(t *testing.T) {
   185  	before := &countOutboundMiddleware{}
   186  	after := &countOutboundMiddleware{}
   187  
   188  	tests := []struct {
   189  		desc string
   190  		mw   middleware.UnaryOutbound
   191  	}{
   192  		{"flat chain", UnaryOutboundMiddleware(before, retryUnaryOutbound, nil, after)},
   193  		{"nested chain", UnaryOutboundMiddleware(before, UnaryOutboundMiddleware(retryUnaryOutbound, after, nil))},
   194  	}
   195  
   196  	for _, tt := range tests {
   197  		t.Run(tt.desc, func(t *testing.T) {
   198  			before.Count, after.Count = 0, 0
   199  			mockCtrl := gomock.NewController(t)
   200  			defer mockCtrl.Finish()
   201  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   202  			defer cancel()
   203  
   204  			req := &transport.Request{
   205  				Caller:    "somecaller",
   206  				Service:   "someservice",
   207  				Encoding:  transport.Encoding("raw"),
   208  				Procedure: "hello",
   209  				Body:      bytes.NewReader([]byte{1, 2, 3}),
   210  			}
   211  			res := &transport.Response{
   212  				Body: ioutil.NopCloser(bytes.NewReader([]byte{4, 5, 6})),
   213  			}
   214  			o := transporttest.NewMockUnaryOutbound(mockCtrl)
   215  			o.EXPECT().Call(ctx, req).After(
   216  				o.EXPECT().Call(ctx, req).Return(nil, errors.New("great sadness")),
   217  			).Return(res, nil)
   218  
   219  			gotRes, err := middleware.ApplyUnaryOutbound(o, tt.mw).Call(ctx, req)
   220  
   221  			assert.NoError(t, err, "expected success")
   222  			assert.Equal(t, 1, before.Count, "expected outer middleware to be called once")
   223  			assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice")
   224  			assert.Equal(t, res, gotRes, "expected response to match")
   225  		})
   226  	}
   227  }
   228  
   229  func TestOnewayOutboundMiddleware(t *testing.T) {
   230  	before := &countOutboundMiddleware{}
   231  	after := &countOutboundMiddleware{}
   232  
   233  	tests := []struct {
   234  		desc string
   235  		mw   middleware.OnewayOutbound
   236  	}{
   237  		{"flat chain", OnewayOutboundMiddleware(before, retryOnewayOutbound, nil, after)},
   238  		{"flat chain", OnewayOutboundMiddleware(before, OnewayOutboundMiddleware(retryOnewayOutbound, after, nil))},
   239  	}
   240  
   241  	for _, tt := range tests {
   242  		t.Run(tt.desc, func(t *testing.T) {
   243  			mockCtrl := gomock.NewController(t)
   244  			defer mockCtrl.Finish()
   245  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   246  			defer cancel()
   247  
   248  			var res transport.Ack
   249  			req := &transport.Request{
   250  				Caller:    "somecaller",
   251  				Service:   "someservice",
   252  				Encoding:  transport.Encoding("raw"),
   253  				Procedure: "hello",
   254  				Body:      bytes.NewReader([]byte{1, 2, 3}),
   255  			}
   256  			o := transporttest.NewMockOnewayOutbound(mockCtrl)
   257  			before.Count, after.Count = 0, 0
   258  			o.EXPECT().CallOneway(ctx, req).After(
   259  				o.EXPECT().CallOneway(ctx, req).Return(nil, errors.New("great sadness")),
   260  			).Return(res, nil)
   261  
   262  			gotRes, err := middleware.ApplyOnewayOutbound(o, tt.mw).CallOneway(ctx, req)
   263  
   264  			assert.NoError(t, err, "expected success")
   265  			assert.Equal(t, 1, before.Count, "expected outer middleware to be called once")
   266  			assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice")
   267  			assert.Equal(t, res, gotRes, "expected response to match")
   268  		})
   269  	}
   270  }
   271  
   272  func TestStreamInboundMiddlewareChain(t *testing.T) {
   273  	mockCtrl := gomock.NewController(t)
   274  	defer mockCtrl.Finish()
   275  
   276  	stream, err := transport.NewServerStream(transporttest.NewMockStream(mockCtrl))
   277  	require.NoError(t, err)
   278  	handler := transporttest.NewMockStreamHandler(mockCtrl)
   279  	handler.EXPECT().HandleStream(stream)
   280  
   281  	inboundMW := StreamInboundMiddleware(
   282  		middleware.NopStreamInbound,
   283  		middleware.NopStreamInbound,
   284  		middleware.NopStreamInbound,
   285  		middleware.NopStreamInbound,
   286  	)
   287  
   288  	h := middleware.ApplyStreamInbound(handler, inboundMW)
   289  
   290  	assert.NoError(t, h.HandleStream(stream))
   291  }
   292  
   293  func TestStreamOutboundMiddlewareChain(t *testing.T) {
   294  	mockCtrl := gomock.NewController(t)
   295  	defer mockCtrl.Finish()
   296  
   297  	ctx := context.Background()
   298  	req := &transport.StreamRequest{}
   299  
   300  	stream, err := transport.NewClientStream(transporttest.NewMockStreamCloser(mockCtrl))
   301  	require.NoError(t, err)
   302  
   303  	out := transporttest.NewMockStreamOutbound(mockCtrl)
   304  	out.EXPECT().CallStream(ctx, req).Return(stream, nil)
   305  
   306  	mw := StreamOutboundMiddleware(
   307  		middleware.NopStreamOutbound,
   308  		middleware.NopStreamOutbound,
   309  		middleware.NopStreamOutbound,
   310  		middleware.NopStreamOutbound,
   311  	)
   312  
   313  	o := middleware.ApplyStreamOutbound(out, mw)
   314  
   315  	s, err := o.CallStream(ctx, req)
   316  
   317  	assert.NoError(t, err)
   318  	assert.Equal(t, stream, s)
   319  }