github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/database/sqlite3/sqltest/sqltest.go (about)

     1  // +build ingore
     2  
     3  package sqltest
     4  
     5  import (
     6  	"database/sql"
     7  	"fmt"
     8  	"math/rand"
     9  	"regexp"
    10  	"strconv"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  )
    15  
    16  type Dialect int
    17  
    18  const (
    19  	SQLITE Dialect = iota
    20  	POSTGRESQL
    21  	MYSQL
    22  )
    23  
    24  type DB struct {
    25  	*testing.T
    26  	*sql.DB
    27  	dialect Dialect
    28  	once    sync.Once
    29  }
    30  
    31  var db *DB
    32  
    33  // the following tables will be created and dropped during the test
    34  var testTables = []string{"foo", "bar", "t", "bench"}
    35  
    36  var tests = []testing.InternalTest{
    37  	{"TestBlobs", TestBlobs},
    38  	{"TestManyQueryRow", TestManyQueryRow},
    39  	{"TestTxQuery", TestTxQuery},
    40  	{"TestPreparedStmt", TestPreparedStmt},
    41  }
    42  
    43  var benchmarks = []testing.InternalBenchmark{
    44  	{"BenchmarkExec", BenchmarkExec},
    45  	{"BenchmarkQuery", BenchmarkQuery},
    46  	{"BenchmarkParams", BenchmarkParams},
    47  	{"BenchmarkStmt", BenchmarkStmt},
    48  	{"BenchmarkRows", BenchmarkRows},
    49  	{"BenchmarkStmtRows", BenchmarkStmtRows},
    50  }
    51  
    52  // RunTests runs the SQL test suite
    53  func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
    54  	db = &DB{t, d, dialect, sync.Once{}}
    55  	testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
    56  
    57  	if !testing.Short() {
    58  		for _, b := range benchmarks {
    59  			fmt.Printf("%-20s", b.Name)
    60  			r := testing.Benchmark(b.F)
    61  			fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
    62  		}
    63  	}
    64  	db.tearDown()
    65  }
    66  
    67  func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
    68  	res, err := db.Exec(sql, args...)
    69  	if err != nil {
    70  		db.Fatalf("Error running %q: %v", sql, err)
    71  	}
    72  	return res
    73  }
    74  
    75  func (db *DB) tearDown() {
    76  	for _, tbl := range testTables {
    77  		switch db.dialect {
    78  		case SQLITE:
    79  			db.mustExec("drop table if exists " + tbl)
    80  		case MYSQL, POSTGRESQL:
    81  			db.mustExec("drop table if exists " + tbl)
    82  		default:
    83  			db.Fatal("unkown dialect")
    84  		}
    85  	}
    86  }
    87  
    88  // q replaces ? parameters if needed
    89  func (db *DB) q(sql string) string {
    90  	switch db.dialect {
    91  	case POSTGRESQL: // repace with $1, $2, ..
    92  		qrx := regexp.MustCompile(`\?`)
    93  		n := 0
    94  		return qrx.ReplaceAllStringFunc(sql, func(string) string {
    95  			n++
    96  			return "$" + strconv.Itoa(n)
    97  		})
    98  	}
    99  	return sql
   100  }
   101  
   102  func (db *DB) blobType(size int) string {
   103  	switch db.dialect {
   104  	case SQLITE:
   105  		return fmt.Sprintf("blob[%d]", size)
   106  	case POSTGRESQL:
   107  		return "bytea"
   108  	case MYSQL:
   109  		return fmt.Sprintf("VARBINARY(%d)", size)
   110  	}
   111  	panic("unkown dialect")
   112  }
   113  
   114  func (db *DB) serialPK() string {
   115  	switch db.dialect {
   116  	case SQLITE:
   117  		return "integer primary key autoincrement"
   118  	case POSTGRESQL:
   119  		return "serial primary key"
   120  	case MYSQL:
   121  		return "integer primary key auto_increment"
   122  	}
   123  	panic("unkown dialect")
   124  }
   125  
   126  func (db *DB) now() string {
   127  	switch db.dialect {
   128  	case SQLITE:
   129  		return "datetime('now')"
   130  	case POSTGRESQL:
   131  		return "now()"
   132  	case MYSQL:
   133  		return "now()"
   134  	}
   135  	panic("unkown dialect")
   136  }
   137  
   138  func makeBench() {
   139  	if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
   140  		panic(err)
   141  	}
   142  	st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
   143  	if err != nil {
   144  		panic(err)
   145  	}
   146  	defer st.Close()
   147  	for i := 0; i < 100; i++ {
   148  		if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
   149  			panic(err)
   150  		}
   151  	}
   152  }
   153  
   154  func TestResult(t *testing.T) {
   155  	db.tearDown()
   156  	db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
   157  
   158  	for i := 1; i < 3; i++ {
   159  		r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
   160  		n, err := r.RowsAffected()
   161  		if err != nil {
   162  			t.Fatal(err)
   163  		}
   164  		if n != 1 {
   165  			t.Errorf("got %v, want %v", n, 1)
   166  		}
   167  		n, err = r.LastInsertId()
   168  		if err != nil {
   169  			t.Fatal(err)
   170  		}
   171  		if n != int64(i) {
   172  			t.Errorf("got %v, want %v", n, i)
   173  		}
   174  	}
   175  	if _, err := db.Exec("error!"); err == nil {
   176  		t.Fatalf("expected error")
   177  	}
   178  }
   179  
   180  func TestBlobs(t *testing.T) {
   181  	db.tearDown()
   182  	var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
   183  	db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
   184  	db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
   185  
   186  	want := fmt.Sprintf("%x", blob)
   187  
   188  	b := make([]byte, 16)
   189  	err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
   190  	got := fmt.Sprintf("%x", b)
   191  	if err != nil {
   192  		t.Errorf("[]byte scan: %v", err)
   193  	} else if got != want {
   194  		t.Errorf("for []byte, got %q; want %q", got, want)
   195  	}
   196  
   197  	err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
   198  	want = string(blob)
   199  	if err != nil {
   200  		t.Errorf("string scan: %v", err)
   201  	} else if got != want {
   202  		t.Errorf("for string, got %q; want %q", got, want)
   203  	}
   204  }
   205  
   206  func TestManyQueryRow(t *testing.T) {
   207  	if testing.Short() {
   208  		t.Log("skipping in short mode")
   209  		return
   210  	}
   211  	db.tearDown()
   212  	db.mustExec("create table foo (id integer primary key, name varchar(50))")
   213  	db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
   214  	var name string
   215  	for i := 0; i < 10000; i++ {
   216  		err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
   217  		if err != nil || name != "bob" {
   218  			t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
   219  		}
   220  	}
   221  }
   222  
   223  func TestTxQuery(t *testing.T) {
   224  	db.tearDown()
   225  	tx, err := db.Begin()
   226  	if err != nil {
   227  		t.Fatal(err)
   228  	}
   229  	defer tx.Rollback()
   230  
   231  	_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
   232  	if err != nil {
   233  		t.Fatal(err)
   234  	}
   235  
   236  	_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
   237  	if err != nil {
   238  		t.Fatal(err)
   239  	}
   240  
   241  	r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
   242  	if err != nil {
   243  		t.Fatal(err)
   244  	}
   245  	defer r.Close()
   246  
   247  	if !r.Next() {
   248  		if r.Err() != nil {
   249  			t.Fatal(err)
   250  		}
   251  		t.Fatal("expected one rows")
   252  	}
   253  
   254  	var name string
   255  	err = r.Scan(&name)
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  }
   260  
   261  func TestPreparedStmt(t *testing.T) {
   262  	db.tearDown()
   263  	db.mustExec("CREATE TABLE t (count INT)")
   264  	sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
   265  	if err != nil {
   266  		t.Fatalf("prepare 1: %v", err)
   267  	}
   268  	ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
   269  	if err != nil {
   270  		t.Fatalf("prepare 2: %v", err)
   271  	}
   272  
   273  	for n := 1; n <= 3; n++ {
   274  		if _, err := ins.Exec(n); err != nil {
   275  			t.Fatalf("insert(%d) = %v", n, err)
   276  		}
   277  	}
   278  
   279  	const nRuns = 10
   280  	ch := make(chan bool)
   281  	for i := 0; i < nRuns; i++ {
   282  		go func() {
   283  			defer func() {
   284  				ch <- true
   285  			}()
   286  			for j := 0; j < 10; j++ {
   287  				count := 0
   288  				if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
   289  					t.Errorf("Query: %v", err)
   290  					return
   291  				}
   292  				if _, err := ins.Exec(rand.Intn(100)); err != nil {
   293  					t.Errorf("Insert: %v", err)
   294  					return
   295  				}
   296  			}
   297  		}()
   298  	}
   299  	for i := 0; i < nRuns; i++ {
   300  		<-ch
   301  	}
   302  }
   303  
   304  // Benchmarks need to use panic() since b.Error errors are lost when
   305  // running via testing.Benchmark() I would like to run these via go
   306  // test -bench but calling Benchmark() from a benchmark test
   307  // currently hangs go.
   308  
   309  func BenchmarkExec(b *testing.B) {
   310  	for i := 0; i < b.N; i++ {
   311  		if _, err := db.Exec("select 1"); err != nil {
   312  			panic(err)
   313  		}
   314  	}
   315  }
   316  
   317  func BenchmarkQuery(b *testing.B) {
   318  	for i := 0; i < b.N; i++ {
   319  		var n sql.NullString
   320  		var i int
   321  		var f float64
   322  		var s string
   323  		//		var t time.Time
   324  		if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
   325  			panic(err)
   326  		}
   327  	}
   328  }
   329  
   330  func BenchmarkParams(b *testing.B) {
   331  	for i := 0; i < b.N; i++ {
   332  		var n sql.NullString
   333  		var i int
   334  		var f float64
   335  		var s string
   336  		//		var t time.Time
   337  		if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
   338  			panic(err)
   339  		}
   340  	}
   341  }
   342  
   343  func BenchmarkStmt(b *testing.B) {
   344  	st, err := db.Prepare("select ?, ?, ?, ?")
   345  	if err != nil {
   346  		panic(err)
   347  	}
   348  	defer st.Close()
   349  
   350  	for n := 0; n < b.N; n++ {
   351  		var n sql.NullString
   352  		var i int
   353  		var f float64
   354  		var s string
   355  		//		var t time.Time
   356  		if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
   357  			panic(err)
   358  		}
   359  	}
   360  }
   361  
   362  func BenchmarkRows(b *testing.B) {
   363  	db.once.Do(makeBench)
   364  
   365  	for n := 0; n < b.N; n++ {
   366  		var n sql.NullString
   367  		var i int
   368  		var f float64
   369  		var s string
   370  		var t time.Time
   371  		r, err := db.Query("select * from bench")
   372  		if err != nil {
   373  			panic(err)
   374  		}
   375  		for r.Next() {
   376  			if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
   377  				panic(err)
   378  			}
   379  		}
   380  		if err = r.Err(); err != nil {
   381  			panic(err)
   382  		}
   383  	}
   384  }
   385  
   386  func BenchmarkStmtRows(b *testing.B) {
   387  	db.once.Do(makeBench)
   388  
   389  	st, err := db.Prepare("select * from bench")
   390  	if err != nil {
   391  		panic(err)
   392  	}
   393  	defer st.Close()
   394  
   395  	for n := 0; n < b.N; n++ {
   396  		var n sql.NullString
   397  		var i int
   398  		var f float64
   399  		var s string
   400  		var t time.Time
   401  		r, err := st.Query()
   402  		if err != nil {
   403  			panic(err)
   404  		}
   405  		for r.Next() {
   406  			if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
   407  				panic(err)
   408  			}
   409  		}
   410  		if err = r.Err(); err != nil {
   411  			panic(err)
   412  		}
   413  	}
   414  }