github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/proxy/reverse_proxy.go (about)

     1  package proxy
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httputil"
     7  	"net/url"
     8  	"os"
     9  	"strings"
    10  
    11  	"github.com/go-chi/chi"
    12  	"github.com/hellofresh/stats-go/bucket"
    13  	"github.com/hellofresh/stats-go/client"
    14  	log "github.com/sirupsen/logrus"
    15  	"go.opencensus.io/tag"
    16  	"go.opencensus.io/trace"
    17  
    18  	"github.com/hellofresh/janus/pkg/observability"
    19  	"github.com/hellofresh/janus/pkg/proxy/balancer"
    20  	"github.com/hellofresh/janus/pkg/router"
    21  )
    22  
    23  const (
    24  	statsSection = "upstream"
    25  )
    26  
    27  // NewBalancedReverseProxy creates a reverse proxy that is load balanced
    28  func NewBalancedReverseProxy(def *Definition, balancer balancer.Balancer, statsClient client.Client) *httputil.ReverseProxy {
    29  	return &httputil.ReverseProxy{
    30  		Director: createDirector(def, balancer, statsClient),
    31  	}
    32  }
    33  
    34  func createDirector(proxyDefinition *Definition, balancer balancer.Balancer, statsClient client.Client) func(req *http.Request) {
    35  	paramNameExtractor := router.NewListenPathParamNameExtractor()
    36  	matcher := router.NewListenPathMatcher()
    37  
    38  	return func(req *http.Request) {
    39  		upstream, err := balancer.Elect(proxyDefinition.Upstreams.Targets.ToBalancerTargets())
    40  		if err != nil {
    41  			log.WithError(err).Error("Could not elect one upstream")
    42  			return
    43  		}
    44  
    45  		targetURL := upstream.Target
    46  
    47  		paramNames := paramNameExtractor.Extract(targetURL)
    48  		parametrizedPath, err := applyParameters(req, targetURL, paramNames)
    49  		if err != nil {
    50  			log.WithError(err).Warn("Unable to extract param from request")
    51  		} else {
    52  			targetURL = parametrizedPath
    53  		}
    54  
    55  		log.WithField("target", targetURL).Debug("Target upstream elected")
    56  
    57  		target, err := url.Parse(targetURL)
    58  		if err != nil {
    59  			log.WithError(err).WithField("upstream_url", targetURL).Error("Could not parse the target URL")
    60  			return
    61  		}
    62  
    63  		originalURI := req.RequestURI
    64  		targetQuery := target.RawQuery
    65  		req.URL.Scheme = target.Scheme
    66  		req.URL.Host = target.Host
    67  		path := target.Path
    68  
    69  		if proxyDefinition.AppendPath {
    70  			log.Debug("Appending listen path to the target url")
    71  			path = singleJoiningSlash(target.Path, req.URL.Path)
    72  		}
    73  
    74  		if proxyDefinition.StripPath {
    75  			path = singleJoiningSlash(target.Path, req.URL.Path)
    76  			listenPath := matcher.Extract(proxyDefinition.ListenPath)
    77  
    78  			log.WithField("listen_path", listenPath).Debug("Stripping listen path")
    79  			if len(paramNames) > 0 {
    80  				path = stripPathWithParams(req, path, listenPath, paramNames)
    81  			} else {
    82  				path = strings.Replace(path, listenPath, "", 1)
    83  			}
    84  			if !strings.HasSuffix(target.Path, "/") && strings.HasSuffix(path, "/") {
    85  				path = path[:len(path)-1]
    86  			}
    87  		}
    88  
    89  		log.WithField("path", path).Debug("Upstream Path")
    90  		req.URL.Path = path
    91  
    92  		// This is very important to avoid problems with ssl verification for the HOST header
    93  		if proxyDefinition.PreserveHost {
    94  			log.Debug("Preserving the host header")
    95  		} else {
    96  			req.Host = target.Host
    97  		}
    98  
    99  		if targetQuery == "" || req.URL.RawQuery == "" {
   100  			req.URL.RawQuery = targetQuery + req.URL.RawQuery
   101  		} else {
   102  			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
   103  		}
   104  
   105  		// Since director modifies cloned request there is no way (or I just did not find one)
   106  		// to get upstream from logger middleware, so we're logging original request and upstream here
   107  		// with the same logging level. Original request is here to match two log messages in case
   108  		// RequestID is not enabled.
   109  		log.WithFields(log.Fields{
   110  			"request":          originalURI,
   111  			"request-id":       observability.RequestIDFromContext(req.Context()),
   112  			"upstream-host":    req.URL.Host,
   113  			"upstream-request": req.URL.RequestURI(),
   114  		}).Info("Proxying request to the following upstream")
   115  
   116  		statsClient.TrackMetric(statsSection, bucket.MetricOperation{req.Host})
   117  
   118  		// Add additional trace attributes
   119  		addTraceAttributes(req)
   120  
   121  		// Insert additional tags
   122  		ctx, _ := tag.New(req.Context(), tag.Insert(observability.KeyUpstreamPath, upstream.Target))
   123  		*req = *req.WithContext(ctx)
   124  	}
   125  }
   126  
   127  func addTraceAttributes(req *http.Request) {
   128  	ctx := req.Context()
   129  	span := trace.FromContext(ctx)
   130  	if span == nil {
   131  		return
   132  	}
   133  
   134  	host, err := os.Hostname()
   135  	if host == "" || err != nil {
   136  		log.WithError(err).Debug("Failed to get host name")
   137  		host = "unknown"
   138  	}
   139  
   140  	span.AddAttributes(
   141  		trace.StringAttribute("http.host", host),
   142  		trace.StringAttribute("http.referrer", req.Referer()),
   143  		trace.StringAttribute("http.remote_address", req.RemoteAddr),
   144  		trace.StringAttribute("request.id", observability.RequestIDFromContext(ctx)),
   145  	)
   146  }
   147  
   148  func applyParameters(req *http.Request, path string, paramNames []string) (string, error) {
   149  	for _, paramName := range paramNames {
   150  		paramValue := chi.URLParam(req, paramName)
   151  
   152  		if len(paramValue) == 0 {
   153  			return "", fmt.Errorf("unable to extract {%s} from request", paramName)
   154  		}
   155  
   156  		path = strings.Replace(
   157  			path,
   158  			fmt.Sprintf("{%s}", paramName),
   159  			paramValue,
   160  			-1,
   161  		)
   162  	}
   163  
   164  	return path, nil
   165  }
   166  
   167  func singleJoiningSlash(a, b string) string {
   168  	a = cleanSlashes(a)
   169  	b = cleanSlashes(b)
   170  
   171  	aSlash := strings.HasSuffix(a, "/")
   172  	bSlash := strings.HasPrefix(b, "/")
   173  
   174  	switch {
   175  	case aSlash && bSlash:
   176  		return a + b[1:]
   177  	case !aSlash && !bSlash:
   178  		if len(b) > 0 {
   179  			return a + "/" + b
   180  		}
   181  		return a
   182  	}
   183  	return a + b
   184  }
   185  
   186  func cleanSlashes(a string) string {
   187  	endSlash := strings.HasSuffix(a, "//")
   188  	startSlash := strings.HasPrefix(a, "//")
   189  
   190  	if startSlash {
   191  		a = "/" + strings.TrimPrefix(a, "//")
   192  	}
   193  
   194  	if endSlash {
   195  		a = strings.TrimSuffix(a, "//") + "/"
   196  	}
   197  
   198  	return a
   199  }
   200  
   201  // chiURLParam is created to allow for mocking of the chi.URLParam function.
   202  // This allowed for writing a quick unit test to check that the logic of the function works without having to deal with chi's context requirements.
   203  var chiURLParam = chi.URLParam
   204  // stripPathWithParams is intended to properly strip the listen path from the requested path when named parameters are used.
   205  // From left to right, it removes the first instance of each section of the listenPath and each paramName from the path.
   206  func stripPathWithParams(req *http.Request, path string, listenPath string, paramNames []string) string {
   207  	remove := strings.Split(listenPath, "/")
   208  	for i := 0; i < len(paramNames); i ++ {
   209  		remove = append(remove, chiURLParam(req, paramNames[i]))
   210  	}
   211  	for i := 1; i < len(remove); i++ {
   212  		path = strings.Replace(path, "/" + remove[i], "", 1)
   213  	}
   214  	return path
   215  }