go.uber.org/yarpc@v1.72.1/internal/outboundmiddleware/chain_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package outboundmiddleware 22 23 import ( 24 "bytes" 25 "context" 26 "errors" 27 "io/ioutil" 28 "testing" 29 30 "github.com/golang/mock/gomock" 31 "github.com/stretchr/testify/assert" 32 "github.com/stretchr/testify/require" 33 "go.uber.org/yarpc/api/middleware" 34 "go.uber.org/yarpc/api/middleware/middlewaretest" 35 "go.uber.org/yarpc/api/transport" 36 "go.uber.org/yarpc/api/transport/transporttest" 37 "go.uber.org/yarpc/api/x/introspection" 38 "go.uber.org/yarpc/internal/testtime" 39 ) 40 41 type countOutboundMiddleware struct{ Count int } 42 43 func (c *countOutboundMiddleware) Call( 44 ctx context.Context, req *transport.Request, o transport.UnaryOutbound) (*transport.Response, error) { 45 c.Count++ 46 return o.Call(ctx, req) 47 } 48 49 func (c *countOutboundMiddleware) CallOneway(ctx context.Context, req *transport.Request, o transport.OnewayOutbound) (transport.Ack, error) { 50 c.Count++ 51 return o.CallOneway(ctx, req) 52 } 53 54 func (c *countOutboundMiddleware) CallStream(ctx context.Context, req *transport.StreamRequest, o transport.StreamOutbound) (*transport.ClientStream, error) { 55 c.Count++ 56 return o.CallStream(ctx, req) 57 } 58 59 var retryUnaryOutbound middleware.UnaryOutboundFunc = func( 60 ctx context.Context, req *transport.Request, o transport.UnaryOutbound) (*transport.Response, error) { 61 res, err := o.Call(ctx, req) 62 if err != nil { 63 res, err = o.Call(ctx, req) 64 } 65 return res, err 66 } 67 68 func TestUnaryChain(t *testing.T) { 69 before := &countOutboundMiddleware{} 70 after := &countOutboundMiddleware{} 71 72 tests := []struct { 73 desc string 74 mw middleware.UnaryOutbound 75 }{ 76 {"flat chain", UnaryChain(before, retryUnaryOutbound, nil, after)}, 77 {"nested chain", UnaryChain(before, UnaryChain(retryUnaryOutbound, after, nil))}, 78 } 79 80 for _, tt := range tests { 81 t.Run(tt.desc, func(t *testing.T) { 82 before.Count, after.Count = 0, 0 83 mockCtrl := gomock.NewController(t) 84 defer mockCtrl.Finish() 85 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 86 defer cancel() 87 88 req := &transport.Request{ 89 Caller: "somecaller", 90 Service: "someservice", 91 Encoding: transport.Encoding("raw"), 92 Procedure: "hello", 93 Body: bytes.NewReader([]byte{1, 2, 3}), 94 } 95 res := &transport.Response{ 96 Body: ioutil.NopCloser(bytes.NewReader([]byte{4, 5, 6})), 97 } 98 o := transporttest.NewMockUnaryOutbound(mockCtrl) 99 o.EXPECT().Call(ctx, req).After( 100 o.EXPECT().Call(ctx, req).Return(nil, errors.New("great sadness")), 101 ).Return(res, nil) 102 103 gotRes, err := middleware.ApplyUnaryOutbound(o, tt.mw).Call(ctx, req) 104 105 assert.NoError(t, err, "expected success") 106 assert.Equal(t, 1, before.Count, "expected outer middleware to be called once") 107 assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice") 108 assert.Equal(t, res, gotRes, "expected response to match") 109 }) 110 } 111 } 112 113 var retryOnewayOutbound middleware.OnewayOutboundFunc = func( 114 ctx context.Context, req *transport.Request, o transport.OnewayOutbound) (transport.Ack, error) { 115 res, err := o.CallOneway(ctx, req) 116 if err != nil { 117 res, err = o.CallOneway(ctx, req) 118 } 119 return res, err 120 } 121 122 func TestOnewayChain(t *testing.T) { 123 before := &countOutboundMiddleware{} 124 after := &countOutboundMiddleware{} 125 126 tests := []struct { 127 desc string 128 mw middleware.OnewayOutbound 129 }{ 130 {"flat chain", OnewayChain(before, retryOnewayOutbound, nil, after)}, 131 {"flat chain", OnewayChain(before, OnewayChain(retryOnewayOutbound, after, nil))}, 132 } 133 134 for _, tt := range tests { 135 t.Run(tt.desc, func(t *testing.T) { 136 mockCtrl := gomock.NewController(t) 137 defer mockCtrl.Finish() 138 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 139 defer cancel() 140 141 var res transport.Ack 142 req := &transport.Request{ 143 Caller: "somecaller", 144 Service: "someservice", 145 Encoding: transport.Encoding("raw"), 146 Procedure: "hello", 147 Body: bytes.NewReader([]byte{1, 2, 3}), 148 } 149 o := transporttest.NewMockOnewayOutbound(mockCtrl) 150 before.Count, after.Count = 0, 0 151 o.EXPECT().CallOneway(ctx, req).After( 152 o.EXPECT().CallOneway(ctx, req).Return(nil, errors.New("great sadness")), 153 ).Return(res, nil) 154 155 gotRes, err := middleware.ApplyOnewayOutbound(o, tt.mw).CallOneway(ctx, req) 156 157 assert.NoError(t, err, "expected success") 158 assert.Equal(t, 1, before.Count, "expected outer middleware to be called once") 159 assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice") 160 assert.Equal(t, res, gotRes, "expected response to match") 161 }) 162 } 163 } 164 165 func TestEmptyChain(t *testing.T) { 166 errMsg := "expected nop Outbound" 167 168 t.Run("unary", func(t *testing.T) { 169 require.Equal(t, middleware.NopUnaryOutbound, UnaryChain(), errMsg) 170 }) 171 172 t.Run("oneway", func(t *testing.T) { 173 require.Equal(t, middleware.NopOnewayOutbound, OnewayChain(), errMsg) 174 }) 175 } 176 177 func TestSingleOutboundChain(t *testing.T) { 178 ctrl := gomock.NewController(t) 179 180 t.Run("unary", func(t *testing.T) { 181 out := middlewaretest.NewMockUnaryOutbound(ctrl) 182 require.Equal(t, out, UnaryChain(out)) 183 }) 184 185 t.Run("oneway", func(t *testing.T) { 186 out := middlewaretest.NewMockOnewayOutbound(ctrl) 187 require.Equal(t, out, OnewayChain(out)) 188 }) 189 } 190 191 func TestUnaryChainExec(t *testing.T) { 192 ctrl := gomock.NewController(t) 193 out := transporttest.NewMockUnaryOutbound(ctrl) 194 195 chain := &unaryChainExec{Final: out} 196 197 // start 198 out.EXPECT().Start().Return(nil) 199 assert.NoError(t, chain.Start(), "could not start outbound") 200 201 // transports 202 out.EXPECT().Transports() 203 chain.Transports() 204 205 // is running 206 out.EXPECT().IsRunning().Return(true) 207 assert.True(t, chain.IsRunning(), "expected outbound to be running") 208 209 // stop 210 out.EXPECT().Stop().Return(nil) 211 assert.NoError(t, chain.Stop(), "unexpected error stopping outbound") 212 } 213 214 func TestOnewayChainExec(t *testing.T) { 215 ctrl := gomock.NewController(t) 216 out := transporttest.NewMockOnewayOutbound(ctrl) 217 218 chain := &onewayChainExec{Final: out} 219 220 // start 221 out.EXPECT().Start().Return(nil) 222 assert.NoError(t, chain.Start(), "could not start outbound") 223 224 // transports 225 out.EXPECT().Transports() 226 chain.Transports() 227 228 // is running 229 out.EXPECT().IsRunning().Return(true) 230 assert.True(t, chain.IsRunning(), "expected outbound to be running") 231 232 // stop 233 out.EXPECT().Stop().Return(nil) 234 assert.NoError(t, chain.Stop(), "unexpected error stopping outbound") 235 } 236 237 func TestIntrospect(t *testing.T) { 238 ctrl := gomock.NewController(t) 239 expectStatus := introspection.OutboundStatusNotSupported 240 errMsg := "expected not supported status" 241 242 t.Run("unary", func(t *testing.T) { 243 out := transporttest.NewMockUnaryOutbound(ctrl) 244 chain := &unaryChainExec{Final: out} 245 assert.Equal(t, expectStatus, chain.Introspect(), errMsg) 246 }) 247 248 t.Run("oneway", func(t *testing.T) { 249 out := transporttest.NewMockOnewayOutbound(ctrl) 250 chain := &onewayChainExec{Final: out} 251 assert.Equal(t, expectStatus, chain.Introspect(), errMsg) 252 }) 253 } 254 255 var retryStreamOutbound middleware.StreamOutboundFunc = func( 256 ctx context.Context, req *transport.StreamRequest, o transport.StreamOutbound) (*transport.ClientStream, error) { 257 res, err := o.CallStream(ctx, req) 258 if err != nil { 259 res, err = o.CallStream(ctx, req) 260 } 261 return res, err 262 } 263 264 func TestStreamChain(t *testing.T) { 265 before := &countOutboundMiddleware{} 266 after := &countOutboundMiddleware{} 267 268 tests := []struct { 269 desc string 270 mw middleware.StreamOutbound 271 }{ 272 {"flat chain", StreamChain(before, retryStreamOutbound, nil, after)}, 273 {"nested chain", StreamChain(before, StreamChain(retryStreamOutbound, after, nil))}, 274 {"single chain", StreamChain(StreamChain(before), retryStreamOutbound, StreamChain(after), StreamChain())}, 275 } 276 277 for _, tt := range tests { 278 t.Run(tt.desc, func(t *testing.T) { 279 mockCtrl := gomock.NewController(t) 280 defer mockCtrl.Finish() 281 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 282 defer cancel() 283 284 var res *transport.ClientStream 285 req := &transport.StreamRequest{ 286 Meta: &transport.RequestMeta{ 287 Caller: "somecaller", 288 Service: "someservice", 289 Encoding: transport.Encoding("raw"), 290 Procedure: "hello", 291 }, 292 } 293 o := transporttest.NewMockStreamOutbound(mockCtrl) 294 295 before.Count, after.Count = 0, 0 296 o.EXPECT().CallStream(ctx, req).After( 297 o.EXPECT().CallStream(ctx, req).Return(nil, errors.New("great sadness")), 298 ).Return(res, nil) 299 300 mw := middleware.ApplyStreamOutbound(o, tt.mw) 301 gotRes, err := mw.CallStream(ctx, req) 302 303 assert.NoError(t, err, "expected success") 304 assert.Equal(t, 1, before.Count, "expected outer middleware to be called once") 305 assert.Equal(t, 2, after.Count, "expected inner middleware to be called twice") 306 assert.Equal(t, res, gotRes, "expected response to match") 307 }) 308 } 309 } 310 311 func TestStreamChainExecFuncs(t *testing.T) { 312 mockCtrl := gomock.NewController(t) 313 defer mockCtrl.Finish() 314 315 o := transporttest.NewMockStreamOutbound(mockCtrl) 316 o.EXPECT().Stop() 317 o.EXPECT().Start() 318 o.EXPECT().Transports() 319 o.EXPECT().IsRunning().Return(true) 320 321 mw := streamChainExec{Final: o} 322 323 assert.Nil(t, mw.Start()) 324 assert.True(t, mw.IsRunning()) 325 assert.Nil(t, mw.Stop()) 326 assert.Len(t, mw.Transports(), 0) 327 }