github.com/vmware/transport-go@v1.3.4/plank/pkg/middleware/middleware_manager.go (about) 1 // Copyright 2019-2021 VMware, Inc. 2 // SPDX-License-Identifier: BSD-2-Clause 3 4 package middleware 5 6 import ( 7 "fmt" 8 "github.com/gorilla/mux" 9 "github.com/vmware/transport-go/plank/utils" 10 "net/http" 11 "strings" 12 "sync" 13 ) 14 15 type MiddlewareManager interface { 16 SetGlobalMiddleware(middleware []mux.MiddlewareFunc) error 17 SetNewMiddleware(route *mux.Route, middleware []mux.MiddlewareFunc) error 18 RemoveMiddleware(route *mux.Route) error 19 GetRouteByUriAndMethod(uri, method string) (*mux.Route, error) 20 GetRouteByUri(uri string) (*mux.Route, error) 21 GetStaticRoute(prefix string) (*mux.Route, error) 22 } 23 24 type Middleware interface { 25 //Intercept(h http.Handler) http.Handler 26 Interceptor() mux.MiddlewareFunc 27 Name() string 28 } 29 30 type middlewareManager struct { 31 endpointHandlerMap *map[string]http.HandlerFunc 32 originalHandlersMap map[string]http.HandlerFunc 33 router *mux.Router 34 mu sync.Mutex 35 } 36 37 func (m *middlewareManager) SetGlobalMiddleware(middleware []mux.MiddlewareFunc) error { 38 m.mu.Lock() 39 defer m.mu.Unlock() 40 m.router.Use(middleware...) 41 return nil 42 } 43 44 func (m *middlewareManager) SetNewMiddleware(route *mux.Route, middleware []mux.MiddlewareFunc) error { 45 var key string 46 // expection is that a route's name ending with '*' means it's a prefix route 47 isPrefixRoute := route.GetName()[len(route.GetName())-1] == '*' 48 49 if !isPrefixRoute { 50 uri, method := m.extractUriVerbFromMuxRoute(route) 51 if route == nil { 52 return fmt.Errorf("failed to set a new middleware. route does not exist at %s (%s)", uri, method) 53 } 54 // for REST-bridge service a key is in the format of {uri}-{verb} 55 key = uri + "-" + method 56 } else { 57 // if the route instance is a prefix route use the route name as-is 58 key = route.GetName() 59 } 60 61 m.mu.Lock() 62 defer m.mu.Unlock() 63 64 // find if base handler exists first. if not, error out 65 original, exists := (*m.endpointHandlerMap)[key] 66 if !exists { 67 return fmt.Errorf("cannot set middleware. handler does not exist at %s", key) 68 } 69 70 // make a backup of the original handler that has no other middleware attached to it 71 if _, exists := m.originalHandlersMap[key]; !exists { 72 m.originalHandlersMap[key] = original 73 } 74 75 // build a new middleware chain and apply it 76 handler := m.buildMiddlewareChain(middleware, original).(http.HandlerFunc) 77 (*m.endpointHandlerMap)[key] = handler 78 route.Handler(handler) 79 80 for _, mw := range middleware { 81 utils.Log.Debugf("middleware '%v' registered for %s", mw, key) 82 } 83 84 utils.Log.Infof("New middleware configured for REST bridge at %s", key) 85 86 return nil 87 } 88 89 func (m *middlewareManager) RemoveMiddleware(route *mux.Route) error { 90 uri, method := m.extractUriVerbFromMuxRoute(route) 91 if route == nil { 92 return fmt.Errorf("failed to remove middleware. route does not exist at %s (%s)", uri, method) 93 } 94 m.mu.Lock() 95 defer m.mu.Unlock() 96 key := uri + "-" + method 97 if _, found := (*m.endpointHandlerMap)[key]; !found { 98 return fmt.Errorf("failed to remove handler. REST bridge handler does not exist at %s (%s)", uri, method) 99 } 100 defer func() { 101 if r := recover(); r != nil { 102 utils.Log.Errorln(r) 103 } 104 }() 105 106 (*m.endpointHandlerMap)[key] = m.originalHandlersMap[key] 107 route.Handler(m.originalHandlersMap[key]) 108 utils.Log.Debugf("All middleware have been stripped from %s (%s)", uri, method) 109 110 return nil 111 } 112 113 func (m *middlewareManager) GetRouteByUriAndMethod(uri, method string) (*mux.Route, error) { 114 m.mu.Lock() 115 defer m.mu.Unlock() 116 route := m.router.Get(fmt.Sprintf("%s-%s", uri, method)) 117 if route == nil { 118 return nil, fmt.Errorf("no route found at %s (%s)", uri, method) 119 } 120 return route, nil 121 } 122 123 func (m *middlewareManager) GetStaticRoute(prefix string) (*mux.Route, error) { 124 m.mu.Lock() 125 defer m.mu.Unlock() 126 routeName := prefix + "*" 127 route := m.router.Get(routeName) 128 if route == nil { 129 return nil, fmt.Errorf("no route found at static prefix %s", routeName) 130 } 131 return route, nil 132 } 133 134 func (m *middlewareManager) GetRouteByUri(uri string) (*mux.Route, error) { 135 m.mu.Lock() 136 defer m.mu.Unlock() 137 route := m.router.Get(uri) 138 if route == nil { 139 return nil, fmt.Errorf("no route found at %s", uri) 140 } 141 return route, nil 142 } 143 144 func (m *middlewareManager) buildMiddlewareChain(handlers []mux.MiddlewareFunc, originalHandler http.Handler) http.Handler { 145 var idx = len(handlers) - 1 146 var finalHandler http.Handler 147 148 for idx >= 0 { 149 var currHandler http.Handler 150 if idx == len(handlers)-1 { 151 currHandler = originalHandler 152 } else { 153 currHandler = finalHandler 154 } 155 middlewareFn := handlers[idx] 156 finalHandler = middlewareFn(currHandler) 157 idx-- 158 } 159 160 return finalHandler 161 } 162 163 // extractUriVerbFromMuxRoute takes *mux.Route and returns URI and verb as string values 164 func (m *middlewareManager) extractUriVerbFromMuxRoute(route *mux.Route) (string, string) { 165 opRawString := route.GetName() 166 delimiterIdx := strings.LastIndex(opRawString, "-") 167 return opRawString[:delimiterIdx], opRawString[delimiterIdx+1:] 168 } 169 170 // NewMiddlewareManager sets up a new middleware manager singleton instance 171 func NewMiddlewareManager(endpointHandlerMapPtr *map[string]http.HandlerFunc, router *mux.Router) MiddlewareManager { 172 return &middlewareManager{ 173 endpointHandlerMap: endpointHandlerMapPtr, 174 originalHandlersMap: make(map[string]http.HandlerFunc), 175 router: router, 176 } 177 }