github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/telemetry/reporter.go (about)

     1  // Package telemetry implements a client for reporting telemetry data used to
     2  // prioritize development of SpiceDB.
     3  //
     4  // For more information, see:
     5  // https://github.com/authzed/spicedb/blob/main/TELEMETRY.md
     6  package telemetry
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"crypto/tls"
    12  	"fmt"
    13  	"io"
    14  	"math/rand"
    15  	"net/http"
    16  	"net/url"
    17  	"time"
    18  
    19  	prompb "buf.build/gen/go/prometheus/prometheus/protocolbuffers/go"
    20  	"github.com/cenkalti/backoff/v4"
    21  	"github.com/gogo/protobuf/proto"
    22  	"github.com/golang/snappy"
    23  	"github.com/prometheus/client_golang/prometheus"
    24  	"github.com/prometheus/common/expfmt"
    25  	"github.com/prometheus/common/model"
    26  
    27  	log "github.com/authzed/spicedb/internal/logging"
    28  	"github.com/authzed/spicedb/pkg/x509util"
    29  )
    30  
    31  const (
    32  	// DefaultEndpoint is the endpoint to which telemetry will report if none
    33  	// other is specified.
    34  	DefaultEndpoint = "https://telemetry.authzed.com"
    35  
    36  	// DefaultInterval is the default amount of time to wait between telemetry
    37  	// reports.
    38  	DefaultInterval = 1 * time.Hour
    39  
    40  	// MaxElapsedTimeBetweenReports is the maximum amount of time that the
    41  	// telemetry reporter will attempt to write to the telemetry endpoint
    42  	// before terminating the reporter.
    43  	MaxElapsedTimeBetweenReports = 168 * time.Hour
    44  
    45  	// MinimumAllowedInterval is the minimum amount of time one can request
    46  	// between telemetry reports.
    47  	MinimumAllowedInterval = 1 * time.Minute
    48  )
    49  
    50  func writeTimeSeries(ctx context.Context, client *http.Client, endpoint string, ts []*prompb.TimeSeries) error {
    51  	// Reference upstream client:
    52  	// https://github.com/prometheus/prometheus/blob/6555cc68caf8d8f323056e497ae7bb1e32a81667/storage/remote/client.go#L191
    53  	pbBytes, err := proto.Marshal(&prompb.WriteRequest{
    54  		Timeseries: ts,
    55  	})
    56  	if err != nil {
    57  		return fmt.Errorf("failed to marshal Prometheus remote write protobuf: %w", err)
    58  	}
    59  	compressedPB := snappy.Encode(nil, pbBytes)
    60  
    61  	r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(compressedPB))
    62  	if err != nil {
    63  		return fmt.Errorf("failed to create Prometheus remote write http request: %w", err)
    64  	}
    65  
    66  	r.Header.Add("X-Prometheus-Remote-Write-Version", "0.1.0")
    67  	r.Header.Add("Content-Encoding", "snappy")
    68  	r.Header.Set("Content-Type", "application/x-protobuf")
    69  
    70  	resp, err := client.Do(r)
    71  	if err != nil {
    72  		return fmt.Errorf("failed to send Prometheus remote write: %w", err)
    73  	}
    74  	defer resp.Body.Close()
    75  
    76  	if resp.StatusCode/100 != 2 {
    77  		body, _ := io.ReadAll(resp.Body)
    78  		return fmt.Errorf(
    79  			"unexpected Prometheus remote write response: %d: %s",
    80  			resp.StatusCode,
    81  			string(body),
    82  		)
    83  	}
    84  
    85  	return nil
    86  }
    87  
    88  func discoverTimeseries(registry *prometheus.Registry) (allTS []*prompb.TimeSeries, err error) {
    89  	metricFams, err := registry.Gather()
    90  	if err != nil {
    91  		return nil, fmt.Errorf("failed to gather telemetry metrics: %w", err)
    92  	}
    93  
    94  	defaultTimestamp := model.Time(time.Now().UnixNano() / int64(time.Millisecond))
    95  	sampleVector, err := expfmt.ExtractSamples(&expfmt.DecodeOptions{
    96  		Timestamp: defaultTimestamp,
    97  	}, metricFams...)
    98  	if err != nil {
    99  		return nil, fmt.Errorf("unable to extract sample from metrics families: %w", err)
   100  	}
   101  
   102  	for _, sample := range sampleVector {
   103  		allTS = append(allTS, &prompb.TimeSeries{
   104  			Labels: convertLabels(sample.Metric),
   105  			Samples: []*prompb.Sample{{
   106  				Value:     float64(sample.Value),
   107  				Timestamp: int64(sample.Timestamp),
   108  			}},
   109  		})
   110  	}
   111  
   112  	return
   113  }
   114  
   115  func discoverAndWriteMetrics(
   116  	ctx context.Context,
   117  	registry *prometheus.Registry,
   118  	client *http.Client,
   119  	endpoint string,
   120  ) error {
   121  	ts, err := discoverTimeseries(registry)
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	return writeTimeSeries(ctx, client, endpoint, ts)
   127  }
   128  
   129  type Reporter func(ctx context.Context) error
   130  
   131  // RemoteReporter creates a telemetry reporter with the specified parameters, or errors
   132  // if the configuration was invalid.
   133  func RemoteReporter(
   134  	registry *prometheus.Registry,
   135  	endpoint string,
   136  	caOverridePath string,
   137  	interval time.Duration,
   138  ) (Reporter, error) {
   139  	if _, err := url.Parse(endpoint); err != nil {
   140  		return nil, fmt.Errorf("invalid telemetry endpoint: %w", err)
   141  	}
   142  	if interval < MinimumAllowedInterval {
   143  		return nil, fmt.Errorf("invalid telemetry reporting interval: %s < %s", interval, MinimumAllowedInterval)
   144  	}
   145  	if endpoint == DefaultEndpoint && interval != DefaultInterval {
   146  		return nil, fmt.Errorf("cannot change the telemetry reporting interval for the default endpoint")
   147  	}
   148  
   149  	client := &http.Client{}
   150  	if caOverridePath != "" {
   151  		pool, err := x509util.CustomCertPool(caOverridePath)
   152  		if err != nil {
   153  			return nil, fmt.Errorf("invalid custom cert pool path `%s`: %w", caOverridePath, err)
   154  		}
   155  
   156  		t := &http.Transport{
   157  			TLSClientConfig: &tls.Config{
   158  				RootCAs:    pool,
   159  				MinVersion: tls.VersionTLS12,
   160  			},
   161  		}
   162  
   163  		client.Transport = t
   164  	}
   165  
   166  	return func(ctx context.Context) error {
   167  		// nolint:gosec
   168  		// G404 use of non cryptographically secure random number generator is not a security concern here,
   169  		// as this is only used to smear the startup delay out over 10% of the reporting interval
   170  		startupDelay := time.Duration(rand.Int63n(int64(interval.Seconds()/10))) * time.Second
   171  
   172  		log.Ctx(ctx).Info().
   173  			Stringer("interval", interval).
   174  			Str("endpoint", endpoint).
   175  			Stringer("next", startupDelay).
   176  			Msg("telemetry reporter scheduled")
   177  
   178  		backoffInterval := backoff.NewExponentialBackOff()
   179  		backoffInterval.InitialInterval = interval
   180  		backoffInterval.MaxInterval = MaxElapsedTimeBetweenReports
   181  		backoffInterval.MaxElapsedTime = 0
   182  
   183  		// Must reset the backoff object after changing parameters
   184  		backoffInterval.Reset()
   185  
   186  		ticker := time.After(startupDelay)
   187  
   188  		for {
   189  			select {
   190  			case <-ticker:
   191  				nextPush := backoffInterval.InitialInterval
   192  				if err := discoverAndWriteMetrics(ctx, registry, client, endpoint); err != nil {
   193  					nextPush = backoffInterval.NextBackOff()
   194  					log.Ctx(ctx).Warn().
   195  						Err(err).
   196  						Str("endpoint", endpoint).
   197  						Stringer("next", nextPush).
   198  						Msg("failed to push telemetry metric")
   199  				} else {
   200  					log.Ctx(ctx).Debug().
   201  						Str("endpoint", endpoint).
   202  						Stringer("next", nextPush).
   203  						Msg("reported telemetry")
   204  					backoffInterval.Reset()
   205  				}
   206  				if nextPush == backoff.Stop {
   207  					return fmt.Errorf(
   208  						"exceeded maximum time between successful reports of %s",
   209  						MaxElapsedTimeBetweenReports,
   210  					)
   211  				}
   212  				ticker = time.After(nextPush)
   213  
   214  			case <-ctx.Done():
   215  				return nil
   216  			}
   217  		}
   218  	}, nil
   219  }
   220  
   221  func DisabledReporter(ctx context.Context) error {
   222  	log.Ctx(ctx).Info().Msg("telemetry disabled")
   223  	return nil
   224  }
   225  
   226  func SilentlyDisabledReporter(_ context.Context) error {
   227  	return nil
   228  }
   229  
   230  func convertLabels(labels model.Metric) []*prompb.Label {
   231  	out := make([]*prompb.Label, 0, len(labels))
   232  	for name, value := range labels {
   233  		out = append(out, &prompb.Label{
   234  			Name:  string(name),
   235  			Value: string(value),
   236  		})
   237  	}
   238  	return out
   239  }