github.com/avenga/couper@v1.12.2/handler/proxy.go (about)

     1  package handler
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httputil"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/hashicorp/hcl/v2"
    14  	"github.com/hashicorp/hcl/v2/hclsyntax"
    15  	"github.com/sirupsen/logrus"
    16  
    17  	hclbody "github.com/avenga/couper/config/body"
    18  	"github.com/avenga/couper/config/request"
    19  	"github.com/avenga/couper/errors"
    20  	"github.com/avenga/couper/eval"
    21  	"github.com/avenga/couper/handler/ascii"
    22  	"github.com/avenga/couper/handler/transport"
    23  	"github.com/avenga/couper/internal/seetie"
    24  	"github.com/avenga/couper/server/writer"
    25  )
    26  
    27  // headerBlacklist lists all header keys which will be removed after
    28  // context variable evaluation to ensure to not pass them upstream.
    29  var headerBlacklist = []string{"Authorization", "Cookie"}
    30  
    31  // Proxy wraps a httputil.ReverseProxy to apply additional configuration context
    32  // and have control over the roundtrip configuration.
    33  type Proxy struct {
    34  	allowWS bool
    35  	backend http.RoundTripper
    36  	context *hclsyntax.Body
    37  	logger  *logrus.Entry
    38  }
    39  
    40  func NewProxy(backend http.RoundTripper, ctx *hclsyntax.Body, allowWS bool, logger *logrus.Entry) *Proxy {
    41  	proxy := &Proxy{
    42  		allowWS: allowWS,
    43  		backend: backend,
    44  		context: ctx,
    45  		logger:  logger,
    46  	}
    47  
    48  	return proxy
    49  }
    50  
    51  func (p *Proxy) RoundTrip(req *http.Request) (*http.Response, error) {
    52  	// 1. Apply proxy blacklist
    53  	for _, key := range headerBlacklist {
    54  		req.Header.Del(key)
    55  	}
    56  
    57  	hclCtx := eval.ContextFromRequest(req).HCLContextSync()
    58  
    59  	// 2. Apply proxy-body
    60  	err := eval.ApplyRequestContext(hclCtx, p.context, req)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	// 3. Apply websockets-body
    66  	outCtx, err := p.applyWebsocketsRequest(hclCtx, req)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	// 4. apply some hcl context
    72  	expStatusVal, err := eval.ValueFromBodyAttribute(hclCtx, p.context, "expected_status")
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	outCtx = context.WithValue(outCtx, request.EndpointExpectedStatus, seetie.ValueToIntSlice(expStatusVal))
    78  
    79  	*req = *req.WithContext(outCtx)
    80  
    81  	if err = p.registerWebsocketsResponse(req); err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	// the chore reverse-proxy part
    86  	if req.ContentLength == 0 {
    87  		req.Body = nil // Issue 16036: nil Body for http.Transport retries
    88  	}
    89  	if req.Body != nil {
    90  		defer req.Body.Close()
    91  	}
    92  	req.Close = false
    93  
    94  	reqUpType := upgradeType(req.Header)
    95  	if !ascii.IsPrint(reqUpType) {
    96  		return nil, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)
    97  	}
    98  
    99  	transport.RemoveConnectionHeaders(req.Header)
   100  
   101  	// Remove hop-by-hop headers to the backend. Especially
   102  	// important is "Connection" because we want a persistent
   103  	// connection, regardless of what the client sent to us.
   104  	for _, h := range transport.HopHeaders {
   105  		req.Header.Del(h)
   106  	}
   107  
   108  	// TODO: trailer header here
   109  
   110  	// After stripping all the hop-by-hop connection headers above, add back any
   111  	// necessary for protocol upgrades, such as for websockets.
   112  	if reqUpType != "" {
   113  		req.Header.Set("Connection", "Upgrade")
   114  		req.Header.Set("Upgrade", reqUpType)
   115  	}
   116  
   117  	beresp, err := p.backend.RoundTrip(req)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
   123  	if beresp.StatusCode == http.StatusSwitchingProtocols {
   124  		return beresp, p.handleUpgradeResponse(req, beresp)
   125  	}
   126  
   127  	transport.RemoveConnectionHeaders(beresp.Header)
   128  	transport.RemoveHopHeaders(beresp.Header)
   129  
   130  	evalCtx := eval.ContextFromRequest(req)
   131  	err = eval.ApplyResponseContext(evalCtx.HCLContextSync(), p.context, beresp)
   132  
   133  	return beresp, err
   134  }
   135  
   136  func upgradeType(h http.Header) string {
   137  	conn, exist := h["Connection"]
   138  	if !exist {
   139  		return ""
   140  	}
   141  	for _, v := range conn {
   142  		if strings.ToLower(v) == "upgrade" {
   143  			return h.Get("Upgrade")
   144  		}
   145  	}
   146  	return ""
   147  }
   148  
   149  func (p *Proxy) applyWebsocketsRequest(hclCtx *hcl.EvalContext, req *http.Request) (context.Context, error) {
   150  	outCtx := req.Context()
   151  	if p.allowWS {
   152  		outCtx = context.WithValue(outCtx, request.WebsocketsAllowed, p.allowWS)
   153  	} else {
   154  		return outCtx, nil
   155  	}
   156  
   157  	// This method needs the 'request.WebsocketsAllowed' flag in the 'req.context'.
   158  	if !eval.IsUpgradeRequest(req.WithContext(outCtx)) {
   159  		return outCtx, nil
   160  	}
   161  
   162  	wsBody := p.getWebsocketsBody()
   163  	if wsBody == nil { // applies if just the websockets attribute is given
   164  		return outCtx, nil
   165  	}
   166  
   167  	if err := eval.ApplyRequestContext(hclCtx, wsBody, req); err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	attr, ok := wsBody.Attributes["timeout"]
   172  	if !ok {
   173  		return outCtx, nil
   174  	}
   175  
   176  	val, err := eval.Value(hclCtx, attr.Expr)
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	str := seetie.ValueToString(val)
   182  
   183  	timeout, err := time.ParseDuration(str)
   184  	if str != "" && err != nil {
   185  		return nil, err
   186  	}
   187  
   188  	outCtx = context.WithValue(outCtx, request.WebsocketsTimeout, timeout)
   189  	return outCtx, nil
   190  }
   191  
   192  func (p *Proxy) registerWebsocketsResponse(req *http.Request) error {
   193  	if !eval.IsUpgradeRequest(req) {
   194  		return nil
   195  	}
   196  
   197  	wsBody := p.getWebsocketsBody()
   198  	evalCtx := eval.ContextFromRequest(req)
   199  
   200  	if rw, ok := req.Context().Value(request.ResponseWriter).(*writer.Response); ok {
   201  		rw.AddModifier(evalCtx.HCLContextSync(), wsBody, p.context)
   202  	}
   203  
   204  	return nil
   205  }
   206  
   207  func (p *Proxy) getWebsocketsBody() *hclsyntax.Body {
   208  	wss := hclbody.BlocksOfType(p.context, "websockets")
   209  	if len(wss) != 1 {
   210  		return nil
   211  	}
   212  
   213  	return wss[0].Body
   214  }
   215  
   216  func (p *Proxy) handleUpgradeResponse(req *http.Request, res *http.Response) error {
   217  	rw, ok := req.Context().Value(request.ResponseWriter).(http.ResponseWriter)
   218  	if !ok {
   219  		return fmt.Errorf("can't switch protocols using non-ResponseWriter type %T", rw)
   220  	}
   221  
   222  	reqUpType := upgradeType(req.Header)
   223  	resUpType := upgradeType(res.Header)
   224  	if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
   225  		return fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType)
   226  	}
   227  	if !ascii.EqualFold(reqUpType, resUpType) {
   228  		return fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)
   229  	}
   230  
   231  	hj, ok := rw.(http.Hijacker)
   232  	if !ok {
   233  		return fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)
   234  	}
   235  	backConn, ok := res.Body.(io.ReadWriteCloser)
   236  	if !ok {
   237  		return fmt.Errorf("internal error: 101 switching protocols response with non-writable body")
   238  	}
   239  
   240  	backConnCloseCh := make(chan bool)
   241  	go func() {
   242  		// Ensure that the cancellation of a request closes the backend.
   243  		// See issue https://golang.org/issue/35559.
   244  		select {
   245  		case <-req.Context().Done():
   246  		case <-backConnCloseCh:
   247  		}
   248  		backConn.Close()
   249  	}()
   250  
   251  	defer close(backConnCloseCh)
   252  
   253  	conn, brw, err := hj.Hijack()
   254  	if err != nil {
   255  		return fmt.Errorf("hijack failed on protocol switch: %v", err)
   256  	}
   257  	defer conn.Close()
   258  
   259  	copyHeader(rw.Header(), res.Header)
   260  
   261  	res.Header = rw.Header()
   262  	res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
   263  	if err := res.Write(brw); err != nil {
   264  		return fmt.Errorf("response write: %v", err)
   265  	}
   266  	if err := brw.Flush(); err != nil {
   267  		return fmt.Errorf("response flush: %v", err)
   268  	}
   269  	errc := make(chan error, 1)
   270  	spc := switchProtocolCopier{user: conn, backend: backConn}
   271  	go spc.copyToBackend(errc)
   272  	go spc.copyFromBackend(errc)
   273  	<-errc
   274  	return nil
   275  }
   276  
   277  func copyHeader(dst, src http.Header) {
   278  	for k, vv := range src {
   279  		for _, v := range vv {
   280  			dst.Add(k, v)
   281  		}
   282  	}
   283  }
   284  
   285  func flushInterval(res *http.Response) time.Duration {
   286  	resCT := res.Header.Get("Content-Type")
   287  
   288  	// For Server-Sent Events responses, flush immediately.
   289  	// The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
   290  	if resCT == "text/event-stream" {
   291  		return -1 // negative means immediately
   292  	}
   293  
   294  	// We might have the case of streaming for which Content-Length might be unset.
   295  	if res.ContentLength == -1 {
   296  		return -1
   297  	}
   298  
   299  	return time.Millisecond * 100
   300  }
   301  
   302  var bufferPool httputil.BufferPool
   303  
   304  func copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
   305  	if flushInterval != 0 {
   306  		if wf, ok := dst.(writeFlusher); ok {
   307  			mlw := &maxLatencyWriter{
   308  				dst:     wf,
   309  				latency: flushInterval,
   310  			}
   311  			defer mlw.stop()
   312  
   313  			// set up initial timer so headers get flushed even if body writes are delayed
   314  			mlw.flushPending = true
   315  			mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
   316  
   317  			dst = mlw
   318  		}
   319  	}
   320  
   321  	var buf []byte
   322  	if bufferPool != nil {
   323  		buf = bufferPool.Get()
   324  		defer bufferPool.Put(buf)
   325  	}
   326  	_, err := copyBuffer(dst, src, buf)
   327  	return err
   328  }
   329  
   330  // copyBuffer returns any write errors or non-EOF read errors, and the amount
   331  // of bytes written.
   332  func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
   333  	if len(buf) == 0 {
   334  		buf = make([]byte, 32*1024)
   335  	}
   336  	var written int64
   337  	for {
   338  		nr, rerr := src.Read(buf)
   339  		if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
   340  			return 0, errors.Server.With(rerr).Message("read error during body copy")
   341  		}
   342  		if nr > 0 {
   343  			nw, werr := dst.Write(buf[:nr])
   344  			if nw > 0 {
   345  				written += int64(nw)
   346  			}
   347  			if werr != nil {
   348  				return written, werr
   349  			}
   350  			if nr != nw {
   351  				return written, io.ErrShortWrite
   352  			}
   353  		}
   354  		if rerr != nil {
   355  			if rerr == io.EOF {
   356  				rerr = nil
   357  			}
   358  			return written, rerr
   359  		}
   360  	}
   361  }
   362  
   363  type writeFlusher interface {
   364  	io.Writer
   365  	http.Flusher
   366  }
   367  
   368  type maxLatencyWriter struct {
   369  	dst     writeFlusher
   370  	latency time.Duration // non-zero; negative means to flush immediately
   371  
   372  	mu           sync.Mutex // protects t, flushPending, and dst.Flush
   373  	t            *time.Timer
   374  	flushPending bool
   375  }
   376  
   377  func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
   378  	m.mu.Lock()
   379  	defer m.mu.Unlock()
   380  	n, err = m.dst.Write(p)
   381  	if m.latency < 0 {
   382  		m.dst.Flush()
   383  		return
   384  	}
   385  	if m.flushPending {
   386  		return
   387  	}
   388  	if m.t == nil {
   389  		m.t = time.AfterFunc(m.latency, m.delayedFlush)
   390  	} else {
   391  		m.t.Reset(m.latency)
   392  	}
   393  	m.flushPending = true
   394  	return
   395  }
   396  
   397  func (m *maxLatencyWriter) delayedFlush() {
   398  	m.mu.Lock()
   399  	defer m.mu.Unlock()
   400  	if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
   401  		return
   402  	}
   403  	m.dst.Flush()
   404  	m.flushPending = false
   405  }
   406  
   407  func (m *maxLatencyWriter) stop() {
   408  	m.mu.Lock()
   409  	defer m.mu.Unlock()
   410  	m.flushPending = false
   411  	if m.t != nil {
   412  		m.t.Stop()
   413  	}
   414  }
   415  
   416  // switchProtocolCopier exists so goroutines proxying data back and
   417  // forth have nice names in stacks.
   418  type switchProtocolCopier struct {
   419  	user, backend io.ReadWriter
   420  }
   421  
   422  func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
   423  	_, err := io.Copy(c.user, c.backend)
   424  	errc <- err
   425  }
   426  
   427  func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
   428  	_, err := io.Copy(c.backend, c.user)
   429  	errc <- err
   430  }