github.com/xmidt-org/webpa-common@v1.11.9/xhttp/fanout/handler.go (about)

     1  package fanout
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/url"
    10  
    11  	"github.com/go-kit/kit/log"
    12  	"github.com/go-kit/kit/log/level"
    13  	gokithttp "github.com/go-kit/kit/transport/http"
    14  	"github.com/xmidt-org/webpa-common/logging"
    15  	"github.com/xmidt-org/webpa-common/tracing"
    16  	"github.com/xmidt-org/webpa-common/tracing/tracinghttp"
    17  )
    18  
    19  var (
    20  	errNoFanoutURLs  = errors.New("No fanout URLs")
    21  	errBadTransactor = errors.New("Transactor did not conform to stdlib API")
    22  )
    23  
    24  // Option provides a single configuration option for a fanout Handler
    25  type Option func(*Handler)
    26  
    27  // WithShouldTerminate configures a custom termination predicate for the fanout.  If terminate
    28  // is nil, DefaultShouldTerminate is used.
    29  func WithShouldTerminate(terminate ShouldTerminateFunc) Option {
    30  	return func(h *Handler) {
    31  		if terminate != nil {
    32  			h.shouldTerminate = terminate
    33  		} else {
    34  			h.shouldTerminate = DefaultShouldTerminate
    35  		}
    36  	}
    37  }
    38  
    39  // WithErrorEncoder configures a custom error encoder for errors that occur during fanout setup.
    40  // If encoder is nil, go-kit's DefaultErrorEncoder is used.
    41  func WithErrorEncoder(encoder gokithttp.ErrorEncoder) Option {
    42  	return func(h *Handler) {
    43  		if encoder != nil {
    44  			h.errorEncoder = encoder
    45  		} else {
    46  			h.errorEncoder = gokithttp.DefaultErrorEncoder
    47  		}
    48  	}
    49  }
    50  
    51  // WithTransactor configures a custom HTTP client transaction function.  If transactor is nil,
    52  // http.DefaultClient.Do is used as the transactor.
    53  func WithTransactor(transactor func(*http.Request) (*http.Response, error)) Option {
    54  	return func(h *Handler) {
    55  		if transactor != nil {
    56  			h.transactor = transactor
    57  		} else {
    58  			h.transactor = http.DefaultClient.Do
    59  		}
    60  	}
    61  }
    62  
    63  // WithFanoutBefore adds zero or more request functions that will tailor each fanout request.
    64  func WithFanoutBefore(before ...FanoutRequestFunc) Option {
    65  	return func(h *Handler) {
    66  		h.before = append(h.before, before...)
    67  	}
    68  }
    69  
    70  // WithClientBefore adds zero or more go-kit RequestFunc functions that will be applied to
    71  // each fanout request.
    72  func WithClientBefore(before ...gokithttp.RequestFunc) Option {
    73  	return func(h *Handler) {
    74  		for _, rf := range before {
    75  			h.before = append(
    76  				h.before,
    77  				func(ctx context.Context, _, fanout *http.Request, _ []byte) (context.Context, error) {
    78  					return rf(ctx, fanout), nil
    79  				},
    80  			)
    81  		}
    82  	}
    83  }
    84  
    85  // WithFanoutAfter adds zero or more response functions that are invoked to tailor the response
    86  // when a successful (i.e. terminating) fanout response is received.
    87  func WithFanoutAfter(after ...FanoutResponseFunc) Option {
    88  	return func(h *Handler) {
    89  		h.after = append(h.after, after...)
    90  	}
    91  }
    92  
    93  // WithClientAfter allows zero or more go-kit ClientResponseFuncs to be used as fanout after functions.
    94  func WithClientAfter(after ...gokithttp.ClientResponseFunc) Option {
    95  	return func(h *Handler) {
    96  		for _, rf := range after {
    97  			h.after = append(
    98  				h.after,
    99  				func(ctx context.Context, response http.ResponseWriter, result Result) context.Context {
   100  					return rf(ctx, result.Response)
   101  				},
   102  			)
   103  		}
   104  	}
   105  }
   106  
   107  // WithFanoutFailure adds zero or more response functions that are invoked to tailor the response
   108  // when a failed fanout responses have been received.
   109  func WithFanoutFailure(failure ...FanoutResponseFunc) Option {
   110  	return func(h *Handler) {
   111  		h.failure = append(h.failure, failure...)
   112  	}
   113  }
   114  
   115  // WithClientFailure allows zero or more go-kit ClientResponseFuncs to be used as fanout failure functions.
   116  func WithClientFailure(failure ...gokithttp.ClientResponseFunc) Option {
   117  	return func(h *Handler) {
   118  		for _, rf := range failure {
   119  			h.failure = append(
   120  				h.failure,
   121  				func(ctx context.Context, response http.ResponseWriter, result Result) context.Context {
   122  					return rf(ctx, result.Response)
   123  				},
   124  			)
   125  		}
   126  	}
   127  }
   128  
   129  // WithConfiguration uses a set of (typically injected) fanout configuration options to configure a Handler.
   130  // Use of this option will not override the configured Endpoints instance.
   131  func WithConfiguration(c Configuration) Option {
   132  	return func(h *Handler) {
   133  		WithTransactor(NewTransactor(c))(h)
   134  
   135  		authorization := c.authorization()
   136  		if len(authorization) > 0 {
   137  			WithClientBefore(gokithttp.SetRequestHeader("Authorization", authorization))(h)
   138  		}
   139  	}
   140  }
   141  
   142  // Handler is the http.Handler that fans out HTTP requests using the configured Endpoints strategy.
   143  type Handler struct {
   144  	endpoints       Endpoints
   145  	errorEncoder    gokithttp.ErrorEncoder
   146  	before          []FanoutRequestFunc
   147  	after           []FanoutResponseFunc
   148  	failure         []FanoutResponseFunc
   149  	shouldTerminate ShouldTerminateFunc
   150  	transactor      func(*http.Request) (*http.Response, error)
   151  }
   152  
   153  // New creates a fanout Handler.  The Endpoints strategy is required, and this constructor function will
   154  // panic if it is nil.
   155  //
   156  // By default, all fanout requests have the same HTTP method as the original request, but no body is set..  Clients must use the OriginalBody
   157  // strategy to set the original request's body on each fanout request.
   158  func New(e Endpoints, options ...Option) *Handler {
   159  	if e == nil {
   160  		panic("An Endpoints strategy is required")
   161  	}
   162  
   163  	h := &Handler{
   164  		endpoints:       e,
   165  		errorEncoder:    gokithttp.DefaultErrorEncoder,
   166  		shouldTerminate: DefaultShouldTerminate,
   167  		transactor:      http.DefaultClient.Do,
   168  	}
   169  
   170  	for _, o := range options {
   171  		o(h)
   172  	}
   173  
   174  	return h
   175  }
   176  
   177  // newFanoutRequests uses the Endpoints strategy and builds (1) HTTP request for each endpoint.  The configured
   178  // FanoutRequestFunc options are used to build each request.  This method returns an error if no endpoints were returned
   179  // by the strategy or if an error reading the original request body occurred.
   180  func (h *Handler) newFanoutRequests(fanoutCtx context.Context, original *http.Request) ([]*http.Request, error) {
   181  	body, err := ioutil.ReadAll(original.Body)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  
   186  	urls, err := h.endpoints.FanoutURLs(original)
   187  	if err != nil {
   188  		return nil, err
   189  	} else if len(urls) == 0 {
   190  		return nil, errNoFanoutURLs
   191  	}
   192  
   193  	requests := make([]*http.Request, len(urls))
   194  	for i := 0; i < len(urls); i++ {
   195  		fanout := &http.Request{
   196  			Method:     original.Method,
   197  			URL:        urls[i],
   198  			Proto:      "HTTP/1.1",
   199  			ProtoMajor: 1,
   200  			ProtoMinor: 1,
   201  			Header:     make(http.Header),
   202  			Host:       urls[i].Host,
   203  		}
   204  
   205  		endpointCtx := fanoutCtx
   206  		var err error
   207  		for _, rf := range h.before {
   208  			endpointCtx, err = rf(endpointCtx, original, fanout, body)
   209  			if err != nil {
   210  				return nil, err
   211  			}
   212  		}
   213  
   214  		requests[i] = fanout.WithContext(endpointCtx)
   215  	}
   216  
   217  	return requests, nil
   218  }
   219  
   220  // execute performs a single fanout HTTP transaction and sends the result on a channel.  This method is invoked
   221  // as a goroutine.  It takes care of draining the fanout's response prior to returning.
   222  func (h *Handler) execute(logger log.Logger, spanner tracing.Spanner, results chan<- Result, request *http.Request) {
   223  	var (
   224  		finisher = spanner.Start(request.URL.String())
   225  		result   = Result{
   226  			Request: request,
   227  		}
   228  	)
   229  
   230  	result.Response, result.Err = h.transactor(request)
   231  	switch {
   232  	case result.Response != nil:
   233  		result.StatusCode = result.Response.StatusCode
   234  		result.ContentType = result.Response.Header.Get("Content-Type")
   235  
   236  		var err error
   237  		if result.Body, err = ioutil.ReadAll(result.Response.Body); err != nil {
   238  			logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error reading fanout response body", logging.ErrorKey(), err)
   239  		}
   240  
   241  		if err = result.Response.Body.Close(); err != nil {
   242  			logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error closing fanout response body", logging.ErrorKey(), err)
   243  		}
   244  
   245  	case result.Err != nil:
   246  		result.Body = []byte(fmt.Sprintf("%s", result.Err))
   247  		result.ContentType = "text/plain"
   248  
   249  		if ue, ok := result.Err.(*url.Error); ok && ue.Err != nil {
   250  			// unwrap the URL error
   251  			result.Err = ue.Err
   252  		}
   253  
   254  		if result.Err == context.Canceled || result.Err == context.DeadlineExceeded {
   255  			result.StatusCode = http.StatusGatewayTimeout
   256  		} else {
   257  			result.StatusCode = http.StatusServiceUnavailable
   258  		}
   259  
   260  	default:
   261  		// this "should" never happen, but just in case set a known status code
   262  		result.StatusCode = http.StatusServiceUnavailable
   263  		result.Err = errBadTransactor
   264  		result.Body = []byte(errBadTransactor.Error())
   265  		result.ContentType = "text/plain"
   266  	}
   267  
   268  	result.Span = finisher(result.Err)
   269  	results <- result
   270  }
   271  
   272  // finish takes a terminating fanout result and writes the appropriate information to the top-level response.  This method
   273  // is only invoked when a particular fanout response terminates the fanout, i.e. is considered successful.
   274  func (h *Handler) finish(logger log.Logger, response http.ResponseWriter, result Result, after []FanoutResponseFunc) {
   275  	ctx := result.Request.Context()
   276  	for _, rf := range after {
   277  		// NOTE: we don't use the context for anything here,
   278  		// but to preserve go-kit semantics we pass it to each after function
   279  		ctx = rf(ctx, response, result)
   280  	}
   281  
   282  	if len(result.Body) > 0 {
   283  		if len(result.ContentType) > 0 {
   284  			response.Header().Set("Content-Type", result.ContentType)
   285  		} else {
   286  			response.Header().Set("Content-Type", "application/octet-stream")
   287  		}
   288  
   289  		response.WriteHeader(result.StatusCode)
   290  		count, err := response.Write(result.Body)
   291  		if err != nil {
   292  			logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "wrote fanout response", "bytes", count, logging.ErrorKey(), err)
   293  		} else {
   294  			logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "wrote fanout response", "bytes", count)
   295  		}
   296  	} else {
   297  		logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "wrote fanout response", "statusCode", result.StatusCode)
   298  		response.WriteHeader(result.StatusCode)
   299  	}
   300  }
   301  
   302  func (h *Handler) ServeHTTP(response http.ResponseWriter, original *http.Request) {
   303  	var (
   304  		fanoutCtx     = original.Context()
   305  		logger        = logging.GetLogger(fanoutCtx)
   306  		requests, err = h.newFanoutRequests(fanoutCtx, original)
   307  	)
   308  
   309  	if err != nil {
   310  		logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "unable to create fanout", logging.ErrorKey(), err)
   311  		h.errorEncoder(fanoutCtx, err, response)
   312  		return
   313  	}
   314  
   315  	var (
   316  		spanner = tracing.NewSpanner()
   317  		results = make(chan Result, len(requests))
   318  	)
   319  
   320  	for _, r := range requests {
   321  		go h.execute(logger, spanner, results, r)
   322  	}
   323  
   324  	statusCode := 0
   325  	var latestResponse Result
   326  	for i := 0; i < len(requests); i++ {
   327  		select {
   328  		case <-fanoutCtx.Done():
   329  			logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "fanout operation canceled or timed out", "statusCode", http.StatusGatewayTimeout, "url", original.URL, logging.ErrorKey(), fanoutCtx.Err())
   330  			response.WriteHeader(http.StatusGatewayTimeout)
   331  			return
   332  
   333  		case r := <-results:
   334  			tracinghttp.HeadersForSpans("", response.Header(), r.Span)
   335  			if r.Err != nil {
   336  				logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "fanout request complete", "statusCode", r.StatusCode, "url", r.Request.URL, logging.ErrorKey(), r.Err)
   337  			} else {
   338  				logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "fanout request complete", "statusCode", r.StatusCode, "url", r.Request.URL)
   339  			}
   340  
   341  			if h.shouldTerminate(r) {
   342  				// this was a "success", so no reason to wait any longer
   343  				h.finish(logger, response, r, h.after)
   344  				return
   345  			}
   346  
   347  			if statusCode < r.StatusCode {
   348  				statusCode = r.StatusCode
   349  				latestResponse = r
   350  			}
   351  		}
   352  	}
   353  
   354  	logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "all fanout requests failed", "statusCode", statusCode, "url", original.URL)
   355  	h.finish(logger, response, latestResponse, h.failure)
   356  }