google.golang.org/grpc@v1.72.2/internal/transport/handler_server_test.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package transport 20 21 import ( 22 "context" 23 "errors" 24 "fmt" 25 "io" 26 "net/http" 27 "net/http/httptest" 28 "net/url" 29 "reflect" 30 "sync" 31 "testing" 32 "time" 33 34 epb "google.golang.org/genproto/googleapis/rpc/errdetails" 35 "google.golang.org/grpc/codes" 36 "google.golang.org/grpc/mem" 37 "google.golang.org/grpc/metadata" 38 "google.golang.org/grpc/status" 39 "google.golang.org/protobuf/proto" 40 "google.golang.org/protobuf/protoadapt" 41 "google.golang.org/protobuf/types/known/durationpb" 42 ) 43 44 func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { 45 type testCase struct { 46 name string 47 req *http.Request 48 wantErr string 49 wantErrCode int 50 modrw func(http.ResponseWriter) http.ResponseWriter 51 check func(*serverHandlerTransport, *testCase) error 52 } 53 tests := []testCase{ 54 { 55 name: "bad method", 56 req: &http.Request{ 57 ProtoMajor: 2, 58 Method: "GET", 59 Header: http.Header{}, 60 }, 61 wantErr: `invalid gRPC request method "GET"`, 62 wantErrCode: http.StatusMethodNotAllowed, 63 }, 64 { 65 name: "bad content type", 66 req: &http.Request{ 67 ProtoMajor: 2, 68 Method: "POST", 69 Header: http.Header{ 70 "Content-Type": {"application/foo"}, 71 }, 72 }, 73 wantErr: `invalid gRPC request content-type "application/foo"`, 74 wantErrCode: http.StatusUnsupportedMediaType, 75 }, 76 { 77 name: "http/1.1", 78 req: &http.Request{ 79 ProtoMajor: 1, 80 ProtoMinor: 1, 81 Method: "POST", 82 Header: http.Header{"Content-Type": []string{"application/grpc"}}, 83 }, 84 wantErr: "gRPC requires HTTP/2", 85 wantErrCode: http.StatusHTTPVersionNotSupported, 86 }, 87 { 88 name: "not flusher", 89 req: &http.Request{ 90 ProtoMajor: 2, 91 Method: "POST", 92 Header: http.Header{ 93 "Content-Type": {"application/grpc"}, 94 }, 95 }, 96 modrw: func(w http.ResponseWriter) http.ResponseWriter { 97 // Return w without its Flush method 98 type onlyCloseNotifier interface { 99 http.ResponseWriter 100 } 101 return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)} 102 }, 103 wantErr: "gRPC requires a ResponseWriter supporting http.Flusher", 104 wantErrCode: http.StatusInternalServerError, 105 }, 106 { 107 name: "valid", 108 req: &http.Request{ 109 ProtoMajor: 2, 110 Method: "POST", 111 Header: http.Header{ 112 "Content-Type": {"application/grpc"}, 113 }, 114 URL: &url.URL{ 115 Path: "/service/foo.bar", 116 }, 117 }, 118 check: func(t *serverHandlerTransport, tt *testCase) error { 119 if t.req != tt.req { 120 return fmt.Errorf("t.req = %p; want %p", t.req, tt.req) 121 } 122 if t.rw == nil { 123 return errors.New("t.rw = nil; want non-nil") 124 } 125 return nil 126 }, 127 }, 128 { 129 name: "with timeout", 130 req: &http.Request{ 131 ProtoMajor: 2, 132 Method: "POST", 133 Header: http.Header{ 134 "Content-Type": []string{"application/grpc"}, 135 "Grpc-Timeout": {"200m"}, 136 }, 137 URL: &url.URL{ 138 Path: "/service/foo.bar", 139 }, 140 }, 141 check: func(t *serverHandlerTransport, _ *testCase) error { 142 if !t.timeoutSet { 143 return errors.New("timeout not set") 144 } 145 if want := 200 * time.Millisecond; t.timeout != want { 146 return fmt.Errorf("timeout = %v; want %v", t.timeout, want) 147 } 148 return nil 149 }, 150 }, 151 { 152 name: "with bad timeout", 153 req: &http.Request{ 154 ProtoMajor: 2, 155 Method: "POST", 156 Header: http.Header{ 157 "Content-Type": []string{"application/grpc"}, 158 "Grpc-Timeout": {"tomorrow"}, 159 }, 160 URL: &url.URL{ 161 Path: "/service/foo.bar", 162 }, 163 }, 164 wantErr: `rpc error: code = Internal desc = malformed grpc-timeout: transport: timeout unit is not recognized: "tomorrow"`, 165 wantErrCode: http.StatusBadRequest, 166 }, 167 { 168 name: "with metadata", 169 req: &http.Request{ 170 ProtoMajor: 2, 171 Method: "POST", 172 Header: http.Header{ 173 "Content-Type": []string{"application/grpc"}, 174 "meta-foo": {"foo-val"}, 175 "meta-bar": {"bar-val1", "bar-val2"}, 176 "user-agent": {"x/y a/b"}, 177 }, 178 URL: &url.URL{ 179 Path: "/service/foo.bar", 180 }, 181 }, 182 check: func(ht *serverHandlerTransport, _ *testCase) error { 183 want := metadata.MD{ 184 "meta-bar": {"bar-val1", "bar-val2"}, 185 "user-agent": {"x/y a/b"}, 186 "meta-foo": {"foo-val"}, 187 "content-type": {"application/grpc"}, 188 } 189 190 if !reflect.DeepEqual(ht.headerMD, want) { 191 return fmt.Errorf("metadata = %#v; want %#v", ht.headerMD, want) 192 } 193 return nil 194 }, 195 }, 196 } 197 198 for _, tt := range tests { 199 rrec := httptest.NewRecorder() 200 rw := http.ResponseWriter(testHandlerResponseWriter{ 201 ResponseRecorder: rrec, 202 }) 203 204 if tt.modrw != nil { 205 rw = tt.modrw(rw) 206 } 207 got, gotErr := NewServerHandlerTransport(rw, tt.req, nil, mem.DefaultBufferPool()) 208 if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) { 209 t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr) 210 continue 211 } 212 if tt.wantErrCode == 0 { 213 tt.wantErrCode = http.StatusOK 214 } 215 if rrec.Code != tt.wantErrCode { 216 t.Errorf("%s: code = %d; want %d", tt.name, rrec.Code, tt.wantErrCode) 217 continue 218 } 219 if gotErr != nil { 220 continue 221 } 222 if tt.check != nil { 223 if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil { 224 t.Errorf("%s: %v", tt.name, err) 225 } 226 } 227 } 228 } 229 230 type testHandlerResponseWriter struct { 231 *httptest.ResponseRecorder 232 } 233 234 func (w testHandlerResponseWriter) Flush() {} 235 236 func newTestHandlerResponseWriter() http.ResponseWriter { 237 return testHandlerResponseWriter{ 238 ResponseRecorder: httptest.NewRecorder(), 239 } 240 } 241 242 type handleStreamTest struct { 243 t *testing.T 244 bodyw *io.PipeWriter 245 rw testHandlerResponseWriter 246 ht *serverHandlerTransport 247 } 248 249 func newHandleStreamTest(t *testing.T) *handleStreamTest { 250 bodyr, bodyw := io.Pipe() 251 req := &http.Request{ 252 ProtoMajor: 2, 253 Method: "POST", 254 Header: http.Header{ 255 "Content-Type": {"application/grpc"}, 256 }, 257 URL: &url.URL{ 258 Path: "/service/foo.bar", 259 }, 260 Body: bodyr, 261 } 262 rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) 263 ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool()) 264 if err != nil { 265 t.Fatal(err) 266 } 267 return &handleStreamTest{ 268 t: t, 269 bodyw: bodyw, 270 ht: ht.(*serverHandlerTransport), 271 rw: rw, 272 } 273 } 274 275 func (s) TestHandlerTransport_HandleStreams(t *testing.T) { 276 st := newHandleStreamTest(t) 277 handleStream := func(s *ServerStream) { 278 if want := "/service/foo.bar"; s.method != want { 279 t.Errorf("stream method = %q; want %q", s.method, want) 280 } 281 282 if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil { 283 t.Error(err) 284 } 285 286 if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil { 287 t.Error(err) 288 } 289 290 if err := s.SetSendCompress("gzip"); err != nil { 291 t.Error(err) 292 } 293 294 md := metadata.Pairs("custom-header", "Another custom header value") 295 if err := s.SendHeader(md); err != nil { 296 t.Error(err) 297 } 298 delete(md, "custom-header") 299 300 if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil { 301 t.Error("expected SetHeader call after SendHeader to fail") 302 } 303 304 if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil { 305 t.Error("expected second SendHeader call to fail") 306 } 307 308 if err := s.SetSendCompress("snappy"); err == nil { 309 t.Error("expected second SetSendCompress call to fail") 310 } 311 312 st.bodyw.Close() // no body 313 s.WriteStatus(status.New(codes.OK, "")) 314 } 315 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 316 defer cancel() 317 st.ht.HandleStreams( 318 ctx, func(s *ServerStream) { go handleStream(s) }, 319 ) 320 wantHeader := http.Header{ 321 "Date": nil, 322 "Content-Type": {"application/grpc"}, 323 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, 324 "Custom-Header": {"Custom header value", "Another custom header value"}, 325 "Grpc-Encoding": {"gzip"}, 326 } 327 wantTrailer := http.Header{ 328 "Grpc-Status": {"0"}, 329 "Custom-Trailer": {"Custom trailer value"}, 330 } 331 checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer) 332 } 333 334 // Tests that codes.Unimplemented will close the body, per comment in handler_server.go. 335 func (s) TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) { 336 handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented") 337 } 338 339 // Tests that codes.InvalidArgument will close the body, per comment in handler_server.go. 340 func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) { 341 handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg") 342 } 343 344 func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { 345 st := newHandleStreamTest(t) 346 347 handleStream := func(s *ServerStream) { 348 s.WriteStatus(status.New(statusCode, msg)) 349 } 350 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 351 defer cancel() 352 st.ht.HandleStreams( 353 ctx, func(s *ServerStream) { go handleStream(s) }, 354 ) 355 wantHeader := http.Header{ 356 "Date": nil, 357 "Content-Type": {"application/grpc"}, 358 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, 359 } 360 wantTrailer := http.Header{ 361 "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, 362 "Grpc-Message": {encodeGrpcMessage(msg)}, 363 } 364 checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer) 365 } 366 367 func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { 368 bodyr, bodyw := io.Pipe() 369 req := &http.Request{ 370 ProtoMajor: 2, 371 Method: "POST", 372 Header: http.Header{ 373 "Content-Type": {"application/grpc"}, 374 "Grpc-Timeout": {"200m"}, 375 }, 376 URL: &url.URL{ 377 Path: "/service/foo.bar", 378 }, 379 Body: bodyr, 380 } 381 rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) 382 ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool()) 383 if err != nil { 384 t.Fatal(err) 385 } 386 runStream := func(s *ServerStream) { 387 defer bodyw.Close() 388 select { 389 case <-s.ctx.Done(): 390 case <-time.After(5 * time.Second): 391 t.Errorf("timeout waiting for ctx.Done") 392 return 393 } 394 err := s.ctx.Err() 395 if err != context.DeadlineExceeded { 396 t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded) 397 return 398 } 399 s.WriteStatus(status.New(codes.DeadlineExceeded, "too slow")) 400 } 401 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 402 defer cancel() 403 ht.HandleStreams( 404 ctx, func(s *ServerStream) { go runStream(s) }, 405 ) 406 wantHeader := http.Header{ 407 "Date": nil, 408 "Content-Type": {"application/grpc"}, 409 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, 410 } 411 wantTrailer := http.Header{ 412 "Grpc-Status": {"4"}, 413 "Grpc-Message": {encodeGrpcMessage("too slow")}, 414 } 415 checkHeaderAndTrailer(t, rw, wantHeader, wantTrailer) 416 } 417 418 // TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that 419 // concurrent "WriteStatus"s do not panic writing to closed "writes" channel. 420 func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) { 421 testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) { 422 if want := "/service/foo.bar"; s.method != want { 423 t.Errorf("stream method = %q; want %q", s.method, want) 424 } 425 st.bodyw.Close() // no body 426 427 var wg sync.WaitGroup 428 wg.Add(5) 429 for i := 0; i < 5; i++ { 430 go func() { 431 defer wg.Done() 432 s.WriteStatus(status.New(codes.OK, "")) 433 }() 434 } 435 wg.Wait() 436 }) 437 } 438 439 // TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write" 440 // following "WriteStatus" does not panic writing to closed "writes" channel. 441 func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) { 442 testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) { 443 if want := "/service/foo.bar"; s.method != want { 444 t.Errorf("stream method = %q; want %q", s.method, want) 445 } 446 st.bodyw.Close() // no body 447 448 s.WriteStatus(status.New(codes.OK, "")) 449 s.Write([]byte("hdr"), newBufferSlice([]byte("data")), &WriteOptions{}) 450 }) 451 } 452 453 func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) { 454 st := newHandleStreamTest(t) 455 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 456 t.Cleanup(cancel) 457 st.ht.HandleStreams( 458 ctx, func(s *ServerStream) { go handleStream(st, s) }, 459 ) 460 } 461 462 func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { 463 errDetails := []protoadapt.MessageV1{ 464 &epb.RetryInfo{ 465 RetryDelay: &durationpb.Duration{Seconds: 60}, 466 }, 467 &epb.ResourceInfo{ 468 ResourceType: "foo bar", 469 ResourceName: "service.foo.bar", 470 Owner: "User", 471 }, 472 } 473 474 statusCode := codes.ResourceExhausted 475 msg := "you are being throttled" 476 st, err := status.New(statusCode, msg).WithDetails(errDetails...) 477 if err != nil { 478 t.Fatal(err) 479 } 480 481 stBytes, err := proto.Marshal(st.Proto()) 482 if err != nil { 483 t.Fatal(err) 484 } 485 486 hst := newHandleStreamTest(t) 487 handleStream := func(s *ServerStream) { 488 s.WriteStatus(st) 489 } 490 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 491 defer cancel() 492 hst.ht.HandleStreams( 493 ctx, func(s *ServerStream) { go handleStream(s) }, 494 ) 495 wantHeader := http.Header{ 496 "Date": nil, 497 "Content-Type": {"application/grpc"}, 498 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, 499 } 500 wantTrailer := http.Header{ 501 "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, 502 "Grpc-Message": {encodeGrpcMessage(msg)}, 503 "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)}, 504 } 505 506 checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer) 507 } 508 509 // TestHandlerTransport_Drain verifies that Drain() is not implemented 510 // by `serverHandlerTransport`. 511 func (s) TestHandlerTransport_Drain(t *testing.T) { 512 defer func() { recover() }() 513 st := newHandleStreamTest(t) 514 st.ht.Drain("whatever") 515 t.Errorf("serverHandlerTransport.Drain() should have panicked") 516 } 517 518 // checkHeaderAndTrailer checks that the resulting header and trailer matches the expectation. 519 func checkHeaderAndTrailer(t *testing.T, rw testHandlerResponseWriter, wantHeader, wantTrailer http.Header) { 520 // For trailer-only responses, the trailer values might be reported as part of the Header. They will however 521 // be present in Trailer in either case. Hence, normalize the header by removing all trailer values. 522 actualHeader := rw.Result().Header.Clone() 523 for _, trailerKey := range actualHeader["Trailer"] { 524 actualHeader.Del(trailerKey) 525 } 526 527 if !reflect.DeepEqual(actualHeader, wantHeader) { 528 t.Errorf("Header mismatch.\n got: %#v\n want: %#v", actualHeader, wantHeader) 529 } 530 if actualTrailer := rw.Result().Trailer; !reflect.DeepEqual(actualTrailer, wantTrailer) { 531 t.Errorf("Trailer mismatch.\n got: %#v\n want: %#v", actualTrailer, wantTrailer) 532 } 533 }