github.com/avenga/couper@v1.12.2/handler/proxy_test.go (about)

     1  package handler_test
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/hashicorp/hcl/v2/hclsyntax"
    11  
    12  	"github.com/avenga/couper/config"
    13  	"github.com/avenga/couper/config/body"
    14  	"github.com/avenga/couper/config/request"
    15  	"github.com/avenga/couper/eval"
    16  	"github.com/avenga/couper/handler"
    17  	"github.com/avenga/couper/handler/transport"
    18  	"github.com/avenga/couper/internal/test"
    19  )
    20  
    21  func TestProxy_BlacklistHeaderRemoval(t *testing.T) {
    22  	log, _ := test.NewLogger()
    23  	logEntry := log.WithContext(context.Background())
    24  	p := handler.NewProxy(
    25  		transport.NewBackend(body.NewHCLSyntaxBodyWithStringAttr("origin", "https://1.2.3.4"), &transport.Config{
    26  			Origin: "https://1.2.3.4/",
    27  		}, nil, logEntry),
    28  		&hclsyntax.Body{},
    29  		false,
    30  		logEntry,
    31  	)
    32  
    33  	outreq := httptest.NewRequest("GET", "https://1.2.3.4/", nil)
    34  	outreq.Header.Set("Authorization", "Basic 123")
    35  	outreq.Header.Set("Cookie", "123")
    36  	outreq = outreq.WithContext(eval.NewContext(nil, &config.Defaults{}, "").WithClientRequest(outreq))
    37  	ctx, cancel := context.WithDeadline(context.WithValue(context.Background(), request.RoundTripProxy, true), time.Now().Add(time.Millisecond*50))
    38  	outreq = outreq.WithContext(ctx)
    39  	defer cancel()
    40  
    41  	_, _ = p.RoundTrip(outreq)
    42  
    43  	if outreq.Header.Get("Authorization") != "" {
    44  		t.Error("Expected removed Authorization header")
    45  	}
    46  
    47  	if outreq.Header.Get("Cookie") != "" {
    48  		t.Error("Expected removed Cookie header")
    49  	}
    50  }
    51  
    52  func TestProxy_WebsocketsAllowed(t *testing.T) {
    53  	log, _ := test.NewLogger()
    54  	logEntry := log.WithContext(context.Background())
    55  
    56  	origin := test.NewBackend()
    57  
    58  	pNotAllowed := handler.NewProxy(
    59  		transport.NewBackend(body.NewHCLSyntaxBodyWithStringAttr("origin", origin.Addr()), &transport.Config{
    60  			Origin: origin.Addr(),
    61  		}, nil, logEntry),
    62  		&hclsyntax.Body{},
    63  		false,
    64  		logEntry,
    65  	)
    66  
    67  	pAllowed := handler.NewProxy(
    68  		transport.NewBackend(body.NewHCLSyntaxBodyWithStringAttr("origin", origin.Addr()), &transport.Config{
    69  			Origin: origin.Addr(),
    70  		}, nil, logEntry),
    71  		&hclsyntax.Body{},
    72  		true,
    73  		logEntry,
    74  	)
    75  
    76  	headers := http.Header{
    77  		"Connection": []string{"upgrade"},
    78  		"Upgrade":    []string{"websocket"},
    79  	}
    80  
    81  	outreqN := httptest.NewRequest("GET", "http://couper.local/ws", nil)
    82  	outreqA := httptest.NewRequest("GET", "http://couper.local/ws", nil)
    83  
    84  	outCtx := context.WithValue(context.Background(), request.RoundTripProxy, true)
    85  
    86  	for _, r := range []*http.Request{outreqN, outreqA} {
    87  		for h := range headers {
    88  			r.Header.Set(h, headers.Get(h))
    89  		}
    90  	}
    91  
    92  	resN, _ := pNotAllowed.RoundTrip(outreqN.WithContext(outCtx))
    93  	resA, _ := pAllowed.RoundTrip(outreqA.WithContext(outCtx))
    94  
    95  	if resN.StatusCode != http.StatusBadRequest {
    96  		t.Errorf("expected a bad request on ws endpoint without related headers, got: %d", resN.StatusCode)
    97  	}
    98  
    99  	if resA.StatusCode != http.StatusSwitchingProtocols {
   100  		t.Errorf("expcted passed Connection and Upgrade header which results in 101, got: %d", resA.StatusCode)
   101  	}
   102  }