github.com/avenga/couper@v1.12.2/handler/middleware/uid.go (about) 1 package middleware 2 3 import ( 4 "context" 5 "net/http" 6 "regexp" 7 8 "github.com/google/uuid" 9 "github.com/rs/xid" 10 11 "github.com/avenga/couper/config" 12 "github.com/avenga/couper/config/request" 13 "github.com/avenga/couper/errors" 14 ) 15 16 var regexUID = regexp.MustCompile(`^[a-zA-Z0-9@=/+-]{12,64}$`) 17 18 type UID struct { 19 conf *config.Settings 20 devProxyHeader string 21 generate UIDFunc 22 handler http.Handler 23 } 24 25 func NewUIDHandler(conf *config.Settings, devProxy string) Next { 26 return func(handler http.Handler) *NextHandler { 27 return NewHandler(&UID{ 28 conf: conf, 29 devProxyHeader: devProxy, 30 generate: NewUIDFunc(conf.RequestIDFormat), 31 handler: handler, 32 }, handler) 33 } 34 } 35 36 // ServeHTTP generates a unique request-id and add this id to the request context and 37 // at least the response header even on error case. 38 func (u *UID) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 39 uid, err := u.newUID(req.Header) 40 41 *req = *req.WithContext(context.WithValue(req.Context(), request.UID, uid)) 42 43 if u.conf.RequestIDClientHeader != "" { 44 rw.Header().Set(u.conf.RequestIDClientHeader, uid) 45 } 46 47 if err != nil { 48 errors.DefaultHTML.WithError(errors.ClientRequest.With(err)).ServeHTTP(rw, req) 49 return 50 } 51 52 if u.conf.RequestIDBackendHeader != "" { 53 req.Header.Set(u.conf.RequestIDBackendHeader, uid) 54 } 55 56 u.handler.ServeHTTP(rw, req) 57 } 58 59 func (u *UID) newUID(header http.Header) (string, error) { 60 if u.conf.RequestIDAcceptFromHeader != "" { 61 if v := header.Get(u.conf.RequestIDAcceptFromHeader); v != "" { 62 if !regexUID.MatchString(v) { 63 return u.generate(), errors.ClientRequest. 64 Messagef("invalid request-id header value: %s: %s", u.conf.RequestIDAcceptFromHeader, v) 65 } 66 67 return v, nil 68 } 69 } else if httpsDevProxyID := header.Get(u.devProxyHeader); httpsDevProxyID != "" { 70 header.Del(u.devProxyHeader) 71 return httpsDevProxyID, nil 72 } 73 return u.generate(), nil 74 } 75 76 // UIDFunc wraps different unique id implementations. 77 type UIDFunc func() string 78 79 func NewUIDFunc(requestIDFormat string) UIDFunc { 80 var fn UIDFunc 81 if requestIDFormat == "uuid4" { 82 uuid.EnableRandPool() // Enabling the pool may improve the UUID generation throughput significantly. 83 fn = func() string { 84 return uuid.NewString() 85 } 86 } else { 87 fn = func() string { 88 return xid.New().String() 89 } 90 } 91 return fn 92 }