github.com/avenga/couper@v1.12.2/handler/middleware/uid.go (about)

     1  package middleware
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"regexp"
     7  
     8  	"github.com/google/uuid"
     9  	"github.com/rs/xid"
    10  
    11  	"github.com/avenga/couper/config"
    12  	"github.com/avenga/couper/config/request"
    13  	"github.com/avenga/couper/errors"
    14  )
    15  
    16  var regexUID = regexp.MustCompile(`^[a-zA-Z0-9@=/+-]{12,64}$`)
    17  
    18  type UID struct {
    19  	conf           *config.Settings
    20  	devProxyHeader string
    21  	generate       UIDFunc
    22  	handler        http.Handler
    23  }
    24  
    25  func NewUIDHandler(conf *config.Settings, devProxy string) Next {
    26  	return func(handler http.Handler) *NextHandler {
    27  		return NewHandler(&UID{
    28  			conf:           conf,
    29  			devProxyHeader: devProxy,
    30  			generate:       NewUIDFunc(conf.RequestIDFormat),
    31  			handler:        handler,
    32  		}, handler)
    33  	}
    34  }
    35  
    36  // ServeHTTP generates a unique request-id and add this id to the request context and
    37  // at least the response header even on error case.
    38  func (u *UID) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
    39  	uid, err := u.newUID(req.Header)
    40  
    41  	*req = *req.WithContext(context.WithValue(req.Context(), request.UID, uid))
    42  
    43  	if u.conf.RequestIDClientHeader != "" {
    44  		rw.Header().Set(u.conf.RequestIDClientHeader, uid)
    45  	}
    46  
    47  	if err != nil {
    48  		errors.DefaultHTML.WithError(errors.ClientRequest.With(err)).ServeHTTP(rw, req)
    49  		return
    50  	}
    51  
    52  	if u.conf.RequestIDBackendHeader != "" {
    53  		req.Header.Set(u.conf.RequestIDBackendHeader, uid)
    54  	}
    55  
    56  	u.handler.ServeHTTP(rw, req)
    57  }
    58  
    59  func (u *UID) newUID(header http.Header) (string, error) {
    60  	if u.conf.RequestIDAcceptFromHeader != "" {
    61  		if v := header.Get(u.conf.RequestIDAcceptFromHeader); v != "" {
    62  			if !regexUID.MatchString(v) {
    63  				return u.generate(), errors.ClientRequest.
    64  					Messagef("invalid request-id header value: %s: %s", u.conf.RequestIDAcceptFromHeader, v)
    65  			}
    66  
    67  			return v, nil
    68  		}
    69  	} else if httpsDevProxyID := header.Get(u.devProxyHeader); httpsDevProxyID != "" {
    70  		header.Del(u.devProxyHeader)
    71  		return httpsDevProxyID, nil
    72  	}
    73  	return u.generate(), nil
    74  }
    75  
    76  // UIDFunc wraps different unique id implementations.
    77  type UIDFunc func() string
    78  
    79  func NewUIDFunc(requestIDFormat string) UIDFunc {
    80  	var fn UIDFunc
    81  	if requestIDFormat == "uuid4" {
    82  		uuid.EnableRandPool() // Enabling the pool may improve the UUID generation throughput significantly.
    83  		fn = func() string {
    84  			return uuid.NewString()
    85  		}
    86  	} else {
    87  		fn = func() string {
    88  			return xid.New().String()
    89  		}
    90  	}
    91  	return fn
    92  }