github.com/google/martian/v3@v3.3.3/proxy.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package martian 16 17 import ( 18 "bufio" 19 "bytes" 20 "crypto/tls" 21 "errors" 22 "io" 23 "net" 24 "net/http" 25 "net/http/httputil" 26 "net/url" 27 "regexp" 28 "sync" 29 "time" 30 31 "github.com/google/martian/v3/log" 32 "github.com/google/martian/v3/mitm" 33 "github.com/google/martian/v3/nosigpipe" 34 "github.com/google/martian/v3/proxyutil" 35 "github.com/google/martian/v3/trafficshape" 36 ) 37 38 var errClose = errors.New("closing connection") 39 var noop = Noop("martian") 40 41 func isCloseable(err error) bool { 42 if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 43 return true 44 } 45 46 switch err { 47 case io.EOF, io.ErrClosedPipe, errClose: 48 return true 49 } 50 51 return false 52 } 53 54 // Proxy is an HTTP proxy with support for TLS MITM and customizable behavior. 55 type Proxy struct { 56 roundTripper http.RoundTripper 57 dial func(string, string) (net.Conn, error) 58 timeout time.Duration 59 mitm *mitm.Config 60 proxyURL *url.URL 61 conns sync.WaitGroup 62 connsMu sync.Mutex // protects conns.Add/Wait from concurrent access 63 closing chan bool 64 65 reqmod RequestModifier 66 resmod ResponseModifier 67 } 68 69 // NewProxy returns a new HTTP proxy. 70 func NewProxy() *Proxy { 71 proxy := &Proxy{ 72 roundTripper: &http.Transport{ 73 // TODO(adamtanner): This forces the http.Transport to not upgrade requests 74 // to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2. 75 TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), 76 Proxy: http.ProxyFromEnvironment, 77 TLSHandshakeTimeout: 10 * time.Second, 78 ExpectContinueTimeout: time.Second, 79 }, 80 timeout: 5 * time.Minute, 81 closing: make(chan bool), 82 reqmod: noop, 83 resmod: noop, 84 } 85 proxy.SetDial((&net.Dialer{ 86 Timeout: 30 * time.Second, 87 KeepAlive: 30 * time.Second, 88 }).Dial) 89 return proxy 90 } 91 92 // GetRoundTripper gets the http.RoundTripper of the proxy. 93 func (p *Proxy) GetRoundTripper() http.RoundTripper { 94 return p.roundTripper 95 } 96 97 // SetRoundTripper sets the http.RoundTripper of the proxy. 98 func (p *Proxy) SetRoundTripper(rt http.RoundTripper) { 99 p.roundTripper = rt 100 101 if tr, ok := p.roundTripper.(*http.Transport); ok { 102 tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) 103 tr.Proxy = http.ProxyURL(p.proxyURL) 104 tr.Dial = p.dial 105 } 106 } 107 108 // SetDownstreamProxy sets the proxy that receives requests from the upstream 109 // proxy. 110 func (p *Proxy) SetDownstreamProxy(proxyURL *url.URL) { 111 p.proxyURL = proxyURL 112 113 if tr, ok := p.roundTripper.(*http.Transport); ok { 114 tr.Proxy = http.ProxyURL(p.proxyURL) 115 } 116 } 117 118 // SetTimeout sets the request timeout of the proxy. 119 func (p *Proxy) SetTimeout(timeout time.Duration) { 120 p.timeout = timeout 121 } 122 123 // SetMITM sets the config to use for MITMing of CONNECT requests. 124 func (p *Proxy) SetMITM(config *mitm.Config) { 125 p.mitm = config 126 } 127 128 // SetDial sets the dial func used to establish a connection. 129 func (p *Proxy) SetDial(dial func(string, string) (net.Conn, error)) { 130 p.dial = func(a, b string) (net.Conn, error) { 131 c, e := dial(a, b) 132 nosigpipe.IgnoreSIGPIPE(c) 133 return c, e 134 } 135 136 if tr, ok := p.roundTripper.(*http.Transport); ok { 137 tr.Dial = p.dial 138 } 139 } 140 141 // Close sets the proxy to the closing state so it stops receiving new connections, 142 // finishes processing any inflight requests, and closes existing connections without 143 // reading anymore requests from them. 144 func (p *Proxy) Close() { 145 log.Infof("martian: closing down proxy") 146 147 close(p.closing) 148 149 log.Infof("martian: waiting for connections to close") 150 p.connsMu.Lock() 151 p.conns.Wait() 152 p.connsMu.Unlock() 153 log.Infof("martian: all connections closed") 154 } 155 156 // Closing returns whether the proxy is in the closing state. 157 func (p *Proxy) Closing() bool { 158 select { 159 case <-p.closing: 160 return true 161 default: 162 return false 163 } 164 } 165 166 // SetRequestModifier sets the request modifier. 167 func (p *Proxy) SetRequestModifier(reqmod RequestModifier) { 168 if reqmod == nil { 169 reqmod = noop 170 } 171 172 p.reqmod = reqmod 173 } 174 175 // SetResponseModifier sets the response modifier. 176 func (p *Proxy) SetResponseModifier(resmod ResponseModifier) { 177 if resmod == nil { 178 resmod = noop 179 } 180 181 p.resmod = resmod 182 } 183 184 // Serve accepts connections from the listener and handles the requests. 185 func (p *Proxy) Serve(l net.Listener) error { 186 defer l.Close() 187 188 var delay time.Duration 189 for { 190 if p.Closing() { 191 return nil 192 } 193 194 conn, err := l.Accept() 195 nosigpipe.IgnoreSIGPIPE(conn) 196 if err != nil { 197 if nerr, ok := err.(net.Error); ok && nerr.Temporary() { 198 if delay == 0 { 199 delay = 5 * time.Millisecond 200 } else { 201 delay *= 2 202 } 203 if max := time.Second; delay > max { 204 delay = max 205 } 206 207 log.Debugf("martian: temporary error on accept: %v", err) 208 time.Sleep(delay) 209 continue 210 } 211 212 if errors.Is(err, net.ErrClosed) { 213 log.Debugf("martian: listener closed, returning") 214 return err 215 } 216 217 log.Errorf("martian: failed to accept: %v", err) 218 return err 219 } 220 delay = 0 221 log.Debugf("martian: accepted connection from %s", conn.RemoteAddr()) 222 223 if tconn, ok := conn.(*net.TCPConn); ok { 224 tconn.SetKeepAlive(true) 225 tconn.SetKeepAlivePeriod(3 * time.Minute) 226 } 227 228 go p.handleLoop(conn) 229 } 230 } 231 232 func (p *Proxy) handleLoop(conn net.Conn) { 233 p.connsMu.Lock() 234 p.conns.Add(1) 235 p.connsMu.Unlock() 236 defer p.conns.Done() 237 defer conn.Close() 238 if p.Closing() { 239 return 240 } 241 242 brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) 243 244 s, err := newSession(conn, brw) 245 if err != nil { 246 log.Errorf("martian: failed to create session: %v", err) 247 return 248 } 249 250 ctx, err := withSession(s) 251 if err != nil { 252 log.Errorf("martian: failed to create context: %v", err) 253 return 254 } 255 256 for { 257 deadline := time.Now().Add(p.timeout) 258 conn.SetDeadline(deadline) 259 260 if err := p.handle(ctx, conn, brw); isCloseable(err) { 261 log.Debugf("martian: closing connection: %v", conn.RemoteAddr()) 262 return 263 } 264 } 265 } 266 267 func (p *Proxy) readRequest(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) (*http.Request, error) { 268 var req *http.Request 269 reqc := make(chan *http.Request, 1) 270 errc := make(chan error, 1) 271 go func() { 272 r, err := http.ReadRequest(brw.Reader) 273 if err != nil { 274 errc <- err 275 return 276 } 277 reqc <- r 278 }() 279 select { 280 case err := <-errc: 281 if isCloseable(err) { 282 log.Debugf("martian: connection closed prematurely: %v", err) 283 } else { 284 log.Errorf("martian: failed to read request: %v", err) 285 } 286 287 // TODO: TCPConn.WriteClose() to avoid sending an RST to the client. 288 289 return nil, errClose 290 case req = <-reqc: 291 case <-p.closing: 292 return nil, errClose 293 } 294 295 return req, nil 296 } 297 298 func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *Session, brw *bufio.ReadWriter, conn net.Conn) error { 299 if err := p.reqmod.ModifyRequest(req); err != nil { 300 log.Errorf("martian: error modifying CONNECT request: %v", err) 301 proxyutil.Warning(req.Header, err) 302 } 303 if session.Hijacked() { 304 log.Debugf("martian: connection hijacked by request modifier") 305 return nil 306 } 307 308 if p.mitm != nil { 309 log.Debugf("martian: attempting MITM for connection: %s / %s", req.Host, req.URL.String()) 310 311 res := proxyutil.NewResponse(200, nil, req) 312 313 if err := p.resmod.ModifyResponse(res); err != nil { 314 log.Errorf("martian: error modifying CONNECT response: %v", err) 315 proxyutil.Warning(res.Header, err) 316 } 317 if session.Hijacked() { 318 log.Infof("martian: connection hijacked by response modifier") 319 return nil 320 } 321 322 if err := res.Write(brw); err != nil { 323 log.Errorf("martian: got error while writing response back to client: %v", err) 324 } 325 if err := brw.Flush(); err != nil { 326 log.Errorf("martian: got error while flushing response back to client: %v", err) 327 } 328 329 log.Debugf("martian: completed MITM for connection: %s", req.Host) 330 331 b := make([]byte, 1) 332 if _, err := brw.Read(b); err != nil { 333 log.Errorf("martian: error peeking message through CONNECT tunnel to determine type: %v", err) 334 } 335 336 // Drain all of the rest of the buffered data. 337 buf := make([]byte, brw.Reader.Buffered()) 338 brw.Read(buf) 339 340 // 22 is the TLS handshake. 341 // https://tools.ietf.org/html/rfc5246#section-6.2.1 342 if b[0] == 22 { 343 // Prepend the previously read data to be read again by 344 // http.ReadRequest. 345 tlsconn := tls.Server(&peekedConn{conn, io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)}, p.mitm.TLSForHost(req.Host)) 346 347 if err := tlsconn.Handshake(); err != nil { 348 p.mitm.HandshakeErrorCallback(req, err) 349 return err 350 } 351 if tlsconn.ConnectionState().NegotiatedProtocol == "h2" { 352 return p.mitm.H2Config().Proxy(p.closing, tlsconn, req.URL) 353 } 354 355 var nconn net.Conn 356 nconn = tlsconn 357 // If the original connection is a traffic shaped connection, wrap the tls 358 // connection inside a traffic shaped connection too. 359 if ptsconn, ok := conn.(*trafficshape.Conn); ok { 360 nconn = ptsconn.Listener.GetTrafficShapedConn(tlsconn) 361 } 362 brw.Writer.Reset(nconn) 363 brw.Reader.Reset(nconn) 364 return p.handle(ctx, nconn, brw) 365 } 366 367 // Prepend the previously read data to be read again by http.ReadRequest. 368 brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)) 369 return p.handle(ctx, conn, brw) 370 } 371 372 log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host) 373 res, cconn, cerr := p.connect(req) 374 if cerr != nil { 375 log.Errorf("martian: failed to CONNECT: %v", cerr) 376 res = proxyutil.NewResponse(502, nil, req) 377 proxyutil.Warning(res.Header, cerr) 378 379 if err := p.resmod.ModifyResponse(res); err != nil { 380 log.Errorf("martian: error modifying CONNECT response: %v", err) 381 proxyutil.Warning(res.Header, err) 382 } 383 if session.Hijacked() { 384 log.Infof("martian: connection hijacked by response modifier") 385 return nil 386 } 387 388 if err := res.Write(brw); err != nil { 389 log.Errorf("martian: got error while writing response back to client: %v", err) 390 } 391 err := brw.Flush() 392 if err != nil { 393 log.Errorf("martian: got error while flushing response back to client: %v", err) 394 } 395 return err 396 } 397 defer res.Body.Close() 398 defer cconn.Close() 399 400 if err := p.resmod.ModifyResponse(res); err != nil { 401 log.Errorf("martian: error modifying CONNECT response: %v", err) 402 proxyutil.Warning(res.Header, err) 403 } 404 if session.Hijacked() { 405 log.Infof("martian: connection hijacked by response modifier") 406 return nil 407 } 408 409 res.ContentLength = -1 410 if err := res.Write(brw); err != nil { 411 log.Errorf("martian: got error while writing response back to client: %v", err) 412 } 413 if err := brw.Flush(); err != nil { 414 log.Errorf("martian: got error while flushing response back to client: %v", err) 415 } 416 417 cbw := bufio.NewWriter(cconn) 418 cbr := bufio.NewReader(cconn) 419 defer cbw.Flush() 420 421 copySync := func(w io.Writer, r io.Reader, donec chan<- bool) { 422 if _, err := io.Copy(w, r); err != nil && err != io.EOF { 423 log.Errorf("martian: failed to copy CONNECT tunnel: %v", err) 424 } 425 426 log.Debugf("martian: CONNECT tunnel finished copying") 427 donec <- true 428 } 429 430 donec := make(chan bool, 2) 431 go copySync(cbw, brw, donec) 432 go copySync(brw, cbr, donec) 433 434 log.Debugf("martian: established CONNECT tunnel, proxying traffic") 435 <-donec 436 <-donec 437 log.Debugf("martian: closed CONNECT tunnel") 438 439 return errClose 440 } 441 442 func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error { 443 log.Debugf("martian: waiting for request: %v", conn.RemoteAddr()) 444 445 req, err := p.readRequest(ctx, conn, brw) 446 if err != nil { 447 return err 448 } 449 defer req.Body.Close() 450 451 session := ctx.Session() 452 ctx, err = withSession(session) 453 if err != nil { 454 log.Errorf("martian: failed to build new context: %v", err) 455 return err 456 } 457 458 link(req, ctx) 459 defer unlink(req) 460 461 if tsconn, ok := conn.(*trafficshape.Conn); ok { 462 wrconn := tsconn.GetWrappedConn() 463 if sconn, ok := wrconn.(*tls.Conn); ok { 464 session.MarkSecure() 465 466 cs := sconn.ConnectionState() 467 req.TLS = &cs 468 } 469 } 470 471 if tconn, ok := conn.(*tls.Conn); ok { 472 session.MarkSecure() 473 474 cs := tconn.ConnectionState() 475 req.TLS = &cs 476 } 477 478 req.URL.Scheme = "http" 479 if session.IsSecure() { 480 log.Infof("martian: forcing HTTPS inside secure session") 481 req.URL.Scheme = "https" 482 } 483 484 req.RemoteAddr = conn.RemoteAddr().String() 485 if req.URL.Host == "" { 486 req.URL.Host = req.Host 487 } 488 489 if req.Method == "CONNECT" { 490 return p.handleConnectRequest(ctx, req, session, brw, conn) 491 } 492 493 // Not a CONNECT request 494 if err := p.reqmod.ModifyRequest(req); err != nil { 495 log.Errorf("martian: error modifying request: %v", err) 496 proxyutil.Warning(req.Header, err) 497 } 498 if session.Hijacked() { 499 return nil 500 } 501 502 // perform the HTTP roundtrip 503 res, err := p.roundTrip(ctx, req) 504 if err != nil { 505 log.Errorf("martian: failed to round trip: %v", err) 506 res = proxyutil.NewResponse(502, nil, req) 507 proxyutil.Warning(res.Header, err) 508 } 509 defer res.Body.Close() 510 511 // set request to original request manually, res.Request may be changed in transport. 512 // see https://github.com/google/martian/issues/298 513 res.Request = req 514 515 if err := p.resmod.ModifyResponse(res); err != nil { 516 log.Errorf("martian: error modifying response: %v", err) 517 proxyutil.Warning(res.Header, err) 518 } 519 if session.Hijacked() { 520 log.Infof("martian: connection hijacked by response modifier") 521 return nil 522 } 523 524 var closing error 525 if req.Close || res.Close || p.Closing() { 526 log.Debugf("martian: received close request: %v", req.RemoteAddr) 527 res.Close = true 528 closing = errClose 529 } 530 531 // check if conn is a traffic shaped connection. 532 if ptsconn, ok := conn.(*trafficshape.Conn); ok { 533 ptsconn.Context = &trafficshape.Context{} 534 // Check if the request URL matches any URLRegex in Shapes. If so, set the connections's Context 535 // with the required information, so that the Write() method of the Conn has access to it. 536 for urlregex, buckets := range ptsconn.LocalBuckets { 537 if match, _ := regexp.MatchString(urlregex, req.URL.String()); match { 538 if rangeStart := proxyutil.GetRangeStart(res); rangeStart > -1 { 539 dump, err := httputil.DumpResponse(res, false) 540 if err != nil { 541 return err 542 } 543 ptsconn.Context = &trafficshape.Context{ 544 Shaping: true, 545 Buckets: buckets, 546 GlobalBucket: ptsconn.GlobalBuckets[urlregex], 547 URLRegex: urlregex, 548 RangeStart: rangeStart, 549 ByteOffset: rangeStart, 550 HeaderLen: int64(len(dump)), 551 HeaderBytesWritten: 0, 552 } 553 // Get the next action to perform, if there. 554 ptsconn.Context.NextActionInfo = ptsconn.GetNextActionFromByte(rangeStart) 555 // Check if response lies in a throttled byte range. 556 ptsconn.Context.ThrottleContext = ptsconn.GetCurrentThrottle(rangeStart) 557 if ptsconn.Context.ThrottleContext.ThrottleNow { 558 ptsconn.Context.Buckets.WriteBucket.SetCapacity( 559 ptsconn.Context.ThrottleContext.Bandwidth) 560 } 561 log.Infof( 562 "trafficshape: Request %s with Range Start: %d matches a Shaping request %s. Enforcing Traffic shaping.", 563 req.URL, rangeStart, urlregex) 564 } 565 break 566 } 567 } 568 } 569 570 err = res.Write(brw) 571 if err != nil { 572 log.Errorf("martian: got error while writing response back to client: %v", err) 573 if _, ok := err.(*trafficshape.ErrForceClose); ok { 574 closing = errClose 575 } 576 } 577 err = brw.Flush() 578 if err != nil { 579 log.Errorf("martian: got error while flushing response back to client: %v", err) 580 if _, ok := err.(*trafficshape.ErrForceClose); ok { 581 closing = errClose 582 } 583 } 584 return closing 585 } 586 587 // A peekedConn subverts the net.Conn.Read implementation, primarily so that 588 // sniffed bytes can be transparently prepended. 589 type peekedConn struct { 590 net.Conn 591 r io.Reader 592 } 593 594 // Read allows control over the embedded net.Conn's read data. By using an 595 // io.MultiReader one can read from a conn, and then replace what they read, to 596 // be read again. 597 func (c *peekedConn) Read(buf []byte) (int, error) { return c.r.Read(buf) } 598 599 func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, error) { 600 if ctx.SkippingRoundTrip() { 601 log.Debugf("martian: skipping round trip") 602 return proxyutil.NewResponse(200, nil, req), nil 603 } 604 605 return p.roundTripper.RoundTrip(req) 606 } 607 608 func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) { 609 if p.proxyURL != nil { 610 log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host) 611 612 conn, err := p.dial("tcp", p.proxyURL.Host) 613 if err != nil { 614 return nil, nil, err 615 } 616 pbw := bufio.NewWriter(conn) 617 pbr := bufio.NewReader(conn) 618 619 req.Write(pbw) 620 pbw.Flush() 621 622 res, err := http.ReadResponse(pbr, req) 623 if err != nil { 624 return nil, nil, err 625 } 626 627 return res, conn, nil 628 } 629 630 log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host) 631 632 conn, err := p.dial("tcp", req.URL.Host) 633 if err != nil { 634 return nil, nil, err 635 } 636 637 return proxyutil.NewResponse(200, nil, req), conn, nil 638 }