github.com/grafana/pyroscope@v1.18.0/pkg/util/interceptor.go (about)

     1  package util
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"connectrpc.com/connect"
     8  	"github.com/go-kit/log"
     9  	"github.com/go-kit/log/level"
    10  	"github.com/grafana/dskit/tracing"
    11  
    12  	"github.com/grafana/pyroscope/pkg/tenant"
    13  )
    14  
    15  type timeoutInterceptor struct {
    16  	timeout time.Duration
    17  }
    18  
    19  // WithTimeout returns a new timeout interceptor.
    20  func WithTimeout(timeout time.Duration) connect.Interceptor {
    21  	return timeoutInterceptor{timeout: timeout}
    22  }
    23  
    24  func (s timeoutInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
    25  	return func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
    26  		ctx, cancel := context.WithTimeout(ctx, s.timeout)
    27  		defer cancel()
    28  		return next(ctx, ar)
    29  	}
    30  }
    31  
    32  func (s timeoutInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
    33  	return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
    34  		ctx, cancel := context.WithTimeout(ctx, s.timeout)
    35  		defer cancel()
    36  		return next(ctx, spec)
    37  	}
    38  }
    39  
    40  func (s timeoutInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
    41  	return func(ctx context.Context, shc connect.StreamingHandlerConn) error {
    42  		ctx, cancel := context.WithTimeout(ctx, s.timeout)
    43  		defer cancel()
    44  		return next(ctx, shc)
    45  	}
    46  }
    47  
    48  // NewLogInterceptor logs the request parameters.
    49  // It logs all kinds of requests.
    50  func NewLogInterceptor(logger log.Logger) connect.UnaryInterceptorFunc {
    51  	interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
    52  		return func(
    53  			ctx context.Context,
    54  			req connect.AnyRequest,
    55  		) (connect.AnyResponse, error) {
    56  			begin := time.Now()
    57  			tenantID, err := tenant.ExtractTenantIDFromContext(ctx)
    58  			if err != nil {
    59  				tenantID = "anonymous"
    60  			}
    61  			traceID, ok := tracing.ExtractTraceID(ctx)
    62  			if !ok {
    63  				traceID = "unknown"
    64  			}
    65  			defer func() {
    66  				level.Info(logger).Log(
    67  					"msg", "request parameters",
    68  					"route", req.Spec().Procedure,
    69  					"tenant", tenantID,
    70  					"traceID", traceID,
    71  					"duration", time.Since(begin),
    72  				)
    73  			}()
    74  
    75  			return next(ctx, req)
    76  		}
    77  	}
    78  	return interceptor
    79  }