github.com/avenga/couper@v1.12.2/handler/proxy.go (about) 1 package handler 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net/http" 8 "net/http/httputil" 9 "strings" 10 "sync" 11 "time" 12 13 "github.com/hashicorp/hcl/v2" 14 "github.com/hashicorp/hcl/v2/hclsyntax" 15 "github.com/sirupsen/logrus" 16 17 hclbody "github.com/avenga/couper/config/body" 18 "github.com/avenga/couper/config/request" 19 "github.com/avenga/couper/errors" 20 "github.com/avenga/couper/eval" 21 "github.com/avenga/couper/handler/ascii" 22 "github.com/avenga/couper/handler/transport" 23 "github.com/avenga/couper/internal/seetie" 24 "github.com/avenga/couper/server/writer" 25 ) 26 27 // headerBlacklist lists all header keys which will be removed after 28 // context variable evaluation to ensure to not pass them upstream. 29 var headerBlacklist = []string{"Authorization", "Cookie"} 30 31 // Proxy wraps a httputil.ReverseProxy to apply additional configuration context 32 // and have control over the roundtrip configuration. 33 type Proxy struct { 34 allowWS bool 35 backend http.RoundTripper 36 context *hclsyntax.Body 37 logger *logrus.Entry 38 } 39 40 func NewProxy(backend http.RoundTripper, ctx *hclsyntax.Body, allowWS bool, logger *logrus.Entry) *Proxy { 41 proxy := &Proxy{ 42 allowWS: allowWS, 43 backend: backend, 44 context: ctx, 45 logger: logger, 46 } 47 48 return proxy 49 } 50 51 func (p *Proxy) RoundTrip(req *http.Request) (*http.Response, error) { 52 // 1. Apply proxy blacklist 53 for _, key := range headerBlacklist { 54 req.Header.Del(key) 55 } 56 57 hclCtx := eval.ContextFromRequest(req).HCLContextSync() 58 59 // 2. Apply proxy-body 60 err := eval.ApplyRequestContext(hclCtx, p.context, req) 61 if err != nil { 62 return nil, err 63 } 64 65 // 3. Apply websockets-body 66 outCtx, err := p.applyWebsocketsRequest(hclCtx, req) 67 if err != nil { 68 return nil, err 69 } 70 71 // 4. apply some hcl context 72 expStatusVal, err := eval.ValueFromBodyAttribute(hclCtx, p.context, "expected_status") 73 if err != nil { 74 return nil, err 75 } 76 77 outCtx = context.WithValue(outCtx, request.EndpointExpectedStatus, seetie.ValueToIntSlice(expStatusVal)) 78 79 *req = *req.WithContext(outCtx) 80 81 if err = p.registerWebsocketsResponse(req); err != nil { 82 return nil, err 83 } 84 85 // the chore reverse-proxy part 86 if req.ContentLength == 0 { 87 req.Body = nil // Issue 16036: nil Body for http.Transport retries 88 } 89 if req.Body != nil { 90 defer req.Body.Close() 91 } 92 req.Close = false 93 94 reqUpType := upgradeType(req.Header) 95 if !ascii.IsPrint(reqUpType) { 96 return nil, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType) 97 } 98 99 transport.RemoveConnectionHeaders(req.Header) 100 101 // Remove hop-by-hop headers to the backend. Especially 102 // important is "Connection" because we want a persistent 103 // connection, regardless of what the client sent to us. 104 for _, h := range transport.HopHeaders { 105 req.Header.Del(h) 106 } 107 108 // TODO: trailer header here 109 110 // After stripping all the hop-by-hop connection headers above, add back any 111 // necessary for protocol upgrades, such as for websockets. 112 if reqUpType != "" { 113 req.Header.Set("Connection", "Upgrade") 114 req.Header.Set("Upgrade", reqUpType) 115 } 116 117 beresp, err := p.backend.RoundTrip(req) 118 if err != nil { 119 return nil, err 120 } 121 122 // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) 123 if beresp.StatusCode == http.StatusSwitchingProtocols { 124 return beresp, p.handleUpgradeResponse(req, beresp) 125 } 126 127 transport.RemoveConnectionHeaders(beresp.Header) 128 transport.RemoveHopHeaders(beresp.Header) 129 130 evalCtx := eval.ContextFromRequest(req) 131 err = eval.ApplyResponseContext(evalCtx.HCLContextSync(), p.context, beresp) 132 133 return beresp, err 134 } 135 136 func upgradeType(h http.Header) string { 137 conn, exist := h["Connection"] 138 if !exist { 139 return "" 140 } 141 for _, v := range conn { 142 if strings.ToLower(v) == "upgrade" { 143 return h.Get("Upgrade") 144 } 145 } 146 return "" 147 } 148 149 func (p *Proxy) applyWebsocketsRequest(hclCtx *hcl.EvalContext, req *http.Request) (context.Context, error) { 150 outCtx := req.Context() 151 if p.allowWS { 152 outCtx = context.WithValue(outCtx, request.WebsocketsAllowed, p.allowWS) 153 } else { 154 return outCtx, nil 155 } 156 157 // This method needs the 'request.WebsocketsAllowed' flag in the 'req.context'. 158 if !eval.IsUpgradeRequest(req.WithContext(outCtx)) { 159 return outCtx, nil 160 } 161 162 wsBody := p.getWebsocketsBody() 163 if wsBody == nil { // applies if just the websockets attribute is given 164 return outCtx, nil 165 } 166 167 if err := eval.ApplyRequestContext(hclCtx, wsBody, req); err != nil { 168 return nil, err 169 } 170 171 attr, ok := wsBody.Attributes["timeout"] 172 if !ok { 173 return outCtx, nil 174 } 175 176 val, err := eval.Value(hclCtx, attr.Expr) 177 if err != nil { 178 return nil, err 179 } 180 181 str := seetie.ValueToString(val) 182 183 timeout, err := time.ParseDuration(str) 184 if str != "" && err != nil { 185 return nil, err 186 } 187 188 outCtx = context.WithValue(outCtx, request.WebsocketsTimeout, timeout) 189 return outCtx, nil 190 } 191 192 func (p *Proxy) registerWebsocketsResponse(req *http.Request) error { 193 if !eval.IsUpgradeRequest(req) { 194 return nil 195 } 196 197 wsBody := p.getWebsocketsBody() 198 evalCtx := eval.ContextFromRequest(req) 199 200 if rw, ok := req.Context().Value(request.ResponseWriter).(*writer.Response); ok { 201 rw.AddModifier(evalCtx.HCLContextSync(), wsBody, p.context) 202 } 203 204 return nil 205 } 206 207 func (p *Proxy) getWebsocketsBody() *hclsyntax.Body { 208 wss := hclbody.BlocksOfType(p.context, "websockets") 209 if len(wss) != 1 { 210 return nil 211 } 212 213 return wss[0].Body 214 } 215 216 func (p *Proxy) handleUpgradeResponse(req *http.Request, res *http.Response) error { 217 rw, ok := req.Context().Value(request.ResponseWriter).(http.ResponseWriter) 218 if !ok { 219 return fmt.Errorf("can't switch protocols using non-ResponseWriter type %T", rw) 220 } 221 222 reqUpType := upgradeType(req.Header) 223 resUpType := upgradeType(res.Header) 224 if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. 225 return fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType) 226 } 227 if !ascii.EqualFold(reqUpType, resUpType) { 228 return fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType) 229 } 230 231 hj, ok := rw.(http.Hijacker) 232 if !ok { 233 return fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw) 234 } 235 backConn, ok := res.Body.(io.ReadWriteCloser) 236 if !ok { 237 return fmt.Errorf("internal error: 101 switching protocols response with non-writable body") 238 } 239 240 backConnCloseCh := make(chan bool) 241 go func() { 242 // Ensure that the cancellation of a request closes the backend. 243 // See issue https://golang.org/issue/35559. 244 select { 245 case <-req.Context().Done(): 246 case <-backConnCloseCh: 247 } 248 backConn.Close() 249 }() 250 251 defer close(backConnCloseCh) 252 253 conn, brw, err := hj.Hijack() 254 if err != nil { 255 return fmt.Errorf("hijack failed on protocol switch: %v", err) 256 } 257 defer conn.Close() 258 259 copyHeader(rw.Header(), res.Header) 260 261 res.Header = rw.Header() 262 res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above 263 if err := res.Write(brw); err != nil { 264 return fmt.Errorf("response write: %v", err) 265 } 266 if err := brw.Flush(); err != nil { 267 return fmt.Errorf("response flush: %v", err) 268 } 269 errc := make(chan error, 1) 270 spc := switchProtocolCopier{user: conn, backend: backConn} 271 go spc.copyToBackend(errc) 272 go spc.copyFromBackend(errc) 273 <-errc 274 return nil 275 } 276 277 func copyHeader(dst, src http.Header) { 278 for k, vv := range src { 279 for _, v := range vv { 280 dst.Add(k, v) 281 } 282 } 283 } 284 285 func flushInterval(res *http.Response) time.Duration { 286 resCT := res.Header.Get("Content-Type") 287 288 // For Server-Sent Events responses, flush immediately. 289 // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream 290 if resCT == "text/event-stream" { 291 return -1 // negative means immediately 292 } 293 294 // We might have the case of streaming for which Content-Length might be unset. 295 if res.ContentLength == -1 { 296 return -1 297 } 298 299 return time.Millisecond * 100 300 } 301 302 var bufferPool httputil.BufferPool 303 304 func copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { 305 if flushInterval != 0 { 306 if wf, ok := dst.(writeFlusher); ok { 307 mlw := &maxLatencyWriter{ 308 dst: wf, 309 latency: flushInterval, 310 } 311 defer mlw.stop() 312 313 // set up initial timer so headers get flushed even if body writes are delayed 314 mlw.flushPending = true 315 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) 316 317 dst = mlw 318 } 319 } 320 321 var buf []byte 322 if bufferPool != nil { 323 buf = bufferPool.Get() 324 defer bufferPool.Put(buf) 325 } 326 _, err := copyBuffer(dst, src, buf) 327 return err 328 } 329 330 // copyBuffer returns any write errors or non-EOF read errors, and the amount 331 // of bytes written. 332 func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { 333 if len(buf) == 0 { 334 buf = make([]byte, 32*1024) 335 } 336 var written int64 337 for { 338 nr, rerr := src.Read(buf) 339 if rerr != nil && rerr != io.EOF && rerr != context.Canceled { 340 return 0, errors.Server.With(rerr).Message("read error during body copy") 341 } 342 if nr > 0 { 343 nw, werr := dst.Write(buf[:nr]) 344 if nw > 0 { 345 written += int64(nw) 346 } 347 if werr != nil { 348 return written, werr 349 } 350 if nr != nw { 351 return written, io.ErrShortWrite 352 } 353 } 354 if rerr != nil { 355 if rerr == io.EOF { 356 rerr = nil 357 } 358 return written, rerr 359 } 360 } 361 } 362 363 type writeFlusher interface { 364 io.Writer 365 http.Flusher 366 } 367 368 type maxLatencyWriter struct { 369 dst writeFlusher 370 latency time.Duration // non-zero; negative means to flush immediately 371 372 mu sync.Mutex // protects t, flushPending, and dst.Flush 373 t *time.Timer 374 flushPending bool 375 } 376 377 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { 378 m.mu.Lock() 379 defer m.mu.Unlock() 380 n, err = m.dst.Write(p) 381 if m.latency < 0 { 382 m.dst.Flush() 383 return 384 } 385 if m.flushPending { 386 return 387 } 388 if m.t == nil { 389 m.t = time.AfterFunc(m.latency, m.delayedFlush) 390 } else { 391 m.t.Reset(m.latency) 392 } 393 m.flushPending = true 394 return 395 } 396 397 func (m *maxLatencyWriter) delayedFlush() { 398 m.mu.Lock() 399 defer m.mu.Unlock() 400 if !m.flushPending { // if stop was called but AfterFunc already started this goroutine 401 return 402 } 403 m.dst.Flush() 404 m.flushPending = false 405 } 406 407 func (m *maxLatencyWriter) stop() { 408 m.mu.Lock() 409 defer m.mu.Unlock() 410 m.flushPending = false 411 if m.t != nil { 412 m.t.Stop() 413 } 414 } 415 416 // switchProtocolCopier exists so goroutines proxying data back and 417 // forth have nice names in stacks. 418 type switchProtocolCopier struct { 419 user, backend io.ReadWriter 420 } 421 422 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { 423 _, err := io.Copy(c.user, c.backend) 424 errc <- err 425 } 426 427 func (c switchProtocolCopier) copyToBackend(errc chan<- error) { 428 _, err := io.Copy(c.backend, c.user) 429 errc <- err 430 }