
     1  // Copyright (c) The Cortex Authors.
     2  // Licensed under the Apache License 2.0.
     4  // Copyright 2016 The Prometheus Authors
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  //
    17  // Mostly lifted from prometheus/web/api/v1/api.go.
    19  package queryrange
    21  import (
    22  	"context"
    23  	"io"
    24  	"io/ioutil"
    25  	"net/http"
    26  	"strings"
    27  	"time"
    29  	""
    30  	""
    31  	""
    32  	""
    33  	""
    34  	""
    35  	""
    36  	""
    37  	""
    39  	""
    40  	""
    41  	""
    42  	""
    43  	""
    44  )
    46  const day = 24 * time.Hour
    48  var (
    49  	// PassthroughMiddleware is a noop middleware
    50  	PassthroughMiddleware = MiddlewareFunc(func(next Handler) Handler {
    51  		return next
    52  	})
    54  	errInvalidMinShardingLookback = errors.New("a non-zero value is required for querier.query-ingesters-within when -querier.parallelise-shardable-queries is enabled")
    55  )
    57  // Config for query_range middleware chain.
    58  type Config struct {
    59  	SplitQueriesByInterval time.Duration `yaml:"split_queries_by_interval"`
    60  	AlignQueriesWithStep   bool          `yaml:"align_queries_with_step"`
    61  	ResultsCacheConfig     `yaml:"results_cache"`
    62  	CacheResults           bool `yaml:"cache_results"`
    63  	MaxRetries             int  `yaml:"max_retries"`
    64  	ShardedQueries         bool `yaml:"parallelise_shardable_queries"`
    65  	// List of headers which query_range middleware chain would forward to downstream querier.
    66  	ForwardHeaders flagext.StringSlice `yaml:"forward_headers_list"`
    67  }
    69  // Validate validates the config.
    70  func (cfg *Config) Validate(qCfg querier.Config) error {
    71  	if cfg.CacheResults {
    72  		if cfg.SplitQueriesByInterval <= 0 {
    73  			return errors.New("querier.cache-results may only be enabled in conjunction with querier.split-queries-by-interval. Please set the latter")
    74  		}
    75  		if err := cfg.ResultsCacheConfig.Validate(qCfg); err != nil {
    76  			return errors.Wrap(err, "invalid ResultsCache config")
    77  		}
    78  	}
    79  	return nil
    80  }
    82  // HandlerFunc is like http.HandlerFunc, but for Handler.
    83  type HandlerFunc func(context.Context, Request) (Response, error)
    85  // Do implements Handler.
    86  func (q HandlerFunc) Do(ctx context.Context, req Request) (Response, error) {
    87  	return q(ctx, req)
    88  }
    90  // Handler is like http.Handle, but specifically for Prometheus query_range calls.
    91  type Handler interface {
    92  	Do(context.Context, Request) (Response, error)
    93  }
    95  // MiddlewareFunc is like http.HandlerFunc, but for Middleware.
    96  type MiddlewareFunc func(Handler) Handler
    98  // Wrap implements Middleware.
    99  func (q MiddlewareFunc) Wrap(h Handler) Handler {
   100  	return q(h)
   101  }
   103  // Middleware is a higher order Handler.
   104  type Middleware interface {
   105  	Wrap(Handler) Handler
   106  }
   108  // MergeMiddlewares produces a middleware that applies multiple middleware in turn;
   109  // ie Merge(f,g,h).Wrap(handler) == f.Wrap(g.Wrap(h.Wrap(handler)))
   110  func MergeMiddlewares(middleware ...Middleware) Middleware {
   111  	return MiddlewareFunc(func(next Handler) Handler {
   112  		for i := len(middleware) - 1; i >= 0; i-- {
   113  			next = middleware[i].Wrap(next)
   114  		}
   115  		return next
   116  	})
   117  }
   119  // Tripperware is a signature for all http client-side middleware.
   120  type Tripperware func(http.RoundTripper) http.RoundTripper
   122  // RoundTripFunc is to http.RoundTripper what http.HandlerFunc is to http.Handler.
   123  type RoundTripFunc func(*http.Request) (*http.Response, error)
   125  // RoundTrip implements http.RoundTripper.
   126  func (f RoundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
   127  	return f(r)
   128  }
   130  // NewTripperware returns a Tripperware configured with middlewares to limit, align, split, retry and cache requests.
   131  func NewTripperware(
   132  	cfg Config,
   133  	log log.Logger,
   134  	limits Limits,
   135  	codec Codec,
   136  	cacheExtractor Extractor,
   137  	engineOpts promql.EngineOpts,
   138  	minShardingLookback time.Duration,
   139  	registerer prometheus.Registerer,
   140  	cacheGenNumberLoader CacheGenNumberLoader,
   141  ) (Tripperware, cache.Cache, error) {
   142  	// Per tenant query metrics.
   143  	queriesPerTenant := promauto.With(registerer).NewCounterVec(prometheus.CounterOpts{
   144  		Name: "cortex_query_frontend_queries_total",
   145  		Help: "Total queries sent per tenant.",
   146  	}, []string{"op", "user"})
   148  	activeUsers := util.NewActiveUsersCleanupWithDefaultValues(func(user string) {
   149  		err := util.DeleteMatchingLabels(queriesPerTenant, map[string]string{"user": user})
   150  		if err != nil {
   151  			level.Warn(log).Log("msg", "failed to remove cortex_query_frontend_queries_total metric for user", "user", user)
   152  		}
   153  	})
   155  	// Metric used to keep track of each middleware execution duration.
   156  	metrics := NewInstrumentMiddlewareMetrics(registerer)
   158  	queryRangeMiddleware := []Middleware{NewLimitsMiddleware(limits)}
   159  	if cfg.AlignQueriesWithStep {
   160  		queryRangeMiddleware = append(queryRangeMiddleware, InstrumentMiddleware("step_align", metrics), StepAlignMiddleware)
   161  	}
   162  	if cfg.SplitQueriesByInterval != 0 {
   163  		staticIntervalFn := func(_ Request) time.Duration { return cfg.SplitQueriesByInterval }
   164  		queryRangeMiddleware = append(queryRangeMiddleware, InstrumentMiddleware("split_by_interval", metrics), SplitByIntervalMiddleware(staticIntervalFn, limits, codec, registerer))
   165  	}
   167  	var c cache.Cache
   168  	if cfg.CacheResults {
   169  		shouldCache := func(r Request) bool {
   170  			return !r.GetCachingOptions().Disabled
   171  		}
   172  		queryCacheMiddleware, cache, err := NewResultsCacheMiddleware(log, cfg.ResultsCacheConfig, constSplitter(cfg.SplitQueriesByInterval), limits, codec, cacheExtractor, cacheGenNumberLoader, shouldCache, registerer)
   173  		if err != nil {
   174  			return nil, nil, err
   175  		}
   176  		c = cache
   177  		queryRangeMiddleware = append(queryRangeMiddleware, InstrumentMiddleware("results_cache", metrics), queryCacheMiddleware)
   178  	}
   180  	if cfg.MaxRetries > 0 {
   181  		queryRangeMiddleware = append(queryRangeMiddleware, InstrumentMiddleware("retry", metrics), NewRetryMiddleware(log, cfg.MaxRetries, NewRetryMiddlewareMetrics(registerer)))
   182  	}
   184  	// Start cleanup. If cleaner stops or fail, we will simply not clean the metrics for inactive users.
   185  	_ = activeUsers.StartAsync(context.Background())
   186  	return func(next http.RoundTripper) http.RoundTripper {
   187  		// Finally, if the user selected any query range middleware, stitch it in.
   188  		if len(queryRangeMiddleware) > 0 {
   189  			queryrange := NewRoundTripper(next, codec, cfg.ForwardHeaders, queryRangeMiddleware...)
   190  			return RoundTripFunc(func(r *http.Request) (*http.Response, error) {
   191  				isQueryRange := strings.HasSuffix(r.URL.Path, "/query_range")
   192  				op := "query"
   193  				if isQueryRange {
   194  					op = "query_range"
   195  				}
   197  				tenantIDs, err := tenant.TenantIDs(r.Context())
   198  				// This should never happen anyways because we have auth middleware before this.
   199  				if err != nil {
   200  					return nil, err
   201  				}
   202  				userStr := tenant.JoinTenantIDs(tenantIDs)
   203  				activeUsers.UpdateUserTimestamp(userStr, time.Now())
   204  				queriesPerTenant.WithLabelValues(op, userStr).Inc()
   206  				if !isQueryRange {
   207  					return next.RoundTrip(r)
   208  				}
   209  				return queryrange.RoundTrip(r)
   210  			})
   211  		}
   212  		return next
   213  	}, c, nil
   214  }
   216  type roundTripper struct {
   217  	next    http.RoundTripper
   218  	handler Handler
   219  	codec   Codec
   220  	headers []string
   221  }
   223  // NewRoundTripper merges a set of middlewares into an handler, then inject it into the `next` roundtripper
   224  // using the codec to translate requests and responses.
   225  func NewRoundTripper(next http.RoundTripper, codec Codec, headers []string, middlewares ...Middleware) http.RoundTripper {
   226  	transport := roundTripper{
   227  		next:    next,
   228  		codec:   codec,
   229  		headers: headers,
   230  	}
   231  	transport.handler = MergeMiddlewares(middlewares...).Wrap(&transport)
   232  	return transport
   233  }
   235  func (q roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
   237  	// include the headers specified in the roundTripper during decoding the request.
   238  	request, err := q.codec.DecodeRequest(r.Context(), r, q.headers)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   243  	if span := opentracing.SpanFromContext(r.Context()); span != nil {
   244  		request.LogToSpan(span)
   245  	}
   247  	response, err := q.handler.Do(r.Context(), request)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   252  	return q.codec.EncodeResponse(r.Context(), response)
   253  }
   255  // Do implements Handler.
   256  func (q roundTripper) Do(ctx context.Context, r Request) (Response, error) {
   257  	request, err := q.codec.EncodeRequest(ctx, r)
   258  	if err != nil {
   259  		return nil, err
   260  	}
   262  	if err := user.InjectOrgIDIntoHTTPRequest(ctx, request); err != nil {
   263  		return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
   264  	}
   266  	response, err :=
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  	defer func() {
   271  		io.Copy(ioutil.Discard, io.LimitReader(response.Body, 1024))
   272  		_ = response.Body.Close()
   273  	}()
   275  	return q.codec.DecodeResponse(ctx, response, r)
   276  }