github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlstats/sqlstats.go (about)

     1  // Package sqlstats implements an SQLite Tracer that collects query stats.
     2  package sqlstats
     3  
     4  import (
     5  	"context"
     6  	"expvar"
     7  	"fmt"
     8  	"html"
     9  	"net/http"
    10  	"net/url"
    11  	"regexp"
    12  	"sort"
    13  	"strings"
    14  	"sync"
    15  	"sync/atomic"
    16  	"time"
    17  
    18  	"github.com/tailscale/sqlite/sqliteh"
    19  )
    20  
    21  // Tracer implements sqlite.Tracer and collects query stats.
    22  //
    23  // To use, pass the tracer object to sqlite.Connector, then start a debug
    24  // web server with http.HandlerFunc(sqlTracer.Handle).
    25  type Tracer struct {
    26  	TxCount        *expvar.Map
    27  	TxCommit       *expvar.Map
    28  	TxCommitError  *expvar.Map
    29  	TxRollback     *expvar.Map
    30  	TxTotalSeconds *expvar.Map
    31  
    32  	curTxs sync.Map // TraceConnID -> *connStats
    33  
    34  	// Once a query has been seen once, only the read lock
    35  	// is required to update stats.
    36  	//
    37  	// TODO(crawshaw): assuming queries is effectively read-only
    38  	// in the steady state, a sync.Map would be a faster object
    39  	// here.
    40  	mu      sync.RWMutex
    41  	queries map[string]*QueryStats // normalized query -> stats
    42  }
    43  
    44  // Reset resets the state of t to its initial conditions.
    45  func (t *Tracer) Reset() {
    46  	if t.TxCount != nil {
    47  		t.TxCount.Init()
    48  	}
    49  	if t.TxCommit != nil {
    50  		t.TxCommit.Init()
    51  	}
    52  	if t.TxCommitError != nil {
    53  		t.TxCommitError.Init()
    54  	}
    55  	if t.TxRollback != nil {
    56  		t.TxRollback.Init()
    57  	}
    58  	if t.TxTotalSeconds != nil {
    59  		t.TxTotalSeconds.Init()
    60  	}
    61  	t.curTxs.Range(func(key, value any) bool {
    62  		t.curTxs.Delete(key)
    63  		return true
    64  	})
    65  
    66  	t.mu.Lock()
    67  	defer t.mu.Unlock()
    68  	t.queries = nil
    69  }
    70  
    71  type connStats struct {
    72  	mu       sync.Mutex
    73  	why      string
    74  	at       time.Time
    75  	readOnly bool
    76  }
    77  
    78  func (t *Tracer) done(s *connStats) (why string, readOnly bool) {
    79  	s.mu.Lock()
    80  	why = s.why
    81  	readOnly = s.readOnly
    82  	at := s.at
    83  	s.why = ""
    84  	s.at = time.Time{}
    85  	s.readOnly = false
    86  	s.mu.Unlock()
    87  
    88  	if t.TxTotalSeconds != nil {
    89  		sec := time.Since(at).Seconds()
    90  		t.TxTotalSeconds.AddFloat(why, sec)
    91  		if readOnly {
    92  			t.TxTotalSeconds.AddFloat("read", sec)
    93  		} else {
    94  			t.TxTotalSeconds.AddFloat("write", sec)
    95  		}
    96  	}
    97  	return why, readOnly
    98  }
    99  
   100  // QueryStats is a collection of stats for a given Query.
   101  type QueryStats struct {
   102  	Query string
   103  
   104  	// When inside the queries map all fields must be accessed as atomics.
   105  
   106  	// Count represents the number of times this query has been
   107  	// executed.
   108  	Count int64
   109  
   110  	// Errors represents the number of errors encountered executing
   111  	// this query.
   112  	Errors int64
   113  
   114  	// TotalDuration represents the accumulated time spent executing the query.
   115  	TotalDuration time.Duration
   116  
   117  	// MeanDuration represents the average time spent executing the query.
   118  	MeanDuration time.Duration
   119  
   120  	// TODO lastErr atomic.Value
   121  }
   122  
   123  var rxRemoveInClause = regexp.MustCompile(`(?i)\s+in\s*\((?:\s*\d+\s*(?:,\s*\d+\s*)*)\)`)
   124  
   125  func normalizeQuery(q string) string {
   126  	if strings.Contains(q, " in (") || strings.Contains(q, " IN (") {
   127  		q = rxRemoveInClause.ReplaceAllString(q, " IN (...)")
   128  	}
   129  	return q
   130  }
   131  
   132  func (t *Tracer) queryStats(query string) *QueryStats {
   133  	query = normalizeQuery(query)
   134  
   135  	t.mu.RLock()
   136  	stats := t.queries[query]
   137  	t.mu.RUnlock()
   138  
   139  	if stats != nil {
   140  		return stats
   141  	}
   142  
   143  	t.mu.Lock()
   144  	defer t.mu.Unlock()
   145  	if t.queries == nil {
   146  		t.queries = make(map[string]*QueryStats)
   147  	}
   148  	stats = t.queries[query]
   149  	if stats == nil {
   150  		stats = &QueryStats{Query: query}
   151  		t.queries[query] = stats
   152  	}
   153  	return stats
   154  }
   155  
   156  // Collect returns the list of QueryStats pointers from the
   157  // Tracer.
   158  func (t *Tracer) Collect() (rows []*QueryStats) {
   159  	t.mu.RLock()
   160  	defer t.mu.RUnlock()
   161  
   162  	rows = make([]*QueryStats, 0, len(t.queries))
   163  	for query, s := range t.queries {
   164  		row := QueryStats{
   165  			Query:         query,
   166  			Count:         atomic.LoadInt64(&s.Count),
   167  			Errors:        atomic.LoadInt64(&s.Errors),
   168  			TotalDuration: time.Duration(atomic.LoadInt64((*int64)(&s.TotalDuration))),
   169  		}
   170  
   171  		row.MeanDuration = time.Duration(int64(row.TotalDuration) / row.Count)
   172  		rows = append(rows, &row)
   173  	}
   174  	return rows
   175  }
   176  
   177  func (t *Tracer) Query(
   178  	prepCtx context.Context,
   179  	id sqliteh.TraceConnID,
   180  	query string,
   181  	duration time.Duration,
   182  	err error,
   183  ) {
   184  	stats := t.queryStats(query)
   185  
   186  	atomic.AddInt64(&stats.Count, 1)
   187  	atomic.AddInt64((*int64)(&stats.TotalDuration), int64(duration))
   188  
   189  	if err != nil {
   190  		atomic.AddInt64(&stats.Errors, 1)
   191  	}
   192  }
   193  
   194  func (t *Tracer) connStats(id sqliteh.TraceConnID) *connStats {
   195  	var s *connStats
   196  	v, ok := t.curTxs.Load(id)
   197  	if ok {
   198  		s = v.(*connStats)
   199  	} else {
   200  		s = &connStats{}
   201  		t.curTxs.Store(id, s)
   202  	}
   203  	return s
   204  }
   205  
   206  func (t *Tracer) BeginTx(
   207  	beginCtx context.Context,
   208  	id sqliteh.TraceConnID,
   209  	why string,
   210  	readOnly bool,
   211  	err error,
   212  ) {
   213  	s := t.connStats(id)
   214  
   215  	s.mu.Lock()
   216  	s.why = why
   217  	s.at = time.Now()
   218  	s.readOnly = readOnly
   219  	s.mu.Unlock()
   220  
   221  	if t.TxCount != nil {
   222  		t.TxCount.Add(why, 1)
   223  		if readOnly {
   224  			t.TxCount.Add("read", 1)
   225  		} else {
   226  			t.TxCount.Add("write", 1)
   227  		}
   228  	}
   229  }
   230  
   231  func (t *Tracer) Commit(id sqliteh.TraceConnID, err error) {
   232  	s := t.connStats(id)
   233  	why, readOnly := t.done(s)
   234  	if err == nil {
   235  		if t.TxCommit != nil {
   236  			t.TxCommit.Add(why, 1)
   237  			if readOnly {
   238  				t.TxCommit.Add("read", 1)
   239  			} else {
   240  				t.TxCommit.Add("write", 1)
   241  			}
   242  		}
   243  	} else {
   244  		if t.TxCommitError != nil {
   245  			t.TxCommitError.Add(why, 1)
   246  			if readOnly {
   247  				t.TxCommitError.Add("read", 1)
   248  			} else {
   249  				t.TxCommitError.Add("write", 1)
   250  			}
   251  		}
   252  	}
   253  }
   254  
   255  func (t *Tracer) Rollback(id sqliteh.TraceConnID, err error) {
   256  	s := t.connStats(id)
   257  	why, readOnly := t.done(s)
   258  	if t.TxRollback != nil {
   259  		t.TxRollback.Add(why, 1)
   260  		if readOnly {
   261  			t.TxRollback.Add("read", 1)
   262  		} else {
   263  			t.TxRollback.Add("write", 1)
   264  		}
   265  	}
   266  }
   267  
   268  func (t *Tracer) HandleConns(w http.ResponseWriter, r *http.Request) {
   269  	type txSummary struct {
   270  		name     string
   271  		start    time.Time
   272  		readOnly bool
   273  	}
   274  	var summary []txSummary
   275  
   276  	t.curTxs.Range(func(k, v any) bool {
   277  		s := v.(*connStats)
   278  
   279  		s.mu.Lock()
   280  		summary = append(summary, txSummary{
   281  			name:     s.why,
   282  			start:    s.at,
   283  			readOnly: s.readOnly,
   284  		})
   285  		s.mu.Unlock()
   286  
   287  		return true
   288  	})
   289  
   290  	sort.Slice(summary, func(i, j int) bool { return summary[i].start.Before(summary[j].start) })
   291  
   292  	now := time.Now()
   293  
   294  	w.Header().Set("Content-Type", "text/html; charset=utf-8")
   295  	w.WriteHeader(200)
   296  	fmt.Fprintf(w, "<!DOCTYPE html><html><title>sqlite conns</title><body>\n")
   297  	fmt.Fprintf(w, "<p>outstanding sqlite transactions: %d</p>\n", len(summary))
   298  	fmt.Fprintf(w, "<pre>\n")
   299  	for _, s := range summary {
   300  		rw := ""
   301  		if !s.readOnly {
   302  			rw = " read-write"
   303  		}
   304  		fmt.Fprintf(
   305  			w,
   306  			"\n\t%s (%v)%s",
   307  			html.EscapeString(s.name),
   308  			now.Sub(s.start).Round(time.Millisecond),
   309  			rw,
   310  		)
   311  	}
   312  	fmt.Fprintf(w, "</pre></body></html>\n")
   313  }
   314  
   315  func (t *Tracer) Handle(w http.ResponseWriter, r *http.Request) {
   316  	getArgs, _ := url.ParseQuery(r.URL.RawQuery)
   317  	sortParam := strings.TrimSpace(getArgs.Get("sort"))
   318  	rows := t.Collect()
   319  
   320  	switch sortParam {
   321  	case "", "count":
   322  		sort.Slice(rows, func(i, j int) bool { return rows[i].Count > rows[j].Count })
   323  	case "query":
   324  		sort.Slice(rows, func(i, j int) bool { return rows[i].Query < rows[j].Query })
   325  	case "duration":
   326  		sort.Slice(
   327  			rows,
   328  			func(i, j int) bool { return rows[i].TotalDuration > rows[j].TotalDuration },
   329  		)
   330  	case "errors":
   331  		sort.Slice(rows, func(i, j int) bool { return rows[i].Errors > rows[j].Errors })
   332  	case "mean":
   333  		sort.Slice(rows, func(i, j int) bool { return rows[i].MeanDuration > rows[j].MeanDuration })
   334  	default:
   335  		http.Error(w, fmt.Sprintf("unknown sort: %q", sortParam), 400)
   336  	}
   337  
   338  	w.Header().Set("Content-Type", "text/html; charset=utf-8")
   339  	w.WriteHeader(200)
   340  	fmt.Fprintf(w, `<!DOCTYPE html><html><body>
   341  	<p>Trace of SQLite queries run via the github.com/tailscale/sqlite driver.</p>
   342  	<table border="1">
   343  	<tr>
   344  	<th><a href="?sort=query">Query</a></th>
   345  	<th><a href="?sort=count">Count</a></th>
   346  	<th><a href="?sort=duration">Duration</a></th>
   347  	<th><a href="?sort=mean">Mean</a></th>
   348  	<th><a href="?sort=errors">Errors</a></th>
   349  	</tr>
   350  	`)
   351  	for _, row := range rows {
   352  		fmt.Fprintf(w, "<tr><td>%s</td><td>%d</td><td>%s</td><td>%s</td><td>%d</td></tr>\n",
   353  			row.Query,
   354  			row.Count,
   355  			row.TotalDuration.Round(time.Second),
   356  			row.MeanDuration.Round(time.Millisecond),
   357  			row.Errors,
   358  		)
   359  	}
   360  	fmt.Fprintf(w, "</table></body></html>")
   361  }