github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/database/sqlite3/sqltest/sqltest.go (about) 1 // +build ingore 2 3 package sqltest 4 5 import ( 6 "database/sql" 7 "fmt" 8 "math/rand" 9 "regexp" 10 "strconv" 11 "sync" 12 "testing" 13 "time" 14 ) 15 16 type Dialect int 17 18 const ( 19 SQLITE Dialect = iota 20 POSTGRESQL 21 MYSQL 22 ) 23 24 type DB struct { 25 *testing.T 26 *sql.DB 27 dialect Dialect 28 once sync.Once 29 } 30 31 var db *DB 32 33 // the following tables will be created and dropped during the test 34 var testTables = []string{"foo", "bar", "t", "bench"} 35 36 var tests = []testing.InternalTest{ 37 {"TestBlobs", TestBlobs}, 38 {"TestManyQueryRow", TestManyQueryRow}, 39 {"TestTxQuery", TestTxQuery}, 40 {"TestPreparedStmt", TestPreparedStmt}, 41 } 42 43 var benchmarks = []testing.InternalBenchmark{ 44 {"BenchmarkExec", BenchmarkExec}, 45 {"BenchmarkQuery", BenchmarkQuery}, 46 {"BenchmarkParams", BenchmarkParams}, 47 {"BenchmarkStmt", BenchmarkStmt}, 48 {"BenchmarkRows", BenchmarkRows}, 49 {"BenchmarkStmtRows", BenchmarkStmtRows}, 50 } 51 52 // RunTests runs the SQL test suite 53 func RunTests(t *testing.T, d *sql.DB, dialect Dialect) { 54 db = &DB{t, d, dialect, sync.Once{}} 55 testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests) 56 57 if !testing.Short() { 58 for _, b := range benchmarks { 59 fmt.Printf("%-20s", b.Name) 60 r := testing.Benchmark(b.F) 61 fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds()) 62 } 63 } 64 db.tearDown() 65 } 66 67 func (db *DB) mustExec(sql string, args ...interface{}) sql.Result { 68 res, err := db.Exec(sql, args...) 69 if err != nil { 70 db.Fatalf("Error running %q: %v", sql, err) 71 } 72 return res 73 } 74 75 func (db *DB) tearDown() { 76 for _, tbl := range testTables { 77 switch db.dialect { 78 case SQLITE: 79 db.mustExec("drop table if exists " + tbl) 80 case MYSQL, POSTGRESQL: 81 db.mustExec("drop table if exists " + tbl) 82 default: 83 db.Fatal("unkown dialect") 84 } 85 } 86 } 87 88 // q replaces ? parameters if needed 89 func (db *DB) q(sql string) string { 90 switch db.dialect { 91 case POSTGRESQL: // repace with $1, $2, .. 92 qrx := regexp.MustCompile(`\?`) 93 n := 0 94 return qrx.ReplaceAllStringFunc(sql, func(string) string { 95 n++ 96 return "$" + strconv.Itoa(n) 97 }) 98 } 99 return sql 100 } 101 102 func (db *DB) blobType(size int) string { 103 switch db.dialect { 104 case SQLITE: 105 return fmt.Sprintf("blob[%d]", size) 106 case POSTGRESQL: 107 return "bytea" 108 case MYSQL: 109 return fmt.Sprintf("VARBINARY(%d)", size) 110 } 111 panic("unkown dialect") 112 } 113 114 func (db *DB) serialPK() string { 115 switch db.dialect { 116 case SQLITE: 117 return "integer primary key autoincrement" 118 case POSTGRESQL: 119 return "serial primary key" 120 case MYSQL: 121 return "integer primary key auto_increment" 122 } 123 panic("unkown dialect") 124 } 125 126 func (db *DB) now() string { 127 switch db.dialect { 128 case SQLITE: 129 return "datetime('now')" 130 case POSTGRESQL: 131 return "now()" 132 case MYSQL: 133 return "now()" 134 } 135 panic("unkown dialect") 136 } 137 138 func makeBench() { 139 if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil { 140 panic(err) 141 } 142 st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)") 143 if err != nil { 144 panic(err) 145 } 146 defer st.Close() 147 for i := 0; i < 100; i++ { 148 if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil { 149 panic(err) 150 } 151 } 152 } 153 154 func TestResult(t *testing.T) { 155 db.tearDown() 156 db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))") 157 158 for i := 1; i < 3; i++ { 159 r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i)) 160 n, err := r.RowsAffected() 161 if err != nil { 162 t.Fatal(err) 163 } 164 if n != 1 { 165 t.Errorf("got %v, want %v", n, 1) 166 } 167 n, err = r.LastInsertId() 168 if err != nil { 169 t.Fatal(err) 170 } 171 if n != int64(i) { 172 t.Errorf("got %v, want %v", n, i) 173 } 174 } 175 if _, err := db.Exec("error!"); err == nil { 176 t.Fatalf("expected error") 177 } 178 } 179 180 func TestBlobs(t *testing.T) { 181 db.tearDown() 182 var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} 183 db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") 184 db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob) 185 186 want := fmt.Sprintf("%x", blob) 187 188 b := make([]byte, 16) 189 err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b) 190 got := fmt.Sprintf("%x", b) 191 if err != nil { 192 t.Errorf("[]byte scan: %v", err) 193 } else if got != want { 194 t.Errorf("for []byte, got %q; want %q", got, want) 195 } 196 197 err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got) 198 want = string(blob) 199 if err != nil { 200 t.Errorf("string scan: %v", err) 201 } else if got != want { 202 t.Errorf("for string, got %q; want %q", got, want) 203 } 204 } 205 206 func TestManyQueryRow(t *testing.T) { 207 if testing.Short() { 208 t.Log("skipping in short mode") 209 return 210 } 211 db.tearDown() 212 db.mustExec("create table foo (id integer primary key, name varchar(50))") 213 db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") 214 var name string 215 for i := 0; i < 10000; i++ { 216 err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name) 217 if err != nil || name != "bob" { 218 t.Fatalf("on query %d: err=%v, name=%q", i, err, name) 219 } 220 } 221 } 222 223 func TestTxQuery(t *testing.T) { 224 db.tearDown() 225 tx, err := db.Begin() 226 if err != nil { 227 t.Fatal(err) 228 } 229 defer tx.Rollback() 230 231 _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))") 232 if err != nil { 233 t.Fatal(err) 234 } 235 236 _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") 237 if err != nil { 238 t.Fatal(err) 239 } 240 241 r, err := tx.Query(db.q("select name from foo where id = ?"), 1) 242 if err != nil { 243 t.Fatal(err) 244 } 245 defer r.Close() 246 247 if !r.Next() { 248 if r.Err() != nil { 249 t.Fatal(err) 250 } 251 t.Fatal("expected one rows") 252 } 253 254 var name string 255 err = r.Scan(&name) 256 if err != nil { 257 t.Fatal(err) 258 } 259 } 260 261 func TestPreparedStmt(t *testing.T) { 262 db.tearDown() 263 db.mustExec("CREATE TABLE t (count INT)") 264 sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC") 265 if err != nil { 266 t.Fatalf("prepare 1: %v", err) 267 } 268 ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)")) 269 if err != nil { 270 t.Fatalf("prepare 2: %v", err) 271 } 272 273 for n := 1; n <= 3; n++ { 274 if _, err := ins.Exec(n); err != nil { 275 t.Fatalf("insert(%d) = %v", n, err) 276 } 277 } 278 279 const nRuns = 10 280 ch := make(chan bool) 281 for i := 0; i < nRuns; i++ { 282 go func() { 283 defer func() { 284 ch <- true 285 }() 286 for j := 0; j < 10; j++ { 287 count := 0 288 if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { 289 t.Errorf("Query: %v", err) 290 return 291 } 292 if _, err := ins.Exec(rand.Intn(100)); err != nil { 293 t.Errorf("Insert: %v", err) 294 return 295 } 296 } 297 }() 298 } 299 for i := 0; i < nRuns; i++ { 300 <-ch 301 } 302 } 303 304 // Benchmarks need to use panic() since b.Error errors are lost when 305 // running via testing.Benchmark() I would like to run these via go 306 // test -bench but calling Benchmark() from a benchmark test 307 // currently hangs go. 308 309 func BenchmarkExec(b *testing.B) { 310 for i := 0; i < b.N; i++ { 311 if _, err := db.Exec("select 1"); err != nil { 312 panic(err) 313 } 314 } 315 } 316 317 func BenchmarkQuery(b *testing.B) { 318 for i := 0; i < b.N; i++ { 319 var n sql.NullString 320 var i int 321 var f float64 322 var s string 323 // var t time.Time 324 if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { 325 panic(err) 326 } 327 } 328 } 329 330 func BenchmarkParams(b *testing.B) { 331 for i := 0; i < b.N; i++ { 332 var n sql.NullString 333 var i int 334 var f float64 335 var s string 336 // var t time.Time 337 if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { 338 panic(err) 339 } 340 } 341 } 342 343 func BenchmarkStmt(b *testing.B) { 344 st, err := db.Prepare("select ?, ?, ?, ?") 345 if err != nil { 346 panic(err) 347 } 348 defer st.Close() 349 350 for n := 0; n < b.N; n++ { 351 var n sql.NullString 352 var i int 353 var f float64 354 var s string 355 // var t time.Time 356 if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { 357 panic(err) 358 } 359 } 360 } 361 362 func BenchmarkRows(b *testing.B) { 363 db.once.Do(makeBench) 364 365 for n := 0; n < b.N; n++ { 366 var n sql.NullString 367 var i int 368 var f float64 369 var s string 370 var t time.Time 371 r, err := db.Query("select * from bench") 372 if err != nil { 373 panic(err) 374 } 375 for r.Next() { 376 if err = r.Scan(&n, &i, &f, &s, &t); err != nil { 377 panic(err) 378 } 379 } 380 if err = r.Err(); err != nil { 381 panic(err) 382 } 383 } 384 } 385 386 func BenchmarkStmtRows(b *testing.B) { 387 db.once.Do(makeBench) 388 389 st, err := db.Prepare("select * from bench") 390 if err != nil { 391 panic(err) 392 } 393 defer st.Close() 394 395 for n := 0; n < b.N; n++ { 396 var n sql.NullString 397 var i int 398 var f float64 399 var s string 400 var t time.Time 401 r, err := st.Query() 402 if err != nil { 403 panic(err) 404 } 405 for r.Next() { 406 if err = r.Scan(&n, &i, &f, &s, &t); err != nil { 407 panic(err) 408 } 409 } 410 if err = r.Err(); err != nil { 411 panic(err) 412 } 413 } 414 }