go.uber.org/yarpc@v1.72.1/internal/inboundmiddleware/chain_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package inboundmiddleware 22 23 import ( 24 "bytes" 25 "context" 26 "errors" 27 "testing" 28 29 "github.com/golang/mock/gomock" 30 "github.com/stretchr/testify/assert" 31 "github.com/stretchr/testify/require" 32 "go.uber.org/yarpc/api/middleware" 33 "go.uber.org/yarpc/api/transport" 34 "go.uber.org/yarpc/api/transport/transporttest" 35 "go.uber.org/yarpc/internal/testtime" 36 ) 37 38 type countInboundMiddleware struct{ Count int } 39 40 func (c *countInboundMiddleware) Handle( 41 ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error { 42 c.Count++ 43 return h.Handle(ctx, req, resw) 44 } 45 46 func (c *countInboundMiddleware) HandleOneway(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error { 47 c.Count++ 48 return h.HandleOneway(ctx, req) 49 } 50 51 func (c *countInboundMiddleware) HandleStream(s *transport.ServerStream, h transport.StreamHandler) error { 52 c.Count++ 53 return h.HandleStream(s) 54 } 55 56 var retryUnaryInbound middleware.UnaryInboundFunc = func( 57 ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error { 58 if err := h.Handle(ctx, req, resw); err != nil { 59 return h.Handle(ctx, req, resw) 60 } 61 return nil 62 } 63 64 func TestUnaryChain(t *testing.T) { 65 before := &countInboundMiddleware{} 66 after := &countInboundMiddleware{} 67 68 tests := []struct { 69 desc string 70 mw middleware.UnaryInbound 71 }{ 72 {"flat chain", UnaryChain(before, retryUnaryInbound, after, nil)}, 73 {"nested chain", UnaryChain(before, UnaryChain(retryUnaryInbound, nil, after))}, 74 } 75 76 for _, tt := range tests { 77 t.Run(tt.desc, func(t *testing.T) { 78 before.Count, after.Count = 0, 0 79 mockCtrl := gomock.NewController(t) 80 defer mockCtrl.Finish() 81 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 82 defer cancel() 83 84 req := &transport.Request{ 85 Caller: "somecaller", 86 Service: "someservice", 87 Encoding: transport.Encoding("raw"), 88 Procedure: "hello", 89 Body: bytes.NewReader([]byte{1, 2, 3}), 90 } 91 resw := new(transporttest.FakeResponseWriter) 92 h := transporttest.NewMockUnaryHandler(mockCtrl) 93 h.EXPECT().Handle(ctx, req, resw).After( 94 h.EXPECT().Handle(ctx, req, resw).Return(errors.New("great sadness")), 95 ).Return(nil) 96 97 err := middleware.ApplyUnaryInbound(h, tt.mw).Handle(ctx, req, resw) 98 99 assert.NoError(t, err, "expected success") 100 assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once") 101 assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice") 102 }) 103 } 104 } 105 106 var retryOnewayInbound middleware.OnewayInboundFunc = func( 107 ctx context.Context, req *transport.Request, h transport.OnewayHandler) error { 108 if err := h.HandleOneway(ctx, req); err != nil { 109 return h.HandleOneway(ctx, req) 110 } 111 return nil 112 } 113 114 func TestOnewayChain(t *testing.T) { 115 before := &countInboundMiddleware{} 116 after := &countInboundMiddleware{} 117 118 tests := []struct { 119 desc string 120 mw middleware.OnewayInbound 121 }{ 122 {"flat chain", OnewayChain(before, retryOnewayInbound, after, nil)}, 123 {"nested chain", OnewayChain(before, OnewayChain(retryOnewayInbound, nil, after))}, 124 } 125 126 for _, tt := range tests { 127 t.Run(tt.desc, func(t *testing.T) { 128 before.Count, after.Count = 0, 0 129 mockCtrl := gomock.NewController(t) 130 defer mockCtrl.Finish() 131 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 132 defer cancel() 133 134 req := &transport.Request{ 135 Caller: "somecaller", 136 Service: "someservice", 137 Encoding: transport.Encoding("raw"), 138 Procedure: "hello", 139 Body: bytes.NewReader([]byte{1, 2, 3}), 140 } 141 h := transporttest.NewMockOnewayHandler(mockCtrl) 142 h.EXPECT().HandleOneway(ctx, req).After( 143 h.EXPECT().HandleOneway(ctx, req).Return(errors.New("great sadness")), 144 ).Return(nil) 145 146 err := middleware.ApplyOnewayInbound(h, tt.mw).HandleOneway(ctx, req) 147 148 assert.NoError(t, err, "expected success") 149 assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once") 150 assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice") 151 }) 152 } 153 } 154 155 var retryStreamInbound middleware.StreamInboundFunc = func( 156 s *transport.ServerStream, h transport.StreamHandler) error { 157 if err := h.HandleStream(s); err != nil { 158 return h.HandleStream(s) 159 } 160 return nil 161 } 162 163 func TestStreamChain(t *testing.T) { 164 before := &countInboundMiddleware{} 165 after := &countInboundMiddleware{} 166 167 tests := []struct { 168 desc string 169 mw middleware.StreamInbound 170 }{ 171 {"flat chain", StreamChain(before, retryStreamInbound, after, nil)}, 172 {"nested chain", StreamChain(before, StreamChain(retryStreamInbound, nil, after))}, 173 {"single chain", StreamChain(StreamChain(before), retryStreamInbound, StreamChain(after), StreamChain())}, 174 } 175 176 for _, tt := range tests { 177 t.Run(tt.desc, func(t *testing.T) { 178 before.Count, after.Count = 0, 0 179 mockCtrl := gomock.NewController(t) 180 defer mockCtrl.Finish() 181 s, err := transport.NewServerStream(transporttest.NewMockStream(mockCtrl)) 182 require.NoError(t, err) 183 184 h := transporttest.NewMockStreamHandler(mockCtrl) 185 h.EXPECT().HandleStream(s).After( 186 h.EXPECT().HandleStream(s).Return(errors.New("great sadness")), 187 ).Return(nil) 188 189 err = middleware.ApplyStreamInbound(h, tt.mw).HandleStream(s) 190 191 assert.NoError(t, err, "expected success") 192 assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once") 193 assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice") 194 }) 195 } 196 }