github.com/renbou/grpcbridge@v0.0.2-0.20240416012907-bcbd8b12648a/routing/pattern_router.go (about)

     1  package routing
     2  
     3  import (
     4  	"container/list"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"strings"
     9  	"sync"
    10  	"sync/atomic"
    11  
    12  	"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
    13  	"github.com/renbou/grpcbridge/bridgedesc"
    14  	"github.com/renbou/grpcbridge/bridgelog"
    15  	"github.com/renbou/grpcbridge/grpcadapter"
    16  	"github.com/renbou/grpcbridge/internal/httprule"
    17  	"github.com/renbou/grpcbridge/internal/syncset"
    18  	"google.golang.org/grpc/codes"
    19  	"google.golang.org/grpc/status"
    20  )
    21  
    22  // PatternRouterOpts define all the optional settings which can be set for [PatternRouter].
    23  type PatternRouterOpts struct {
    24  	// Logs are discarded by default.
    25  	Logger bridgelog.Logger
    26  }
    27  
    28  func (o PatternRouterOpts) withDefaults() PatternRouterOpts {
    29  	if o.Logger == nil {
    30  		o.Logger = bridgelog.Discard()
    31  	}
    32  
    33  	return o
    34  }
    35  
    36  // PatternRouter is a router meant for routing HTTP requests with non-gRPC URLs/contents.
    37  // It uses pattern-based route matching like the one used in [gRPC-Gateway], but additionally supports dynamic routing updates
    38  // via [PatternRouterWatcher], meant to be used with a description resolver such as the one in the [github.com/renbou/grpcbridge/reflection] package.
    39  //
    40  // Unlike gRPC-Gateway it doesn't support POST->GET fallbacks and X-HTTP-Method-Override,
    41  // since such features can easily become a source of security issues for an unsuspecting developer.
    42  // By the same logic, request paths aren't cleaned, i.e. multiple slashes, ./.. elements aren't removed.
    43  //
    44  // [gRPC-Gateway]: https://github.com/grpc-ecosystem/grpc-gateway
    45  type PatternRouter struct {
    46  	pool       grpcadapter.ClientPool
    47  	logger     bridgelog.Logger
    48  	routes     *mutablePatternRoutingTable
    49  	watcherSet *syncset.SyncSet[string]
    50  }
    51  
    52  // NewPatternRouter initializes a new [PatternRouter] with the specified connection pool and options.
    53  //
    54  // The connection pool will be used to perform a simple retrieval of the connection to a target by its name.
    55  // for more complex connection routing this router's [PatternRouter.RouteHTTP] can be wrapped to return a
    56  // different connection based on the matched method and HTTP request parameters.
    57  func NewPatternRouter(pool grpcadapter.ClientPool, opts PatternRouterOpts) *PatternRouter {
    58  	opts = opts.withDefaults()
    59  
    60  	return &PatternRouter{
    61  		pool:       pool,
    62  		logger:     opts.Logger.WithComponent("grpcbridge.routing"),
    63  		routes:     newMutablePatternRoutingTable(),
    64  		watcherSet: syncset.New[string](),
    65  	}
    66  }
    67  
    68  // RouteHTTP routes the HTTP request based on its URL path and method
    69  // using the target descriptions received via updates through [PatternRouterWatcher.UpdateDesc].
    70  //
    71  // Errors returned by RouteHTTP are gRPC status.Status errors with the code set accordingly.
    72  // Currently, the NotFound, InvalidArgument, and Unavailable codes are returned.
    73  // Additionally, it can return an error implementing interface { HTTPStatus() int } to set a custom status code, but it doesn't currently do so.
    74  //
    75  // Performance-wise it is notable that updates to the routing information don't block RouteHTTP, happening fully in the background.
    76  func (pr *PatternRouter) RouteHTTP(r *http.Request) (grpcadapter.ClientConn, HTTPRoute, error) {
    77  	// Try to follow the same steps as in https://github.com/grpc-ecosystem/grpc-gateway/blob/main/runtime/mux.go#L328 (ServeMux.ServeHTTP).
    78  	// Specifically, use RawPath for pattern matching, since it will be properly decoded by the pattern itself.
    79  	path := r.URL.RawPath
    80  	if path == "" {
    81  		path = r.URL.Path
    82  	}
    83  
    84  	if !strings.HasPrefix(path, "/") {
    85  		return nil, HTTPRoute{}, status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
    86  	}
    87  
    88  	pathComponents := strings.Split(path[1:], "/")
    89  	lastPathComponent := pathComponents[len(pathComponents)-1]
    90  	matchComponents := make([]string, len(pathComponents))
    91  
    92  	var routeErr error
    93  	var matched bool
    94  	var matchedRoute HTTPRoute
    95  
    96  	pr.routes.iterate(r.Method, func(target *bridgedesc.Target, route *patternRoute) bool {
    97  		var verb string
    98  		patternVerb := route.pattern.Verb()
    99  
   100  		verbIdx := -1
   101  		if patternVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patternVerb) {
   102  			verbIdx = len(lastPathComponent) - len(patternVerb) - 1
   103  		}
   104  
   105  		// path segments consisting only of verbs aren't allowed
   106  		if verbIdx == 0 {
   107  			routeErr = status.Error(codes.NotFound, http.StatusText(http.StatusNotFound))
   108  			return false
   109  		}
   110  
   111  		matchComponents = matchComponents[:len(pathComponents)]
   112  		copy(matchComponents, pathComponents)
   113  
   114  		if verbIdx > 0 {
   115  			matchComponents[len(matchComponents)-1], verb = lastPathComponent[:verbIdx], lastPathComponent[verbIdx+1:]
   116  		}
   117  
   118  		// Perform unescaping as specified for gRPC transcoding in https://github.com/googleapis/googleapis/blob/e0677a395947c2f3f3411d7202a6868a7b069a41/google/api/http.proto#L295.
   119  		params, err := route.pattern.MatchAndEscape(matchComponents, verb, runtime.UnescapingModeAllExceptReserved)
   120  		if err != nil {
   121  			var mse runtime.MalformedSequenceError
   122  			if ok := errors.As(err, &mse); ok {
   123  				routeErr = status.Error(codes.InvalidArgument, err.Error())
   124  				return false
   125  			}
   126  
   127  			// Ignore runtime.ErrNotMatch
   128  			return true
   129  		}
   130  
   131  		// Avoid returning empty maps when not needed.
   132  		if len(params) == 0 {
   133  			params = nil
   134  		}
   135  
   136  		// Found match
   137  		matched = true
   138  		matchedRoute = HTTPRoute{
   139  			Target:     target,
   140  			Service:    route.service,
   141  			Method:     route.method,
   142  			Binding:    route.binding,
   143  			PathParams: params,
   144  		}
   145  		return false
   146  	})
   147  
   148  	if routeErr != nil {
   149  		return nil, HTTPRoute{}, routeErr
   150  	} else if !matched {
   151  		return nil, HTTPRoute{}, status.Error(codes.NotFound, http.StatusText(http.StatusNotFound))
   152  	}
   153  
   154  	conn, ok := pr.pool.Get(matchedRoute.Target.Name)
   155  	if !ok {
   156  		return nil, HTTPRoute{}, status.Errorf(codes.Unavailable, "no connection available to target %q", matchedRoute.Target.Name)
   157  	}
   158  
   159  	return conn, matchedRoute, nil
   160  }
   161  
   162  // Watch starts watching the specified target for description changes.
   163  // It returns a [*PatternRouterWatcher] through which new updates for this target can be applied.
   164  //
   165  // It is an error to try Watch()ing the same target multiple times on a single PatternRouter instance,
   166  // the previous [PatternRouterWatcher] must be explicitly closed before launching a new one.
   167  // Instead of trying to synchronize such procedures, however, it's better to have a properly defined lifecycle
   168  // for each possible target, with clear logic about when it gets added or removed to/from all the components of a bridge.
   169  func (pr *PatternRouter) Watch(target string) (*PatternRouterWatcher, error) {
   170  	if pr.watcherSet.Add(target) {
   171  		return &PatternRouterWatcher{pr: pr, target: target, logger: pr.logger.With("target", target)}, nil
   172  	}
   173  
   174  	return nil, ErrAlreadyWatching
   175  }
   176  
   177  // PatternRouterWatcher is a description update watcher created for a specific target in the context of a [PatternRouter] instance.
   178  // New PatternRouterWatchers are created through [PatternRouter.Watch].
   179  type PatternRouterWatcher struct {
   180  	pr     *PatternRouter
   181  	logger bridgelog.Logger
   182  	target string
   183  	closed atomic.Bool
   184  }
   185  
   186  // UpdateDesc updates the description of the target this watcher is watching.
   187  // After the watcher is Close()d, UpdateDesc becomes a no-op, to avoid writing meaningless updates to the router.
   188  // Note that desc.Name must match the target this watcher was created for, otherwise the update will be ignored.
   189  //
   190  // Updates to the routing information are made without any locking,
   191  // instead replacing the currently present info with the updated one using an atomic pointer.
   192  //
   193  // UpdateDesc returns only when the routing state has been completely updated on the router,
   194  // which should be used to synchronize the target description update polling/watching logic.
   195  func (prw *PatternRouterWatcher) UpdateDesc(desc *bridgedesc.Target) {
   196  	if prw.closed.Load() {
   197  		return
   198  	}
   199  
   200  	if desc.Name != prw.target {
   201  		// use PatternRouter logger without the "target" field
   202  		prw.pr.logger.Error("PatternRouterWatcher got update for different target, will ignore", "watcher_target", prw.target, "update_target", desc.Name)
   203  		return
   204  	}
   205  
   206  	routes := buildPatternRoutes(desc, prw.logger)
   207  	prw.pr.routes.addTarget(desc, routes)
   208  }
   209  
   210  // ReportError is currently a no-op, present simply to implement the Watcher interface
   211  // of the grpcbridge description resolvers, such as the one in [github.com/renbou/grpcbridge/reflection].
   212  func (prw *PatternRouterWatcher) ReportError(error) {}
   213  
   214  // Close closes the watcher, preventing further updates from being applied to the router through it.
   215  // It is an error to call Close() multiple times on the same watcher, and doing so will result in a panic.
   216  func (prw *PatternRouterWatcher) Close() {
   217  	if !prw.closed.CompareAndSwap(false, true) {
   218  		panic("grpcbridge: PatternRouterWatcher.Close() called multiple times")
   219  	}
   220  
   221  	// Fully remove the target's routes, only then mark the watcher as closed.
   222  	prw.pr.routes.removeTarget(prw.target)
   223  	prw.pr.watcherSet.Remove(prw.target)
   224  }
   225  
   226  func buildPatternRoutes(desc *bridgedesc.Target, logger bridgelog.Logger) map[string][]patternRoute {
   227  	builder := newPatternRouteBuilder()
   228  
   229  	for svcIdx := range desc.Services {
   230  		svc := &desc.Services[svcIdx]
   231  		methods := desc.Services[svcIdx].Methods // avoid copying the whole desc structures in loop
   232  		for methodIdx := range methods {
   233  			method := &methods[methodIdx]
   234  
   235  			if len(method.Bindings) < 1 {
   236  				if routeErr := builder.addDefault(svc, method); routeErr != nil {
   237  					logger.Error("failed to add default HTTP binding for gRPC method with no defined bindings",
   238  						"service", svc.Name, "method", method.RPCName,
   239  						"error", routeErr,
   240  					)
   241  				} else {
   242  					logger.Debug("added default HTTP binding for gRPC method", "service", svc.Name, "method", method.RPCName)
   243  				}
   244  				continue
   245  			}
   246  
   247  			for bindingIdx := range method.Bindings {
   248  				binding := &method.Bindings[bindingIdx]
   249  				if routeErr := builder.addBinding(svc, method, binding); routeErr != nil {
   250  					logger.Error("failed to add HTTP binding for gRPC method",
   251  						"service", svc.Name, "method", method.RPCName,
   252  						"binding.method", binding.HTTPMethod, "binding.pattern", binding.Pattern,
   253  						"error", routeErr,
   254  					)
   255  				} else {
   256  					logger.Debug("added HTTP binding for gRPC method",
   257  						"service", svc.Name, "method", method.RPCName,
   258  						"binding.method", binding.HTTPMethod, "binding.pattern", binding.Pattern,
   259  					)
   260  				}
   261  			}
   262  		}
   263  	}
   264  
   265  	return builder.routes
   266  }
   267  
   268  type patternRoute struct {
   269  	service *bridgedesc.Service
   270  	method  *bridgedesc.Method
   271  	binding *bridgedesc.Binding
   272  	pattern runtime.Pattern
   273  }
   274  
   275  func buildPattern(route string) (runtime.Pattern, error) {
   276  	compiler, err := httprule.Parse(route)
   277  	if err != nil {
   278  		return runtime.Pattern{}, fmt.Errorf("parsing route: %w", err)
   279  	}
   280  
   281  	tp := compiler.Compile()
   282  
   283  	pattern, routeErr := runtime.NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
   284  	if routeErr != nil {
   285  		return runtime.Pattern{}, fmt.Errorf("creating route pattern matcher: %w", routeErr)
   286  	}
   287  
   288  	return pattern, nil
   289  }
   290  
   291  // patternRouteBuilder is a helper structure used for building the routing table for a single target.
   292  type patternRouteBuilder struct {
   293  	routes map[string][]patternRoute // http method -> routes
   294  }
   295  
   296  func newPatternRouteBuilder() patternRouteBuilder {
   297  	return patternRouteBuilder{
   298  		routes: make(map[string][]patternRoute),
   299  	}
   300  }
   301  
   302  func (rb *patternRouteBuilder) addBinding(s *bridgedesc.Service, m *bridgedesc.Method, b *bridgedesc.Binding) error {
   303  	pr, err := buildPattern(b.Pattern)
   304  	if err != nil {
   305  		return fmt.Errorf("building pattern for %s: %w", b.HTTPMethod, err)
   306  	}
   307  
   308  	rb.routes[b.HTTPMethod] = append(rb.routes[b.HTTPMethod], patternRoute{
   309  		service: s,
   310  		method:  m,
   311  		binding: b,
   312  		pattern: pr,
   313  	})
   314  	return nil
   315  }
   316  
   317  func (rb *patternRouteBuilder) addDefault(s *bridgedesc.Service, m *bridgedesc.Method) error {
   318  	// Default gRPC form, as specified in https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests.
   319  	return rb.addBinding(s, m, bridgedesc.DefaultBinding(m))
   320  }
   321  
   322  // targetPatternRoutes define the routes for a single target+method combination.
   323  type targetPatternRoutes struct {
   324  	target *bridgedesc.Target
   325  	routes []patternRoute
   326  }
   327  
   328  type methodPatternRoutes struct {
   329  	method string
   330  	link   *list.Element // Value of type targetPatternRoutes
   331  }
   332  
   333  // mutablePatternRoutingTable is a pattern-based routing table which can be modified with all operations protected by a mutex.
   334  type mutablePatternRoutingTable struct {
   335  	// protects all the routing state, including the pointer to the static table,
   336  	// which must be updated while the mutex is held to avoid overwriting new state with old state.
   337  	mu          sync.Mutex
   338  	static      atomic.Pointer[staticPatternRoutingTable]
   339  	routes      map[string]*list.List            // http method -> linked list of targetPatternRoutes
   340  	targetLinks map[string][]methodPatternRoutes // target -> list of elements to be modified
   341  }
   342  
   343  func newMutablePatternRoutingTable() *mutablePatternRoutingTable {
   344  	mt := &mutablePatternRoutingTable{
   345  		routes:      make(map[string]*list.List),
   346  		targetLinks: make(map[string][]methodPatternRoutes),
   347  	}
   348  
   349  	mt.static.Store(&staticPatternRoutingTable{})
   350  
   351  	return mt
   352  }
   353  
   354  // addTargets adds or updates the routes of a target and updates the static pointer.
   355  func (mt *mutablePatternRoutingTable) addTarget(target *bridgedesc.Target, routes map[string][]patternRoute) {
   356  	mt.mu.Lock()
   357  	defer mt.mu.Unlock()
   358  
   359  	newMethodLinks := make([]methodPatternRoutes, 0, len(routes))
   360  
   361  	// Update existing elements without recreating them.
   362  	for _, link := range mt.targetLinks[target.Name] {
   363  		patternRoutes, ok := routes[link.method]
   364  		if !ok {
   365  			// delete existing link, no more routes for this method
   366  			mt.removeRoute(link.method, link.link)
   367  			continue
   368  		}
   369  
   370  		// set new pattern routes via the link & mark link as in-use
   371  		link.link.Value = targetPatternRoutes{target: target, routes: patternRoutes}
   372  		newMethodLinks = append(newMethodLinks, link)
   373  
   374  		// mark method as handled, we don't need to re-add it
   375  		delete(routes, link.method)
   376  	}
   377  
   378  	// Add routes for new methods.
   379  	for method, patternRoutes := range routes {
   380  		link := mt.addRoute(method, targetPatternRoutes{target: target, routes: patternRoutes})
   381  		newMethodLinks = append(newMethodLinks, methodPatternRoutes{method: method, link: link})
   382  	}
   383  
   384  	mt.targetLinks[target.Name] = newMethodLinks
   385  
   386  	mt.static.Store(mt.commit())
   387  }
   388  
   389  // removeTarget removes all routes of a target and updates the static pointer.
   390  func (mt *mutablePatternRoutingTable) removeTarget(target string) {
   391  	mt.mu.Lock()
   392  	defer mt.mu.Unlock()
   393  
   394  	for _, link := range mt.targetLinks[target] {
   395  		mt.removeRoute(link.method, link.link)
   396  	}
   397  
   398  	delete(mt.targetLinks, target)
   399  
   400  	mt.static.Store(mt.commit())
   401  }
   402  
   403  func (mt *mutablePatternRoutingTable) removeRoute(method string, link *list.Element) {
   404  	lst := mt.routes[method]
   405  	lst.Remove(link)
   406  	if lst.Len() == 0 {
   407  		delete(mt.routes, method)
   408  	}
   409  }
   410  
   411  func (mt *mutablePatternRoutingTable) addRoute(method string, route targetPatternRoutes) *list.Element {
   412  	lst, ok := mt.routes[method]
   413  	if !ok {
   414  		lst = list.New()
   415  		mt.routes[method] = lst
   416  	}
   417  
   418  	return lst.PushBack(route)
   419  }
   420  
   421  func (mt *mutablePatternRoutingTable) commit() *staticPatternRoutingTable {
   422  	routes := make(map[string]*list.List, len(mt.routes))
   423  	for method, list := range mt.routes {
   424  		routes[method] = cloneLinkedList(list)
   425  	}
   426  
   427  	return &staticPatternRoutingTable{routes: routes}
   428  }
   429  
   430  func (mt *mutablePatternRoutingTable) iterate(method string, fn func(target *bridgedesc.Target, route *patternRoute) bool) {
   431  	mt.static.Load().iterate(method, fn)
   432  }
   433  
   434  // staticPatternRoutingTable is a pattern-based routing table which can only be read.
   435  // it is created by mutablePatternRoutingTable when modifications occur.
   436  type staticPatternRoutingTable struct {
   437  	routes map[string]*list.List // http method -> linked list of targetPatternRoutes
   438  }
   439  
   440  func (st *staticPatternRoutingTable) iterate(method string, fn func(target *bridgedesc.Target, route *patternRoute) bool) {
   441  	list, ok := st.routes[method]
   442  	if !ok {
   443  		return
   444  	}
   445  
   446  	for e := list.Front(); e != nil; e = e.Next() {
   447  		pr := e.Value.(targetPatternRoutes)
   448  		for i := range pr.routes {
   449  			if !fn(pr.target, &pr.routes[i]) {
   450  				return
   451  			}
   452  		}
   453  	}
   454  }
   455  
   456  func cloneLinkedList(l *list.List) *list.List {
   457  	cp := list.New()
   458  	for e := l.Front(); e != nil; e = e.Next() {
   459  		cp.PushBack(e.Value)
   460  	}
   461  	return cp
   462  }