github.com/muhammadn/cortex@v1.9.1-0.20220510110439-46bb7000d03d/tools/querytee/proxy.go (about)

     1  package querytee
     2  
     3  import (
     4  	"context"
     5  	"flag"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httputil"
    10  	"net/url"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/go-kit/log"
    17  	"github.com/go-kit/log/level"
    18  	"github.com/gorilla/mux"
    19  	"github.com/pkg/errors"
    20  	"github.com/prometheus/client_golang/prometheus"
    21  )
    22  
    23  var (
    24  	errMinBackends = errors.New("at least 1 backend is required")
    25  )
    26  
    27  type ProxyConfig struct {
    28  	ServerServicePort              int
    29  	BackendEndpoints               string
    30  	PreferredBackend               string
    31  	BackendReadTimeout             time.Duration
    32  	CompareResponses               bool
    33  	ValueComparisonTolerance       float64
    34  	PassThroughNonRegisteredRoutes bool
    35  }
    36  
    37  func (cfg *ProxyConfig) RegisterFlags(f *flag.FlagSet) {
    38  	f.IntVar(&cfg.ServerServicePort, "server.service-port", 80, "The port where the query-tee service listens to.")
    39  	f.StringVar(&cfg.BackendEndpoints, "backend.endpoints", "", "Comma separated list of backend endpoints to query.")
    40  	f.StringVar(&cfg.PreferredBackend, "backend.preferred", "", "The hostname of the preferred backend when selecting the response to send back to the client. If no preferred backend is configured then the query-tee will send back to the client the first successful response received without waiting for other backends.")
    41  	f.DurationVar(&cfg.BackendReadTimeout, "backend.read-timeout", 90*time.Second, "The timeout when reading the response from a backend.")
    42  	f.BoolVar(&cfg.CompareResponses, "proxy.compare-responses", false, "Compare responses between preferred and secondary endpoints for supported routes.")
    43  	f.Float64Var(&cfg.ValueComparisonTolerance, "proxy.value-comparison-tolerance", 0.000001, "The tolerance to apply when comparing floating point values in the responses. 0 to disable tolerance and require exact match (not recommended).")
    44  	f.BoolVar(&cfg.PassThroughNonRegisteredRoutes, "proxy.passthrough-non-registered-routes", false, "Passthrough requests for non-registered routes to preferred backend.")
    45  }
    46  
    47  type Route struct {
    48  	Path               string
    49  	RouteName          string
    50  	Methods            []string
    51  	ResponseComparator ResponsesComparator
    52  }
    53  
    54  type Proxy struct {
    55  	cfg      ProxyConfig
    56  	backends []*ProxyBackend
    57  	logger   log.Logger
    58  	metrics  *ProxyMetrics
    59  	routes   []Route
    60  
    61  	// The HTTP server used to run the proxy service.
    62  	srv         *http.Server
    63  	srvListener net.Listener
    64  
    65  	// Wait group used to wait until the server has done.
    66  	done sync.WaitGroup
    67  }
    68  
    69  func NewProxy(cfg ProxyConfig, logger log.Logger, routes []Route, registerer prometheus.Registerer) (*Proxy, error) {
    70  	if cfg.CompareResponses && cfg.PreferredBackend == "" {
    71  		return nil, fmt.Errorf("when enabling comparison of results -backend.preferred flag must be set to hostname of preferred backend")
    72  	}
    73  
    74  	if cfg.PassThroughNonRegisteredRoutes && cfg.PreferredBackend == "" {
    75  		return nil, fmt.Errorf("when enabling passthrough for non-registered routes -backend.preferred flag must be set to hostname of backend where those requests needs to be passed")
    76  	}
    77  
    78  	p := &Proxy{
    79  		cfg:     cfg,
    80  		logger:  logger,
    81  		metrics: NewProxyMetrics(registerer),
    82  		routes:  routes,
    83  	}
    84  
    85  	// Parse the backend endpoints (comma separated).
    86  	parts := strings.Split(cfg.BackendEndpoints, ",")
    87  
    88  	for idx, part := range parts {
    89  		// Skip empty ones.
    90  		part = strings.TrimSpace(part)
    91  		if part == "" {
    92  			continue
    93  		}
    94  
    95  		u, err := url.Parse(part)
    96  		if err != nil {
    97  			return nil, errors.Wrapf(err, "invalid backend endpoint %s", part)
    98  		}
    99  
   100  		// The backend name is hardcoded as the backend hostname.
   101  		name := u.Hostname()
   102  		preferred := name == cfg.PreferredBackend
   103  
   104  		// In tests we have the same hostname for all backends, so we also
   105  		// support a numeric preferred backend which is the index in the list
   106  		// of backends.
   107  		if preferredIdx, err := strconv.Atoi(cfg.PreferredBackend); err == nil {
   108  			preferred = preferredIdx == idx
   109  		}
   110  
   111  		p.backends = append(p.backends, NewProxyBackend(name, u, cfg.BackendReadTimeout, preferred))
   112  	}
   113  
   114  	// At least 1 backend is required
   115  	if len(p.backends) < 1 {
   116  		return nil, errMinBackends
   117  	}
   118  
   119  	// If the preferred backend is configured, then it must exists among the actual backends.
   120  	if cfg.PreferredBackend != "" {
   121  		exists := false
   122  		for _, b := range p.backends {
   123  			if b.preferred {
   124  				exists = true
   125  				break
   126  			}
   127  		}
   128  
   129  		if !exists {
   130  			return nil, fmt.Errorf("the preferred backend (hostname) has not been found among the list of configured backends")
   131  		}
   132  	}
   133  
   134  	if cfg.CompareResponses && len(p.backends) != 2 {
   135  		return nil, fmt.Errorf("when enabling comparison of results number of backends should be 2 exactly")
   136  	}
   137  
   138  	// At least 2 backends are suggested
   139  	if len(p.backends) < 2 {
   140  		level.Warn(p.logger).Log("msg", "The proxy is running with only 1 backend. At least 2 backends are required to fulfil the purpose of the proxy and compare results.")
   141  	}
   142  
   143  	return p, nil
   144  }
   145  
   146  func (p *Proxy) Start() error {
   147  	// Setup listener first, so we can fail early if the port is in use.
   148  	listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.cfg.ServerServicePort))
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	router := mux.NewRouter()
   154  
   155  	// Health check endpoint.
   156  	router.Path("/").Methods("GET").Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   157  		w.WriteHeader(http.StatusOK)
   158  	}))
   159  
   160  	// register routes
   161  	for _, route := range p.routes {
   162  		var comparator ResponsesComparator
   163  		if p.cfg.CompareResponses {
   164  			comparator = route.ResponseComparator
   165  		}
   166  		router.Path(route.Path).Methods(route.Methods...).Handler(NewProxyEndpoint(p.backends, route.RouteName, p.metrics, p.logger, comparator))
   167  	}
   168  
   169  	if p.cfg.PassThroughNonRegisteredRoutes {
   170  		for _, backend := range p.backends {
   171  			if backend.preferred {
   172  				router.PathPrefix("/").Handler(httputil.NewSingleHostReverseProxy(backend.endpoint))
   173  				break
   174  			}
   175  		}
   176  	}
   177  
   178  	p.srvListener = listener
   179  	p.srv = &http.Server{
   180  		ReadTimeout:  1 * time.Minute,
   181  		WriteTimeout: 2 * time.Minute,
   182  		Handler:      router,
   183  	}
   184  
   185  	// Run in a dedicated goroutine.
   186  	p.done.Add(1)
   187  	go func() {
   188  		defer p.done.Done()
   189  
   190  		if err := p.srv.Serve(p.srvListener); err != nil {
   191  			level.Error(p.logger).Log("msg", "Proxy server failed", "err", err)
   192  		}
   193  	}()
   194  
   195  	level.Info(p.logger).Log("msg", "The proxy is up and running.")
   196  	return nil
   197  }
   198  
   199  func (p *Proxy) Stop() error {
   200  	if p.srv == nil {
   201  		return nil
   202  	}
   203  
   204  	return p.srv.Shutdown(context.Background())
   205  }
   206  
   207  func (p *Proxy) Await() {
   208  	// Wait until terminated.
   209  	p.done.Wait()
   210  }
   211  
   212  func (p *Proxy) Endpoint() string {
   213  	if p.srvListener == nil {
   214  		return ""
   215  	}
   216  
   217  	return p.srvListener.Addr().String()
   218  }