github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/ext/stats/stats_test.go (about)

     1  package stats_test
     2  
     3  import (
     4  	"math"
     5  	"testing"
     6  
     7  	"github.com/ncruces/go-sqlite3"
     8  	_ "github.com/ncruces/go-sqlite3/embed"
     9  	"github.com/ncruces/go-sqlite3/ext/stats"
    10  	_ "github.com/ncruces/go-sqlite3/tests/testcfg"
    11  )
    12  
    13  func TestRegister_variance(t *testing.T) {
    14  	t.Parallel()
    15  
    16  	db, err := sqlite3.Open(":memory:")
    17  	if err != nil {
    18  		t.Fatal(err)
    19  	}
    20  	defer db.Close()
    21  
    22  	stats.Register(db)
    23  
    24  	err = db.Exec(`CREATE TABLE data (x)`)
    25  	if err != nil {
    26  		t.Fatal(err)
    27  	}
    28  
    29  	err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
    30  	if err != nil {
    31  		t.Fatal(err)
    32  	}
    33  
    34  	stmt, _, err := db.Prepare(`
    35  		SELECT
    36  			sum(x), avg(x),
    37  			var_samp(x), var_pop(x),
    38  			stddev_samp(x), stddev_pop(x)
    39  		FROM data`)
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	defer stmt.Close()
    44  
    45  	if stmt.Step() {
    46  		if got := stmt.ColumnFloat(0); got != 40 {
    47  			t.Errorf("got %v, want 40", got)
    48  		}
    49  		if got := stmt.ColumnFloat(1); got != 10 {
    50  			t.Errorf("got %v, want 10", got)
    51  		}
    52  		if got := stmt.ColumnFloat(2); got != 30 {
    53  			t.Errorf("got %v, want 30", got)
    54  		}
    55  		if got := stmt.ColumnFloat(3); got != 22.5 {
    56  			t.Errorf("got %v, want 22.5", got)
    57  		}
    58  		if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
    59  			t.Errorf("got %v, want √30", got)
    60  		}
    61  		if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
    62  			t.Errorf("got %v, want √22.5", got)
    63  		}
    64  	}
    65  
    66  	{
    67  		stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
    68  		if err != nil {
    69  			t.Fatal(err)
    70  		}
    71  		defer stmt.Close()
    72  
    73  		want := [...]float64{0, 4.5, 18, 0, 0}
    74  		for i := 0; stmt.Step(); i++ {
    75  			if got := stmt.ColumnFloat(0); got != want[i] {
    76  				t.Errorf("got %v, want %v", got, want[i])
    77  			}
    78  			if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
    79  				t.Errorf("got %v, want %v", got, want[i])
    80  			}
    81  		}
    82  	}
    83  }
    84  
    85  func TestRegister_covariance(t *testing.T) {
    86  	t.Parallel()
    87  
    88  	db, err := sqlite3.Open(":memory:")
    89  	if err != nil {
    90  		t.Fatal(err)
    91  	}
    92  	defer db.Close()
    93  
    94  	stats.Register(db)
    95  
    96  	err = db.Exec(`CREATE TABLE data (y, x)`)
    97  	if err != nil {
    98  		t.Fatal(err)
    99  	}
   100  
   101  	err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  
   106  	stmt, _, err := db.Prepare(`SELECT
   107  		corr(y, x), covar_samp(y, x), covar_pop(y, x),
   108  		regr_avgy(y, x), regr_avgx(y, x),
   109  		regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
   110  		regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x),
   111  		regr_count(y, x), regr_json(y, x)
   112  		FROM data`)
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	defer stmt.Close()
   117  
   118  	if stmt.Step() {
   119  		if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
   120  			t.Errorf("got %v, want 0.9881049293224639", got)
   121  		}
   122  		if got := stmt.ColumnFloat(1); got != 21.25 {
   123  			t.Errorf("got %v, want 21.25", got)
   124  		}
   125  		if got := stmt.ColumnFloat(2); got != 17 {
   126  			t.Errorf("got %v, want 17", got)
   127  		}
   128  		if got := stmt.ColumnFloat(3); got != 4.2 {
   129  			t.Errorf("got %v, want 4.2", got)
   130  		}
   131  		if got := stmt.ColumnFloat(4); got != 75 {
   132  			t.Errorf("got %v, want 75", got)
   133  		}
   134  		if got := stmt.ColumnFloat(5); got != 14.8 {
   135  			t.Errorf("got %v, want 14.8", got)
   136  		}
   137  		if got := stmt.ColumnFloat(6); got != 500 {
   138  			t.Errorf("got %v, want 500", got)
   139  		}
   140  		if got := stmt.ColumnFloat(7); got != 85 {
   141  			t.Errorf("got %v, want 85", got)
   142  		}
   143  		if got := stmt.ColumnFloat(8); got != 0.17 {
   144  			t.Errorf("got %v, want 0.17", got)
   145  		}
   146  		if got := stmt.ColumnFloat(9); got != -8.55 {
   147  			t.Errorf("got %v, want -8.55", got)
   148  		}
   149  		if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
   150  			t.Errorf("got %v, want 0.9763513513513513", got)
   151  		}
   152  		if got := stmt.ColumnInt(11); got != 5 {
   153  			t.Errorf("got %v, want 5", got)
   154  		}
   155  		var a map[string]float64
   156  		if err := stmt.ColumnJSON(12, &a); err != nil {
   157  			t.Error(err)
   158  		} else if got := a["count"]; got != 5 {
   159  			t.Errorf("got %v, want 5", got)
   160  		}
   161  	}
   162  
   163  	{
   164  		stmt, _, err := db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
   165  		if err != nil {
   166  			t.Fatal(err)
   167  		}
   168  		defer stmt.Close()
   169  
   170  		want := [...]float64{0, 10, 30, 75, 22.5}
   171  		for i := 0; stmt.Step(); i++ {
   172  			if got := stmt.ColumnFloat(0); got != want[i] {
   173  				t.Errorf("got %v, want %v", got, want[i])
   174  			}
   175  			if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
   176  				t.Errorf("got %v, want %v", got, want[i])
   177  			}
   178  		}
   179  	}
   180  }
   181  
   182  func Benchmark_average(b *testing.B) {
   183  	sqlite3.Initialize()
   184  	b.ResetTimer()
   185  
   186  	db, err := sqlite3.Open(":memory:")
   187  	if err != nil {
   188  		b.Fatal(err)
   189  	}
   190  	defer db.Close()
   191  
   192  	stmt, _, err := db.Prepare(`SELECT avg(value) FROM generate_series(0, ?)`)
   193  	if err != nil {
   194  		b.Fatal(err)
   195  	}
   196  	defer stmt.Close()
   197  
   198  	err = stmt.BindInt(1, b.N)
   199  	if err != nil {
   200  		b.Fatal(err)
   201  	}
   202  
   203  	if stmt.Step() {
   204  		want := float64(b.N) / 2
   205  		if got := stmt.ColumnFloat(0); got != want {
   206  			b.Errorf("got %v, want %v", got, want)
   207  		}
   208  	}
   209  
   210  	err = stmt.Err()
   211  	if err != nil {
   212  		b.Error(err)
   213  	}
   214  }
   215  
   216  func Benchmark_variance(b *testing.B) {
   217  	sqlite3.Initialize()
   218  	b.ResetTimer()
   219  
   220  	db, err := sqlite3.Open(":memory:")
   221  	if err != nil {
   222  		b.Fatal(err)
   223  	}
   224  	defer db.Close()
   225  
   226  	stats.Register(db)
   227  
   228  	stmt, _, err := db.Prepare(`SELECT var_pop(value) FROM generate_series(0, ?)`)
   229  	if err != nil {
   230  		b.Fatal(err)
   231  	}
   232  	defer stmt.Close()
   233  
   234  	err = stmt.BindInt(1, b.N)
   235  	if err != nil {
   236  		b.Fatal(err)
   237  	}
   238  
   239  	if stmt.Step() && b.N > 100 {
   240  		want := float64(b.N*b.N) / 12
   241  		if got := stmt.ColumnFloat(0); want > (got-want)*float64(b.N) {
   242  			b.Errorf("got %v, want %v", got, want)
   243  		}
   244  	}
   245  
   246  	err = stmt.Err()
   247  	if err != nil {
   248  		b.Error(err)
   249  	}
   250  }