github.com/xmidt-org/webpa-common@v1.11.9/xhttp/fanout/handler.go (about) 1 package fanout 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "net/http" 9 "net/url" 10 11 "github.com/go-kit/kit/log" 12 "github.com/go-kit/kit/log/level" 13 gokithttp "github.com/go-kit/kit/transport/http" 14 "github.com/xmidt-org/webpa-common/logging" 15 "github.com/xmidt-org/webpa-common/tracing" 16 "github.com/xmidt-org/webpa-common/tracing/tracinghttp" 17 ) 18 19 var ( 20 errNoFanoutURLs = errors.New("No fanout URLs") 21 errBadTransactor = errors.New("Transactor did not conform to stdlib API") 22 ) 23 24 // Option provides a single configuration option for a fanout Handler 25 type Option func(*Handler) 26 27 // WithShouldTerminate configures a custom termination predicate for the fanout. If terminate 28 // is nil, DefaultShouldTerminate is used. 29 func WithShouldTerminate(terminate ShouldTerminateFunc) Option { 30 return func(h *Handler) { 31 if terminate != nil { 32 h.shouldTerminate = terminate 33 } else { 34 h.shouldTerminate = DefaultShouldTerminate 35 } 36 } 37 } 38 39 // WithErrorEncoder configures a custom error encoder for errors that occur during fanout setup. 40 // If encoder is nil, go-kit's DefaultErrorEncoder is used. 41 func WithErrorEncoder(encoder gokithttp.ErrorEncoder) Option { 42 return func(h *Handler) { 43 if encoder != nil { 44 h.errorEncoder = encoder 45 } else { 46 h.errorEncoder = gokithttp.DefaultErrorEncoder 47 } 48 } 49 } 50 51 // WithTransactor configures a custom HTTP client transaction function. If transactor is nil, 52 // http.DefaultClient.Do is used as the transactor. 53 func WithTransactor(transactor func(*http.Request) (*http.Response, error)) Option { 54 return func(h *Handler) { 55 if transactor != nil { 56 h.transactor = transactor 57 } else { 58 h.transactor = http.DefaultClient.Do 59 } 60 } 61 } 62 63 // WithFanoutBefore adds zero or more request functions that will tailor each fanout request. 64 func WithFanoutBefore(before ...FanoutRequestFunc) Option { 65 return func(h *Handler) { 66 h.before = append(h.before, before...) 67 } 68 } 69 70 // WithClientBefore adds zero or more go-kit RequestFunc functions that will be applied to 71 // each fanout request. 72 func WithClientBefore(before ...gokithttp.RequestFunc) Option { 73 return func(h *Handler) { 74 for _, rf := range before { 75 h.before = append( 76 h.before, 77 func(ctx context.Context, _, fanout *http.Request, _ []byte) (context.Context, error) { 78 return rf(ctx, fanout), nil 79 }, 80 ) 81 } 82 } 83 } 84 85 // WithFanoutAfter adds zero or more response functions that are invoked to tailor the response 86 // when a successful (i.e. terminating) fanout response is received. 87 func WithFanoutAfter(after ...FanoutResponseFunc) Option { 88 return func(h *Handler) { 89 h.after = append(h.after, after...) 90 } 91 } 92 93 // WithClientAfter allows zero or more go-kit ClientResponseFuncs to be used as fanout after functions. 94 func WithClientAfter(after ...gokithttp.ClientResponseFunc) Option { 95 return func(h *Handler) { 96 for _, rf := range after { 97 h.after = append( 98 h.after, 99 func(ctx context.Context, response http.ResponseWriter, result Result) context.Context { 100 return rf(ctx, result.Response) 101 }, 102 ) 103 } 104 } 105 } 106 107 // WithFanoutFailure adds zero or more response functions that are invoked to tailor the response 108 // when a failed fanout responses have been received. 109 func WithFanoutFailure(failure ...FanoutResponseFunc) Option { 110 return func(h *Handler) { 111 h.failure = append(h.failure, failure...) 112 } 113 } 114 115 // WithClientFailure allows zero or more go-kit ClientResponseFuncs to be used as fanout failure functions. 116 func WithClientFailure(failure ...gokithttp.ClientResponseFunc) Option { 117 return func(h *Handler) { 118 for _, rf := range failure { 119 h.failure = append( 120 h.failure, 121 func(ctx context.Context, response http.ResponseWriter, result Result) context.Context { 122 return rf(ctx, result.Response) 123 }, 124 ) 125 } 126 } 127 } 128 129 // WithConfiguration uses a set of (typically injected) fanout configuration options to configure a Handler. 130 // Use of this option will not override the configured Endpoints instance. 131 func WithConfiguration(c Configuration) Option { 132 return func(h *Handler) { 133 WithTransactor(NewTransactor(c))(h) 134 135 authorization := c.authorization() 136 if len(authorization) > 0 { 137 WithClientBefore(gokithttp.SetRequestHeader("Authorization", authorization))(h) 138 } 139 } 140 } 141 142 // Handler is the http.Handler that fans out HTTP requests using the configured Endpoints strategy. 143 type Handler struct { 144 endpoints Endpoints 145 errorEncoder gokithttp.ErrorEncoder 146 before []FanoutRequestFunc 147 after []FanoutResponseFunc 148 failure []FanoutResponseFunc 149 shouldTerminate ShouldTerminateFunc 150 transactor func(*http.Request) (*http.Response, error) 151 } 152 153 // New creates a fanout Handler. The Endpoints strategy is required, and this constructor function will 154 // panic if it is nil. 155 // 156 // By default, all fanout requests have the same HTTP method as the original request, but no body is set.. Clients must use the OriginalBody 157 // strategy to set the original request's body on each fanout request. 158 func New(e Endpoints, options ...Option) *Handler { 159 if e == nil { 160 panic("An Endpoints strategy is required") 161 } 162 163 h := &Handler{ 164 endpoints: e, 165 errorEncoder: gokithttp.DefaultErrorEncoder, 166 shouldTerminate: DefaultShouldTerminate, 167 transactor: http.DefaultClient.Do, 168 } 169 170 for _, o := range options { 171 o(h) 172 } 173 174 return h 175 } 176 177 // newFanoutRequests uses the Endpoints strategy and builds (1) HTTP request for each endpoint. The configured 178 // FanoutRequestFunc options are used to build each request. This method returns an error if no endpoints were returned 179 // by the strategy or if an error reading the original request body occurred. 180 func (h *Handler) newFanoutRequests(fanoutCtx context.Context, original *http.Request) ([]*http.Request, error) { 181 body, err := ioutil.ReadAll(original.Body) 182 if err != nil { 183 return nil, err 184 } 185 186 urls, err := h.endpoints.FanoutURLs(original) 187 if err != nil { 188 return nil, err 189 } else if len(urls) == 0 { 190 return nil, errNoFanoutURLs 191 } 192 193 requests := make([]*http.Request, len(urls)) 194 for i := 0; i < len(urls); i++ { 195 fanout := &http.Request{ 196 Method: original.Method, 197 URL: urls[i], 198 Proto: "HTTP/1.1", 199 ProtoMajor: 1, 200 ProtoMinor: 1, 201 Header: make(http.Header), 202 Host: urls[i].Host, 203 } 204 205 endpointCtx := fanoutCtx 206 var err error 207 for _, rf := range h.before { 208 endpointCtx, err = rf(endpointCtx, original, fanout, body) 209 if err != nil { 210 return nil, err 211 } 212 } 213 214 requests[i] = fanout.WithContext(endpointCtx) 215 } 216 217 return requests, nil 218 } 219 220 // execute performs a single fanout HTTP transaction and sends the result on a channel. This method is invoked 221 // as a goroutine. It takes care of draining the fanout's response prior to returning. 222 func (h *Handler) execute(logger log.Logger, spanner tracing.Spanner, results chan<- Result, request *http.Request) { 223 var ( 224 finisher = spanner.Start(request.URL.String()) 225 result = Result{ 226 Request: request, 227 } 228 ) 229 230 result.Response, result.Err = h.transactor(request) 231 switch { 232 case result.Response != nil: 233 result.StatusCode = result.Response.StatusCode 234 result.ContentType = result.Response.Header.Get("Content-Type") 235 236 var err error 237 if result.Body, err = ioutil.ReadAll(result.Response.Body); err != nil { 238 logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error reading fanout response body", logging.ErrorKey(), err) 239 } 240 241 if err = result.Response.Body.Close(); err != nil { 242 logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error closing fanout response body", logging.ErrorKey(), err) 243 } 244 245 case result.Err != nil: 246 result.Body = []byte(fmt.Sprintf("%s", result.Err)) 247 result.ContentType = "text/plain" 248 249 if ue, ok := result.Err.(*url.Error); ok && ue.Err != nil { 250 // unwrap the URL error 251 result.Err = ue.Err 252 } 253 254 if result.Err == context.Canceled || result.Err == context.DeadlineExceeded { 255 result.StatusCode = http.StatusGatewayTimeout 256 } else { 257 result.StatusCode = http.StatusServiceUnavailable 258 } 259 260 default: 261 // this "should" never happen, but just in case set a known status code 262 result.StatusCode = http.StatusServiceUnavailable 263 result.Err = errBadTransactor 264 result.Body = []byte(errBadTransactor.Error()) 265 result.ContentType = "text/plain" 266 } 267 268 result.Span = finisher(result.Err) 269 results <- result 270 } 271 272 // finish takes a terminating fanout result and writes the appropriate information to the top-level response. This method 273 // is only invoked when a particular fanout response terminates the fanout, i.e. is considered successful. 274 func (h *Handler) finish(logger log.Logger, response http.ResponseWriter, result Result, after []FanoutResponseFunc) { 275 ctx := result.Request.Context() 276 for _, rf := range after { 277 // NOTE: we don't use the context for anything here, 278 // but to preserve go-kit semantics we pass it to each after function 279 ctx = rf(ctx, response, result) 280 } 281 282 if len(result.Body) > 0 { 283 if len(result.ContentType) > 0 { 284 response.Header().Set("Content-Type", result.ContentType) 285 } else { 286 response.Header().Set("Content-Type", "application/octet-stream") 287 } 288 289 response.WriteHeader(result.StatusCode) 290 count, err := response.Write(result.Body) 291 if err != nil { 292 logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "wrote fanout response", "bytes", count, logging.ErrorKey(), err) 293 } else { 294 logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "wrote fanout response", "bytes", count) 295 } 296 } else { 297 logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "wrote fanout response", "statusCode", result.StatusCode) 298 response.WriteHeader(result.StatusCode) 299 } 300 } 301 302 func (h *Handler) ServeHTTP(response http.ResponseWriter, original *http.Request) { 303 var ( 304 fanoutCtx = original.Context() 305 logger = logging.GetLogger(fanoutCtx) 306 requests, err = h.newFanoutRequests(fanoutCtx, original) 307 ) 308 309 if err != nil { 310 logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "unable to create fanout", logging.ErrorKey(), err) 311 h.errorEncoder(fanoutCtx, err, response) 312 return 313 } 314 315 var ( 316 spanner = tracing.NewSpanner() 317 results = make(chan Result, len(requests)) 318 ) 319 320 for _, r := range requests { 321 go h.execute(logger, spanner, results, r) 322 } 323 324 statusCode := 0 325 var latestResponse Result 326 for i := 0; i < len(requests); i++ { 327 select { 328 case <-fanoutCtx.Done(): 329 logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "fanout operation canceled or timed out", "statusCode", http.StatusGatewayTimeout, "url", original.URL, logging.ErrorKey(), fanoutCtx.Err()) 330 response.WriteHeader(http.StatusGatewayTimeout) 331 return 332 333 case r := <-results: 334 tracinghttp.HeadersForSpans("", response.Header(), r.Span) 335 if r.Err != nil { 336 logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "fanout request complete", "statusCode", r.StatusCode, "url", r.Request.URL, logging.ErrorKey(), r.Err) 337 } else { 338 logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "fanout request complete", "statusCode", r.StatusCode, "url", r.Request.URL) 339 } 340 341 if h.shouldTerminate(r) { 342 // this was a "success", so no reason to wait any longer 343 h.finish(logger, response, r, h.after) 344 return 345 } 346 347 if statusCode < r.StatusCode { 348 statusCode = r.StatusCode 349 latestResponse = r 350 } 351 } 352 } 353 354 logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "all fanout requests failed", "statusCode", statusCode, "url", original.URL) 355 h.finish(logger, response, latestResponse, h.failure) 356 }