go.uber.org/yarpc@v1.72.1/middleware_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package yarpc 22 23 import ( 24 "bytes" 25 "context" 26 "errors" 27 "io/ioutil" 28 "testing" 29 30 "github.com/golang/mock/gomock" 31 "github.com/stretchr/testify/assert" 32 "github.com/stretchr/testify/require" 33 "go.uber.org/yarpc/api/middleware" 34 "go.uber.org/yarpc/api/transport" 35 "go.uber.org/yarpc/api/transport/transporttest" 36 "go.uber.org/yarpc/internal/testtime" 37 ) 38 39 var ( 40 retryUnaryInbound middleware.UnaryInboundFunc = func( 41 ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error { 42 if err := h.Handle(ctx, req, resw); err != nil { 43 return h.Handle(ctx, req, resw) 44 } 45 return nil 46 } 47 48 retryOnewayInbound middleware.OnewayInboundFunc = func( 49 ctx context.Context, req *transport.Request, h transport.OnewayHandler) error { 50 if err := h.HandleOneway(ctx, req); err != nil { 51 return h.HandleOneway(ctx, req) 52 } 53 return nil 54 } 55 56 retryUnaryOutbound middleware.UnaryOutboundFunc = func( 57 ctx context.Context, req *transport.Request, o transport.UnaryOutbound) (*transport.Response, error) { 58 res, err := o.Call(ctx, req) 59 if err != nil { 60 res, err = o.Call(ctx, req) 61 } 62 return res, err 63 } 64 65 retryOnewayOutbound middleware.OnewayOutboundFunc = func( 66 ctx context.Context, req *transport.Request, o transport.OnewayOutbound) (transport.Ack, error) { 67 res, err := o.CallOneway(ctx, req) 68 if err != nil { 69 res, err = o.CallOneway(ctx, req) 70 } 71 return res, err 72 } 73 ) 74 75 type countInboundMiddleware struct{ Count int } 76 77 func (c *countInboundMiddleware) Handle( 78 ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error { 79 c.Count++ 80 return h.Handle(ctx, req, resw) 81 } 82 83 func (c *countInboundMiddleware) HandleOneway(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error { 84 c.Count++ 85 return h.HandleOneway(ctx, req) 86 } 87 88 type countOutboundMiddleware struct{ Count int } 89 90 func (c *countOutboundMiddleware) Call( 91 ctx context.Context, req *transport.Request, o transport.UnaryOutbound) (*transport.Response, error) { 92 c.Count++ 93 return o.Call(ctx, req) 94 } 95 96 func (c *countOutboundMiddleware) CallOneway(ctx context.Context, req *transport.Request, o transport.OnewayOutbound) (transport.Ack, error) { 97 c.Count++ 98 return o.CallOneway(ctx, req) 99 } 100 101 func TestUnaryInboundMiddleware(t *testing.T) { 102 before := &countInboundMiddleware{} 103 after := &countInboundMiddleware{} 104 105 tests := []struct { 106 desc string 107 mw middleware.UnaryInbound 108 }{ 109 {"flat chain", UnaryInboundMiddleware(before, retryUnaryInbound, after, nil)}, 110 {"nested chain", UnaryInboundMiddleware(before, UnaryInboundMiddleware(retryUnaryInbound, nil, after))}, 111 } 112 113 for _, tt := range tests { 114 t.Run(tt.desc, func(t *testing.T) { 115 before.Count, after.Count = 0, 0 116 mockCtrl := gomock.NewController(t) 117 defer mockCtrl.Finish() 118 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 119 defer cancel() 120 121 req := &transport.Request{ 122 Caller: "somecaller", 123 Service: "someservice", 124 Encoding: transport.Encoding("raw"), 125 Procedure: "hello", 126 Body: bytes.NewReader([]byte{1, 2, 3}), 127 } 128 resw := new(transporttest.FakeResponseWriter) 129 h := transporttest.NewMockUnaryHandler(mockCtrl) 130 h.EXPECT().Handle(ctx, req, resw).After( 131 h.EXPECT().Handle(ctx, req, resw).Return(errors.New("great sadness")), 132 ).Return(nil) 133 134 err := middleware.ApplyUnaryInbound(h, tt.mw).Handle(ctx, req, resw) 135 136 assert.NoError(t, err, "expected success") 137 assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once") 138 assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice") 139 }) 140 } 141 } 142 143 func TestOnewayInboundMiddleware(t *testing.T) { 144 before := &countInboundMiddleware{} 145 after := &countInboundMiddleware{} 146 147 tests := []struct { 148 desc string 149 mw middleware.OnewayInbound 150 }{ 151 {"flat chain", OnewayInboundMiddleware(before, retryOnewayInbound, after, nil)}, 152 {"nested chain", OnewayInboundMiddleware(before, OnewayInboundMiddleware(retryOnewayInbound, nil, after))}, 153 } 154 155 for _, tt := range tests { 156 t.Run(tt.desc, func(t *testing.T) { 157 before.Count, after.Count = 0, 0 158 mockCtrl := gomock.NewController(t) 159 defer mockCtrl.Finish() 160 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 161 defer cancel() 162 163 req := &transport.Request{ 164 Caller: "somecaller", 165 Service: "someservice", 166 Encoding: transport.Encoding("raw"), 167 Procedure: "hello", 168 Body: bytes.NewReader([]byte{1, 2, 3}), 169 } 170 h := transporttest.NewMockOnewayHandler(mockCtrl) 171 h.EXPECT().HandleOneway(ctx, req).After( 172 h.EXPECT().HandleOneway(ctx, req).Return(errors.New("great sadness")), 173 ).Return(nil) 174 175 err := middleware.ApplyOnewayInbound(h, tt.mw).HandleOneway(ctx, req) 176 177 assert.NoError(t, err, "expected success") 178 assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once") 179 assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice") 180 }) 181 } 182 } 183 184 func TestUnaryOutboundMiddleware(t *testing.T) { 185 before := &countOutboundMiddleware{} 186 after := &countOutboundMiddleware{} 187 188 tests := []struct { 189 desc string 190 mw middleware.UnaryOutbound 191 }{ 192 {"flat chain", UnaryOutboundMiddleware(before, retryUnaryOutbound, nil, after)}, 193 {"nested chain", UnaryOutboundMiddleware(before, UnaryOutboundMiddleware(retryUnaryOutbound, after, nil))}, 194 } 195 196 for _, tt := range tests { 197 t.Run(tt.desc, func(t *testing.T) { 198 before.Count, after.Count = 0, 0 199 mockCtrl := gomock.NewController(t) 200 defer mockCtrl.Finish() 201 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 202 defer cancel() 203 204 req := &transport.Request{ 205 Caller: "somecaller", 206 Service: "someservice", 207 Encoding: transport.Encoding("raw"), 208 Procedure: "hello", 209 Body: bytes.NewReader([]byte{1, 2, 3}), 210 } 211 res := &transport.Response{ 212 Body: ioutil.NopCloser(bytes.NewReader([]byte{4, 5, 6})), 213 } 214 o := transporttest.NewMockUnaryOutbound(mockCtrl) 215 o.EXPECT().Call(ctx, req).After( 216 o.EXPECT().Call(ctx, req).Return(nil, errors.New("great sadness")), 217 ).Return(res, nil) 218 219 gotRes, err := middleware.ApplyUnaryOutbound(o, tt.mw).Call(ctx, req) 220 221 assert.NoError(t, err, "expected success") 222 assert.Equal(t, 1, before.Count, "expected outer middleware to be called once") 223 assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice") 224 assert.Equal(t, res, gotRes, "expected response to match") 225 }) 226 } 227 } 228 229 func TestOnewayOutboundMiddleware(t *testing.T) { 230 before := &countOutboundMiddleware{} 231 after := &countOutboundMiddleware{} 232 233 tests := []struct { 234 desc string 235 mw middleware.OnewayOutbound 236 }{ 237 {"flat chain", OnewayOutboundMiddleware(before, retryOnewayOutbound, nil, after)}, 238 {"flat chain", OnewayOutboundMiddleware(before, OnewayOutboundMiddleware(retryOnewayOutbound, after, nil))}, 239 } 240 241 for _, tt := range tests { 242 t.Run(tt.desc, func(t *testing.T) { 243 mockCtrl := gomock.NewController(t) 244 defer mockCtrl.Finish() 245 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 246 defer cancel() 247 248 var res transport.Ack 249 req := &transport.Request{ 250 Caller: "somecaller", 251 Service: "someservice", 252 Encoding: transport.Encoding("raw"), 253 Procedure: "hello", 254 Body: bytes.NewReader([]byte{1, 2, 3}), 255 } 256 o := transporttest.NewMockOnewayOutbound(mockCtrl) 257 before.Count, after.Count = 0, 0 258 o.EXPECT().CallOneway(ctx, req).After( 259 o.EXPECT().CallOneway(ctx, req).Return(nil, errors.New("great sadness")), 260 ).Return(res, nil) 261 262 gotRes, err := middleware.ApplyOnewayOutbound(o, tt.mw).CallOneway(ctx, req) 263 264 assert.NoError(t, err, "expected success") 265 assert.Equal(t, 1, before.Count, "expected outer middleware to be called once") 266 assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice") 267 assert.Equal(t, res, gotRes, "expected response to match") 268 }) 269 } 270 } 271 272 func TestStreamInboundMiddlewareChain(t *testing.T) { 273 mockCtrl := gomock.NewController(t) 274 defer mockCtrl.Finish() 275 276 stream, err := transport.NewServerStream(transporttest.NewMockStream(mockCtrl)) 277 require.NoError(t, err) 278 handler := transporttest.NewMockStreamHandler(mockCtrl) 279 handler.EXPECT().HandleStream(stream) 280 281 inboundMW := StreamInboundMiddleware( 282 middleware.NopStreamInbound, 283 middleware.NopStreamInbound, 284 middleware.NopStreamInbound, 285 middleware.NopStreamInbound, 286 ) 287 288 h := middleware.ApplyStreamInbound(handler, inboundMW) 289 290 assert.NoError(t, h.HandleStream(stream)) 291 } 292 293 func TestStreamOutboundMiddlewareChain(t *testing.T) { 294 mockCtrl := gomock.NewController(t) 295 defer mockCtrl.Finish() 296 297 ctx := context.Background() 298 req := &transport.StreamRequest{} 299 300 stream, err := transport.NewClientStream(transporttest.NewMockStreamCloser(mockCtrl)) 301 require.NoError(t, err) 302 303 out := transporttest.NewMockStreamOutbound(mockCtrl) 304 out.EXPECT().CallStream(ctx, req).Return(stream, nil) 305 306 mw := StreamOutboundMiddleware( 307 middleware.NopStreamOutbound, 308 middleware.NopStreamOutbound, 309 middleware.NopStreamOutbound, 310 middleware.NopStreamOutbound, 311 ) 312 313 o := middleware.ApplyStreamOutbound(out, mw) 314 315 s, err := o.CallStream(ctx, req) 316 317 assert.NoError(t, err) 318 assert.Equal(t, stream, s) 319 }