github.com/vmware/transport-go@v1.3.4/plank/pkg/middleware/middleware_manager.go (about)

     1  // Copyright 2019-2021 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package middleware
     5  
     6  import (
     7  	"fmt"
     8  	"github.com/gorilla/mux"
     9  	"github.com/vmware/transport-go/plank/utils"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  )
    14  
    15  type MiddlewareManager interface {
    16  	SetGlobalMiddleware(middleware []mux.MiddlewareFunc) error
    17  	SetNewMiddleware(route *mux.Route, middleware []mux.MiddlewareFunc) error
    18  	RemoveMiddleware(route *mux.Route) error
    19  	GetRouteByUriAndMethod(uri, method string) (*mux.Route, error)
    20  	GetRouteByUri(uri string) (*mux.Route, error)
    21  	GetStaticRoute(prefix string) (*mux.Route, error)
    22  }
    23  
    24  type Middleware interface {
    25  	//Intercept(h http.Handler) http.Handler
    26  	Interceptor() mux.MiddlewareFunc
    27  	Name() string
    28  }
    29  
    30  type middlewareManager struct {
    31  	endpointHandlerMap  *map[string]http.HandlerFunc
    32  	originalHandlersMap map[string]http.HandlerFunc
    33  	router              *mux.Router
    34  	mu                  sync.Mutex
    35  }
    36  
    37  func (m *middlewareManager) SetGlobalMiddleware(middleware []mux.MiddlewareFunc) error {
    38  	m.mu.Lock()
    39  	defer m.mu.Unlock()
    40  	m.router.Use(middleware...)
    41  	return nil
    42  }
    43  
    44  func (m *middlewareManager) SetNewMiddleware(route *mux.Route, middleware []mux.MiddlewareFunc) error {
    45  	var key string
    46  	// expection is that a route's name ending with '*' means it's a prefix route
    47  	isPrefixRoute := route.GetName()[len(route.GetName())-1] == '*'
    48  
    49  	if !isPrefixRoute {
    50  		uri, method := m.extractUriVerbFromMuxRoute(route)
    51  		if route == nil {
    52  			return fmt.Errorf("failed to set a new middleware. route does not exist at %s (%s)", uri, method)
    53  		}
    54  		// for REST-bridge service a key is in the format of {uri}-{verb}
    55  		key = uri + "-" + method
    56  	} else {
    57  		// if the route instance is a prefix route use the route name as-is
    58  		key = route.GetName()
    59  	}
    60  
    61  	m.mu.Lock()
    62  	defer m.mu.Unlock()
    63  
    64  	// find if base handler exists first. if not, error out
    65  	original, exists := (*m.endpointHandlerMap)[key]
    66  	if !exists {
    67  		return fmt.Errorf("cannot set middleware. handler does not exist at %s", key)
    68  	}
    69  
    70  	// make a backup of the original handler that has no other middleware attached to it
    71  	if _, exists := m.originalHandlersMap[key]; !exists {
    72  		m.originalHandlersMap[key] = original
    73  	}
    74  
    75  	// build a new middleware chain and apply it
    76  	handler := m.buildMiddlewareChain(middleware, original).(http.HandlerFunc)
    77  	(*m.endpointHandlerMap)[key] = handler
    78  	route.Handler(handler)
    79  
    80  	for _, mw := range middleware {
    81  		utils.Log.Debugf("middleware '%v' registered for %s", mw, key)
    82  	}
    83  
    84  	utils.Log.Infof("New middleware configured for REST bridge at %s", key)
    85  
    86  	return nil
    87  }
    88  
    89  func (m *middlewareManager) RemoveMiddleware(route *mux.Route) error {
    90  	uri, method := m.extractUriVerbFromMuxRoute(route)
    91  	if route == nil {
    92  		return fmt.Errorf("failed to remove middleware. route does not exist at %s (%s)", uri, method)
    93  	}
    94  	m.mu.Lock()
    95  	defer m.mu.Unlock()
    96  	key := uri + "-" + method
    97  	if _, found := (*m.endpointHandlerMap)[key]; !found {
    98  		return fmt.Errorf("failed to remove handler. REST bridge handler does not exist at %s (%s)", uri, method)
    99  	}
   100  	defer func() {
   101  		if r := recover(); r != nil {
   102  			utils.Log.Errorln(r)
   103  		}
   104  	}()
   105  
   106  	(*m.endpointHandlerMap)[key] = m.originalHandlersMap[key]
   107  	route.Handler(m.originalHandlersMap[key])
   108  	utils.Log.Debugf("All middleware have been stripped from %s (%s)", uri, method)
   109  
   110  	return nil
   111  }
   112  
   113  func (m *middlewareManager) GetRouteByUriAndMethod(uri, method string) (*mux.Route, error) {
   114  	m.mu.Lock()
   115  	defer m.mu.Unlock()
   116  	route := m.router.Get(fmt.Sprintf("%s-%s", uri, method))
   117  	if route == nil {
   118  		return nil, fmt.Errorf("no route found at %s (%s)", uri, method)
   119  	}
   120  	return route, nil
   121  }
   122  
   123  func (m *middlewareManager) GetStaticRoute(prefix string) (*mux.Route, error) {
   124  	m.mu.Lock()
   125  	defer m.mu.Unlock()
   126  	routeName := prefix + "*"
   127  	route := m.router.Get(routeName)
   128  	if route == nil {
   129  		return nil, fmt.Errorf("no route found at static prefix %s", routeName)
   130  	}
   131  	return route, nil
   132  }
   133  
   134  func (m *middlewareManager) GetRouteByUri(uri string) (*mux.Route, error) {
   135  	m.mu.Lock()
   136  	defer m.mu.Unlock()
   137  	route := m.router.Get(uri)
   138  	if route == nil {
   139  		return nil, fmt.Errorf("no route found at %s", uri)
   140  	}
   141  	return route, nil
   142  }
   143  
   144  func (m *middlewareManager) buildMiddlewareChain(handlers []mux.MiddlewareFunc, originalHandler http.Handler) http.Handler {
   145  	var idx = len(handlers) - 1
   146  	var finalHandler http.Handler
   147  
   148  	for idx >= 0 {
   149  		var currHandler http.Handler
   150  		if idx == len(handlers)-1 {
   151  			currHandler = originalHandler
   152  		} else {
   153  			currHandler = finalHandler
   154  		}
   155  		middlewareFn := handlers[idx]
   156  		finalHandler = middlewareFn(currHandler)
   157  		idx--
   158  	}
   159  
   160  	return finalHandler
   161  }
   162  
   163  // extractUriVerbFromMuxRoute takes *mux.Route and returns URI and verb as string values
   164  func (m *middlewareManager) extractUriVerbFromMuxRoute(route *mux.Route) (string, string) {
   165  	opRawString := route.GetName()
   166  	delimiterIdx := strings.LastIndex(opRawString, "-")
   167  	return opRawString[:delimiterIdx], opRawString[delimiterIdx+1:]
   168  }
   169  
   170  // NewMiddlewareManager sets up a new middleware manager singleton instance
   171  func NewMiddlewareManager(endpointHandlerMapPtr *map[string]http.HandlerFunc, router *mux.Router) MiddlewareManager {
   172  	return &middlewareManager{
   173  		endpointHandlerMap:  endpointHandlerMapPtr,
   174  		originalHandlersMap: make(map[string]http.HandlerFunc),
   175  		router:              router,
   176  	}
   177  }