github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/ext/stats/stats.go (about)

     1  // Package stats provides aggregate functions for statistics.
     2  //
     3  // Provided functions:
     4  //   - stddev_pop: population standard deviation
     5  //   - stddev_samp: sample standard deviation
     6  //   - var_pop: population variance
     7  //   - var_samp: sample variance
     8  //   - covar_pop: population covariance
     9  //   - covar_samp: sample covariance
    10  //   - corr: correlation coefficient
    11  //   - regr_r2: correlation coefficient squared
    12  //   - regr_avgx: average of the independent variable
    13  //   - regr_avgy: average of the dependent variable
    14  //   - regr_sxx: sum of the squares of the independent variable
    15  //   - regr_syy: sum of the squares of the dependent variable
    16  //   - regr_sxy: sum of the products of each pair of variables
    17  //   - regr_count: count non-null pairs of variables
    18  //   - regr_slope: slope of the least-squares-fit linear equation
    19  //   - regr_intercept: y-intercept of the least-squares-fit linear equation
    20  //   - regr_json: all regr stats in a JSON object
    21  //
    22  // These join the [Built-in Aggregate Functions]:
    23  //   - count: count rows/values
    24  //   - sum: sum values
    25  //   - avg: average value
    26  //   - min: minimum value
    27  //   - max: maximum value
    28  //
    29  // See: [ANSI SQL Aggregate Functions]
    30  //
    31  // [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
    32  // [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
    33  package stats
    34  
    35  import "github.com/ncruces/go-sqlite3"
    36  
    37  // Register registers statistics functions.
    38  func Register(db *sqlite3.Conn) {
    39  	flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
    40  	db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
    41  	db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
    42  	db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
    43  	db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
    44  	db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
    45  	db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
    46  	db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
    47  	db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2))
    48  	db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx))
    49  	db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy))
    50  	db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy))
    51  	db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx))
    52  	db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy))
    53  	db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope))
    54  	db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
    55  	db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
    56  	db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json))
    57  }
    58  
    59  const (
    60  	var_pop = iota
    61  	var_samp
    62  	stddev_pop
    63  	stddev_samp
    64  	corr
    65  	regr_r2
    66  	regr_sxx
    67  	regr_syy
    68  	regr_sxy
    69  	regr_avgx
    70  	regr_avgy
    71  	regr_slope
    72  	regr_intercept
    73  	regr_count
    74  	regr_json
    75  )
    76  
    77  func newVariance(kind int) func() sqlite3.AggregateFunction {
    78  	return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
    79  }
    80  
    81  type variance struct {
    82  	kind int
    83  	welford
    84  }
    85  
    86  func (fn *variance) Value(ctx sqlite3.Context) {
    87  	var r float64
    88  	switch fn.kind {
    89  	case var_pop:
    90  		r = fn.var_pop()
    91  	case var_samp:
    92  		r = fn.var_samp()
    93  	case stddev_pop:
    94  		r = fn.stddev_pop()
    95  	case stddev_samp:
    96  		r = fn.stddev_samp()
    97  	}
    98  	ctx.ResultFloat(r)
    99  }
   100  
   101  func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
   102  	if a := arg[0]; a.NumericType() != sqlite3.NULL {
   103  		fn.enqueue(a.Float())
   104  	}
   105  }
   106  
   107  func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
   108  	if a := arg[0]; a.NumericType() != sqlite3.NULL {
   109  		fn.dequeue(a.Float())
   110  	}
   111  }
   112  
   113  func newCovariance(kind int) func() sqlite3.AggregateFunction {
   114  	return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
   115  }
   116  
   117  type covariance struct {
   118  	kind int
   119  	welford2
   120  }
   121  
   122  func (fn *covariance) Value(ctx sqlite3.Context) {
   123  	var r float64
   124  	switch fn.kind {
   125  	case var_pop:
   126  		r = fn.covar_pop()
   127  	case var_samp:
   128  		r = fn.covar_samp()
   129  	case corr:
   130  		r = fn.correlation()
   131  	case regr_r2:
   132  		r = fn.regr_r2()
   133  	case regr_sxx:
   134  		r = fn.regr_sxx()
   135  	case regr_syy:
   136  		r = fn.regr_syy()
   137  	case regr_sxy:
   138  		r = fn.regr_sxy()
   139  	case regr_avgx:
   140  		r = fn.regr_avgx()
   141  	case regr_avgy:
   142  		r = fn.regr_avgy()
   143  	case regr_slope:
   144  		r = fn.regr_slope()
   145  	case regr_intercept:
   146  		r = fn.regr_intercept()
   147  	case regr_count:
   148  		ctx.ResultInt64(fn.regr_count())
   149  		return
   150  	case regr_json:
   151  		ctx.ResultText(fn.regr_json())
   152  		return
   153  	}
   154  	ctx.ResultFloat(r)
   155  }
   156  
   157  func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
   158  	a, b := arg[0], arg[1]
   159  	if a.NumericType() != sqlite3.NULL && b.NumericType() != sqlite3.NULL {
   160  		fn.enqueue(a.Float(), b.Float())
   161  	}
   162  }
   163  
   164  func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
   165  	a, b := arg[0], arg[1]
   166  	if a.NumericType() != sqlite3.NULL && b.NumericType() != sqlite3.NULL {
   167  		fn.dequeue(a.Float(), b.Float())
   168  	}
   169  }