github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/proxy/reverse_proxy.go (about) 1 package proxy 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/http/httputil" 7 "net/url" 8 "os" 9 "strings" 10 11 "github.com/go-chi/chi" 12 "github.com/hellofresh/stats-go/bucket" 13 "github.com/hellofresh/stats-go/client" 14 log "github.com/sirupsen/logrus" 15 "go.opencensus.io/tag" 16 "go.opencensus.io/trace" 17 18 "github.com/hellofresh/janus/pkg/observability" 19 "github.com/hellofresh/janus/pkg/proxy/balancer" 20 "github.com/hellofresh/janus/pkg/router" 21 ) 22 23 const ( 24 statsSection = "upstream" 25 ) 26 27 // NewBalancedReverseProxy creates a reverse proxy that is load balanced 28 func NewBalancedReverseProxy(def *Definition, balancer balancer.Balancer, statsClient client.Client) *httputil.ReverseProxy { 29 return &httputil.ReverseProxy{ 30 Director: createDirector(def, balancer, statsClient), 31 } 32 } 33 34 func createDirector(proxyDefinition *Definition, balancer balancer.Balancer, statsClient client.Client) func(req *http.Request) { 35 paramNameExtractor := router.NewListenPathParamNameExtractor() 36 matcher := router.NewListenPathMatcher() 37 38 return func(req *http.Request) { 39 upstream, err := balancer.Elect(proxyDefinition.Upstreams.Targets.ToBalancerTargets()) 40 if err != nil { 41 log.WithError(err).Error("Could not elect one upstream") 42 return 43 } 44 45 targetURL := upstream.Target 46 47 paramNames := paramNameExtractor.Extract(targetURL) 48 parametrizedPath, err := applyParameters(req, targetURL, paramNames) 49 if err != nil { 50 log.WithError(err).Warn("Unable to extract param from request") 51 } else { 52 targetURL = parametrizedPath 53 } 54 55 log.WithField("target", targetURL).Debug("Target upstream elected") 56 57 target, err := url.Parse(targetURL) 58 if err != nil { 59 log.WithError(err).WithField("upstream_url", targetURL).Error("Could not parse the target URL") 60 return 61 } 62 63 originalURI := req.RequestURI 64 targetQuery := target.RawQuery 65 req.URL.Scheme = target.Scheme 66 req.URL.Host = target.Host 67 path := target.Path 68 69 if proxyDefinition.AppendPath { 70 log.Debug("Appending listen path to the target url") 71 path = singleJoiningSlash(target.Path, req.URL.Path) 72 } 73 74 if proxyDefinition.StripPath { 75 path = singleJoiningSlash(target.Path, req.URL.Path) 76 listenPath := matcher.Extract(proxyDefinition.ListenPath) 77 78 log.WithField("listen_path", listenPath).Debug("Stripping listen path") 79 if len(paramNames) > 0 { 80 path = stripPathWithParams(req, path, listenPath, paramNames) 81 } else { 82 path = strings.Replace(path, listenPath, "", 1) 83 } 84 if !strings.HasSuffix(target.Path, "/") && strings.HasSuffix(path, "/") { 85 path = path[:len(path)-1] 86 } 87 } 88 89 log.WithField("path", path).Debug("Upstream Path") 90 req.URL.Path = path 91 92 // This is very important to avoid problems with ssl verification for the HOST header 93 if proxyDefinition.PreserveHost { 94 log.Debug("Preserving the host header") 95 } else { 96 req.Host = target.Host 97 } 98 99 if targetQuery == "" || req.URL.RawQuery == "" { 100 req.URL.RawQuery = targetQuery + req.URL.RawQuery 101 } else { 102 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 103 } 104 105 // Since director modifies cloned request there is no way (or I just did not find one) 106 // to get upstream from logger middleware, so we're logging original request and upstream here 107 // with the same logging level. Original request is here to match two log messages in case 108 // RequestID is not enabled. 109 log.WithFields(log.Fields{ 110 "request": originalURI, 111 "request-id": observability.RequestIDFromContext(req.Context()), 112 "upstream-host": req.URL.Host, 113 "upstream-request": req.URL.RequestURI(), 114 }).Info("Proxying request to the following upstream") 115 116 statsClient.TrackMetric(statsSection, bucket.MetricOperation{req.Host}) 117 118 // Add additional trace attributes 119 addTraceAttributes(req) 120 121 // Insert additional tags 122 ctx, _ := tag.New(req.Context(), tag.Insert(observability.KeyUpstreamPath, upstream.Target)) 123 *req = *req.WithContext(ctx) 124 } 125 } 126 127 func addTraceAttributes(req *http.Request) { 128 ctx := req.Context() 129 span := trace.FromContext(ctx) 130 if span == nil { 131 return 132 } 133 134 host, err := os.Hostname() 135 if host == "" || err != nil { 136 log.WithError(err).Debug("Failed to get host name") 137 host = "unknown" 138 } 139 140 span.AddAttributes( 141 trace.StringAttribute("http.host", host), 142 trace.StringAttribute("http.referrer", req.Referer()), 143 trace.StringAttribute("http.remote_address", req.RemoteAddr), 144 trace.StringAttribute("request.id", observability.RequestIDFromContext(ctx)), 145 ) 146 } 147 148 func applyParameters(req *http.Request, path string, paramNames []string) (string, error) { 149 for _, paramName := range paramNames { 150 paramValue := chi.URLParam(req, paramName) 151 152 if len(paramValue) == 0 { 153 return "", fmt.Errorf("unable to extract {%s} from request", paramName) 154 } 155 156 path = strings.Replace( 157 path, 158 fmt.Sprintf("{%s}", paramName), 159 paramValue, 160 -1, 161 ) 162 } 163 164 return path, nil 165 } 166 167 func singleJoiningSlash(a, b string) string { 168 a = cleanSlashes(a) 169 b = cleanSlashes(b) 170 171 aSlash := strings.HasSuffix(a, "/") 172 bSlash := strings.HasPrefix(b, "/") 173 174 switch { 175 case aSlash && bSlash: 176 return a + b[1:] 177 case !aSlash && !bSlash: 178 if len(b) > 0 { 179 return a + "/" + b 180 } 181 return a 182 } 183 return a + b 184 } 185 186 func cleanSlashes(a string) string { 187 endSlash := strings.HasSuffix(a, "//") 188 startSlash := strings.HasPrefix(a, "//") 189 190 if startSlash { 191 a = "/" + strings.TrimPrefix(a, "//") 192 } 193 194 if endSlash { 195 a = strings.TrimSuffix(a, "//") + "/" 196 } 197 198 return a 199 } 200 201 // chiURLParam is created to allow for mocking of the chi.URLParam function. 202 // This allowed for writing a quick unit test to check that the logic of the function works without having to deal with chi's context requirements. 203 var chiURLParam = chi.URLParam 204 // stripPathWithParams is intended to properly strip the listen path from the requested path when named parameters are used. 205 // From left to right, it removes the first instance of each section of the listenPath and each paramName from the path. 206 func stripPathWithParams(req *http.Request, path string, listenPath string, paramNames []string) string { 207 remove := strings.Split(listenPath, "/") 208 for i := 0; i < len(paramNames); i ++ { 209 remove = append(remove, chiURLParam(req, paramNames[i])) 210 } 211 for i := 1; i < len(remove); i++ { 212 path = strings.Replace(path, "/" + remove[i], "", 1) 213 } 214 return path 215 }