github.com/varialus/godfly@v0.0.0-20130904042352-1934f9f095ab/src/pkg/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "database/sql/driver" 9 "errors" 10 "fmt" 11 "io" 12 "log" 13 "strconv" 14 "strings" 15 "sync" 16 "testing" 17 "time" 18 ) 19 20 var _ = log.Printf 21 22 // fakeDriver is a fake database that implements Go's driver.Driver 23 // interface, just for testing. 24 // 25 // It speaks a query language that's semantically similar to but 26 // syntantically different and simpler than SQL. The syntax is as 27 // follows: 28 // 29 // WIPE 30 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 31 // where types are: "string", [u]int{8,16,32,64}, "bool" 32 // INSERT|<tablename>|col=val,col2=val2,col3=? 33 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 34 // 35 // When opening a fakeDriver's database, it starts empty with no 36 // tables. All tables and data are stored in memory only. 37 type fakeDriver struct { 38 mu sync.Mutex // guards 3 following fields 39 openCount int // conn opens 40 closeCount int // conn closes 41 dbs map[string]*fakeDB 42 } 43 44 type fakeDB struct { 45 name string 46 47 mu sync.Mutex 48 free []*fakeConn 49 tables map[string]*table 50 badConn bool 51 } 52 53 type table struct { 54 mu sync.Mutex 55 colname []string 56 coltype []string 57 rows []*row 58 } 59 60 func (t *table) columnIndex(name string) int { 61 for n, nname := range t.colname { 62 if name == nname { 63 return n 64 } 65 } 66 return -1 67 } 68 69 type row struct { 70 cols []interface{} // must be same size as its table colname + coltype 71 } 72 73 func (r *row) clone() *row { 74 nrow := &row{cols: make([]interface{}, len(r.cols))} 75 copy(nrow.cols, r.cols) 76 return nrow 77 } 78 79 type fakeConn struct { 80 db *fakeDB // where to return ourselves to 81 82 currTx *fakeTx 83 84 // Stats for tests: 85 mu sync.Mutex 86 stmtsMade int 87 stmtsClosed int 88 numPrepare int 89 bad bool 90 } 91 92 func (c *fakeConn) incrStat(v *int) { 93 c.mu.Lock() 94 *v++ 95 c.mu.Unlock() 96 } 97 98 type fakeTx struct { 99 c *fakeConn 100 } 101 102 type fakeStmt struct { 103 c *fakeConn 104 q string // just for debugging 105 106 cmd string 107 table string 108 109 closed bool 110 111 colName []string // used by CREATE, INSERT, SELECT (selected columns) 112 colType []string // used by CREATE 113 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 114 placeholders int // used by INSERT/SELECT: number of ? params 115 116 whereCol []string // used by SELECT (all placeholders) 117 118 placeholderConverter []driver.ValueConverter // used by INSERT 119 } 120 121 var fdriver driver.Driver = &fakeDriver{} 122 123 func init() { 124 Register("test", fdriver) 125 } 126 127 // Supports dsn forms: 128 // <dbname> 129 // <dbname>;<opts> (only currently supported option is `badConn`, 130 // which causes driver.ErrBadConn to be returned on 131 // every other conn.Begin()) 132 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 133 parts := strings.Split(dsn, ";") 134 if len(parts) < 1 { 135 return nil, errors.New("fakedb: no database name") 136 } 137 name := parts[0] 138 139 db := d.getDB(name) 140 141 d.mu.Lock() 142 d.openCount++ 143 d.mu.Unlock() 144 conn := &fakeConn{db: db} 145 146 if len(parts) >= 2 && parts[1] == "badConn" { 147 conn.bad = true 148 } 149 return conn, nil 150 } 151 152 func (d *fakeDriver) getDB(name string) *fakeDB { 153 d.mu.Lock() 154 defer d.mu.Unlock() 155 if d.dbs == nil { 156 d.dbs = make(map[string]*fakeDB) 157 } 158 db, ok := d.dbs[name] 159 if !ok { 160 db = &fakeDB{name: name} 161 d.dbs[name] = db 162 } 163 return db 164 } 165 166 func (db *fakeDB) wipe() { 167 db.mu.Lock() 168 defer db.mu.Unlock() 169 db.tables = nil 170 } 171 172 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 173 db.mu.Lock() 174 defer db.mu.Unlock() 175 if db.tables == nil { 176 db.tables = make(map[string]*table) 177 } 178 if _, exist := db.tables[name]; exist { 179 return fmt.Errorf("table %q already exists", name) 180 } 181 if len(columnNames) != len(columnTypes) { 182 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", 183 name, len(columnNames), len(columnTypes)) 184 } 185 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 186 return nil 187 } 188 189 // must be called with db.mu lock held 190 func (db *fakeDB) table(table string) (*table, bool) { 191 if db.tables == nil { 192 return nil, false 193 } 194 t, ok := db.tables[table] 195 return t, ok 196 } 197 198 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 199 db.mu.Lock() 200 defer db.mu.Unlock() 201 t, ok := db.table(table) 202 if !ok { 203 return 204 } 205 for n, cname := range t.colname { 206 if cname == column { 207 return t.coltype[n], true 208 } 209 } 210 return "", false 211 } 212 213 func (c *fakeConn) isBad() bool { 214 // if not simulating bad conn, do nothing 215 if !c.bad { 216 return false 217 } 218 // alternate between bad conn and not bad conn 219 c.db.badConn = !c.db.badConn 220 return c.db.badConn 221 } 222 223 func (c *fakeConn) Begin() (driver.Tx, error) { 224 if c.isBad() { 225 return nil, driver.ErrBadConn 226 } 227 if c.currTx != nil { 228 return nil, errors.New("already in a transaction") 229 } 230 c.currTx = &fakeTx{c: c} 231 return c.currTx, nil 232 } 233 234 var hookPostCloseConn struct { 235 sync.Mutex 236 fn func(*fakeConn, error) 237 } 238 239 func setHookpostCloseConn(fn func(*fakeConn, error)) { 240 hookPostCloseConn.Lock() 241 defer hookPostCloseConn.Unlock() 242 hookPostCloseConn.fn = fn 243 } 244 245 var testStrictClose *testing.T 246 247 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 248 // fails to close. If nil, the check is disabled. 249 func setStrictFakeConnClose(t *testing.T) { 250 testStrictClose = t 251 } 252 253 func (c *fakeConn) Close() (err error) { 254 drv := fdriver.(*fakeDriver) 255 defer func() { 256 if err != nil && testStrictClose != nil { 257 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 258 } 259 hookPostCloseConn.Lock() 260 fn := hookPostCloseConn.fn 261 hookPostCloseConn.Unlock() 262 if fn != nil { 263 fn(c, err) 264 } 265 if err == nil { 266 drv.mu.Lock() 267 drv.closeCount++ 268 drv.mu.Unlock() 269 } 270 }() 271 if c.currTx != nil { 272 return errors.New("can't close fakeConn; in a Transaction") 273 } 274 if c.db == nil { 275 return errors.New("can't close fakeConn; already closed") 276 } 277 if c.stmtsMade > c.stmtsClosed { 278 return errors.New("can't close; dangling statement(s)") 279 } 280 c.db = nil 281 return nil 282 } 283 284 func checkSubsetTypes(args []driver.Value) error { 285 for n, arg := range args { 286 switch arg.(type) { 287 case int64, float64, bool, nil, []byte, string, time.Time: 288 default: 289 return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) 290 } 291 } 292 return nil 293 } 294 295 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 296 // This is an optional interface, but it's implemented here 297 // just to check that all the args are of the proper types. 298 // ErrSkip is returned so the caller acts as if we didn't 299 // implement this at all. 300 err := checkSubsetTypes(args) 301 if err != nil { 302 return nil, err 303 } 304 return nil, driver.ErrSkip 305 } 306 307 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 308 // This is an optional interface, but it's implemented here 309 // just to check that all the args are of the proper types. 310 // ErrSkip is returned so the caller acts as if we didn't 311 // implement this at all. 312 err := checkSubsetTypes(args) 313 if err != nil { 314 return nil, err 315 } 316 return nil, driver.ErrSkip 317 } 318 319 func errf(msg string, args ...interface{}) error { 320 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 321 } 322 323 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 324 // (note that where columns must always contain ? marks, 325 // just a limitation for fakedb) 326 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { 327 if len(parts) != 3 { 328 stmt.Close() 329 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 330 } 331 stmt.table = parts[0] 332 stmt.colName = strings.Split(parts[1], ",") 333 for n, colspec := range strings.Split(parts[2], ",") { 334 if colspec == "" { 335 continue 336 } 337 nameVal := strings.Split(colspec, "=") 338 if len(nameVal) != 2 { 339 stmt.Close() 340 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 341 } 342 column, value := nameVal[0], nameVal[1] 343 _, ok := c.db.columnType(stmt.table, column) 344 if !ok { 345 stmt.Close() 346 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 347 } 348 if value != "?" { 349 stmt.Close() 350 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 351 stmt.table, column) 352 } 353 stmt.whereCol = append(stmt.whereCol, column) 354 stmt.placeholders++ 355 } 356 return stmt, nil 357 } 358 359 // parts are table|col=type,col2=type2 360 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { 361 if len(parts) != 2 { 362 stmt.Close() 363 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 364 } 365 stmt.table = parts[0] 366 for n, colspec := range strings.Split(parts[1], ",") { 367 nameType := strings.Split(colspec, "=") 368 if len(nameType) != 2 { 369 stmt.Close() 370 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 371 } 372 stmt.colName = append(stmt.colName, nameType[0]) 373 stmt.colType = append(stmt.colType, nameType[1]) 374 } 375 return stmt, nil 376 } 377 378 // parts are table|col=?,col2=val 379 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { 380 if len(parts) != 2 { 381 stmt.Close() 382 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 383 } 384 stmt.table = parts[0] 385 for n, colspec := range strings.Split(parts[1], ",") { 386 nameVal := strings.Split(colspec, "=") 387 if len(nameVal) != 2 { 388 stmt.Close() 389 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 390 } 391 column, value := nameVal[0], nameVal[1] 392 ctype, ok := c.db.columnType(stmt.table, column) 393 if !ok { 394 stmt.Close() 395 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 396 } 397 stmt.colName = append(stmt.colName, column) 398 399 if value != "?" { 400 var subsetVal interface{} 401 // Convert to driver subset type 402 switch ctype { 403 case "string": 404 subsetVal = []byte(value) 405 case "blob": 406 subsetVal = []byte(value) 407 case "int32": 408 i, err := strconv.Atoi(value) 409 if err != nil { 410 stmt.Close() 411 return nil, errf("invalid conversion to int32 from %q", value) 412 } 413 subsetVal = int64(i) // int64 is a subset type, but not int32 414 default: 415 stmt.Close() 416 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 417 } 418 stmt.colValue = append(stmt.colValue, subsetVal) 419 } else { 420 stmt.placeholders++ 421 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 422 stmt.colValue = append(stmt.colValue, "?") 423 } 424 } 425 return stmt, nil 426 } 427 428 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 429 c.numPrepare++ 430 if c.db == nil { 431 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 432 } 433 parts := strings.Split(query, "|") 434 if len(parts) < 1 { 435 return nil, errf("empty query") 436 } 437 cmd := parts[0] 438 parts = parts[1:] 439 stmt := &fakeStmt{q: query, c: c, cmd: cmd} 440 c.incrStat(&c.stmtsMade) 441 switch cmd { 442 case "WIPE": 443 // Nothing 444 case "SELECT": 445 return c.prepareSelect(stmt, parts) 446 case "CREATE": 447 return c.prepareCreate(stmt, parts) 448 case "INSERT": 449 return c.prepareInsert(stmt, parts) 450 case "NOSERT": 451 // Do all the prep-work like for an INSERT but don't actually insert the row. 452 // Used for some of the concurrent tests. 453 return c.prepareInsert(stmt, parts) 454 default: 455 stmt.Close() 456 return nil, errf("unsupported command type %q", cmd) 457 } 458 return stmt, nil 459 } 460 461 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 462 if len(s.placeholderConverter) == 0 { 463 return driver.DefaultParameterConverter 464 } 465 return s.placeholderConverter[idx] 466 } 467 468 func (s *fakeStmt) Close() error { 469 if s.c == nil { 470 panic("nil conn in fakeStmt.Close") 471 } 472 if s.c.db == nil { 473 panic("in fakeStmt.Close, conn's db is nil (already closed)") 474 } 475 if !s.closed { 476 s.c.incrStat(&s.c.stmtsClosed) 477 s.closed = true 478 } 479 return nil 480 } 481 482 var errClosed = errors.New("fakedb: statement has been closed") 483 484 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 485 if s.closed { 486 return nil, errClosed 487 } 488 err := checkSubsetTypes(args) 489 if err != nil { 490 return nil, err 491 } 492 493 db := s.c.db 494 switch s.cmd { 495 case "WIPE": 496 db.wipe() 497 return driver.ResultNoRows, nil 498 case "CREATE": 499 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 500 return nil, err 501 } 502 return driver.ResultNoRows, nil 503 case "INSERT": 504 return s.execInsert(args, true) 505 case "NOSERT": 506 // Do all the prep-work like for an INSERT but don't actually insert the row. 507 // Used for some of the concurrent tests. 508 return s.execInsert(args, false) 509 } 510 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) 511 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) 512 } 513 514 // When doInsert is true, add the row to the table. 515 // When doInsert is false do prep-work and error checking, but don't 516 // actually add the row to the table. 517 func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) { 518 db := s.c.db 519 if len(args) != s.placeholders { 520 panic("error in pkg db; should only get here if size is correct") 521 } 522 db.mu.Lock() 523 t, ok := db.table(s.table) 524 db.mu.Unlock() 525 if !ok { 526 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 527 } 528 529 t.mu.Lock() 530 defer t.mu.Unlock() 531 532 var cols []interface{} 533 if doInsert { 534 cols = make([]interface{}, len(t.colname)) 535 } 536 argPos := 0 537 for n, colname := range s.colName { 538 colidx := t.columnIndex(colname) 539 if colidx == -1 { 540 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 541 } 542 var val interface{} 543 if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { 544 val = args[argPos] 545 argPos++ 546 } else { 547 val = s.colValue[n] 548 } 549 if doInsert { 550 cols[colidx] = val 551 } 552 } 553 554 if doInsert { 555 t.rows = append(t.rows, &row{cols: cols}) 556 } 557 return driver.RowsAffected(1), nil 558 } 559 560 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 561 if s.closed { 562 return nil, errClosed 563 } 564 err := checkSubsetTypes(args) 565 if err != nil { 566 return nil, err 567 } 568 569 db := s.c.db 570 if len(args) != s.placeholders { 571 panic("error in pkg db; should only get here if size is correct") 572 } 573 574 db.mu.Lock() 575 t, ok := db.table(s.table) 576 db.mu.Unlock() 577 if !ok { 578 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 579 } 580 581 if s.table == "magicquery" { 582 if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { 583 if args[0] == "sleep" { 584 time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) 585 } 586 } 587 } 588 589 t.mu.Lock() 590 defer t.mu.Unlock() 591 592 colIdx := make(map[string]int) // select column name -> column index in table 593 for _, name := range s.colName { 594 idx := t.columnIndex(name) 595 if idx == -1 { 596 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 597 } 598 colIdx[name] = idx 599 } 600 601 mrows := []*row{} 602 rows: 603 for _, trow := range t.rows { 604 // Process the where clause, skipping non-match rows. This is lazy 605 // and just uses fmt.Sprintf("%v") to test equality. Good enough 606 // for test code. 607 for widx, wcol := range s.whereCol { 608 idx := t.columnIndex(wcol) 609 if idx == -1 { 610 return nil, fmt.Errorf("db: invalid where clause column %q", wcol) 611 } 612 tcol := trow.cols[idx] 613 if bs, ok := tcol.([]byte); ok { 614 // lazy hack to avoid sprintf %v on a []byte 615 tcol = string(bs) 616 } 617 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { 618 continue rows 619 } 620 } 621 mrow := &row{cols: make([]interface{}, len(s.colName))} 622 for seli, name := range s.colName { 623 mrow.cols[seli] = trow.cols[colIdx[name]] 624 } 625 mrows = append(mrows, mrow) 626 } 627 628 cursor := &rowsCursor{ 629 pos: -1, 630 rows: mrows, 631 cols: s.colName, 632 errPos: -1, 633 } 634 return cursor, nil 635 } 636 637 func (s *fakeStmt) NumInput() int { 638 return s.placeholders 639 } 640 641 func (tx *fakeTx) Commit() error { 642 tx.c.currTx = nil 643 return nil 644 } 645 646 func (tx *fakeTx) Rollback() error { 647 tx.c.currTx = nil 648 return nil 649 } 650 651 type rowsCursor struct { 652 cols []string 653 pos int 654 rows []*row 655 closed bool 656 657 // errPos and err are for making Next return early with error. 658 errPos int 659 err error 660 661 // a clone of slices to give out to clients, indexed by the 662 // the original slice's first byte address. we clone them 663 // just so we're able to corrupt them on close. 664 bytesClone map[*byte][]byte 665 } 666 667 func (rc *rowsCursor) Close() error { 668 if !rc.closed { 669 for _, bs := range rc.bytesClone { 670 bs[0] = 255 // first byte corrupted 671 } 672 } 673 rc.closed = true 674 return nil 675 } 676 677 func (rc *rowsCursor) Columns() []string { 678 return rc.cols 679 } 680 681 func (rc *rowsCursor) Next(dest []driver.Value) error { 682 if rc.closed { 683 return errors.New("fakedb: cursor is closed") 684 } 685 rc.pos++ 686 if rc.pos == rc.errPos { 687 return rc.err 688 } 689 if rc.pos >= len(rc.rows) { 690 return io.EOF // per interface spec 691 } 692 for i, v := range rc.rows[rc.pos].cols { 693 // TODO(bradfitz): convert to subset types? naah, I 694 // think the subset types should only be input to 695 // driver, but the sql package should be able to handle 696 // a wider range of types coming out of drivers. all 697 // for ease of drivers, and to prevent drivers from 698 // messing up conversions or doing them differently. 699 dest[i] = v 700 701 if bs, ok := v.([]byte); ok { 702 if rc.bytesClone == nil { 703 rc.bytesClone = make(map[*byte][]byte) 704 } 705 clone, ok := rc.bytesClone[&bs[0]] 706 if !ok { 707 clone = make([]byte, len(bs)) 708 copy(clone, bs) 709 rc.bytesClone[&bs[0]] = clone 710 } 711 dest[i] = clone 712 } 713 } 714 return nil 715 } 716 717 // fakeDriverString is like driver.String, but indirects pointers like 718 // DefaultValueConverter. 719 // 720 // This could be surprising behavior to retroactively apply to 721 // driver.String now that Go1 is out, but this is convenient for 722 // our TestPointerParamsAndScans. 723 // 724 type fakeDriverString struct{} 725 726 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 727 switch c := v.(type) { 728 case string, []byte: 729 return v, nil 730 case *string: 731 if c == nil { 732 return nil, nil 733 } 734 return *c, nil 735 } 736 return fmt.Sprintf("%v", v), nil 737 } 738 739 func converterForType(typ string) driver.ValueConverter { 740 switch typ { 741 case "bool": 742 return driver.Bool 743 case "nullbool": 744 return driver.Null{Converter: driver.Bool} 745 case "int32": 746 return driver.Int32 747 case "string": 748 return driver.NotNull{Converter: fakeDriverString{}} 749 case "nullstring": 750 return driver.Null{Converter: fakeDriverString{}} 751 case "int64": 752 // TODO(coopernurse): add type-specific converter 753 return driver.NotNull{Converter: driver.DefaultParameterConverter} 754 case "nullint64": 755 // TODO(coopernurse): add type-specific converter 756 return driver.Null{Converter: driver.DefaultParameterConverter} 757 case "float64": 758 // TODO(coopernurse): add type-specific converter 759 return driver.NotNull{Converter: driver.DefaultParameterConverter} 760 case "nullfloat64": 761 // TODO(coopernurse): add type-specific converter 762 return driver.Null{Converter: driver.DefaultParameterConverter} 763 case "datetime": 764 return driver.DefaultParameterConverter 765 } 766 panic("invalid fakedb column type of " + typ) 767 }