github.com/goproxy0/go@v0.0.0-20171111080102-49cc0c489d2c/src/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "log" 14 "reflect" 15 "sort" 16 "strconv" 17 "strings" 18 "sync" 19 "testing" 20 "time" 21 ) 22 23 var _ = log.Printf 24 25 // fakeDriver is a fake database that implements Go's driver.Driver 26 // interface, just for testing. 27 // 28 // It speaks a query language that's semantically similar to but 29 // syntactically different and simpler than SQL. The syntax is as 30 // follows: 31 // 32 // WIPE 33 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 34 // where types are: "string", [u]int{8,16,32,64}, "bool" 35 // INSERT|<tablename>|col=val,col2=val2,col3=? 36 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 37 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 38 // 39 // Any of these can be preceded by PANIC|<method>|, to cause the 40 // named method on fakeStmt to panic. 41 // 42 // Any of these can be proceeded by WAIT|<duration>|, to cause the 43 // named method on fakeStmt to sleep for the specified duration. 44 // 45 // Multiple of these can be combined when separated with a semicolon. 46 // 47 // When opening a fakeDriver's database, it starts empty with no 48 // tables. All tables and data are stored in memory only. 49 type fakeDriver struct { 50 mu sync.Mutex // guards 3 following fields 51 openCount int // conn opens 52 closeCount int // conn closes 53 waitCh chan struct{} 54 waitingCh chan struct{} 55 dbs map[string]*fakeDB 56 } 57 58 type fakeConnector struct { 59 name string 60 61 waiter func(context.Context) 62 } 63 64 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 65 conn, err := fdriver.Open(c.name) 66 conn.(*fakeConn).waiter = c.waiter 67 return conn, err 68 } 69 70 func (c *fakeConnector) Driver() driver.Driver { 71 return fdriver 72 } 73 74 type fakeDB struct { 75 name string 76 77 mu sync.Mutex 78 tables map[string]*table 79 badConn bool 80 allowAny bool 81 } 82 83 type table struct { 84 mu sync.Mutex 85 colname []string 86 coltype []string 87 rows []*row 88 } 89 90 func (t *table) columnIndex(name string) int { 91 for n, nname := range t.colname { 92 if name == nname { 93 return n 94 } 95 } 96 return -1 97 } 98 99 type row struct { 100 cols []interface{} // must be same size as its table colname + coltype 101 } 102 103 type memToucher interface { 104 // touchMem reads & writes some memory, to help find data races. 105 touchMem() 106 } 107 108 type fakeConn struct { 109 db *fakeDB // where to return ourselves to 110 111 currTx *fakeTx 112 113 // Every operation writes to line to enable the race detector 114 // check for data races. 115 line int64 116 117 // Stats for tests: 118 mu sync.Mutex 119 stmtsMade int 120 stmtsClosed int 121 numPrepare int 122 123 // bad connection tests; see isBad() 124 bad bool 125 stickyBad bool 126 127 skipDirtySession bool // tests that use Conn should set this to true. 128 129 // dirtySession tests ResetSession, true if a query has executed 130 // until ResetSession is called. 131 dirtySession bool 132 133 // The waiter is called before each query. May be used in place of the "WAIT" 134 // directive. 135 waiter func(context.Context) 136 } 137 138 func (c *fakeConn) touchMem() { 139 c.line++ 140 } 141 142 func (c *fakeConn) incrStat(v *int) { 143 c.mu.Lock() 144 *v++ 145 c.mu.Unlock() 146 } 147 148 type fakeTx struct { 149 c *fakeConn 150 } 151 152 type boundCol struct { 153 Column string 154 Placeholder string 155 Ordinal int 156 } 157 158 type fakeStmt struct { 159 memToucher 160 c *fakeConn 161 q string // just for debugging 162 163 cmd string 164 table string 165 panic string 166 wait time.Duration 167 168 next *fakeStmt // used for returning multiple results. 169 170 closed bool 171 172 colName []string // used by CREATE, INSERT, SELECT (selected columns) 173 colType []string // used by CREATE 174 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 175 placeholders int // used by INSERT/SELECT: number of ? params 176 177 whereCol []boundCol // used by SELECT (all placeholders) 178 179 placeholderConverter []driver.ValueConverter // used by INSERT 180 } 181 182 var fdriver driver.Driver = &fakeDriver{} 183 184 func init() { 185 Register("test", fdriver) 186 } 187 188 func contains(list []string, y string) bool { 189 for _, x := range list { 190 if x == y { 191 return true 192 } 193 } 194 return false 195 } 196 197 type Dummy struct { 198 driver.Driver 199 } 200 201 func TestDrivers(t *testing.T) { 202 unregisterAllDrivers() 203 Register("test", fdriver) 204 Register("invalid", Dummy{}) 205 all := Drivers() 206 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 207 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 208 } 209 } 210 211 // hook to simulate connection failures 212 var hookOpenErr struct { 213 sync.Mutex 214 fn func() error 215 } 216 217 func setHookOpenErr(fn func() error) { 218 hookOpenErr.Lock() 219 defer hookOpenErr.Unlock() 220 hookOpenErr.fn = fn 221 } 222 223 // Supports dsn forms: 224 // <dbname> 225 // <dbname>;<opts> (only currently supported option is `badConn`, 226 // which causes driver.ErrBadConn to be returned on 227 // every other conn.Begin()) 228 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 229 hookOpenErr.Lock() 230 fn := hookOpenErr.fn 231 hookOpenErr.Unlock() 232 if fn != nil { 233 if err := fn(); err != nil { 234 return nil, err 235 } 236 } 237 parts := strings.Split(dsn, ";") 238 if len(parts) < 1 { 239 return nil, errors.New("fakedb: no database name") 240 } 241 name := parts[0] 242 243 db := d.getDB(name) 244 245 d.mu.Lock() 246 d.openCount++ 247 d.mu.Unlock() 248 conn := &fakeConn{db: db} 249 250 if len(parts) >= 2 && parts[1] == "badConn" { 251 conn.bad = true 252 } 253 if d.waitCh != nil { 254 d.waitingCh <- struct{}{} 255 <-d.waitCh 256 d.waitCh = nil 257 d.waitingCh = nil 258 } 259 return conn, nil 260 } 261 262 func (d *fakeDriver) getDB(name string) *fakeDB { 263 d.mu.Lock() 264 defer d.mu.Unlock() 265 if d.dbs == nil { 266 d.dbs = make(map[string]*fakeDB) 267 } 268 db, ok := d.dbs[name] 269 if !ok { 270 db = &fakeDB{name: name} 271 d.dbs[name] = db 272 } 273 return db 274 } 275 276 func (db *fakeDB) wipe() { 277 db.mu.Lock() 278 defer db.mu.Unlock() 279 db.tables = nil 280 } 281 282 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 283 db.mu.Lock() 284 defer db.mu.Unlock() 285 if db.tables == nil { 286 db.tables = make(map[string]*table) 287 } 288 if _, exist := db.tables[name]; exist { 289 return fmt.Errorf("table %q already exists", name) 290 } 291 if len(columnNames) != len(columnTypes) { 292 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", 293 name, len(columnNames), len(columnTypes)) 294 } 295 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 296 return nil 297 } 298 299 // must be called with db.mu lock held 300 func (db *fakeDB) table(table string) (*table, bool) { 301 if db.tables == nil { 302 return nil, false 303 } 304 t, ok := db.tables[table] 305 return t, ok 306 } 307 308 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 309 db.mu.Lock() 310 defer db.mu.Unlock() 311 t, ok := db.table(table) 312 if !ok { 313 return 314 } 315 for n, cname := range t.colname { 316 if cname == column { 317 return t.coltype[n], true 318 } 319 } 320 return "", false 321 } 322 323 func (c *fakeConn) isBad() bool { 324 if c.stickyBad { 325 return true 326 } else if c.bad { 327 if c.db == nil { 328 return false 329 } 330 // alternate between bad conn and not bad conn 331 c.db.badConn = !c.db.badConn 332 return c.db.badConn 333 } else { 334 return false 335 } 336 } 337 338 func (c *fakeConn) isDirtyAndMark() bool { 339 if c.skipDirtySession { 340 return false 341 } 342 if c.currTx != nil { 343 c.dirtySession = true 344 return false 345 } 346 if c.dirtySession { 347 return true 348 } 349 c.dirtySession = true 350 return false 351 } 352 353 func (c *fakeConn) Begin() (driver.Tx, error) { 354 if c.isBad() { 355 return nil, driver.ErrBadConn 356 } 357 if c.currTx != nil { 358 return nil, errors.New("already in a transaction") 359 } 360 c.touchMem() 361 c.currTx = &fakeTx{c: c} 362 return c.currTx, nil 363 } 364 365 var hookPostCloseConn struct { 366 sync.Mutex 367 fn func(*fakeConn, error) 368 } 369 370 func setHookpostCloseConn(fn func(*fakeConn, error)) { 371 hookPostCloseConn.Lock() 372 defer hookPostCloseConn.Unlock() 373 hookPostCloseConn.fn = fn 374 } 375 376 var testStrictClose *testing.T 377 378 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 379 // fails to close. If nil, the check is disabled. 380 func setStrictFakeConnClose(t *testing.T) { 381 testStrictClose = t 382 } 383 384 func (c *fakeConn) ResetSession(ctx context.Context) error { 385 c.dirtySession = false 386 if c.isBad() { 387 return driver.ErrBadConn 388 } 389 return nil 390 } 391 392 func (c *fakeConn) Close() (err error) { 393 drv := fdriver.(*fakeDriver) 394 defer func() { 395 if err != nil && testStrictClose != nil { 396 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 397 } 398 hookPostCloseConn.Lock() 399 fn := hookPostCloseConn.fn 400 hookPostCloseConn.Unlock() 401 if fn != nil { 402 fn(c, err) 403 } 404 if err == nil { 405 drv.mu.Lock() 406 drv.closeCount++ 407 drv.mu.Unlock() 408 } 409 }() 410 c.touchMem() 411 if c.currTx != nil { 412 return errors.New("can't close fakeConn; in a Transaction") 413 } 414 if c.db == nil { 415 return errors.New("can't close fakeConn; already closed") 416 } 417 if c.stmtsMade > c.stmtsClosed { 418 return errors.New("can't close; dangling statement(s)") 419 } 420 c.db = nil 421 return nil 422 } 423 424 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 425 for _, arg := range args { 426 switch arg.Value.(type) { 427 case int64, float64, bool, nil, []byte, string, time.Time: 428 default: 429 if !allowAny { 430 return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 431 } 432 } 433 } 434 return nil 435 } 436 437 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 438 // Ensure that ExecContext is called if available. 439 panic("ExecContext was not called.") 440 } 441 442 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 443 // This is an optional interface, but it's implemented here 444 // just to check that all the args are of the proper types. 445 // ErrSkip is returned so the caller acts as if we didn't 446 // implement this at all. 447 err := checkSubsetTypes(c.db.allowAny, args) 448 if err != nil { 449 return nil, err 450 } 451 return nil, driver.ErrSkip 452 } 453 454 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 455 // Ensure that ExecContext is called if available. 456 panic("QueryContext was not called.") 457 } 458 459 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 460 // This is an optional interface, but it's implemented here 461 // just to check that all the args are of the proper types. 462 // ErrSkip is returned so the caller acts as if we didn't 463 // implement this at all. 464 err := checkSubsetTypes(c.db.allowAny, args) 465 if err != nil { 466 return nil, err 467 } 468 return nil, driver.ErrSkip 469 } 470 471 func errf(msg string, args ...interface{}) error { 472 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 473 } 474 475 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 476 // (note that where columns must always contain ? marks, 477 // just a limitation for fakedb) 478 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 479 if len(parts) != 3 { 480 stmt.Close() 481 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 482 } 483 stmt.table = parts[0] 484 485 stmt.colName = strings.Split(parts[1], ",") 486 for n, colspec := range strings.Split(parts[2], ",") { 487 if colspec == "" { 488 continue 489 } 490 nameVal := strings.Split(colspec, "=") 491 if len(nameVal) != 2 { 492 stmt.Close() 493 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 494 } 495 column, value := nameVal[0], nameVal[1] 496 _, ok := c.db.columnType(stmt.table, column) 497 if !ok { 498 stmt.Close() 499 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 500 } 501 if !strings.HasPrefix(value, "?") { 502 stmt.Close() 503 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 504 stmt.table, column) 505 } 506 stmt.placeholders++ 507 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 508 } 509 return stmt, nil 510 } 511 512 // parts are table|col=type,col2=type2 513 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 514 if len(parts) != 2 { 515 stmt.Close() 516 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 517 } 518 stmt.table = parts[0] 519 for n, colspec := range strings.Split(parts[1], ",") { 520 nameType := strings.Split(colspec, "=") 521 if len(nameType) != 2 { 522 stmt.Close() 523 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 524 } 525 stmt.colName = append(stmt.colName, nameType[0]) 526 stmt.colType = append(stmt.colType, nameType[1]) 527 } 528 return stmt, nil 529 } 530 531 // parts are table|col=?,col2=val 532 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 533 if len(parts) != 2 { 534 stmt.Close() 535 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 536 } 537 stmt.table = parts[0] 538 for n, colspec := range strings.Split(parts[1], ",") { 539 nameVal := strings.Split(colspec, "=") 540 if len(nameVal) != 2 { 541 stmt.Close() 542 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 543 } 544 column, value := nameVal[0], nameVal[1] 545 ctype, ok := c.db.columnType(stmt.table, column) 546 if !ok { 547 stmt.Close() 548 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 549 } 550 stmt.colName = append(stmt.colName, column) 551 552 if !strings.HasPrefix(value, "?") { 553 var subsetVal interface{} 554 // Convert to driver subset type 555 switch ctype { 556 case "string": 557 subsetVal = []byte(value) 558 case "blob": 559 subsetVal = []byte(value) 560 case "int32": 561 i, err := strconv.Atoi(value) 562 if err != nil { 563 stmt.Close() 564 return nil, errf("invalid conversion to int32 from %q", value) 565 } 566 subsetVal = int64(i) // int64 is a subset type, but not int32 567 default: 568 stmt.Close() 569 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 570 } 571 stmt.colValue = append(stmt.colValue, subsetVal) 572 } else { 573 stmt.placeholders++ 574 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 575 stmt.colValue = append(stmt.colValue, value) 576 } 577 } 578 return stmt, nil 579 } 580 581 // hook to simulate broken connections 582 var hookPrepareBadConn func() bool 583 584 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 585 panic("use PrepareContext") 586 } 587 588 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 589 c.numPrepare++ 590 if c.db == nil { 591 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 592 } 593 594 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 595 return nil, driver.ErrBadConn 596 } 597 598 c.touchMem() 599 var firstStmt, prev *fakeStmt 600 for _, query := range strings.Split(query, ";") { 601 parts := strings.Split(query, "|") 602 if len(parts) < 1 { 603 return nil, errf("empty query") 604 } 605 stmt := &fakeStmt{q: query, c: c, memToucher: c} 606 if firstStmt == nil { 607 firstStmt = stmt 608 } 609 if len(parts) >= 3 { 610 switch parts[0] { 611 case "PANIC": 612 stmt.panic = parts[1] 613 parts = parts[2:] 614 case "WAIT": 615 wait, err := time.ParseDuration(parts[1]) 616 if err != nil { 617 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 618 } 619 parts = parts[2:] 620 stmt.wait = wait 621 } 622 } 623 cmd := parts[0] 624 stmt.cmd = cmd 625 parts = parts[1:] 626 627 if c.waiter != nil { 628 c.waiter(ctx) 629 } 630 631 if stmt.wait > 0 { 632 wait := time.NewTimer(stmt.wait) 633 select { 634 case <-wait.C: 635 case <-ctx.Done(): 636 wait.Stop() 637 return nil, ctx.Err() 638 } 639 } 640 641 c.incrStat(&c.stmtsMade) 642 var err error 643 switch cmd { 644 case "WIPE": 645 // Nothing 646 case "SELECT": 647 stmt, err = c.prepareSelect(stmt, parts) 648 case "CREATE": 649 stmt, err = c.prepareCreate(stmt, parts) 650 case "INSERT": 651 stmt, err = c.prepareInsert(stmt, parts) 652 case "NOSERT": 653 // Do all the prep-work like for an INSERT but don't actually insert the row. 654 // Used for some of the concurrent tests. 655 stmt, err = c.prepareInsert(stmt, parts) 656 default: 657 stmt.Close() 658 return nil, errf("unsupported command type %q", cmd) 659 } 660 if err != nil { 661 return nil, err 662 } 663 if prev != nil { 664 prev.next = stmt 665 } 666 prev = stmt 667 } 668 return firstStmt, nil 669 } 670 671 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 672 if s.panic == "ColumnConverter" { 673 panic(s.panic) 674 } 675 if len(s.placeholderConverter) == 0 { 676 return driver.DefaultParameterConverter 677 } 678 return s.placeholderConverter[idx] 679 } 680 681 func (s *fakeStmt) Close() error { 682 if s.panic == "Close" { 683 panic(s.panic) 684 } 685 if s.c == nil { 686 panic("nil conn in fakeStmt.Close") 687 } 688 if s.c.db == nil { 689 panic("in fakeStmt.Close, conn's db is nil (already closed)") 690 } 691 s.touchMem() 692 if !s.closed { 693 s.c.incrStat(&s.c.stmtsClosed) 694 s.closed = true 695 } 696 if s.next != nil { 697 s.next.Close() 698 } 699 return nil 700 } 701 702 var errClosed = errors.New("fakedb: statement has been closed") 703 704 // hook to simulate broken connections 705 var hookExecBadConn func() bool 706 707 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 708 panic("Using ExecContext") 709 } 710 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 711 if s.panic == "Exec" { 712 panic(s.panic) 713 } 714 if s.closed { 715 return nil, errClosed 716 } 717 718 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 719 return nil, driver.ErrBadConn 720 } 721 if s.c.isDirtyAndMark() { 722 return nil, errors.New("session is dirty") 723 } 724 725 err := checkSubsetTypes(s.c.db.allowAny, args) 726 if err != nil { 727 return nil, err 728 } 729 s.touchMem() 730 731 if s.wait > 0 { 732 time.Sleep(s.wait) 733 } 734 735 select { 736 default: 737 case <-ctx.Done(): 738 return nil, ctx.Err() 739 } 740 741 db := s.c.db 742 switch s.cmd { 743 case "WIPE": 744 db.wipe() 745 return driver.ResultNoRows, nil 746 case "CREATE": 747 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 748 return nil, err 749 } 750 return driver.ResultNoRows, nil 751 case "INSERT": 752 return s.execInsert(args, true) 753 case "NOSERT": 754 // Do all the prep-work like for an INSERT but don't actually insert the row. 755 // Used for some of the concurrent tests. 756 return s.execInsert(args, false) 757 } 758 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) 759 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) 760 } 761 762 // When doInsert is true, add the row to the table. 763 // When doInsert is false do prep-work and error checking, but don't 764 // actually add the row to the table. 765 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 766 db := s.c.db 767 if len(args) != s.placeholders { 768 panic("error in pkg db; should only get here if size is correct") 769 } 770 db.mu.Lock() 771 t, ok := db.table(s.table) 772 db.mu.Unlock() 773 if !ok { 774 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 775 } 776 777 t.mu.Lock() 778 defer t.mu.Unlock() 779 780 var cols []interface{} 781 if doInsert { 782 cols = make([]interface{}, len(t.colname)) 783 } 784 argPos := 0 785 for n, colname := range s.colName { 786 colidx := t.columnIndex(colname) 787 if colidx == -1 { 788 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 789 } 790 var val interface{} 791 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 792 if strvalue == "?" { 793 val = args[argPos].Value 794 } else { 795 // Assign value from argument placeholder name. 796 for _, a := range args { 797 if a.Name == strvalue[1:] { 798 val = a.Value 799 break 800 } 801 } 802 } 803 argPos++ 804 } else { 805 val = s.colValue[n] 806 } 807 if doInsert { 808 cols[colidx] = val 809 } 810 } 811 812 if doInsert { 813 t.rows = append(t.rows, &row{cols: cols}) 814 } 815 return driver.RowsAffected(1), nil 816 } 817 818 // hook to simulate broken connections 819 var hookQueryBadConn func() bool 820 821 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 822 panic("Use QueryContext") 823 } 824 825 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 826 if s.panic == "Query" { 827 panic(s.panic) 828 } 829 if s.closed { 830 return nil, errClosed 831 } 832 833 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 834 return nil, driver.ErrBadConn 835 } 836 if s.c.isDirtyAndMark() { 837 return nil, errors.New("session is dirty") 838 } 839 840 err := checkSubsetTypes(s.c.db.allowAny, args) 841 if err != nil { 842 return nil, err 843 } 844 845 s.touchMem() 846 db := s.c.db 847 if len(args) != s.placeholders { 848 panic("error in pkg db; should only get here if size is correct") 849 } 850 851 setMRows := make([][]*row, 0, 1) 852 setColumns := make([][]string, 0, 1) 853 setColType := make([][]string, 0, 1) 854 855 for { 856 db.mu.Lock() 857 t, ok := db.table(s.table) 858 db.mu.Unlock() 859 if !ok { 860 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 861 } 862 863 if s.table == "magicquery" { 864 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 865 if args[0].Value == "sleep" { 866 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 867 } 868 } 869 } 870 871 t.mu.Lock() 872 873 colIdx := make(map[string]int) // select column name -> column index in table 874 for _, name := range s.colName { 875 idx := t.columnIndex(name) 876 if idx == -1 { 877 t.mu.Unlock() 878 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 879 } 880 colIdx[name] = idx 881 } 882 883 mrows := []*row{} 884 rows: 885 for _, trow := range t.rows { 886 // Process the where clause, skipping non-match rows. This is lazy 887 // and just uses fmt.Sprintf("%v") to test equality. Good enough 888 // for test code. 889 for _, wcol := range s.whereCol { 890 idx := t.columnIndex(wcol.Column) 891 if idx == -1 { 892 t.mu.Unlock() 893 return nil, fmt.Errorf("db: invalid where clause column %q", wcol) 894 } 895 tcol := trow.cols[idx] 896 if bs, ok := tcol.([]byte); ok { 897 // lazy hack to avoid sprintf %v on a []byte 898 tcol = string(bs) 899 } 900 var argValue interface{} 901 if wcol.Placeholder == "?" { 902 argValue = args[wcol.Ordinal-1].Value 903 } else { 904 // Assign arg value from placeholder name. 905 for _, a := range args { 906 if a.Name == wcol.Placeholder[1:] { 907 argValue = a.Value 908 break 909 } 910 } 911 } 912 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 913 continue rows 914 } 915 } 916 mrow := &row{cols: make([]interface{}, len(s.colName))} 917 for seli, name := range s.colName { 918 mrow.cols[seli] = trow.cols[colIdx[name]] 919 } 920 mrows = append(mrows, mrow) 921 } 922 923 var colType []string 924 for _, column := range s.colName { 925 colType = append(colType, t.coltype[t.columnIndex(column)]) 926 } 927 928 t.mu.Unlock() 929 930 setMRows = append(setMRows, mrows) 931 setColumns = append(setColumns, s.colName) 932 setColType = append(setColType, colType) 933 934 if s.next == nil { 935 break 936 } 937 s = s.next 938 } 939 940 cursor := &rowsCursor{ 941 parentMem: s.c, 942 posRow: -1, 943 rows: setMRows, 944 cols: setColumns, 945 colType: setColType, 946 errPos: -1, 947 } 948 return cursor, nil 949 } 950 951 func (s *fakeStmt) NumInput() int { 952 if s.panic == "NumInput" { 953 panic(s.panic) 954 } 955 return s.placeholders 956 } 957 958 // hook to simulate broken connections 959 var hookCommitBadConn func() bool 960 961 func (tx *fakeTx) Commit() error { 962 tx.c.currTx = nil 963 if hookCommitBadConn != nil && hookCommitBadConn() { 964 return driver.ErrBadConn 965 } 966 tx.c.touchMem() 967 return nil 968 } 969 970 // hook to simulate broken connections 971 var hookRollbackBadConn func() bool 972 973 func (tx *fakeTx) Rollback() error { 974 tx.c.currTx = nil 975 if hookRollbackBadConn != nil && hookRollbackBadConn() { 976 return driver.ErrBadConn 977 } 978 tx.c.touchMem() 979 return nil 980 } 981 982 type rowsCursor struct { 983 parentMem memToucher 984 cols [][]string 985 colType [][]string 986 posSet int 987 posRow int 988 rows [][]*row 989 closed bool 990 991 // errPos and err are for making Next return early with error. 992 errPos int 993 err error 994 995 // a clone of slices to give out to clients, indexed by the 996 // the original slice's first byte address. we clone them 997 // just so we're able to corrupt them on close. 998 bytesClone map[*byte][]byte 999 1000 // Every operation writes to line to enable the race detector 1001 // check for data races. 1002 // This is separate from the fakeConn.line to allow for drivers that 1003 // can start multiple queries on the same transaction at the same time. 1004 line int64 1005 } 1006 1007 func (rc *rowsCursor) touchMem() { 1008 rc.parentMem.touchMem() 1009 rc.line++ 1010 } 1011 1012 func (rc *rowsCursor) Close() error { 1013 if !rc.closed { 1014 for _, bs := range rc.bytesClone { 1015 bs[0] = 255 // first byte corrupted 1016 } 1017 } 1018 rc.touchMem() 1019 rc.parentMem.touchMem() 1020 rc.closed = true 1021 return nil 1022 } 1023 1024 func (rc *rowsCursor) Columns() []string { 1025 return rc.cols[rc.posSet] 1026 } 1027 1028 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1029 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1030 } 1031 1032 var rowsCursorNextHook func(dest []driver.Value) error 1033 1034 func (rc *rowsCursor) Next(dest []driver.Value) error { 1035 if rowsCursorNextHook != nil { 1036 return rowsCursorNextHook(dest) 1037 } 1038 1039 if rc.closed { 1040 return errors.New("fakedb: cursor is closed") 1041 } 1042 rc.touchMem() 1043 rc.posRow++ 1044 if rc.posRow == rc.errPos { 1045 return rc.err 1046 } 1047 if rc.posRow >= len(rc.rows[rc.posSet]) { 1048 return io.EOF // per interface spec 1049 } 1050 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1051 // TODO(bradfitz): convert to subset types? naah, I 1052 // think the subset types should only be input to 1053 // driver, but the sql package should be able to handle 1054 // a wider range of types coming out of drivers. all 1055 // for ease of drivers, and to prevent drivers from 1056 // messing up conversions or doing them differently. 1057 dest[i] = v 1058 1059 if bs, ok := v.([]byte); ok { 1060 if rc.bytesClone == nil { 1061 rc.bytesClone = make(map[*byte][]byte) 1062 } 1063 clone, ok := rc.bytesClone[&bs[0]] 1064 if !ok { 1065 clone = make([]byte, len(bs)) 1066 copy(clone, bs) 1067 rc.bytesClone[&bs[0]] = clone 1068 } 1069 dest[i] = clone 1070 } 1071 } 1072 return nil 1073 } 1074 1075 func (rc *rowsCursor) HasNextResultSet() bool { 1076 rc.touchMem() 1077 return rc.posSet < len(rc.rows)-1 1078 } 1079 1080 func (rc *rowsCursor) NextResultSet() error { 1081 rc.touchMem() 1082 if rc.HasNextResultSet() { 1083 rc.posSet++ 1084 rc.posRow = -1 1085 return nil 1086 } 1087 return io.EOF // Per interface spec. 1088 } 1089 1090 // fakeDriverString is like driver.String, but indirects pointers like 1091 // DefaultValueConverter. 1092 // 1093 // This could be surprising behavior to retroactively apply to 1094 // driver.String now that Go1 is out, but this is convenient for 1095 // our TestPointerParamsAndScans. 1096 // 1097 type fakeDriverString struct{} 1098 1099 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 1100 switch c := v.(type) { 1101 case string, []byte: 1102 return v, nil 1103 case *string: 1104 if c == nil { 1105 return nil, nil 1106 } 1107 return *c, nil 1108 } 1109 return fmt.Sprintf("%v", v), nil 1110 } 1111 1112 type anyTypeConverter struct{} 1113 1114 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { 1115 return v, nil 1116 } 1117 1118 func converterForType(typ string) driver.ValueConverter { 1119 switch typ { 1120 case "bool": 1121 return driver.Bool 1122 case "nullbool": 1123 return driver.Null{Converter: driver.Bool} 1124 case "int32": 1125 return driver.Int32 1126 case "string": 1127 return driver.NotNull{Converter: fakeDriverString{}} 1128 case "nullstring": 1129 return driver.Null{Converter: fakeDriverString{}} 1130 case "int64": 1131 // TODO(coopernurse): add type-specific converter 1132 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1133 case "nullint64": 1134 // TODO(coopernurse): add type-specific converter 1135 return driver.Null{Converter: driver.DefaultParameterConverter} 1136 case "float64": 1137 // TODO(coopernurse): add type-specific converter 1138 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1139 case "nullfloat64": 1140 // TODO(coopernurse): add type-specific converter 1141 return driver.Null{Converter: driver.DefaultParameterConverter} 1142 case "datetime": 1143 return driver.DefaultParameterConverter 1144 case "any": 1145 return anyTypeConverter{} 1146 } 1147 panic("invalid fakedb column type of " + typ) 1148 } 1149 1150 func colTypeToReflectType(typ string) reflect.Type { 1151 switch typ { 1152 case "bool": 1153 return reflect.TypeOf(false) 1154 case "nullbool": 1155 return reflect.TypeOf(NullBool{}) 1156 case "int32": 1157 return reflect.TypeOf(int32(0)) 1158 case "string": 1159 return reflect.TypeOf("") 1160 case "nullstring": 1161 return reflect.TypeOf(NullString{}) 1162 case "int64": 1163 return reflect.TypeOf(int64(0)) 1164 case "nullint64": 1165 return reflect.TypeOf(NullInt64{}) 1166 case "float64": 1167 return reflect.TypeOf(float64(0)) 1168 case "nullfloat64": 1169 return reflect.TypeOf(NullFloat64{}) 1170 case "datetime": 1171 return reflect.TypeOf(time.Time{}) 1172 case "any": 1173 return reflect.TypeOf(new(interface{})).Elem() 1174 } 1175 panic("invalid fakedb column type of " + typ) 1176 }