go.uber.org/yarpc@v1.72.1/transport/grpc/outbound_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package grpc 22 23 import ( 24 "bytes" 25 "context" 26 "net" 27 "testing" 28 "time" 29 30 "github.com/gogo/protobuf/types" 31 "github.com/golang/mock/gomock" 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 "go.uber.org/yarpc/api/peer" 35 "go.uber.org/yarpc/api/peer/peertest" 36 "go.uber.org/yarpc/api/transport" 37 "go.uber.org/yarpc/yarpcerrors" 38 "google.golang.org/grpc" 39 ) 40 41 // shared between Unary and Streaming InvalidHeaderValue tests. 42 var malformedValues = []string{ 43 "value with line feed\n", 44 "value with carriage return\r", 45 "value with Nul" + string('\x00'), 46 } 47 48 func TestTransportNamer(t *testing.T) { 49 assert.Equal(t, TransportName, NewTransport().NewOutbound(nil).TransportName()) 50 } 51 52 func TestNoRequest(t *testing.T) { 53 tran := NewTransport() 54 out := tran.NewSingleOutbound("localhost:0") 55 56 _, err := out.Call(context.Background(), nil) 57 assert.Equal(t, yarpcerrors.InvalidArgumentErrorf("request for grpc outbound was nil"), err) 58 } 59 60 func TestCallWithInvalidHeaderValue(t *testing.T) { 61 listener, err := net.Listen("tcp", "127.0.0.1:0") 62 require.NoError(t, err) 63 64 tran := NewTransport() 65 out := tran.NewSingleOutbound(listener.Addr().String()) 66 require.NoError(t, tran.Start()) 67 require.NoError(t, out.Start()) 68 defer tran.Stop() 69 defer out.Stop() 70 71 for _, v := range malformedValues { 72 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) 73 defer cancel() 74 req := &transport.Request{ 75 Caller: "caller", 76 Service: "service", 77 Encoding: transport.Encoding("raw"), 78 Procedure: "proc", 79 Headers: transport.NewHeaders().With("valid-key", v), 80 } 81 _, err = out.Call(ctx, req) 82 83 require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("grpc request header value contains invalid characters including ASCII 0xd, 0xa, or 0x0").Error()) 84 } 85 } 86 87 func TestCallStreamWhenNotRunning(t *testing.T) { 88 listener, err := net.Listen("tcp", "127.0.0.1:0") 89 require.NoError(t, err) 90 91 tran := NewTransport() 92 out := tran.NewSingleOutbound(listener.Addr().String()) 93 94 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) 95 defer cancel() 96 _, err = out.CallStream(ctx, &transport.StreamRequest{}) 97 98 require.Contains(t, err.Error(), context.DeadlineExceeded.Error()) 99 } 100 101 func TestCallStreamWithNoRequestMeta(t *testing.T) { 102 listener, err := net.Listen("tcp", "127.0.0.1:0") 103 require.NoError(t, err) 104 105 tran := NewTransport() 106 out := tran.NewSingleOutbound(listener.Addr().String()) 107 require.NoError(t, tran.Start()) 108 require.NoError(t, out.Start()) 109 defer tran.Stop() 110 defer out.Stop() 111 112 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) 113 defer cancel() 114 _, err = out.CallStream(ctx, &transport.StreamRequest{}) 115 116 require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("stream request requires a request metadata").Error()) 117 } 118 119 func TestCallWithReservedHeaderKey(t *testing.T) { 120 listener, err := net.Listen("tcp", "127.0.0.1:0") 121 require.NoError(t, err) 122 123 tran := NewTransport() 124 out := tran.NewSingleOutbound(listener.Addr().String()) 125 require.NoError(t, tran.Start()) 126 require.NoError(t, out.Start()) 127 defer tran.Stop() 128 defer out.Stop() 129 130 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) 131 defer cancel() 132 req := &transport.StreamRequest{ 133 Meta: &transport.RequestMeta{ 134 Caller: "caller", 135 Service: "service", 136 Encoding: transport.Encoding("raw"), 137 Procedure: "proc", 138 Headers: transport.NewHeaders().With("rpc-caller", "reserved header"), 139 }, 140 } 141 _, err = out.CallStream(ctx, req) 142 143 require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: rpc-caller").Error()) 144 } 145 146 func TestCallStreamWithInvalidProcedure(t *testing.T) { 147 listener, err := net.Listen("tcp", "127.0.0.1:0") 148 require.NoError(t, err) 149 150 tran := NewTransport() 151 out := tran.NewSingleOutbound(listener.Addr().String()) 152 require.NoError(t, tran.Start()) 153 require.NoError(t, out.Start()) 154 defer tran.Stop() 155 defer out.Stop() 156 157 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) 158 defer cancel() 159 req := &transport.StreamRequest{ 160 Meta: &transport.RequestMeta{ 161 Caller: "caller", 162 Service: "service", 163 Encoding: transport.Encoding("raw"), 164 Procedure: "", 165 }, 166 } 167 _, err = out.CallStream(ctx, req) 168 169 require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("invalid procedure name: ").Error()) 170 } 171 172 func TestCallStreamWithInvalidHeaderValue(t *testing.T) { 173 listener, err := net.Listen("tcp", "127.0.0.1:0") 174 require.NoError(t, err) 175 176 tran := NewTransport() 177 out := tran.NewSingleOutbound(listener.Addr().String()) 178 require.NoError(t, tran.Start()) 179 require.NoError(t, out.Start()) 180 defer tran.Stop() 181 defer out.Stop() 182 183 for _, v := range malformedValues { 184 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) 185 defer cancel() 186 req := &transport.StreamRequest{ 187 Meta: &transport.RequestMeta{ 188 Caller: "caller", 189 Service: "service", 190 Encoding: transport.Encoding("raw"), 191 Procedure: "proc", 192 Headers: transport.NewHeaders().With("valid-key", v), 193 }, 194 } 195 _, err = out.CallStream(ctx, req) 196 197 require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("grpc request header value contains invalid characters including ASCII 0xd, 0xa, or 0x0").Error()) 198 } 199 } 200 201 func TestCallStreamWithChooserError(t *testing.T) { 202 mockCtrl := gomock.NewController(t) 203 defer mockCtrl.Finish() 204 205 chooser := peertest.NewMockChooser(mockCtrl) 206 chooser.EXPECT().Start() 207 chooser.EXPECT().Stop() 208 chooser.EXPECT().Choose(gomock.Any(), gomock.Any()).Return(nil, nil, yarpcerrors.InternalErrorf("error")) 209 210 tran := NewTransport() 211 out := tran.NewOutbound(chooser) 212 213 require.NoError(t, tran.Start()) 214 require.NoError(t, out.Start()) 215 defer tran.Stop() 216 defer out.Stop() 217 218 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) 219 defer cancel() 220 req := &transport.StreamRequest{ 221 Meta: &transport.RequestMeta{ 222 Caller: "caller", 223 Service: "service", 224 Encoding: transport.Encoding("raw"), 225 Procedure: "proc", 226 }, 227 } 228 _, err := out.CallStream(ctx, req) 229 230 require.Contains(t, err.Error(), yarpcerrors.InternalErrorf("error").Error()) 231 } 232 233 func TestCallStreamWithInvalidPeer(t *testing.T) { 234 mockCtrl := gomock.NewController(t) 235 defer mockCtrl.Finish() 236 237 fakePeer := peertest.NewMockPeer(mockCtrl) 238 chooser := peertest.NewMockChooser(mockCtrl) 239 chooser.EXPECT().Start() 240 chooser.EXPECT().Stop() 241 chooser.EXPECT().Choose(gomock.Any(), gomock.Any()).Return(fakePeer, func(error) {}, nil) 242 243 tran := NewTransport() 244 out := tran.NewOutbound(chooser) 245 246 require.NoError(t, tran.Start()) 247 require.NoError(t, out.Start()) 248 defer tran.Stop() 249 defer out.Stop() 250 251 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) 252 defer cancel() 253 req := &transport.StreamRequest{ 254 Meta: &transport.RequestMeta{ 255 Caller: "caller", 256 Service: "service", 257 Encoding: transport.Encoding("raw"), 258 Procedure: "proc", 259 }, 260 } 261 _, err := out.CallStream(ctx, req) 262 263 require.Contains( 264 t, 265 err.Error(), 266 peer.ErrInvalidPeerConversion{ 267 Peer: fakePeer, 268 ExpectedType: "*grpcPeer", 269 }.Error(), 270 ) 271 } 272 273 func TestCallServiceMatch(t *testing.T) { 274 tests := []struct { 275 msg string 276 headerKey string 277 headerValue string 278 wantErr bool 279 }{ 280 { 281 msg: "call service match success", 282 headerKey: ServiceHeader, 283 headerValue: "Service", 284 }, 285 { 286 msg: "call service match failed", 287 headerKey: ServiceHeader, 288 headerValue: "ThisIsWrongSvcName", 289 wantErr: true, 290 }, 291 { 292 msg: "no service name response header", 293 }, 294 } 295 for _, tt := range tests { 296 t.Run(tt.msg, func(t *testing.T) { 297 server := grpc.NewServer( 298 grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error { 299 responseWriter := newResponseWriter() 300 defer responseWriter.Close() 301 302 if tt.headerKey != "" { 303 responseWriter.AddSystemHeader(tt.headerKey, tt.headerValue) 304 } 305 306 // Send the response attributes back and end the stream. 307 if sendErr := stream.SendMsg(&types.Empty{}); sendErr != nil { 308 // We couldn't send the response. 309 return sendErr 310 } 311 if responseWriter.md != nil { 312 stream.SetTrailer(responseWriter.md) 313 } 314 return nil 315 }), 316 ) 317 listener, err := net.Listen("tcp", "127.0.0.1:0") 318 require.NoError(t, err) 319 go func() { 320 err := server.Serve(listener) 321 require.NoError(t, err) 322 }() 323 defer server.Stop() 324 325 grpcTransport := NewTransport() 326 out := grpcTransport.NewSingleOutbound(listener.Addr().String()) 327 require.NoError(t, grpcTransport.Start()) 328 require.NoError(t, out.Start()) 329 defer grpcTransport.Stop() 330 defer out.Stop() 331 332 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 333 defer cancel() 334 req := &transport.Request{ 335 Service: "Service", 336 Procedure: "Hello", 337 Body: bytes.NewReader([]byte("world")), 338 } 339 _, err = out.Call(ctx, req) 340 if tt.wantErr { 341 require.Error(t, err) 342 assert.Contains(t, err.Error(), "does not match") 343 } else { 344 require.NoError(t, err) 345 } 346 }) 347 } 348 } 349 350 func TestOutboundIntrospection(t *testing.T) { 351 listener, err := net.Listen("tcp", "127.0.0.1:0") 352 require.NoError(t, err) 353 354 grpcTransport := NewTransport() 355 o := grpcTransport.NewSingleOutbound(listener.Addr().String()) 356 357 assert.Equal(t, TransportName, o.Introspect().Transport) 358 assert.Equal(t, "Stopped", o.Introspect().State) 359 assert.False(t, o.IsRunning()) 360 361 require.NoError(t, o.Start(), "could not start outbound") 362 assert.Equal(t, "Running", o.Introspect().State) 363 364 require.NoError(t, o.Stop(), "could not stop outbound") 365 assert.Equal(t, "Stopped", o.Introspect().State) 366 }