github.com/grpc-ecosystem/grpc-gateway/v2@v2.19.1/runtime/handler_test.go (about) 1 package runtime_test 2 3 import ( 4 "context" 5 "io" 6 "net/http" 7 "net/http/httptest" 8 "reflect" 9 "testing" 10 11 "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" 12 pb "github.com/grpc-ecosystem/grpc-gateway/v2/runtime/internal/examplepb" 13 "google.golang.org/grpc/codes" 14 "google.golang.org/grpc/metadata" 15 "google.golang.org/grpc/status" 16 "google.golang.org/protobuf/proto" 17 ) 18 19 type fakeReponseBodyWrapper struct { 20 proto.Message 21 } 22 23 // XXX_ResponseBody returns id of SimpleMessage 24 func (r fakeReponseBodyWrapper) XXX_ResponseBody() interface{} { 25 resp := r.Message.(*pb.SimpleMessage) 26 return resp.Id 27 } 28 29 func TestForwardResponseStream(t *testing.T) { 30 type msg struct { 31 pb proto.Message 32 err error 33 } 34 tests := []struct { 35 name string 36 msgs []msg 37 statusCode int 38 responseBody bool 39 }{{ 40 name: "encoding", 41 msgs: []msg{ 42 {&pb.SimpleMessage{Id: "One"}, nil}, 43 {&pb.SimpleMessage{Id: "Two"}, nil}, 44 }, 45 statusCode: http.StatusOK, 46 }, { 47 name: "empty", 48 statusCode: http.StatusOK, 49 }, { 50 name: "error", 51 msgs: []msg{{nil, status.Errorf(codes.OutOfRange, "400")}}, 52 statusCode: http.StatusBadRequest, 53 }, { 54 name: "stream_error", 55 msgs: []msg{ 56 {&pb.SimpleMessage{Id: "One"}, nil}, 57 {nil, status.Errorf(codes.OutOfRange, "400")}, 58 }, 59 statusCode: http.StatusOK, 60 }, { 61 name: "response body stream case", 62 msgs: []msg{ 63 {fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "One"}}, nil}, 64 {fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "Two"}}, nil}, 65 }, 66 responseBody: true, 67 statusCode: http.StatusOK, 68 }, { 69 name: "response body stream error case", 70 msgs: []msg{ 71 {fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "One"}}, nil}, 72 {nil, status.Errorf(codes.OutOfRange, "400")}, 73 }, 74 responseBody: true, 75 statusCode: http.StatusOK, 76 }} 77 78 newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) { 79 var count int 80 return func() (proto.Message, error) { 81 if count == len(msgs) { 82 return nil, io.EOF 83 } else if count > len(msgs) { 84 t.Errorf("recv() called %d times for %d messages", count, len(msgs)) 85 } 86 count++ 87 msg := msgs[count-1] 88 return msg.pb, msg.err 89 } 90 } 91 ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{}) 92 marshaler := &runtime.JSONPb{} 93 for _, tt := range tests { 94 t.Run(tt.name, func(t *testing.T) { 95 recv := newTestRecv(t, tt.msgs) 96 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 97 resp := httptest.NewRecorder() 98 99 runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv) 100 101 w := resp.Result() 102 if w.StatusCode != tt.statusCode { 103 t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode) 104 } 105 if h := w.Header.Get("Transfer-Encoding"); h != "chunked" { 106 t.Errorf("ForwardResponseStream missing header chunked") 107 } 108 body, err := io.ReadAll(w.Body) 109 if err != nil { 110 t.Errorf("Failed to read response body with %v", err) 111 } 112 w.Body.Close() 113 if len(body) > 0 && w.Header.Get("Content-Type") != "application/json" { 114 t.Errorf("Content-Type %s want application/json", w.Header.Get("Content-Type")) 115 } 116 117 var want []byte 118 for i, msg := range tt.msgs { 119 if msg.err != nil { 120 if i == 0 { 121 // Skip non-stream errors 122 t.Skip("checking error encodings") 123 } 124 delimiter := marshaler.Delimiter() 125 st := status.Convert(msg.err) 126 b, err := marshaler.Marshal(map[string]proto.Message{ 127 "error": st.Proto(), 128 }) 129 if err != nil { 130 t.Errorf("marshaler.Marshal() failed %v", err) 131 } 132 errBytes := body[len(want):] 133 if string(errBytes) != string(b)+string(delimiter) { 134 t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", errBytes, b) 135 } 136 137 return 138 } 139 140 var b []byte 141 142 if tt.responseBody { 143 // responseBody interface is in runtime package and test is in runtime_test package. hence can't use responseBody directly 144 // So type casting to fakeReponseBodyWrapper struct to verify the data. 145 rb, ok := msg.pb.(fakeReponseBodyWrapper) 146 if !ok { 147 t.Errorf("stream responseBody failed %v", err) 148 } 149 150 b, err = marshaler.Marshal(map[string]interface{}{"result": rb.XXX_ResponseBody()}) 151 } else { 152 b, err = marshaler.Marshal(map[string]interface{}{"result": msg.pb}) 153 } 154 155 if err != nil { 156 t.Errorf("marshaler.Marshal() failed %v", err) 157 } 158 want = append(want, b...) 159 want = append(want, marshaler.Delimiter()...) 160 } 161 162 if string(body) != string(want) { 163 t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want) 164 } 165 }) 166 } 167 } 168 169 // A custom marshaler implementation, that doesn't implement the delimited interface 170 type CustomMarshaler struct { 171 m *runtime.JSONPb 172 } 173 174 func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error) { return c.m.Marshal(v) } 175 func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error { return c.m.Unmarshal(data, v) } 176 func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c.m.NewDecoder(r) } 177 func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) } 178 func (c *CustomMarshaler) ContentType(v interface{}) string { return "Custom-Content-Type" } 179 180 func TestForwardResponseStreamCustomMarshaler(t *testing.T) { 181 type msg struct { 182 pb proto.Message 183 err error 184 } 185 tests := []struct { 186 name string 187 msgs []msg 188 statusCode int 189 }{{ 190 name: "encoding", 191 msgs: []msg{ 192 {&pb.SimpleMessage{Id: "One"}, nil}, 193 {&pb.SimpleMessage{Id: "Two"}, nil}, 194 }, 195 statusCode: http.StatusOK, 196 }, { 197 name: "empty", 198 statusCode: http.StatusOK, 199 }, { 200 name: "error", 201 msgs: []msg{{nil, status.Errorf(codes.OutOfRange, "400")}}, 202 statusCode: http.StatusBadRequest, 203 }, { 204 name: "stream_error", 205 msgs: []msg{ 206 {&pb.SimpleMessage{Id: "One"}, nil}, 207 {nil, status.Errorf(codes.OutOfRange, "400")}, 208 }, 209 statusCode: http.StatusOK, 210 }} 211 212 newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) { 213 var count int 214 return func() (proto.Message, error) { 215 if count == len(msgs) { 216 return nil, io.EOF 217 } else if count > len(msgs) { 218 t.Errorf("recv() called %d times for %d messages", count, len(msgs)) 219 } 220 count++ 221 msg := msgs[count-1] 222 return msg.pb, msg.err 223 } 224 } 225 ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{}) 226 marshaler := &CustomMarshaler{&runtime.JSONPb{}} 227 for _, tt := range tests { 228 t.Run(tt.name, func(t *testing.T) { 229 recv := newTestRecv(t, tt.msgs) 230 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 231 resp := httptest.NewRecorder() 232 233 runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv) 234 235 w := resp.Result() 236 if w.StatusCode != tt.statusCode { 237 t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode) 238 } 239 if h := w.Header.Get("Transfer-Encoding"); h != "chunked" { 240 t.Errorf("ForwardResponseStream missing header chunked") 241 } 242 body, err := io.ReadAll(w.Body) 243 if err != nil { 244 t.Errorf("Failed to read response body with %v", err) 245 } 246 w.Body.Close() 247 if len(body) > 0 && w.Header.Get("Content-Type") != "Custom-Content-Type" { 248 t.Errorf("Content-Type %s want Custom-Content-Type", w.Header.Get("Content-Type")) 249 } 250 251 var want []byte 252 for _, msg := range tt.msgs { 253 if msg.err != nil { 254 t.Skip("checking erorr encodings") 255 } 256 b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb}) 257 if err != nil { 258 t.Errorf("marshaler.Marshal() failed %v", err) 259 } 260 want = append(want, b...) 261 want = append(want, "\n"...) 262 } 263 264 if string(body) != string(want) { 265 t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want) 266 } 267 }) 268 } 269 } 270 271 func TestForwardResponseMessage(t *testing.T) { 272 msg := &pb.SimpleMessage{Id: "One"} 273 tests := []struct { 274 name string 275 marshaler runtime.Marshaler 276 contentType string 277 }{{ 278 name: "standard marshaler", 279 marshaler: &runtime.JSONPb{}, 280 contentType: "application/json", 281 }, { 282 name: "httpbody marshaler", 283 marshaler: &runtime.HTTPBodyMarshaler{&runtime.JSONPb{}}, 284 contentType: "application/json", 285 }, { 286 name: "custom marshaler", 287 marshaler: &CustomMarshaler{&runtime.JSONPb{}}, 288 contentType: "Custom-Content-Type", 289 }} 290 291 ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{}) 292 for _, tt := range tests { 293 t.Run(tt.name, func(t *testing.T) { 294 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 295 resp := httptest.NewRecorder() 296 297 runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(), tt.marshaler, resp, req, msg) 298 299 w := resp.Result() 300 if w.StatusCode != http.StatusOK { 301 t.Errorf("StatusCode %d want %d", w.StatusCode, http.StatusOK) 302 } 303 if h := w.Header.Get("Content-Type"); h != tt.contentType { 304 t.Errorf("Content-Type %v want %v", h, tt.contentType) 305 } 306 body, err := io.ReadAll(w.Body) 307 if err != nil { 308 t.Errorf("Failed to read response body with %v", err) 309 } 310 w.Body.Close() 311 312 want, err := tt.marshaler.Marshal(msg) 313 if err != nil { 314 t.Errorf("marshaler.Marshal() failed %v", err) 315 } 316 317 if string(body) != string(want) { 318 t.Errorf("ForwardResponseMessage() = \"%s\" want \"%s\"", body, want) 319 } 320 }) 321 } 322 } 323 324 func TestOutgoingHeaderMatcher(t *testing.T) { 325 t.Parallel() 326 msg := &pb.SimpleMessage{Id: "foo"} 327 for _, tc := range []struct { 328 name string 329 md runtime.ServerMetadata 330 headers http.Header 331 matcher runtime.HeaderMatcherFunc 332 }{ 333 { 334 name: "default matcher", 335 md: runtime.ServerMetadata{ 336 HeaderMD: metadata.Pairs( 337 "foo", "bar", 338 "baz", "qux", 339 ), 340 }, 341 headers: http.Header{ 342 "Content-Type": []string{"application/json"}, 343 "Grpc-Metadata-Foo": []string{"bar"}, 344 "Grpc-Metadata-Baz": []string{"qux"}, 345 }, 346 }, 347 { 348 name: "custom matcher", 349 md: runtime.ServerMetadata{ 350 HeaderMD: metadata.Pairs( 351 "foo", "bar", 352 "baz", "qux", 353 ), 354 }, 355 headers: http.Header{ 356 "Content-Type": []string{"application/json"}, 357 "Custom-Foo": []string{"bar"}, 358 }, 359 matcher: func(key string) (string, bool) { 360 switch key { 361 case "foo": 362 return "custom-foo", true 363 default: 364 return "", false 365 } 366 }, 367 }, 368 } { 369 tc := tc 370 t.Run(tc.name, func(t *testing.T) { 371 t.Parallel() 372 ctx := runtime.NewServerMetadataContext(context.Background(), tc.md) 373 374 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 375 resp := httptest.NewRecorder() 376 377 runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingHeaderMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg) 378 379 w := resp.Result() 380 defer w.Body.Close() 381 if w.StatusCode != http.StatusOK { 382 t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK) 383 } 384 385 if !reflect.DeepEqual(w.Header, tc.headers) { 386 t.Fatalf("Header %v want %v", w.Header, tc.headers) 387 } 388 }) 389 } 390 } 391 392 func TestOutgoingTrailerMatcher(t *testing.T) { 393 t.Parallel() 394 msg := &pb.SimpleMessage{Id: "foo"} 395 for _, tc := range []struct { 396 name string 397 md runtime.ServerMetadata 398 caller http.Header 399 headers http.Header 400 trailer http.Header 401 matcher runtime.HeaderMatcherFunc 402 }{ 403 { 404 name: "default matcher, caller accepts", 405 md: runtime.ServerMetadata{ 406 TrailerMD: metadata.Pairs( 407 "foo", "bar", 408 "baz", "qux", 409 ), 410 }, 411 caller: http.Header{ 412 "Te": []string{"trailers"}, 413 }, 414 headers: http.Header{ 415 "Content-Type": []string{"application/json"}, 416 "Trailer": []string{"Grpc-Trailer-Foo,Grpc-Trailer-Baz"}, 417 }, 418 trailer: http.Header{ 419 "Grpc-Trailer-Foo": []string{"bar"}, 420 "Grpc-Trailer-Baz": []string{"qux"}, 421 }, 422 }, 423 { 424 name: "default matcher, caller rejects", 425 md: runtime.ServerMetadata{ 426 TrailerMD: metadata.Pairs( 427 "foo", "bar", 428 "baz", "qux", 429 ), 430 }, 431 headers: http.Header{ 432 "Content-Type": []string{"application/json"}, 433 }, 434 }, 435 { 436 name: "custom matcher", 437 md: runtime.ServerMetadata{ 438 TrailerMD: metadata.Pairs( 439 "foo", "bar", 440 "baz", "qux", 441 ), 442 }, 443 caller: http.Header{ 444 "Te": []string{"trailers"}, 445 }, 446 headers: http.Header{ 447 "Content-Type": []string{"application/json"}, 448 "Trailer": []string{"Custom-Trailer-Foo"}, 449 }, 450 trailer: http.Header{ 451 "Custom-Trailer-Foo": []string{"bar"}, 452 }, 453 matcher: func(key string) (string, bool) { 454 switch key { 455 case "foo": 456 return "custom-trailer-foo", true 457 default: 458 return "", false 459 } 460 }, 461 }, 462 } { 463 tc := tc 464 t.Run(tc.name, func(t *testing.T) { 465 t.Parallel() 466 ctx := runtime.NewServerMetadataContext(context.Background(), tc.md) 467 468 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 469 req.Header = tc.caller 470 resp := httptest.NewRecorder() 471 472 runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingTrailerMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg) 473 474 w := resp.Result() 475 _, _ = io.Copy(io.Discard, w.Body) 476 defer w.Body.Close() 477 if w.StatusCode != http.StatusOK { 478 t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK) 479 } 480 481 if !reflect.DeepEqual(w.Trailer, tc.trailer) { 482 t.Fatalf("Trailer %v want %v", w.Trailer, tc.trailer) 483 } 484 }) 485 } 486 }