github.com/blend/go-sdk@v1.20220411.3/reverseproxy/proxy_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package reverseproxy
     9  
    10  import (
    11  	"bufio"
    12  	"context"
    13  	"fmt"
    14  	"io"
    15  	"net/http"
    16  	"net/http/httptest"
    17  	"net/url"
    18  	"testing"
    19  
    20  	"github.com/blend/go-sdk/assert"
    21  	"github.com/blend/go-sdk/logger"
    22  	"github.com/blend/go-sdk/webutil"
    23  )
    24  
    25  var (
    26  	_ webutil.HTTPTracer        = (*mockHTTPTracer)(nil)
    27  	_ webutil.HTTPTraceFinisher = (*mockHTTPTraceFinisher)(nil)
    28  )
    29  
    30  func Test_Proxy(t *testing.T) {
    31  	its := assert.New(t)
    32  
    33  	mockedEndpoint := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    34  		if protoHeader := r.Header.Get(webutil.HeaderXForwardedProto); protoHeader == "" {
    35  			http.Error(w, "No `X-Forwarded-Proto` header!", http.StatusBadRequest)
    36  			return
    37  		}
    38  		w.WriteHeader(http.StatusOK)
    39  		fmt.Fprint(w, "Ok!")
    40  	}))
    41  	defer mockedEndpoint.Close()
    42  
    43  	target, err := url.Parse(mockedEndpoint.URL)
    44  	its.Nil(err)
    45  
    46  	proxy, err := NewProxy(
    47  		OptProxyUpstream(NewUpstream(target)),
    48  		OptProxySetHeaderValue(webutil.HeaderXForwardedProto, webutil.SchemeHTTP),
    49  	)
    50  	its.Nil(err)
    51  
    52  	mockedProxy := httptest.NewServer(proxy)
    53  	defer mockedProxy.Close()
    54  
    55  	res, err := http.Get(mockedProxy.URL)
    56  	its.Nil(err)
    57  	defer res.Body.Close()
    58  
    59  	its.Empty(res.Header.Get("x-forwarded-proto"))
    60  	its.Empty(res.Header.Get("x-forwarded-port"))
    61  
    62  	fullBody, err := io.ReadAll(res.Body)
    63  	its.Nil(err)
    64  
    65  	mockedContents := string(fullBody)
    66  	its.Equal(http.StatusOK, res.StatusCode)
    67  	its.Equal("Ok!", mockedContents)
    68  }
    69  
    70  func Test_Proxy_Tracer(t *testing.T) {
    71  	its := assert.New(t)
    72  
    73  	mockedEndpoint := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    74  		if protoHeader := r.Header.Get(webutil.HeaderXForwardedProto); protoHeader == "" {
    75  			http.Error(w, "No `X-Forwarded-Proto` header!", http.StatusBadRequest)
    76  			return
    77  		}
    78  		w.WriteHeader(http.StatusOK)
    79  		fmt.Fprint(w, "Ok!")
    80  	}))
    81  	defer mockedEndpoint.Close()
    82  
    83  	target, err := url.Parse(mockedEndpoint.URL)
    84  	its.Nil(err)
    85  
    86  	tracer := &mockHTTPTracer{}
    87  	proxy, err := NewProxy(
    88  		OptProxyUpstream(NewUpstream(target)),
    89  		OptProxySetHeaderValue(webutil.HeaderXForwardedProto, webutil.SchemeHTTP),
    90  		OptProxyTracer(tracer),
    91  	)
    92  	its.Nil(err)
    93  
    94  	mockedProxy := httptest.NewServer(proxy)
    95  	defer mockedProxy.Close()
    96  
    97  	res, err := http.Get(mockedProxy.URL)
    98  	its.Nil(err)
    99  	defer res.Body.Close()
   100  
   101  	its.Equal(http.StatusOK, res.StatusCode)
   102  
   103  	req := tracer.Request
   104  	its.NotNil(req)
   105  	its.Equal("GET", req.Method)
   106  	its.Equal("/", req.URL.String())
   107  	its.Equal(mockedProxy.URL, "http://"+req.Host)
   108  
   109  	its.Equal(http.StatusOK, tracer.StatusCode)
   110  	its.Nil(tracer.Error)
   111  }
   112  
   113  // Referencing https://golang.org/src/net/http/httputil/reverseproxy_test.go
   114  func TestReverseProxyWebSocket(t *testing.T) {
   115  	assert := assert.New(t)
   116  
   117  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   118  		assert.Equal(UpgradeType(r.Header), "websocket")
   119  
   120  		c, _, err := w.(http.Hijacker).Hijack()
   121  		if err != nil {
   122  			t.Error(err)
   123  			return
   124  		}
   125  		defer c.Close()
   126  		fmt.Fprint(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
   127  		bs := bufio.NewScanner(c)
   128  		if !bs.Scan() {
   129  			t.Errorf("backend failed to read line from client: %v", bs.Err())
   130  			return
   131  		}
   132  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
   133  	}))
   134  	defer backendServer.Close()
   135  
   136  	backendURL := MustParseURL(backendServer.URL)
   137  	proxy, err := NewProxy(
   138  		OptProxyUpstream(NewUpstream(backendURL)),
   139  		OptProxySetHeaderValue(webutil.HeaderXForwardedProto, webutil.SchemeHTTP),
   140  	)
   141  	assert.Nil(err)
   142  
   143  	frontendProxy := httptest.NewServer(proxy)
   144  	defer frontendProxy.Close()
   145  
   146  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
   147  	req.Header.Set("Connection", "Upgrade")
   148  	req.Header.Set("Upgrade", "websocket")
   149  
   150  	c := frontendProxy.Client()
   151  	res, err := c.Do(req)
   152  	assert.Nil(err)
   153  
   154  	assert.Equal(res.StatusCode, 101)
   155  
   156  	assert.Equal(UpgradeType(req.Header), "websocket")
   157  
   158  	rwc, ok := res.Body.(io.ReadWriteCloser)
   159  	assert.True(ok)
   160  	defer rwc.Close()
   161  
   162  	fmt.Fprint(rwc, "Hello\n")
   163  	bs := bufio.NewScanner(rwc)
   164  	assert.True(bs.Scan())
   165  
   166  	got := bs.Text()
   167  	want := `backend got "Hello"`
   168  	assert.Equal(got, want)
   169  }
   170  
   171  type mockHTTPTracer struct {
   172  	Request    *http.Request
   173  	StatusCode int
   174  	Error      error
   175  }
   176  
   177  func (mht *mockHTTPTracer) Start(req *http.Request) (webutil.HTTPTraceFinisher, *http.Request) {
   178  	mht.Request = req
   179  	return &mockHTTPTraceFinisher{mht}, req
   180  }
   181  
   182  type mockHTTPTraceFinisher struct {
   183  	Tracer *mockHTTPTracer
   184  }
   185  
   186  func (mhtf *mockHTTPTraceFinisher) Finish(statusCode int, err error) {
   187  	mhtf.Tracer.StatusCode = statusCode
   188  	mhtf.Tracer.Error = err
   189  }
   190  
   191  func TestProxy_Panic(t *testing.T) {
   192  	its := assert.New(t)
   193  
   194  	mockedEndpoint := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   195  		if protoHeader := r.Header.Get(webutil.HeaderXForwardedProto); protoHeader == "" {
   196  			http.Error(w, "No `X-Forwarded-Proto` header!", http.StatusBadRequest)
   197  			return
   198  		}
   199  		w.WriteHeader(http.StatusOK)
   200  		fmt.Fprint(w, "Ok!")
   201  	}))
   202  	defer mockedEndpoint.Close()
   203  
   204  	target, err := url.Parse(mockedEndpoint.URL)
   205  	its.Nil(err)
   206  
   207  	log := logger.Memory(io.Discard)
   208  	defer log.Close()
   209  
   210  	errors := make(chan error)
   211  	log.Listen(logger.Fatal, "panic-chan", logger.NewErrorEventListener(func(ctx context.Context, e logger.ErrorEvent) {
   212  		errors <- e.Err
   213  	}))
   214  
   215  	proxy, err := NewProxy(
   216  		OptProxyUpstream(NewUpstream(
   217  			target,
   218  		)),
   219  		OptProxyLog(log),
   220  		OptProxyResolver(func(_ *http.Request, _ []*Upstream) (*Upstream, error) {
   221  			panic("this is just a test")
   222  		}),
   223  		OptProxySetHeaderValue(webutil.HeaderXForwardedProto, webutil.SchemeHTTP),
   224  	)
   225  	its.Nil(err)
   226  
   227  	mockedProxy := httptest.NewServer(proxy)
   228  
   229  	res, err := http.Get(mockedProxy.URL)
   230  	its.Nil(err)
   231  	defer res.Body.Close()
   232  	its.Equal(http.StatusOK, res.StatusCode)
   233  	err = <-errors
   234  	its.NotNil(err)
   235  	its.Equal("this is just a test", err.Error())
   236  }
   237  
   238  func TestProxy_Panic_httpAbortHandler(t *testing.T) {
   239  	its := assert.New(t)
   240  
   241  	var didCallEndpoint bool
   242  	mockedEndpoint := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   243  		defer func() { didCallEndpoint = true }()
   244  		if protoHeader := r.Header.Get(webutil.HeaderXForwardedProto); protoHeader == "" {
   245  			http.Error(w, "No `X-Forwarded-Proto` header!", http.StatusBadRequest)
   246  			return
   247  		}
   248  		w.WriteHeader(http.StatusOK)
   249  		fmt.Fprint(w, "Ok!")
   250  	}))
   251  	defer mockedEndpoint.Close()
   252  
   253  	target, err := url.Parse(mockedEndpoint.URL)
   254  	its.Nil(err)
   255  
   256  	log := logger.Memory(io.Discard)
   257  	defer log.Close()
   258  
   259  	errors := make(chan error, 1)
   260  	log.Listen(logger.Fatal, "panic-chan", logger.NewErrorEventListener(func(ctx context.Context, e logger.ErrorEvent) {
   261  		errors <- e.Err
   262  	}))
   263  
   264  	proxy, err := NewProxy(
   265  		OptProxyUpstream(NewUpstream(
   266  			target,
   267  		)),
   268  		OptProxyLog(log),
   269  		OptProxyResolver(func(_ *http.Request, _ []*Upstream) (*Upstream, error) {
   270  			panic(http.ErrAbortHandler)
   271  		}),
   272  		OptProxySetHeaderValue(webutil.HeaderXForwardedProto, webutil.SchemeHTTP),
   273  	)
   274  	its.Nil(err)
   275  
   276  	mockedProxy := httptest.NewServer(proxy)
   277  
   278  	res, err := http.Get(mockedProxy.URL)
   279  	its.Nil(err)
   280  	defer res.Body.Close()
   281  	its.Equal(http.StatusOK, res.StatusCode)
   282  
   283  	// explicitly drain so we process any errors that would come up
   284  	log.Drain()
   285  
   286  	its.Empty(errors)
   287  	its.False(didCallEndpoint)
   288  }