storj.io/minio@v0.0.0-20230509071714-0cbc90f649b1/pkg/handlers/forwarder.go (about)

     1  /*
     2   * MinIO Cloud Storage, (C) 2018-2019 MinIO, Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package handlers
    18  
    19  import (
    20  	"context"
    21  	"net"
    22  	"net/http"
    23  	"net/http/httputil"
    24  	"net/url"
    25  	"strings"
    26  	"time"
    27  )
    28  
    29  const defaultFlushInterval = time.Duration(100) * time.Millisecond
    30  
    31  // Forwarder forwards all incoming HTTP requests to configured transport.
    32  type Forwarder struct {
    33  	RoundTripper http.RoundTripper
    34  	PassHost     bool
    35  	Logger       func(error)
    36  	ErrorHandler func(http.ResponseWriter, *http.Request, error)
    37  
    38  	// internal variables
    39  	rewriter *headerRewriter
    40  }
    41  
    42  // NewForwarder creates an instance of Forwarder based on the provided list of configuration options
    43  func NewForwarder(f *Forwarder) *Forwarder {
    44  	f.rewriter = &headerRewriter{}
    45  	if f.RoundTripper == nil {
    46  		f.RoundTripper = http.DefaultTransport
    47  	}
    48  
    49  	return f
    50  }
    51  
    52  // ServeHTTP forwards HTTP traffic using the configured transport
    53  func (f *Forwarder) ServeHTTP(w http.ResponseWriter, inReq *http.Request) {
    54  	outReq := new(http.Request)
    55  	*outReq = *inReq // includes shallow copies of maps, but we handle this in Director
    56  
    57  	revproxy := httputil.ReverseProxy{
    58  		Director: func(req *http.Request) {
    59  			f.modifyRequest(req, inReq.URL)
    60  		},
    61  		Transport:     f.RoundTripper,
    62  		FlushInterval: defaultFlushInterval,
    63  		ErrorHandler:  f.customErrHandler,
    64  	}
    65  
    66  	if f.ErrorHandler != nil {
    67  		revproxy.ErrorHandler = f.ErrorHandler
    68  	}
    69  
    70  	revproxy.ServeHTTP(w, outReq)
    71  }
    72  
    73  // customErrHandler is originally implemented to avoid having the following error
    74  //    `http: proxy error: context canceled` printed by Golang
    75  func (f *Forwarder) customErrHandler(w http.ResponseWriter, r *http.Request, err error) {
    76  	if f.Logger != nil && err != context.Canceled {
    77  		f.Logger(err)
    78  	}
    79  	w.WriteHeader(http.StatusBadGateway)
    80  }
    81  
    82  func (f *Forwarder) getURLFromRequest(req *http.Request) *url.URL {
    83  	// If the Request was created by Go via a real HTTP request,  RequestURI will
    84  	// contain the original query string. If the Request was created in code, RequestURI
    85  	// will be empty, and we will use the URL object instead
    86  	u := req.URL
    87  	if req.RequestURI != "" {
    88  		parsedURL, err := url.ParseRequestURI(req.RequestURI)
    89  		if err == nil {
    90  			u = parsedURL
    91  		}
    92  	}
    93  	return u
    94  }
    95  
    96  // copyURL provides update safe copy by avoiding shallow copying User field
    97  func copyURL(i *url.URL) *url.URL {
    98  	out := *i
    99  	if i.User != nil {
   100  		u := *i.User
   101  		out.User = &u
   102  	}
   103  	return &out
   104  }
   105  
   106  // Modify the request to handle the target URL
   107  func (f *Forwarder) modifyRequest(outReq *http.Request, target *url.URL) {
   108  	outReq.URL = copyURL(outReq.URL)
   109  	outReq.URL.Scheme = target.Scheme
   110  	outReq.URL.Host = target.Host
   111  
   112  	u := f.getURLFromRequest(outReq)
   113  
   114  	outReq.URL.Path = u.Path
   115  	outReq.URL.RawPath = u.RawPath
   116  	outReq.URL.RawQuery = u.RawQuery
   117  	outReq.RequestURI = "" // Outgoing request should not have RequestURI
   118  
   119  	// Do not pass client Host header unless requested.
   120  	if !f.PassHost {
   121  		outReq.Host = target.Host
   122  	}
   123  
   124  	// TODO: only supports HTTP 1.1 for now.
   125  	outReq.Proto = "HTTP/1.1"
   126  	outReq.ProtoMajor = 1
   127  	outReq.ProtoMinor = 1
   128  
   129  	f.rewriter.Rewrite(outReq)
   130  
   131  	// Disable closeNotify when method GET for http pipelining
   132  	if outReq.Method == http.MethodGet {
   133  		quietReq := outReq.WithContext(context.Background())
   134  		*outReq = *quietReq
   135  	}
   136  }
   137  
   138  // headerRewriter is responsible for removing hop-by-hop headers and setting forwarding headers
   139  type headerRewriter struct{}
   140  
   141  // Clean up IP in case if it is ipv6 address and it has {zone} information in it, like
   142  // "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692"
   143  func ipv6fix(clientIP string) string {
   144  	return strings.Split(clientIP, "%")[0]
   145  }
   146  
   147  func (rw *headerRewriter) Rewrite(req *http.Request) {
   148  	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
   149  		clientIP = ipv6fix(clientIP)
   150  		if req.Header.Get(xRealIP) == "" {
   151  			req.Header.Set(xRealIP, clientIP)
   152  		}
   153  	}
   154  
   155  	xfProto := req.Header.Get(xForwardedProto)
   156  	if xfProto == "" {
   157  		if req.TLS != nil {
   158  			req.Header.Set(xForwardedProto, "https")
   159  		} else {
   160  			req.Header.Set(xForwardedProto, "http")
   161  		}
   162  	}
   163  
   164  	if xfPort := req.Header.Get(xForwardedPort); xfPort == "" {
   165  		req.Header.Set(xForwardedPort, forwardedPort(req))
   166  	}
   167  
   168  	if xfHost := req.Header.Get(xForwardedHost); xfHost == "" && req.Host != "" {
   169  		req.Header.Set(xForwardedHost, req.Host)
   170  	}
   171  }
   172  
   173  func forwardedPort(req *http.Request) string {
   174  	if req == nil {
   175  		return ""
   176  	}
   177  
   178  	if _, port, err := net.SplitHostPort(req.Host); err == nil && port != "" {
   179  		return port
   180  	}
   181  
   182  	if req.TLS != nil {
   183  		return "443"
   184  	}
   185  
   186  	return "80"
   187  }