github.com/icyphox/x@v0.0.355-0.20220311094250-029bd783e8b8/proxy/proxy_full_test.go (about) 1 package proxy 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "net/http" 11 "net/http/httptest" 12 "net/http/httputil" 13 "net/url" 14 "testing" 15 "time" 16 17 "github.com/gorilla/websocket" 18 19 "github.com/pkg/errors" 20 21 "github.com/ory/x/httpx" 22 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/require" 25 26 "github.com/ory/x/urlx" 27 ) 28 29 // This test is a full integration test for the proxy. 30 // It does not have to cover **all** edge cases included in the rewrite 31 // unit test, but should use all features like path prefix, ... 32 33 const statusTestFailure = 555 34 35 type ( 36 remoteT struct { 37 w http.ResponseWriter 38 r *http.Request 39 t *testing.T 40 failed bool 41 } 42 testingRoundTripper struct { 43 t *testing.T 44 rt http.RoundTripper 45 } 46 ) 47 48 func (t *remoteT) Errorf(format string, args ...interface{}) { 49 t.failed = true 50 t.w.WriteHeader(statusTestFailure) 51 t.t.Errorf(format, args...) 52 } 53 54 func (t *remoteT) Header() http.Header { 55 return t.w.Header() 56 } 57 58 func (t *remoteT) Write(i []byte) (int, error) { 59 if t.failed { 60 return 0, nil 61 } 62 return t.w.Write(i) 63 } 64 65 func (t *remoteT) WriteHeader(statusCode int) { 66 if t.failed { 67 return 68 } 69 t.w.WriteHeader(statusCode) 70 } 71 72 func (rt *testingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 73 resp, err := rt.rt.RoundTrip(req) 74 require.NoError(rt.t, err) 75 76 if resp.StatusCode == statusTestFailure { 77 rt.t.Error("got test failure from the server, see output above") 78 rt.t.FailNow() 79 } 80 81 return resp, err 82 } 83 84 func TestFullIntegration(t *testing.T) { 85 upstream, upstreamHandler := httpx.NewChanHandler(0) 86 upstreamServer := httptest.NewTLSServer(upstream) 87 defer upstreamServer.Close() 88 89 // create the proxy 90 hostMapper := make(chan func(*http.Request) (*HostConfig, error)) 91 reqMiddleware := make(chan ReqMiddleware) 92 respMiddleware := make(chan RespMiddleware) 93 94 type CustomErrorReq func(*http.Request, error) 95 type CustomErrorResp func(*http.Response, error) error 96 97 onErrorReq := make(chan CustomErrorReq) 98 onErrorResp := make(chan CustomErrorResp) 99 100 proxy := httptest.NewTLSServer(New( 101 func(_ context.Context, r *http.Request) (*HostConfig, error) { 102 return (<-hostMapper)(r) 103 }, 104 WithTransport(upstreamServer.Client().Transport), 105 WithReqMiddleware(func(req *http.Request, config *HostConfig, body []byte) ([]byte, error) { 106 f := <-reqMiddleware 107 if f == nil { 108 return body, nil 109 } 110 return f(req, config, body) 111 }), 112 WithRespMiddleware(func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) { 113 f := <-respMiddleware 114 if f == nil { 115 return body, nil 116 } 117 return f(resp, config, body) 118 }), 119 WithOnError(func(request *http.Request, err error) { 120 f := <-onErrorReq 121 if f == nil { 122 return 123 } 124 f(request, err) 125 }, func(response *http.Response, err error) error { 126 f := <-onErrorResp 127 if f == nil { 128 return nil 129 } 130 return f(response, err) 131 }))) 132 133 cl := proxy.Client() 134 cl.Transport = &testingRoundTripper{t, cl.Transport} 135 cl.CheckRedirect = func(*http.Request, []*http.Request) error { 136 return http.ErrUseLastResponse 137 } 138 139 for _, tc := range []struct { 140 desc string 141 hostMapper func(host string) (*HostConfig, error) 142 handler func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) 143 request func(t *testing.T) *http.Request 144 assertResponse func(t *testing.T, r *http.Response) 145 reqMiddleware ReqMiddleware 146 respMiddleware RespMiddleware 147 onErrReq CustomErrorReq 148 onErrResp CustomErrorResp 149 }{ 150 { 151 desc: "body replacement", 152 hostMapper: func(host string) (*HostConfig, error) { 153 if host != "example.com" { 154 return nil, fmt.Errorf("got unexpected host %s, expected 'example.com'", host) 155 } 156 return &HostConfig{ 157 CookieDomain: "example.com", 158 PathPrefix: "/foo", 159 }, nil 160 }, 161 handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) { 162 body, err := io.ReadAll(r.Body) 163 assert.NoError(err) 164 assert.Equal(fmt.Sprintf("some random content containing the request URL and path prefix %s/bar but also other stuff", upstreamServer.URL), string(body)) 165 166 _, err = w.Write([]byte(fmt.Sprintf("just responding with my own URL: %s/baz and some path of course", upstreamServer.URL))) 167 assert.NoError(err) 168 }, 169 request: func(t *testing.T) *http.Request { 170 req, err := http.NewRequest(http.MethodPost, proxy.URL+"/foo", bytes.NewBufferString(fmt.Sprintf("some random content containing the request URL and path prefix %s/bar but also other stuff", upstreamServer.URL))) 171 require.NoError(t, err) 172 req.Host = "example.com" 173 return req 174 }, 175 assertResponse: func(t *testing.T, resp *http.Response) { 176 assert.Equal(t, http.StatusOK, resp.StatusCode) 177 178 body, err := io.ReadAll(resp.Body) 179 require.NoError(t, err) 180 assert.Equal(t, "just responding with my own URL: https://example.com/foo/baz and some path of course", string(body)) 181 }, 182 }, 183 { 184 desc: "redirection replacement", 185 hostMapper: func(host string) (*HostConfig, error) { 186 if host != "redirect.me" { 187 return nil, fmt.Errorf("got unexpected host %s, expected 'redirect.me'", host) 188 } 189 return &HostConfig{ 190 CookieDomain: "redirect.me", 191 }, nil 192 }, 193 handler: func(_ *assert.Assertions, w http.ResponseWriter, r *http.Request) { 194 http.Redirect(w, r, upstreamServer.URL+"/redirection/target", http.StatusSeeOther) 195 }, 196 request: func(t *testing.T) *http.Request { 197 req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) 198 require.NoError(t, err) 199 req.Host = "redirect.me" 200 return req 201 }, 202 assertResponse: func(t *testing.T, r *http.Response) { 203 assert.Equal(t, http.StatusSeeOther, r.StatusCode) 204 assert.Equal(t, "https://redirect.me/redirection/target", r.Header.Get("Location")) 205 }, 206 }, 207 { 208 desc: "cookie replacement", 209 hostMapper: func(host string) (*HostConfig, error) { 210 if host != "auth.cookie.love" { 211 return nil, fmt.Errorf("got unexpected host %s, expected 'cookie.love'", host) 212 } 213 return &HostConfig{ 214 CookieDomain: "cookie.love", 215 }, nil 216 }, 217 handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) { 218 http.SetCookie(w, &http.Cookie{ 219 Name: "auth", 220 Value: "my random cookie", 221 Domain: urlx.ParseOrPanic(upstreamServer.URL).Hostname(), 222 }) 223 _, err := w.Write([]byte("OK")) 224 assert.NoError(err) 225 }, 226 request: func(t *testing.T) *http.Request { 227 req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) 228 require.NoError(t, err) 229 req.Host = "auth.cookie.love" 230 return req 231 }, 232 assertResponse: func(t *testing.T, r *http.Response) { 233 cookies := r.Cookies() 234 require.Len(t, cookies, 1) 235 c := cookies[0] 236 assert.Equal(t, "auth", c.Name) 237 assert.Equal(t, "my random cookie", c.Value) 238 assert.Equal(t, "cookie.love", c.Domain) 239 }, 240 }, 241 { 242 desc: "custom middleware", 243 hostMapper: func(host string) (*HostConfig, error) { 244 return &HostConfig{}, nil 245 }, 246 handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) { 247 assert.Equal("noauth.example.com", r.Host) 248 b, err := ioutil.ReadAll(r.Body) 249 assert.NoError(err) 250 assert.Equal("this is a new body", string(b)) 251 252 _, err = w.Write([]byte("OK")) 253 assert.NoError(err) 254 }, 255 request: func(t *testing.T) *http.Request { 256 req, err := http.NewRequest(http.MethodPost, proxy.URL, bytes.NewReader([]byte("body"))) 257 require.NoError(t, err) 258 req.Host = "auth.example.com" 259 return req 260 }, 261 assertResponse: func(t *testing.T, r *http.Response) { 262 body, err := io.ReadAll(r.Body) 263 require.NoError(t, err) 264 assert.Equal(t, "OK", string(body)) 265 assert.Equal(t, "1234", r.Header.Get("Some-Header")) 266 }, 267 reqMiddleware: func(req *http.Request, config *HostConfig, body []byte) ([]byte, error) { 268 req.Host = "noauth.example.com" 269 body = []byte("this is a new body") 270 return body, nil 271 }, 272 respMiddleware: func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) { 273 resp.Header.Add("Some-Header", "1234") 274 return body, nil 275 }, 276 }, 277 { 278 desc: "custom request errors", 279 hostMapper: func(host string) (*HostConfig, error) { 280 return &HostConfig{}, errors.New("some host mapper error occurred") 281 }, 282 handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) { 283 _, err := w.Write([]byte("OK")) 284 assert.NoError(err) 285 }, 286 request: func(t *testing.T) *http.Request { 287 req, err := http.NewRequest(http.MethodPost, proxy.URL, bytes.NewReader([]byte("body"))) 288 require.NoError(t, err) 289 req.Host = "auth.example.com" 290 return req 291 }, 292 assertResponse: func(t *testing.T, r *http.Response) { 293 return 294 }, 295 onErrReq: func(request *http.Request, err error) { 296 assert.Error(t, err) 297 assert.Equal(t, "some host mapper error occurred", err.Error()) 298 }, 299 }, 300 { 301 desc: "custom response errors", 302 hostMapper: func(host string) (*HostConfig, error) { 303 return &HostConfig{}, nil 304 }, 305 handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) { 306 _, err := w.Write([]byte("OK")) 307 assert.NoError(err) 308 }, 309 request: func(t *testing.T) *http.Request { 310 req, err := http.NewRequest(http.MethodPost, proxy.URL, bytes.NewReader([]byte("body"))) 311 require.NoError(t, err) 312 req.Host = "auth.example.com" 313 return req 314 }, 315 assertResponse: func(t *testing.T, r *http.Response) { 316 return 317 }, 318 respMiddleware: func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) { 319 return nil, errors.New("some response middleware error") 320 }, 321 onErrResp: func(response *http.Response, err error) error { 322 assert.Error(t, err) 323 assert.Equal(t, "some response middleware error", err.Error()) 324 return err 325 }, 326 }, 327 } { 328 t.Run("case="+tc.desc, func(t *testing.T) { 329 go func() { 330 hostMapper <- func(r *http.Request) (*HostConfig, error) { 331 host := r.Host 332 hc, err := tc.hostMapper(host) 333 if err == nil { 334 hc.UpstreamHost = urlx.ParseOrPanic(upstreamServer.URL).Host 335 hc.UpstreamScheme = urlx.ParseOrPanic(upstreamServer.URL).Scheme 336 hc.TargetHost = hc.UpstreamHost 337 hc.TargetScheme = hc.UpstreamScheme 338 } 339 return hc, err 340 } 341 reqMiddleware <- tc.reqMiddleware 342 upstreamHandler <- func(w http.ResponseWriter, r *http.Request) { 343 t := &remoteT{t: t, w: w, r: r} 344 tc.handler(assert.New(t), t, r) 345 } 346 respMiddleware <- tc.respMiddleware 347 }() 348 349 go func() { 350 onErrorReq <- tc.onErrReq 351 }() 352 353 go func() { 354 onErrorResp <- tc.onErrResp 355 }() 356 357 resp, err := cl.Do(tc.request(t)) 358 require.NoError(t, err) 359 tc.assertResponse(t, resp) 360 }) 361 } 362 } 363 364 func TestBetweenReverseProxies(t *testing.T) { 365 // the target thinks it is running under the targetHost, while actually it is behind all three proxies 366 targetHost := "foobar.ory.sh" 367 targetHandler, c := httpx.NewChanHandler(1) 368 target := httptest.NewServer(targetHandler) 369 370 revProxyHandler := httputil.NewSingleHostReverseProxy(urlx.ParseOrPanic(target.URL)) 371 revProxy := httptest.NewServer(revProxyHandler) 372 373 thisProxy := httptest.NewServer(New(func(ctx context.Context, _ *http.Request) (*HostConfig, error) { 374 return &HostConfig{ 375 CookieDomain: "sh", 376 UpstreamHost: urlx.ParseOrPanic(revProxy.URL).Host, 377 UpstreamScheme: urlx.ParseOrPanic(revProxy.URL).Scheme, 378 TargetScheme: "http", 379 TargetHost: targetHost, 380 }, nil 381 })) 382 383 ingressHandler := httputil.NewSingleHostReverseProxy(urlx.ParseOrPanic(thisProxy.URL)) 384 ingress := httptest.NewServer(ingressHandler) 385 386 // In this scenario we want to force the use of the X-Forwarded-Host header instead of the Host header. 387 singleHostDirector := ingressHandler.Director 388 ingressHandler.Director = func(req *http.Request) { 389 singleHostDirector(req) 390 req.Header.Set("X-Forwarded-Host", req.Host) 391 req.Host = urlx.ParseOrPanic(ingress.URL).Host 392 } 393 394 t.Run("case=replaces body", func(t *testing.T) { 395 const pattern = "Hello, I am available under http://%s!" 396 c <- func(w http.ResponseWriter, r *http.Request) { 397 fmt.Fprintf(w, pattern, targetHost) 398 } 399 400 host := "example.com" 401 req, err := http.NewRequest(http.MethodGet, ingress.URL, nil) 402 require.NoError(t, err) 403 req.Host = host 404 405 resp, err := http.DefaultClient.Do(req) 406 require.NoError(t, err) 407 body, err := io.ReadAll(resp.Body) 408 require.NoError(t, err) 409 assert.Equal(t, fmt.Sprintf(pattern, host), string(body)) 410 }) 411 412 t.Run("case=replaces cookies", func(t *testing.T) { 413 c <- func(w http.ResponseWriter, r *http.Request) { 414 http.SetCookie(w, &http.Cookie{ 415 Name: "foo", 416 Value: "setting this cookie for my own domain", 417 Domain: targetHost, 418 Secure: true, 419 }) 420 } 421 422 req, err := http.NewRequest(http.MethodGet, ingress.URL, nil) 423 require.NoError(t, err) 424 req.Host = "example.com" 425 426 resp, err := http.DefaultClient.Do(req) 427 require.NoError(t, err) 428 429 cookies := resp.Cookies() 430 require.Len(t, cookies, 1) 431 assert.Equal(t, "foo", cookies[0].Name) 432 assert.Equal(t, "setting this cookie for my own domain", cookies[0].Value) 433 assert.Equal(t, "sh", cookies[0].Domain) 434 assert.Equal(t, false, cookies[0].Secure) 435 }) 436 437 t.Run("case=replaces location", func(t *testing.T) { 438 c <- func(w http.ResponseWriter, r *http.Request) { 439 http.Redirect(w, r, "http://"+targetHost, http.StatusSeeOther) 440 } 441 442 host := "example.com" 443 req, err := http.NewRequest(http.MethodGet, ingress.URL, nil) 444 require.NoError(t, err) 445 req.Host = host 446 447 resp, err := (&http.Client{ 448 CheckRedirect: func(req *http.Request, via []*http.Request) error { 449 return http.ErrUseLastResponse 450 }, 451 }).Do(req) 452 require.NoError(t, err) 453 454 assert.Equal(t, http.StatusSeeOther, resp.StatusCode) 455 assert.Equal(t, "http://"+host, resp.Header.Get("Location")) 456 }) 457 } 458 459 func TestProxyProtoMix(t *testing.T) { 460 const exposedHost = "foo.bar" 461 462 setup := func(t *testing.T, targetServerFunc, upstreamServerFunc func(http.Handler) *httptest.Server) (chan<- http.HandlerFunc, string, string, *http.Client) { 463 targetHandler, targetHandlerC := httpx.NewChanHandler(1) 464 targetServer := targetServerFunc(targetHandler) 465 466 upstream := httputil.NewSingleHostReverseProxy(urlx.ParseOrPanic(targetServer.URL)) 467 upstream.Transport = targetServer.Client().Transport 468 upstreamServer := upstreamServerFunc(upstream) 469 470 proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (*HostConfig, error) { 471 return &HostConfig{ 472 CookieDomain: exposedHost, 473 UpstreamHost: urlx.ParseOrPanic(upstreamServer.URL).Host, 474 UpstreamScheme: urlx.ParseOrPanic(upstreamServer.URL).Scheme, 475 TargetHost: urlx.ParseOrPanic(targetServer.URL).Host, 476 TargetScheme: urlx.ParseOrPanic(targetServer.URL).Scheme, 477 }, nil 478 }, WithTransport(upstreamServer.Client().Transport))) 479 client := proxy.Client() 480 client.CheckRedirect = func(*http.Request, []*http.Request) error { 481 return http.ErrUseLastResponse 482 } 483 484 return targetHandlerC, targetServer.URL, proxy.URL, client 485 } 486 487 for _, tc := range []struct { 488 name string 489 newUpstreamServer, newTargetServer func(http.Handler) *httptest.Server 490 }{ 491 { 492 name: "upstream http, target https", 493 newUpstreamServer: httptest.NewServer, 494 newTargetServer: httptest.NewTLSServer, 495 }, 496 { 497 name: "upstream https, target http", 498 newUpstreamServer: httptest.NewTLSServer, 499 newTargetServer: httptest.NewServer, 500 }, 501 } { 502 t.Run("case="+tc.name, func(t *testing.T) { 503 handler, targetURL, proxyURL, client := setup(t, httptest.NewTLSServer, httptest.NewServer) 504 505 t.Run("case=redirect", func(t *testing.T) { 506 handler <- func(w http.ResponseWriter, r *http.Request) { 507 http.Redirect(w, r, targetURL+"/see-other", http.StatusSeeOther) 508 } 509 510 req, err := http.NewRequest(http.MethodGet, proxyURL, nil) 511 require.NoError(t, err) 512 req.Host = exposedHost 513 514 resp, err := client.Do(req) 515 require.NoError(t, err) 516 assert.Equal(t, "http://"+exposedHost+"/see-other", resp.Header.Get("Location")) 517 }) 518 519 t.Run("case=body rewrite", func(t *testing.T) { 520 const template = "Hello, I am %s, who are you?" 521 522 handler <- func(w http.ResponseWriter, r *http.Request) { 523 _, _ = w.Write([]byte(fmt.Sprintf(template, targetURL))) 524 } 525 526 req, err := http.NewRequest(http.MethodGet, proxyURL, nil) 527 require.NoError(t, err) 528 req.Host = exposedHost 529 530 resp, err := client.Do(req) 531 require.NoError(t, err) 532 body, err := io.ReadAll(resp.Body) 533 require.NoError(t, err) 534 assert.Equal(t, fmt.Sprintf(template, "http://"+exposedHost), string(body)) 535 }) 536 537 t.Run("case=secure cookies", func(t *testing.T) { 538 handler <- func(w http.ResponseWriter, r *http.Request) { 539 cookie := &http.Cookie{ 540 Name: "foo", 541 Value: "bar", 542 Domain: stripPort(urlx.ParseOrPanic(targetURL).Host), 543 Secure: true, 544 } 545 http.SetCookie(w, cookie) 546 _, _ = w.Write([]byte("please eat this cookie")) 547 } 548 549 req, err := http.NewRequest(http.MethodGet, proxyURL, nil) 550 require.NoError(t, err) 551 req.Host = exposedHost 552 553 resp, err := client.Do(req) 554 require.NoError(t, err) 555 556 cookies := resp.Cookies() 557 require.Len(t, cookies, 1) 558 assert.Equal(t, "foo", cookies[0].Name) 559 assert.Equal(t, "bar", cookies[0].Value) 560 assert.Equal(t, exposedHost, cookies[0].Domain) 561 assert.Equal(t, false, cookies[0].Secure) 562 }) 563 }) 564 } 565 } 566 567 func TestProxyWebsocketRequests(t *testing.T) { 568 // create an echo server that uses websockets to communicate 569 setupWebsocketServer := func(ctx context.Context) *httptest.Server { 570 upgrader := websocket.Upgrader{} 571 mux := http.NewServeMux() 572 mux.Handle("/echo", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 573 c, err := upgrader.Upgrade(w, r, nil) 574 require.NoError(t, err) 575 defer c.Close() 576 for { 577 select { 578 case <-ctx.Done(): 579 return 580 default: 581 mt, message, err := c.ReadMessage() 582 if err != nil { 583 return 584 } 585 require.NotEmpty(t, message) 586 err = c.WriteMessage(mt, message) 587 require.NoError(t, err) 588 } 589 } 590 })) 591 return httptest.NewServer(mux) 592 } 593 594 setupProxy := func(targetServer *httptest.Server) *httptest.Server { 595 proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (*HostConfig, error) { 596 return &HostConfig{ 597 UpstreamHost: urlx.ParseOrPanic(targetServer.URL).Host, 598 UpstreamScheme: urlx.ParseOrPanic(targetServer.URL).Scheme, 599 TargetHost: urlx.ParseOrPanic(targetServer.URL).Host, 600 TargetScheme: urlx.ParseOrPanic(targetServer.URL).Scheme, 601 }, nil 602 })) 603 604 return proxy 605 } 606 607 t.Logf("Creating websocket server with proxy with context timeout of 5 seconds") 608 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 609 610 t.Cleanup(cancel) 611 612 websocketServer := setupWebsocketServer(ctx) 613 defer websocketServer.Close() 614 615 proxyServer := setupProxy(websocketServer) 616 defer proxyServer.Close() 617 618 u := url.URL{Scheme: "ws", Host: urlx.ParseOrPanic(proxyServer.URL).Host, Path: "/echo"} 619 620 c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) 621 require.NoError(t, err) 622 defer c.Close() 623 624 messages := make(chan []byte, 2) 625 626 // setup message reader 627 go func(ctx context.Context) { 628 for { 629 select { 630 case <-ctx.Done(): 631 return 632 default: 633 _, message, err := c.ReadMessage() 634 if err != nil { 635 return 636 } 637 messages <- message 638 t.Logf("Received message from websocket client: %s\n", message) 639 } 640 } 641 }(ctx) 642 643 // write a message 644 testMessage := "test" 645 testJson := json.RawMessage(`{"data":"1234"}`) 646 t.Logf("Writing message to websocket server: %s\n", testMessage) 647 require.NoError(t, c.WriteMessage(websocket.TextMessage, []byte(testMessage))) 648 t.Logf("Writing message to websocket server: %s\n", testJson) 649 require.NoError(t, c.WriteJSON(testJson)) 650 651 readChannel := func() []byte { 652 select { 653 case msg := <-messages: 654 return msg 655 case <-ctx.Done(): 656 return []byte("") 657 } 658 } 659 660 require.Equalf(t, testMessage, string(readChannel()), "could not retrieve the test message from the websocket server") 661 require.JSONEqf(t, string(testJson), string(readChannel()), "could not retrieve the test json from the websocket server") 662 }