github.com/yankunsam/loki/v2@v2.6.3-0.20220817130409-389df5235c27/pkg/querier/queryrange/downstreamer.go (about)

     1  package queryrange
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/go-kit/log/level"
     8  	"github.com/grafana/dskit/tenant"
     9  	"github.com/prometheus/prometheus/model/labels"
    10  	"github.com/prometheus/prometheus/promql"
    11  	"github.com/prometheus/prometheus/promql/parser"
    12  
    13  	"github.com/grafana/loki/pkg/loghttp"
    14  	"github.com/grafana/loki/pkg/logql"
    15  	"github.com/grafana/loki/pkg/logqlmodel"
    16  	"github.com/grafana/loki/pkg/querier/queryrange/queryrangebase"
    17  	"github.com/grafana/loki/pkg/util/spanlogger"
    18  )
    19  
    20  const (
    21  	DefaultDownstreamConcurrency = 128
    22  )
    23  
    24  type DownstreamHandler struct {
    25  	limits Limits
    26  	next   queryrangebase.Handler
    27  }
    28  
    29  func ParamsToLokiRequest(params logql.Params, shards logql.Shards) queryrangebase.Request {
    30  	if params.Start().Equal(params.End()) {
    31  		return &LokiInstantRequest{
    32  			Query:     params.Query(),
    33  			Limit:     params.Limit(),
    34  			TimeTs:    params.Start(),
    35  			Direction: params.Direction(),
    36  			Path:      "/loki/api/v1/query", // TODO(owen-d): make this derivable
    37  			Shards:    shards.Encode(),
    38  		}
    39  	}
    40  	return &LokiRequest{
    41  		Query:     params.Query(),
    42  		Limit:     params.Limit(),
    43  		Step:      params.Step().Milliseconds(),
    44  		Interval:  params.Interval().Milliseconds(),
    45  		StartTs:   params.Start(),
    46  		EndTs:     params.End(),
    47  		Direction: params.Direction(),
    48  		Path:      "/loki/api/v1/query_range", // TODO(owen-d): make this derivable
    49  		Shards:    shards.Encode(),
    50  	}
    51  }
    52  
    53  // Note: After the introduction of the LimitedRoundTripper,
    54  // bounding concurrency in the downstreamer is mostly redundant
    55  // The reason we don't remove it is to prevent malicious queries
    56  // from creating an unreasonably large number of goroutines, such as
    57  // the case of a query like `a / a / a / a / a ..etc`, which could try
    58  // to shard each leg, quickly dispatching an unreasonable number of goroutines.
    59  // In the future, it's probably better to replace this with a channel based API
    60  // so we don't have to do all this ugly edge case handling/accounting
    61  func (h DownstreamHandler) Downstreamer(ctx context.Context) logql.Downstreamer {
    62  	p := DefaultDownstreamConcurrency
    63  
    64  	// We may increase parallelism above the default,
    65  	// ensure we don't end up bottlenecking here.
    66  	if user, err := tenant.TenantID(ctx); err == nil {
    67  		if x := h.limits.MaxQueryParallelism(user); x > 0 {
    68  			p = x
    69  		}
    70  	}
    71  
    72  	locks := make(chan struct{}, p)
    73  	for i := 0; i < p; i++ {
    74  		locks <- struct{}{}
    75  	}
    76  	return &instance{
    77  		parallelism: p,
    78  		locks:       locks,
    79  		handler:     h.next,
    80  	}
    81  }
    82  
    83  // instance is an intermediate struct for controlling concurrency across a single query
    84  type instance struct {
    85  	parallelism int
    86  	locks       chan struct{}
    87  	handler     queryrangebase.Handler
    88  }
    89  
    90  func (in instance) Downstream(ctx context.Context, queries []logql.DownstreamQuery) ([]logqlmodel.Result, error) {
    91  	return in.For(ctx, queries, func(qry logql.DownstreamQuery) (logqlmodel.Result, error) {
    92  		req := ParamsToLokiRequest(qry.Params, qry.Shards).WithQuery(qry.Expr.String())
    93  		logger, ctx := spanlogger.New(ctx, "DownstreamHandler.instance")
    94  		defer logger.Finish()
    95  		level.Debug(logger).Log("shards", fmt.Sprintf("%+v", qry.Shards), "query", req.GetQuery(), "step", req.GetStep())
    96  
    97  		res, err := in.handler.Do(ctx, req)
    98  		if err != nil {
    99  			return logqlmodel.Result{}, err
   100  		}
   101  		return ResponseToResult(res)
   102  	})
   103  }
   104  
   105  // For runs a function against a list of queries, collecting the results or returning an error. The indices are preserved such that input[i] maps to output[i].
   106  func (in instance) For(
   107  	ctx context.Context,
   108  	queries []logql.DownstreamQuery,
   109  	fn func(logql.DownstreamQuery) (logqlmodel.Result, error),
   110  ) ([]logqlmodel.Result, error) {
   111  	type resp struct {
   112  		i   int
   113  		res logqlmodel.Result
   114  		err error
   115  	}
   116  
   117  	ctx, cancel := context.WithCancel(ctx)
   118  	defer cancel()
   119  	ch := make(chan resp)
   120  
   121  	// Make one goroutine to dispatch the other goroutines, bounded by instance parallelism
   122  	go func() {
   123  		for i := 0; i < len(queries); i++ {
   124  			select {
   125  			case <-ctx.Done():
   126  				break
   127  			case <-in.locks:
   128  				go func(i int) {
   129  					// release lock back into pool
   130  					defer func() {
   131  						in.locks <- struct{}{}
   132  					}()
   133  
   134  					res, err := fn(queries[i])
   135  					response := resp{
   136  						i:   i,
   137  						res: res,
   138  						err: err,
   139  					}
   140  
   141  					// Feed the result into the channel unless the work has completed.
   142  					select {
   143  					case <-ctx.Done():
   144  					case ch <- response:
   145  					}
   146  				}(i)
   147  			}
   148  		}
   149  	}()
   150  
   151  	results := make([]logqlmodel.Result, len(queries))
   152  	for i := 0; i < len(queries); i++ {
   153  		select {
   154  		case <-ctx.Done():
   155  			return nil, ctx.Err()
   156  		case resp := <-ch:
   157  			if resp.err != nil {
   158  				return nil, resp.err
   159  			}
   160  			results[resp.i] = resp.res
   161  		}
   162  	}
   163  	return results, nil
   164  }
   165  
   166  // convert to matrix
   167  func sampleStreamToMatrix(streams []queryrangebase.SampleStream) parser.Value {
   168  	xs := make(promql.Matrix, 0, len(streams))
   169  	for _, stream := range streams {
   170  		x := promql.Series{}
   171  		x.Metric = make(labels.Labels, 0, len(stream.Labels))
   172  		for _, l := range stream.Labels {
   173  			x.Metric = append(x.Metric, labels.Label(l))
   174  		}
   175  
   176  		x.Points = make([]promql.Point, 0, len(stream.Samples))
   177  		for _, sample := range stream.Samples {
   178  			x.Points = append(x.Points, promql.Point{
   179  				T: sample.TimestampMs,
   180  				V: sample.Value,
   181  			})
   182  		}
   183  
   184  		xs = append(xs, x)
   185  	}
   186  	return xs
   187  }
   188  
   189  func sampleStreamToVector(streams []queryrangebase.SampleStream) parser.Value {
   190  	xs := make(promql.Vector, 0, len(streams))
   191  	for _, stream := range streams {
   192  		x := promql.Sample{}
   193  		x.Metric = make(labels.Labels, 0, len(stream.Labels))
   194  		for _, l := range stream.Labels {
   195  			x.Metric = append(x.Metric, labels.Label(l))
   196  		}
   197  
   198  		x.Point = promql.Point{
   199  			T: stream.Samples[0].TimestampMs,
   200  			V: stream.Samples[0].Value,
   201  		}
   202  
   203  		xs = append(xs, x)
   204  	}
   205  	return xs
   206  }
   207  
   208  func ResponseToResult(resp queryrangebase.Response) (logqlmodel.Result, error) {
   209  	switch r := resp.(type) {
   210  	case *LokiResponse:
   211  		if r.Error != "" {
   212  			return logqlmodel.Result{}, fmt.Errorf("%s: %s", r.ErrorType, r.Error)
   213  		}
   214  
   215  		streams := make(logqlmodel.Streams, 0, len(r.Data.Result))
   216  
   217  		for _, stream := range r.Data.Result {
   218  			streams = append(streams, stream)
   219  		}
   220  
   221  		return logqlmodel.Result{
   222  			Statistics: r.Statistics,
   223  			Data:       streams,
   224  		}, nil
   225  
   226  	case *LokiPromResponse:
   227  		if r.Response.Error != "" {
   228  			return logqlmodel.Result{}, fmt.Errorf("%s: %s", r.Response.ErrorType, r.Response.Error)
   229  		}
   230  		if r.Response.Data.ResultType == loghttp.ResultTypeVector {
   231  			return logqlmodel.Result{
   232  				Statistics: r.Statistics,
   233  				Data:       sampleStreamToVector(r.Response.Data.Result),
   234  			}, nil
   235  		}
   236  		return logqlmodel.Result{
   237  			Statistics: r.Statistics,
   238  			Data:       sampleStreamToMatrix(r.Response.Data.Result),
   239  		}, nil
   240  
   241  	default:
   242  		return logqlmodel.Result{}, fmt.Errorf("cannot decode (%T)", resp)
   243  	}
   244  }