github.com/icyphox/x@v0.0.355-0.20220311094250-029bd783e8b8/proxy/proxy_full_test.go (about)

     1  package proxy
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"net/http/httputil"
    13  	"net/url"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/gorilla/websocket"
    18  
    19  	"github.com/pkg/errors"
    20  
    21  	"github.com/ory/x/httpx"
    22  
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  
    26  	"github.com/ory/x/urlx"
    27  )
    28  
    29  // This test is a full integration test for the proxy.
    30  // It does not have to cover **all** edge cases included in the rewrite
    31  // unit test, but should use all features like path prefix, ...
    32  
    33  const statusTestFailure = 555
    34  
    35  type (
    36  	remoteT struct {
    37  		w      http.ResponseWriter
    38  		r      *http.Request
    39  		t      *testing.T
    40  		failed bool
    41  	}
    42  	testingRoundTripper struct {
    43  		t  *testing.T
    44  		rt http.RoundTripper
    45  	}
    46  )
    47  
    48  func (t *remoteT) Errorf(format string, args ...interface{}) {
    49  	t.failed = true
    50  	t.w.WriteHeader(statusTestFailure)
    51  	t.t.Errorf(format, args...)
    52  }
    53  
    54  func (t *remoteT) Header() http.Header {
    55  	return t.w.Header()
    56  }
    57  
    58  func (t *remoteT) Write(i []byte) (int, error) {
    59  	if t.failed {
    60  		return 0, nil
    61  	}
    62  	return t.w.Write(i)
    63  }
    64  
    65  func (t *remoteT) WriteHeader(statusCode int) {
    66  	if t.failed {
    67  		return
    68  	}
    69  	t.w.WriteHeader(statusCode)
    70  }
    71  
    72  func (rt *testingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    73  	resp, err := rt.rt.RoundTrip(req)
    74  	require.NoError(rt.t, err)
    75  
    76  	if resp.StatusCode == statusTestFailure {
    77  		rt.t.Error("got test failure from the server, see output above")
    78  		rt.t.FailNow()
    79  	}
    80  
    81  	return resp, err
    82  }
    83  
    84  func TestFullIntegration(t *testing.T) {
    85  	upstream, upstreamHandler := httpx.NewChanHandler(0)
    86  	upstreamServer := httptest.NewTLSServer(upstream)
    87  	defer upstreamServer.Close()
    88  
    89  	// create the proxy
    90  	hostMapper := make(chan func(*http.Request) (*HostConfig, error))
    91  	reqMiddleware := make(chan ReqMiddleware)
    92  	respMiddleware := make(chan RespMiddleware)
    93  
    94  	type CustomErrorReq func(*http.Request, error)
    95  	type CustomErrorResp func(*http.Response, error) error
    96  
    97  	onErrorReq := make(chan CustomErrorReq)
    98  	onErrorResp := make(chan CustomErrorResp)
    99  
   100  	proxy := httptest.NewTLSServer(New(
   101  		func(_ context.Context, r *http.Request) (*HostConfig, error) {
   102  			return (<-hostMapper)(r)
   103  		},
   104  		WithTransport(upstreamServer.Client().Transport),
   105  		WithReqMiddleware(func(req *http.Request, config *HostConfig, body []byte) ([]byte, error) {
   106  			f := <-reqMiddleware
   107  			if f == nil {
   108  				return body, nil
   109  			}
   110  			return f(req, config, body)
   111  		}),
   112  		WithRespMiddleware(func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) {
   113  			f := <-respMiddleware
   114  			if f == nil {
   115  				return body, nil
   116  			}
   117  			return f(resp, config, body)
   118  		}),
   119  		WithOnError(func(request *http.Request, err error) {
   120  			f := <-onErrorReq
   121  			if f == nil {
   122  				return
   123  			}
   124  			f(request, err)
   125  		}, func(response *http.Response, err error) error {
   126  			f := <-onErrorResp
   127  			if f == nil {
   128  				return nil
   129  			}
   130  			return f(response, err)
   131  		})))
   132  
   133  	cl := proxy.Client()
   134  	cl.Transport = &testingRoundTripper{t, cl.Transport}
   135  	cl.CheckRedirect = func(*http.Request, []*http.Request) error {
   136  		return http.ErrUseLastResponse
   137  	}
   138  
   139  	for _, tc := range []struct {
   140  		desc           string
   141  		hostMapper     func(host string) (*HostConfig, error)
   142  		handler        func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request)
   143  		request        func(t *testing.T) *http.Request
   144  		assertResponse func(t *testing.T, r *http.Response)
   145  		reqMiddleware  ReqMiddleware
   146  		respMiddleware RespMiddleware
   147  		onErrReq       CustomErrorReq
   148  		onErrResp      CustomErrorResp
   149  	}{
   150  		{
   151  			desc: "body replacement",
   152  			hostMapper: func(host string) (*HostConfig, error) {
   153  				if host != "example.com" {
   154  					return nil, fmt.Errorf("got unexpected host %s, expected 'example.com'", host)
   155  				}
   156  				return &HostConfig{
   157  					CookieDomain: "example.com",
   158  					PathPrefix:   "/foo",
   159  				}, nil
   160  			},
   161  			handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) {
   162  				body, err := io.ReadAll(r.Body)
   163  				assert.NoError(err)
   164  				assert.Equal(fmt.Sprintf("some random content containing the request URL and path prefix %s/bar but also other stuff", upstreamServer.URL), string(body))
   165  
   166  				_, err = w.Write([]byte(fmt.Sprintf("just responding with my own URL: %s/baz and some path of course", upstreamServer.URL)))
   167  				assert.NoError(err)
   168  			},
   169  			request: func(t *testing.T) *http.Request {
   170  				req, err := http.NewRequest(http.MethodPost, proxy.URL+"/foo", bytes.NewBufferString(fmt.Sprintf("some random content containing the request URL and path prefix %s/bar but also other stuff", upstreamServer.URL)))
   171  				require.NoError(t, err)
   172  				req.Host = "example.com"
   173  				return req
   174  			},
   175  			assertResponse: func(t *testing.T, resp *http.Response) {
   176  				assert.Equal(t, http.StatusOK, resp.StatusCode)
   177  
   178  				body, err := io.ReadAll(resp.Body)
   179  				require.NoError(t, err)
   180  				assert.Equal(t, "just responding with my own URL: https://example.com/foo/baz and some path of course", string(body))
   181  			},
   182  		},
   183  		{
   184  			desc: "redirection replacement",
   185  			hostMapper: func(host string) (*HostConfig, error) {
   186  				if host != "redirect.me" {
   187  					return nil, fmt.Errorf("got unexpected host %s, expected 'redirect.me'", host)
   188  				}
   189  				return &HostConfig{
   190  					CookieDomain: "redirect.me",
   191  				}, nil
   192  			},
   193  			handler: func(_ *assert.Assertions, w http.ResponseWriter, r *http.Request) {
   194  				http.Redirect(w, r, upstreamServer.URL+"/redirection/target", http.StatusSeeOther)
   195  			},
   196  			request: func(t *testing.T) *http.Request {
   197  				req, err := http.NewRequest(http.MethodGet, proxy.URL, nil)
   198  				require.NoError(t, err)
   199  				req.Host = "redirect.me"
   200  				return req
   201  			},
   202  			assertResponse: func(t *testing.T, r *http.Response) {
   203  				assert.Equal(t, http.StatusSeeOther, r.StatusCode)
   204  				assert.Equal(t, "https://redirect.me/redirection/target", r.Header.Get("Location"))
   205  			},
   206  		},
   207  		{
   208  			desc: "cookie replacement",
   209  			hostMapper: func(host string) (*HostConfig, error) {
   210  				if host != "auth.cookie.love" {
   211  					return nil, fmt.Errorf("got unexpected host %s, expected 'cookie.love'", host)
   212  				}
   213  				return &HostConfig{
   214  					CookieDomain: "cookie.love",
   215  				}, nil
   216  			},
   217  			handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) {
   218  				http.SetCookie(w, &http.Cookie{
   219  					Name:   "auth",
   220  					Value:  "my random cookie",
   221  					Domain: urlx.ParseOrPanic(upstreamServer.URL).Hostname(),
   222  				})
   223  				_, err := w.Write([]byte("OK"))
   224  				assert.NoError(err)
   225  			},
   226  			request: func(t *testing.T) *http.Request {
   227  				req, err := http.NewRequest(http.MethodGet, proxy.URL, nil)
   228  				require.NoError(t, err)
   229  				req.Host = "auth.cookie.love"
   230  				return req
   231  			},
   232  			assertResponse: func(t *testing.T, r *http.Response) {
   233  				cookies := r.Cookies()
   234  				require.Len(t, cookies, 1)
   235  				c := cookies[0]
   236  				assert.Equal(t, "auth", c.Name)
   237  				assert.Equal(t, "my random cookie", c.Value)
   238  				assert.Equal(t, "cookie.love", c.Domain)
   239  			},
   240  		},
   241  		{
   242  			desc: "custom middleware",
   243  			hostMapper: func(host string) (*HostConfig, error) {
   244  				return &HostConfig{}, nil
   245  			},
   246  			handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) {
   247  				assert.Equal("noauth.example.com", r.Host)
   248  				b, err := ioutil.ReadAll(r.Body)
   249  				assert.NoError(err)
   250  				assert.Equal("this is a new body", string(b))
   251  
   252  				_, err = w.Write([]byte("OK"))
   253  				assert.NoError(err)
   254  			},
   255  			request: func(t *testing.T) *http.Request {
   256  				req, err := http.NewRequest(http.MethodPost, proxy.URL, bytes.NewReader([]byte("body")))
   257  				require.NoError(t, err)
   258  				req.Host = "auth.example.com"
   259  				return req
   260  			},
   261  			assertResponse: func(t *testing.T, r *http.Response) {
   262  				body, err := io.ReadAll(r.Body)
   263  				require.NoError(t, err)
   264  				assert.Equal(t, "OK", string(body))
   265  				assert.Equal(t, "1234", r.Header.Get("Some-Header"))
   266  			},
   267  			reqMiddleware: func(req *http.Request, config *HostConfig, body []byte) ([]byte, error) {
   268  				req.Host = "noauth.example.com"
   269  				body = []byte("this is a new body")
   270  				return body, nil
   271  			},
   272  			respMiddleware: func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) {
   273  				resp.Header.Add("Some-Header", "1234")
   274  				return body, nil
   275  			},
   276  		},
   277  		{
   278  			desc: "custom request errors",
   279  			hostMapper: func(host string) (*HostConfig, error) {
   280  				return &HostConfig{}, errors.New("some host mapper error occurred")
   281  			},
   282  			handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) {
   283  				_, err := w.Write([]byte("OK"))
   284  				assert.NoError(err)
   285  			},
   286  			request: func(t *testing.T) *http.Request {
   287  				req, err := http.NewRequest(http.MethodPost, proxy.URL, bytes.NewReader([]byte("body")))
   288  				require.NoError(t, err)
   289  				req.Host = "auth.example.com"
   290  				return req
   291  			},
   292  			assertResponse: func(t *testing.T, r *http.Response) {
   293  				return
   294  			},
   295  			onErrReq: func(request *http.Request, err error) {
   296  				assert.Error(t, err)
   297  				assert.Equal(t, "some host mapper error occurred", err.Error())
   298  			},
   299  		},
   300  		{
   301  			desc: "custom response errors",
   302  			hostMapper: func(host string) (*HostConfig, error) {
   303  				return &HostConfig{}, nil
   304  			},
   305  			handler: func(assert *assert.Assertions, w http.ResponseWriter, r *http.Request) {
   306  				_, err := w.Write([]byte("OK"))
   307  				assert.NoError(err)
   308  			},
   309  			request: func(t *testing.T) *http.Request {
   310  				req, err := http.NewRequest(http.MethodPost, proxy.URL, bytes.NewReader([]byte("body")))
   311  				require.NoError(t, err)
   312  				req.Host = "auth.example.com"
   313  				return req
   314  			},
   315  			assertResponse: func(t *testing.T, r *http.Response) {
   316  				return
   317  			},
   318  			respMiddleware: func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) {
   319  				return nil, errors.New("some response middleware error")
   320  			},
   321  			onErrResp: func(response *http.Response, err error) error {
   322  				assert.Error(t, err)
   323  				assert.Equal(t, "some response middleware error", err.Error())
   324  				return err
   325  			},
   326  		},
   327  	} {
   328  		t.Run("case="+tc.desc, func(t *testing.T) {
   329  			go func() {
   330  				hostMapper <- func(r *http.Request) (*HostConfig, error) {
   331  					host := r.Host
   332  					hc, err := tc.hostMapper(host)
   333  					if err == nil {
   334  						hc.UpstreamHost = urlx.ParseOrPanic(upstreamServer.URL).Host
   335  						hc.UpstreamScheme = urlx.ParseOrPanic(upstreamServer.URL).Scheme
   336  						hc.TargetHost = hc.UpstreamHost
   337  						hc.TargetScheme = hc.UpstreamScheme
   338  					}
   339  					return hc, err
   340  				}
   341  				reqMiddleware <- tc.reqMiddleware
   342  				upstreamHandler <- func(w http.ResponseWriter, r *http.Request) {
   343  					t := &remoteT{t: t, w: w, r: r}
   344  					tc.handler(assert.New(t), t, r)
   345  				}
   346  				respMiddleware <- tc.respMiddleware
   347  			}()
   348  
   349  			go func() {
   350  				onErrorReq <- tc.onErrReq
   351  			}()
   352  
   353  			go func() {
   354  				onErrorResp <- tc.onErrResp
   355  			}()
   356  
   357  			resp, err := cl.Do(tc.request(t))
   358  			require.NoError(t, err)
   359  			tc.assertResponse(t, resp)
   360  		})
   361  	}
   362  }
   363  
   364  func TestBetweenReverseProxies(t *testing.T) {
   365  	// the target thinks it is running under the targetHost, while actually it is behind all three proxies
   366  	targetHost := "foobar.ory.sh"
   367  	targetHandler, c := httpx.NewChanHandler(1)
   368  	target := httptest.NewServer(targetHandler)
   369  
   370  	revProxyHandler := httputil.NewSingleHostReverseProxy(urlx.ParseOrPanic(target.URL))
   371  	revProxy := httptest.NewServer(revProxyHandler)
   372  
   373  	thisProxy := httptest.NewServer(New(func(ctx context.Context, _ *http.Request) (*HostConfig, error) {
   374  		return &HostConfig{
   375  			CookieDomain:   "sh",
   376  			UpstreamHost:   urlx.ParseOrPanic(revProxy.URL).Host,
   377  			UpstreamScheme: urlx.ParseOrPanic(revProxy.URL).Scheme,
   378  			TargetScheme:   "http",
   379  			TargetHost:     targetHost,
   380  		}, nil
   381  	}))
   382  
   383  	ingressHandler := httputil.NewSingleHostReverseProxy(urlx.ParseOrPanic(thisProxy.URL))
   384  	ingress := httptest.NewServer(ingressHandler)
   385  
   386  	// In this scenario we want to force the use of the X-Forwarded-Host header instead of the Host header.
   387  	singleHostDirector := ingressHandler.Director
   388  	ingressHandler.Director = func(req *http.Request) {
   389  		singleHostDirector(req)
   390  		req.Header.Set("X-Forwarded-Host", req.Host)
   391  		req.Host = urlx.ParseOrPanic(ingress.URL).Host
   392  	}
   393  
   394  	t.Run("case=replaces body", func(t *testing.T) {
   395  		const pattern = "Hello, I am available under http://%s!"
   396  		c <- func(w http.ResponseWriter, r *http.Request) {
   397  			fmt.Fprintf(w, pattern, targetHost)
   398  		}
   399  
   400  		host := "example.com"
   401  		req, err := http.NewRequest(http.MethodGet, ingress.URL, nil)
   402  		require.NoError(t, err)
   403  		req.Host = host
   404  
   405  		resp, err := http.DefaultClient.Do(req)
   406  		require.NoError(t, err)
   407  		body, err := io.ReadAll(resp.Body)
   408  		require.NoError(t, err)
   409  		assert.Equal(t, fmt.Sprintf(pattern, host), string(body))
   410  	})
   411  
   412  	t.Run("case=replaces cookies", func(t *testing.T) {
   413  		c <- func(w http.ResponseWriter, r *http.Request) {
   414  			http.SetCookie(w, &http.Cookie{
   415  				Name:   "foo",
   416  				Value:  "setting this cookie for my own domain",
   417  				Domain: targetHost,
   418  				Secure: true,
   419  			})
   420  		}
   421  
   422  		req, err := http.NewRequest(http.MethodGet, ingress.URL, nil)
   423  		require.NoError(t, err)
   424  		req.Host = "example.com"
   425  
   426  		resp, err := http.DefaultClient.Do(req)
   427  		require.NoError(t, err)
   428  
   429  		cookies := resp.Cookies()
   430  		require.Len(t, cookies, 1)
   431  		assert.Equal(t, "foo", cookies[0].Name)
   432  		assert.Equal(t, "setting this cookie for my own domain", cookies[0].Value)
   433  		assert.Equal(t, "sh", cookies[0].Domain)
   434  		assert.Equal(t, false, cookies[0].Secure)
   435  	})
   436  
   437  	t.Run("case=replaces location", func(t *testing.T) {
   438  		c <- func(w http.ResponseWriter, r *http.Request) {
   439  			http.Redirect(w, r, "http://"+targetHost, http.StatusSeeOther)
   440  		}
   441  
   442  		host := "example.com"
   443  		req, err := http.NewRequest(http.MethodGet, ingress.URL, nil)
   444  		require.NoError(t, err)
   445  		req.Host = host
   446  
   447  		resp, err := (&http.Client{
   448  			CheckRedirect: func(req *http.Request, via []*http.Request) error {
   449  				return http.ErrUseLastResponse
   450  			},
   451  		}).Do(req)
   452  		require.NoError(t, err)
   453  
   454  		assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
   455  		assert.Equal(t, "http://"+host, resp.Header.Get("Location"))
   456  	})
   457  }
   458  
   459  func TestProxyProtoMix(t *testing.T) {
   460  	const exposedHost = "foo.bar"
   461  
   462  	setup := func(t *testing.T, targetServerFunc, upstreamServerFunc func(http.Handler) *httptest.Server) (chan<- http.HandlerFunc, string, string, *http.Client) {
   463  		targetHandler, targetHandlerC := httpx.NewChanHandler(1)
   464  		targetServer := targetServerFunc(targetHandler)
   465  
   466  		upstream := httputil.NewSingleHostReverseProxy(urlx.ParseOrPanic(targetServer.URL))
   467  		upstream.Transport = targetServer.Client().Transport
   468  		upstreamServer := upstreamServerFunc(upstream)
   469  
   470  		proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (*HostConfig, error) {
   471  			return &HostConfig{
   472  				CookieDomain:   exposedHost,
   473  				UpstreamHost:   urlx.ParseOrPanic(upstreamServer.URL).Host,
   474  				UpstreamScheme: urlx.ParseOrPanic(upstreamServer.URL).Scheme,
   475  				TargetHost:     urlx.ParseOrPanic(targetServer.URL).Host,
   476  				TargetScheme:   urlx.ParseOrPanic(targetServer.URL).Scheme,
   477  			}, nil
   478  		}, WithTransport(upstreamServer.Client().Transport)))
   479  		client := proxy.Client()
   480  		client.CheckRedirect = func(*http.Request, []*http.Request) error {
   481  			return http.ErrUseLastResponse
   482  		}
   483  
   484  		return targetHandlerC, targetServer.URL, proxy.URL, client
   485  	}
   486  
   487  	for _, tc := range []struct {
   488  		name                               string
   489  		newUpstreamServer, newTargetServer func(http.Handler) *httptest.Server
   490  	}{
   491  		{
   492  			name:              "upstream http, target https",
   493  			newUpstreamServer: httptest.NewServer,
   494  			newTargetServer:   httptest.NewTLSServer,
   495  		},
   496  		{
   497  			name:              "upstream https, target http",
   498  			newUpstreamServer: httptest.NewTLSServer,
   499  			newTargetServer:   httptest.NewServer,
   500  		},
   501  	} {
   502  		t.Run("case="+tc.name, func(t *testing.T) {
   503  			handler, targetURL, proxyURL, client := setup(t, httptest.NewTLSServer, httptest.NewServer)
   504  
   505  			t.Run("case=redirect", func(t *testing.T) {
   506  				handler <- func(w http.ResponseWriter, r *http.Request) {
   507  					http.Redirect(w, r, targetURL+"/see-other", http.StatusSeeOther)
   508  				}
   509  
   510  				req, err := http.NewRequest(http.MethodGet, proxyURL, nil)
   511  				require.NoError(t, err)
   512  				req.Host = exposedHost
   513  
   514  				resp, err := client.Do(req)
   515  				require.NoError(t, err)
   516  				assert.Equal(t, "http://"+exposedHost+"/see-other", resp.Header.Get("Location"))
   517  			})
   518  
   519  			t.Run("case=body rewrite", func(t *testing.T) {
   520  				const template = "Hello, I am %s, who are you?"
   521  
   522  				handler <- func(w http.ResponseWriter, r *http.Request) {
   523  					_, _ = w.Write([]byte(fmt.Sprintf(template, targetURL)))
   524  				}
   525  
   526  				req, err := http.NewRequest(http.MethodGet, proxyURL, nil)
   527  				require.NoError(t, err)
   528  				req.Host = exposedHost
   529  
   530  				resp, err := client.Do(req)
   531  				require.NoError(t, err)
   532  				body, err := io.ReadAll(resp.Body)
   533  				require.NoError(t, err)
   534  				assert.Equal(t, fmt.Sprintf(template, "http://"+exposedHost), string(body))
   535  			})
   536  
   537  			t.Run("case=secure cookies", func(t *testing.T) {
   538  				handler <- func(w http.ResponseWriter, r *http.Request) {
   539  					cookie := &http.Cookie{
   540  						Name:   "foo",
   541  						Value:  "bar",
   542  						Domain: stripPort(urlx.ParseOrPanic(targetURL).Host),
   543  						Secure: true,
   544  					}
   545  					http.SetCookie(w, cookie)
   546  					_, _ = w.Write([]byte("please eat this cookie"))
   547  				}
   548  
   549  				req, err := http.NewRequest(http.MethodGet, proxyURL, nil)
   550  				require.NoError(t, err)
   551  				req.Host = exposedHost
   552  
   553  				resp, err := client.Do(req)
   554  				require.NoError(t, err)
   555  
   556  				cookies := resp.Cookies()
   557  				require.Len(t, cookies, 1)
   558  				assert.Equal(t, "foo", cookies[0].Name)
   559  				assert.Equal(t, "bar", cookies[0].Value)
   560  				assert.Equal(t, exposedHost, cookies[0].Domain)
   561  				assert.Equal(t, false, cookies[0].Secure)
   562  			})
   563  		})
   564  	}
   565  }
   566  
   567  func TestProxyWebsocketRequests(t *testing.T) {
   568  	// create an echo server that uses websockets to communicate
   569  	setupWebsocketServer := func(ctx context.Context) *httptest.Server {
   570  		upgrader := websocket.Upgrader{}
   571  		mux := http.NewServeMux()
   572  		mux.Handle("/echo", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   573  			c, err := upgrader.Upgrade(w, r, nil)
   574  			require.NoError(t, err)
   575  			defer c.Close()
   576  			for {
   577  				select {
   578  				case <-ctx.Done():
   579  					return
   580  				default:
   581  					mt, message, err := c.ReadMessage()
   582  					if err != nil {
   583  						return
   584  					}
   585  					require.NotEmpty(t, message)
   586  					err = c.WriteMessage(mt, message)
   587  					require.NoError(t, err)
   588  				}
   589  			}
   590  		}))
   591  		return httptest.NewServer(mux)
   592  	}
   593  
   594  	setupProxy := func(targetServer *httptest.Server) *httptest.Server {
   595  		proxy := httptest.NewServer(New(func(ctx context.Context, r *http.Request) (*HostConfig, error) {
   596  			return &HostConfig{
   597  				UpstreamHost:   urlx.ParseOrPanic(targetServer.URL).Host,
   598  				UpstreamScheme: urlx.ParseOrPanic(targetServer.URL).Scheme,
   599  				TargetHost:     urlx.ParseOrPanic(targetServer.URL).Host,
   600  				TargetScheme:   urlx.ParseOrPanic(targetServer.URL).Scheme,
   601  			}, nil
   602  		}))
   603  
   604  		return proxy
   605  	}
   606  
   607  	t.Logf("Creating websocket server with proxy with context timeout of 5 seconds")
   608  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   609  
   610  	t.Cleanup(cancel)
   611  
   612  	websocketServer := setupWebsocketServer(ctx)
   613  	defer websocketServer.Close()
   614  
   615  	proxyServer := setupProxy(websocketServer)
   616  	defer proxyServer.Close()
   617  
   618  	u := url.URL{Scheme: "ws", Host: urlx.ParseOrPanic(proxyServer.URL).Host, Path: "/echo"}
   619  
   620  	c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
   621  	require.NoError(t, err)
   622  	defer c.Close()
   623  
   624  	messages := make(chan []byte, 2)
   625  
   626  	// setup message reader
   627  	go func(ctx context.Context) {
   628  		for {
   629  			select {
   630  			case <-ctx.Done():
   631  				return
   632  			default:
   633  				_, message, err := c.ReadMessage()
   634  				if err != nil {
   635  					return
   636  				}
   637  				messages <- message
   638  				t.Logf("Received message from websocket client: %s\n", message)
   639  			}
   640  		}
   641  	}(ctx)
   642  
   643  	// write a message
   644  	testMessage := "test"
   645  	testJson := json.RawMessage(`{"data":"1234"}`)
   646  	t.Logf("Writing message to websocket server: %s\n", testMessage)
   647  	require.NoError(t, c.WriteMessage(websocket.TextMessage, []byte(testMessage)))
   648  	t.Logf("Writing message to websocket server: %s\n", testJson)
   649  	require.NoError(t, c.WriteJSON(testJson))
   650  
   651  	readChannel := func() []byte {
   652  		select {
   653  		case msg := <-messages:
   654  			return msg
   655  		case <-ctx.Done():
   656  			return []byte("")
   657  		}
   658  	}
   659  
   660  	require.Equalf(t, testMessage, string(readChannel()), "could not retrieve the test message from the websocket server")
   661  	require.JSONEqf(t, string(testJson), string(readChannel()), "could not retrieve the test json from the websocket server")
   662  }