github.com/MetalBlockchain/metalgo@v1.11.9/api/server/router.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package server
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"net/http"
    10  	"sync"
    11  
    12  	"github.com/gorilla/mux"
    13  
    14  	"github.com/MetalBlockchain/metalgo/utils/set"
    15  )
    16  
    17  var (
    18  	errUnknownBaseURL  = errors.New("unknown base url")
    19  	errUnknownEndpoint = errors.New("unknown endpoint")
    20  	errAlreadyReserved = errors.New("route is either already aliased or already maps to a handle")
    21  )
    22  
    23  type router struct {
    24  	lock   sync.RWMutex
    25  	router *mux.Router
    26  
    27  	routeLock      sync.Mutex
    28  	reservedRoutes set.Set[string]                    // Reserves routes so that there can't be alias that conflict
    29  	aliases        map[string][]string                // Maps a route to a set of reserved routes
    30  	routes         map[string]map[string]http.Handler // Maps routes to a handler
    31  }
    32  
    33  func newRouter() *router {
    34  	return &router{
    35  		router:         mux.NewRouter(),
    36  		reservedRoutes: set.Set[string]{},
    37  		aliases:        make(map[string][]string),
    38  		routes:         make(map[string]map[string]http.Handler),
    39  	}
    40  }
    41  
    42  func (r *router) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    43  	r.lock.RLock()
    44  	defer r.lock.RUnlock()
    45  
    46  	r.router.ServeHTTP(writer, request)
    47  }
    48  
    49  func (r *router) GetHandler(base, endpoint string) (http.Handler, error) {
    50  	r.routeLock.Lock()
    51  	defer r.routeLock.Unlock()
    52  
    53  	urlBase, exists := r.routes[base]
    54  	if !exists {
    55  		return nil, errUnknownBaseURL
    56  	}
    57  	handler, exists := urlBase[endpoint]
    58  	if !exists {
    59  		return nil, errUnknownEndpoint
    60  	}
    61  	return handler, nil
    62  }
    63  
    64  func (r *router) AddRouter(base, endpoint string, handler http.Handler) error {
    65  	r.lock.Lock()
    66  	defer r.lock.Unlock()
    67  	r.routeLock.Lock()
    68  	defer r.routeLock.Unlock()
    69  
    70  	return r.addRouter(base, endpoint, handler)
    71  }
    72  
    73  func (r *router) addRouter(base, endpoint string, handler http.Handler) error {
    74  	if r.reservedRoutes.Contains(base) {
    75  		return fmt.Errorf("%w: %s", errAlreadyReserved, base)
    76  	}
    77  
    78  	return r.forceAddRouter(base, endpoint, handler)
    79  }
    80  
    81  func (r *router) forceAddRouter(base, endpoint string, handler http.Handler) error {
    82  	endpoints := r.routes[base]
    83  	if endpoints == nil {
    84  		endpoints = make(map[string]http.Handler)
    85  	}
    86  	url := base + endpoint
    87  	if _, exists := endpoints[endpoint]; exists {
    88  		return fmt.Errorf("failed to create endpoint as %s already exists", url)
    89  	}
    90  
    91  	endpoints[endpoint] = handler
    92  	r.routes[base] = endpoints
    93  
    94  	// Name routes based on their URL for easy retrieval in the future
    95  	route := r.router.Handle(url, handler)
    96  	if route == nil {
    97  		return fmt.Errorf("failed to create new route for %s", url)
    98  	}
    99  	route.Name(url)
   100  
   101  	var err error
   102  	if aliases, exists := r.aliases[base]; exists {
   103  		for _, alias := range aliases {
   104  			if innerErr := r.forceAddRouter(alias, endpoint, handler); err == nil {
   105  				err = innerErr
   106  			}
   107  		}
   108  	}
   109  	return err
   110  }
   111  
   112  func (r *router) AddAlias(base string, aliases ...string) error {
   113  	r.lock.Lock()
   114  	defer r.lock.Unlock()
   115  	r.routeLock.Lock()
   116  	defer r.routeLock.Unlock()
   117  
   118  	for _, alias := range aliases {
   119  		if r.reservedRoutes.Contains(alias) {
   120  			return fmt.Errorf("%w: %s", errAlreadyReserved, alias)
   121  		}
   122  	}
   123  
   124  	for _, alias := range aliases {
   125  		r.reservedRoutes.Add(alias)
   126  	}
   127  
   128  	r.aliases[base] = append(r.aliases[base], aliases...)
   129  
   130  	var err error
   131  	if endpoints, exists := r.routes[base]; exists {
   132  		for endpoint, handler := range endpoints {
   133  			for _, alias := range aliases {
   134  				if innerErr := r.forceAddRouter(alias, endpoint, handler); err == nil {
   135  					err = innerErr
   136  				}
   137  			}
   138  		}
   139  	}
   140  	return err
   141  }