github.com/ltltlt/go-source-code@v0.0.0-20190830023027-95be009773aa/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "log" 14 "reflect" 15 "sort" 16 "strconv" 17 "strings" 18 "sync" 19 "testing" 20 "time" 21 ) 22 23 var _ = log.Printf 24 25 // fakeDriver is a fake database that implements Go's driver.Driver 26 // interface, just for testing. 27 // 28 // It speaks a query language that's semantically similar to but 29 // syntactically different and simpler than SQL. The syntax is as 30 // follows: 31 // 32 // WIPE 33 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 34 // where types are: "string", [u]int{8,16,32,64}, "bool" 35 // INSERT|<tablename>|col=val,col2=val2,col3=? 36 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 37 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 38 // 39 // Any of these can be preceded by PANIC|<method>|, to cause the 40 // named method on fakeStmt to panic. 41 // 42 // Any of these can be proceeded by WAIT|<duration>|, to cause the 43 // named method on fakeStmt to sleep for the specified duration. 44 // 45 // Multiple of these can be combined when separated with a semicolon. 46 // 47 // When opening a fakeDriver's database, it starts empty with no 48 // tables. All tables and data are stored in memory only. 49 type fakeDriver struct { 50 mu sync.Mutex // guards 3 following fields 51 openCount int // conn opens 52 closeCount int // conn closes 53 waitCh chan struct{} 54 waitingCh chan struct{} 55 dbs map[string]*fakeDB 56 } 57 58 type fakeConnector struct { 59 name string 60 61 waiter func(context.Context) 62 } 63 64 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 65 conn, err := fdriver.Open(c.name) 66 conn.(*fakeConn).waiter = c.waiter 67 return conn, err 68 } 69 70 func (c *fakeConnector) Driver() driver.Driver { 71 return fdriver 72 } 73 74 type fakeDriverCtx struct { 75 fakeDriver 76 } 77 78 var _ driver.DriverContext = &fakeDriverCtx{} 79 80 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 81 return &fakeConnector{name: name}, nil 82 } 83 84 type fakeDB struct { 85 name string 86 87 mu sync.Mutex 88 tables map[string]*table 89 badConn bool 90 allowAny bool 91 } 92 93 type table struct { 94 mu sync.Mutex 95 colname []string 96 coltype []string 97 rows []*row 98 } 99 100 func (t *table) columnIndex(name string) int { 101 for n, nname := range t.colname { 102 if name == nname { 103 return n 104 } 105 } 106 return -1 107 } 108 109 type row struct { 110 cols []interface{} // must be same size as its table colname + coltype 111 } 112 113 type memToucher interface { 114 // touchMem reads & writes some memory, to help find data races. 115 touchMem() 116 } 117 118 type fakeConn struct { 119 db *fakeDB // where to return ourselves to 120 121 currTx *fakeTx 122 123 // Every operation writes to line to enable the race detector 124 // check for data races. 125 line int64 126 127 // Stats for tests: 128 mu sync.Mutex 129 stmtsMade int 130 stmtsClosed int 131 numPrepare int 132 133 // bad connection tests; see isBad() 134 bad bool 135 stickyBad bool 136 137 skipDirtySession bool // tests that use Conn should set this to true. 138 139 // dirtySession tests ResetSession, true if a query has executed 140 // until ResetSession is called. 141 dirtySession bool 142 143 // The waiter is called before each query. May be used in place of the "WAIT" 144 // directive. 145 waiter func(context.Context) 146 } 147 148 func (c *fakeConn) touchMem() { 149 c.line++ 150 } 151 152 func (c *fakeConn) incrStat(v *int) { 153 c.mu.Lock() 154 *v++ 155 c.mu.Unlock() 156 } 157 158 type fakeTx struct { 159 c *fakeConn 160 } 161 162 type boundCol struct { 163 Column string 164 Placeholder string 165 Ordinal int 166 } 167 168 type fakeStmt struct { 169 memToucher 170 c *fakeConn 171 q string // just for debugging 172 173 cmd string 174 table string 175 panic string 176 wait time.Duration 177 178 next *fakeStmt // used for returning multiple results. 179 180 closed bool 181 182 colName []string // used by CREATE, INSERT, SELECT (selected columns) 183 colType []string // used by CREATE 184 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 185 placeholders int // used by INSERT/SELECT: number of ? params 186 187 whereCol []boundCol // used by SELECT (all placeholders) 188 189 placeholderConverter []driver.ValueConverter // used by INSERT 190 } 191 192 var fdriver driver.Driver = &fakeDriver{} 193 194 func init() { 195 Register("test", fdriver) 196 } 197 198 func contains(list []string, y string) bool { 199 for _, x := range list { 200 if x == y { 201 return true 202 } 203 } 204 return false 205 } 206 207 type Dummy struct { 208 driver.Driver 209 } 210 211 func TestDrivers(t *testing.T) { 212 unregisterAllDrivers() 213 Register("test", fdriver) 214 Register("invalid", Dummy{}) 215 all := Drivers() 216 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 217 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 218 } 219 } 220 221 // hook to simulate connection failures 222 var hookOpenErr struct { 223 sync.Mutex 224 fn func() error 225 } 226 227 func setHookOpenErr(fn func() error) { 228 hookOpenErr.Lock() 229 defer hookOpenErr.Unlock() 230 hookOpenErr.fn = fn 231 } 232 233 // Supports dsn forms: 234 // <dbname> 235 // <dbname>;<opts> (only currently supported option is `badConn`, 236 // which causes driver.ErrBadConn to be returned on 237 // every other conn.Begin()) 238 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 239 hookOpenErr.Lock() 240 fn := hookOpenErr.fn 241 hookOpenErr.Unlock() 242 if fn != nil { 243 if err := fn(); err != nil { 244 return nil, err 245 } 246 } 247 parts := strings.Split(dsn, ";") 248 if len(parts) < 1 { 249 return nil, errors.New("fakedb: no database name") 250 } 251 name := parts[0] 252 253 db := d.getDB(name) 254 255 d.mu.Lock() 256 d.openCount++ 257 d.mu.Unlock() 258 conn := &fakeConn{db: db} 259 260 if len(parts) >= 2 && parts[1] == "badConn" { 261 conn.bad = true 262 } 263 if d.waitCh != nil { 264 d.waitingCh <- struct{}{} 265 <-d.waitCh 266 d.waitCh = nil 267 d.waitingCh = nil 268 } 269 return conn, nil 270 } 271 272 func (d *fakeDriver) getDB(name string) *fakeDB { 273 d.mu.Lock() 274 defer d.mu.Unlock() 275 if d.dbs == nil { 276 d.dbs = make(map[string]*fakeDB) 277 } 278 db, ok := d.dbs[name] 279 if !ok { 280 db = &fakeDB{name: name} 281 d.dbs[name] = db 282 } 283 return db 284 } 285 286 func (db *fakeDB) wipe() { 287 db.mu.Lock() 288 defer db.mu.Unlock() 289 db.tables = nil 290 } 291 292 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 293 db.mu.Lock() 294 defer db.mu.Unlock() 295 if db.tables == nil { 296 db.tables = make(map[string]*table) 297 } 298 if _, exist := db.tables[name]; exist { 299 return fmt.Errorf("table %q already exists", name) 300 } 301 if len(columnNames) != len(columnTypes) { 302 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", 303 name, len(columnNames), len(columnTypes)) 304 } 305 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 306 return nil 307 } 308 309 // must be called with db.mu lock held 310 func (db *fakeDB) table(table string) (*table, bool) { 311 if db.tables == nil { 312 return nil, false 313 } 314 t, ok := db.tables[table] 315 return t, ok 316 } 317 318 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 319 db.mu.Lock() 320 defer db.mu.Unlock() 321 t, ok := db.table(table) 322 if !ok { 323 return 324 } 325 for n, cname := range t.colname { 326 if cname == column { 327 return t.coltype[n], true 328 } 329 } 330 return "", false 331 } 332 333 func (c *fakeConn) isBad() bool { 334 if c.stickyBad { 335 return true 336 } else if c.bad { 337 if c.db == nil { 338 return false 339 } 340 // alternate between bad conn and not bad conn 341 c.db.badConn = !c.db.badConn 342 return c.db.badConn 343 } else { 344 return false 345 } 346 } 347 348 func (c *fakeConn) isDirtyAndMark() bool { 349 if c.skipDirtySession { 350 return false 351 } 352 if c.currTx != nil { 353 c.dirtySession = true 354 return false 355 } 356 if c.dirtySession { 357 return true 358 } 359 c.dirtySession = true 360 return false 361 } 362 363 func (c *fakeConn) Begin() (driver.Tx, error) { 364 if c.isBad() { 365 return nil, driver.ErrBadConn 366 } 367 if c.currTx != nil { 368 return nil, errors.New("already in a transaction") 369 } 370 c.touchMem() 371 c.currTx = &fakeTx{c: c} 372 return c.currTx, nil 373 } 374 375 var hookPostCloseConn struct { 376 sync.Mutex 377 fn func(*fakeConn, error) 378 } 379 380 func setHookpostCloseConn(fn func(*fakeConn, error)) { 381 hookPostCloseConn.Lock() 382 defer hookPostCloseConn.Unlock() 383 hookPostCloseConn.fn = fn 384 } 385 386 var testStrictClose *testing.T 387 388 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 389 // fails to close. If nil, the check is disabled. 390 func setStrictFakeConnClose(t *testing.T) { 391 testStrictClose = t 392 } 393 394 func (c *fakeConn) ResetSession(ctx context.Context) error { 395 c.dirtySession = false 396 if c.isBad() { 397 return driver.ErrBadConn 398 } 399 return nil 400 } 401 402 func (c *fakeConn) Close() (err error) { 403 drv := fdriver.(*fakeDriver) 404 defer func() { 405 if err != nil && testStrictClose != nil { 406 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 407 } 408 hookPostCloseConn.Lock() 409 fn := hookPostCloseConn.fn 410 hookPostCloseConn.Unlock() 411 if fn != nil { 412 fn(c, err) 413 } 414 if err == nil { 415 drv.mu.Lock() 416 drv.closeCount++ 417 drv.mu.Unlock() 418 } 419 }() 420 c.touchMem() 421 if c.currTx != nil { 422 return errors.New("can't close fakeConn; in a Transaction") 423 } 424 if c.db == nil { 425 return errors.New("can't close fakeConn; already closed") 426 } 427 if c.stmtsMade > c.stmtsClosed { 428 return errors.New("can't close; dangling statement(s)") 429 } 430 c.db = nil 431 return nil 432 } 433 434 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 435 for _, arg := range args { 436 switch arg.Value.(type) { 437 case int64, float64, bool, nil, []byte, string, time.Time: 438 default: 439 if !allowAny { 440 return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 441 } 442 } 443 } 444 return nil 445 } 446 447 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 448 // Ensure that ExecContext is called if available. 449 panic("ExecContext was not called.") 450 } 451 452 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 453 // This is an optional interface, but it's implemented here 454 // just to check that all the args are of the proper types. 455 // ErrSkip is returned so the caller acts as if we didn't 456 // implement this at all. 457 err := checkSubsetTypes(c.db.allowAny, args) 458 if err != nil { 459 return nil, err 460 } 461 return nil, driver.ErrSkip 462 } 463 464 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 465 // Ensure that ExecContext is called if available. 466 panic("QueryContext was not called.") 467 } 468 469 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 470 // This is an optional interface, but it's implemented here 471 // just to check that all the args are of the proper types. 472 // ErrSkip is returned so the caller acts as if we didn't 473 // implement this at all. 474 err := checkSubsetTypes(c.db.allowAny, args) 475 if err != nil { 476 return nil, err 477 } 478 return nil, driver.ErrSkip 479 } 480 481 func errf(msg string, args ...interface{}) error { 482 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 483 } 484 485 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 486 // (note that where columns must always contain ? marks, 487 // just a limitation for fakedb) 488 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 489 if len(parts) != 3 { 490 stmt.Close() 491 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 492 } 493 stmt.table = parts[0] 494 495 stmt.colName = strings.Split(parts[1], ",") 496 for n, colspec := range strings.Split(parts[2], ",") { 497 if colspec == "" { 498 continue 499 } 500 nameVal := strings.Split(colspec, "=") 501 if len(nameVal) != 2 { 502 stmt.Close() 503 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 504 } 505 column, value := nameVal[0], nameVal[1] 506 _, ok := c.db.columnType(stmt.table, column) 507 if !ok { 508 stmt.Close() 509 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 510 } 511 if !strings.HasPrefix(value, "?") { 512 stmt.Close() 513 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 514 stmt.table, column) 515 } 516 stmt.placeholders++ 517 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 518 } 519 return stmt, nil 520 } 521 522 // parts are table|col=type,col2=type2 523 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 524 if len(parts) != 2 { 525 stmt.Close() 526 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 527 } 528 stmt.table = parts[0] 529 for n, colspec := range strings.Split(parts[1], ",") { 530 nameType := strings.Split(colspec, "=") 531 if len(nameType) != 2 { 532 stmt.Close() 533 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 534 } 535 stmt.colName = append(stmt.colName, nameType[0]) 536 stmt.colType = append(stmt.colType, nameType[1]) 537 } 538 return stmt, nil 539 } 540 541 // parts are table|col=?,col2=val 542 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 543 if len(parts) != 2 { 544 stmt.Close() 545 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 546 } 547 stmt.table = parts[0] 548 for n, colspec := range strings.Split(parts[1], ",") { 549 nameVal := strings.Split(colspec, "=") 550 if len(nameVal) != 2 { 551 stmt.Close() 552 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 553 } 554 column, value := nameVal[0], nameVal[1] 555 ctype, ok := c.db.columnType(stmt.table, column) 556 if !ok { 557 stmt.Close() 558 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 559 } 560 stmt.colName = append(stmt.colName, column) 561 562 if !strings.HasPrefix(value, "?") { 563 var subsetVal interface{} 564 // Convert to driver subset type 565 switch ctype { 566 case "string": 567 subsetVal = []byte(value) 568 case "blob": 569 subsetVal = []byte(value) 570 case "int32": 571 i, err := strconv.Atoi(value) 572 if err != nil { 573 stmt.Close() 574 return nil, errf("invalid conversion to int32 from %q", value) 575 } 576 subsetVal = int64(i) // int64 is a subset type, but not int32 577 default: 578 stmt.Close() 579 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 580 } 581 stmt.colValue = append(stmt.colValue, subsetVal) 582 } else { 583 stmt.placeholders++ 584 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 585 stmt.colValue = append(stmt.colValue, value) 586 } 587 } 588 return stmt, nil 589 } 590 591 // hook to simulate broken connections 592 var hookPrepareBadConn func() bool 593 594 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 595 panic("use PrepareContext") 596 } 597 598 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 599 c.numPrepare++ 600 if c.db == nil { 601 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 602 } 603 604 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 605 return nil, driver.ErrBadConn 606 } 607 608 c.touchMem() 609 var firstStmt, prev *fakeStmt 610 for _, query := range strings.Split(query, ";") { 611 parts := strings.Split(query, "|") 612 if len(parts) < 1 { 613 return nil, errf("empty query") 614 } 615 stmt := &fakeStmt{q: query, c: c, memToucher: c} 616 if firstStmt == nil { 617 firstStmt = stmt 618 } 619 if len(parts) >= 3 { 620 switch parts[0] { 621 case "PANIC": 622 stmt.panic = parts[1] 623 parts = parts[2:] 624 case "WAIT": 625 wait, err := time.ParseDuration(parts[1]) 626 if err != nil { 627 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 628 } 629 parts = parts[2:] 630 stmt.wait = wait 631 } 632 } 633 cmd := parts[0] 634 stmt.cmd = cmd 635 parts = parts[1:] 636 637 if c.waiter != nil { 638 c.waiter(ctx) 639 } 640 641 if stmt.wait > 0 { 642 wait := time.NewTimer(stmt.wait) 643 select { 644 case <-wait.C: 645 case <-ctx.Done(): 646 wait.Stop() 647 return nil, ctx.Err() 648 } 649 } 650 651 c.incrStat(&c.stmtsMade) 652 var err error 653 switch cmd { 654 case "WIPE": 655 // Nothing 656 case "SELECT": 657 stmt, err = c.prepareSelect(stmt, parts) 658 case "CREATE": 659 stmt, err = c.prepareCreate(stmt, parts) 660 case "INSERT": 661 stmt, err = c.prepareInsert(stmt, parts) 662 case "NOSERT": 663 // Do all the prep-work like for an INSERT but don't actually insert the row. 664 // Used for some of the concurrent tests. 665 stmt, err = c.prepareInsert(stmt, parts) 666 default: 667 stmt.Close() 668 return nil, errf("unsupported command type %q", cmd) 669 } 670 if err != nil { 671 return nil, err 672 } 673 if prev != nil { 674 prev.next = stmt 675 } 676 prev = stmt 677 } 678 return firstStmt, nil 679 } 680 681 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 682 if s.panic == "ColumnConverter" { 683 panic(s.panic) 684 } 685 if len(s.placeholderConverter) == 0 { 686 return driver.DefaultParameterConverter 687 } 688 return s.placeholderConverter[idx] 689 } 690 691 func (s *fakeStmt) Close() error { 692 if s.panic == "Close" { 693 panic(s.panic) 694 } 695 if s.c == nil { 696 panic("nil conn in fakeStmt.Close") 697 } 698 if s.c.db == nil { 699 panic("in fakeStmt.Close, conn's db is nil (already closed)") 700 } 701 s.touchMem() 702 if !s.closed { 703 s.c.incrStat(&s.c.stmtsClosed) 704 s.closed = true 705 } 706 if s.next != nil { 707 s.next.Close() 708 } 709 return nil 710 } 711 712 var errClosed = errors.New("fakedb: statement has been closed") 713 714 // hook to simulate broken connections 715 var hookExecBadConn func() bool 716 717 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 718 panic("Using ExecContext") 719 } 720 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 721 if s.panic == "Exec" { 722 panic(s.panic) 723 } 724 if s.closed { 725 return nil, errClosed 726 } 727 728 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 729 return nil, driver.ErrBadConn 730 } 731 if s.c.isDirtyAndMark() { 732 return nil, errors.New("session is dirty") 733 } 734 735 err := checkSubsetTypes(s.c.db.allowAny, args) 736 if err != nil { 737 return nil, err 738 } 739 s.touchMem() 740 741 if s.wait > 0 { 742 time.Sleep(s.wait) 743 } 744 745 select { 746 default: 747 case <-ctx.Done(): 748 return nil, ctx.Err() 749 } 750 751 db := s.c.db 752 switch s.cmd { 753 case "WIPE": 754 db.wipe() 755 return driver.ResultNoRows, nil 756 case "CREATE": 757 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 758 return nil, err 759 } 760 return driver.ResultNoRows, nil 761 case "INSERT": 762 return s.execInsert(args, true) 763 case "NOSERT": 764 // Do all the prep-work like for an INSERT but don't actually insert the row. 765 // Used for some of the concurrent tests. 766 return s.execInsert(args, false) 767 } 768 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) 769 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) 770 } 771 772 // When doInsert is true, add the row to the table. 773 // When doInsert is false do prep-work and error checking, but don't 774 // actually add the row to the table. 775 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 776 db := s.c.db 777 if len(args) != s.placeholders { 778 panic("error in pkg db; should only get here if size is correct") 779 } 780 db.mu.Lock() 781 t, ok := db.table(s.table) 782 db.mu.Unlock() 783 if !ok { 784 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 785 } 786 787 t.mu.Lock() 788 defer t.mu.Unlock() 789 790 var cols []interface{} 791 if doInsert { 792 cols = make([]interface{}, len(t.colname)) 793 } 794 argPos := 0 795 for n, colname := range s.colName { 796 colidx := t.columnIndex(colname) 797 if colidx == -1 { 798 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 799 } 800 var val interface{} 801 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 802 if strvalue == "?" { 803 val = args[argPos].Value 804 } else { 805 // Assign value from argument placeholder name. 806 for _, a := range args { 807 if a.Name == strvalue[1:] { 808 val = a.Value 809 break 810 } 811 } 812 } 813 argPos++ 814 } else { 815 val = s.colValue[n] 816 } 817 if doInsert { 818 cols[colidx] = val 819 } 820 } 821 822 if doInsert { 823 t.rows = append(t.rows, &row{cols: cols}) 824 } 825 return driver.RowsAffected(1), nil 826 } 827 828 // hook to simulate broken connections 829 var hookQueryBadConn func() bool 830 831 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 832 panic("Use QueryContext") 833 } 834 835 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 836 if s.panic == "Query" { 837 panic(s.panic) 838 } 839 if s.closed { 840 return nil, errClosed 841 } 842 843 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 844 return nil, driver.ErrBadConn 845 } 846 if s.c.isDirtyAndMark() { 847 return nil, errors.New("session is dirty") 848 } 849 850 err := checkSubsetTypes(s.c.db.allowAny, args) 851 if err != nil { 852 return nil, err 853 } 854 855 s.touchMem() 856 db := s.c.db 857 if len(args) != s.placeholders { 858 panic("error in pkg db; should only get here if size is correct") 859 } 860 861 setMRows := make([][]*row, 0, 1) 862 setColumns := make([][]string, 0, 1) 863 setColType := make([][]string, 0, 1) 864 865 for { 866 db.mu.Lock() 867 t, ok := db.table(s.table) 868 db.mu.Unlock() 869 if !ok { 870 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 871 } 872 873 if s.table == "magicquery" { 874 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 875 if args[0].Value == "sleep" { 876 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 877 } 878 } 879 } 880 881 t.mu.Lock() 882 883 colIdx := make(map[string]int) // select column name -> column index in table 884 for _, name := range s.colName { 885 idx := t.columnIndex(name) 886 if idx == -1 { 887 t.mu.Unlock() 888 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 889 } 890 colIdx[name] = idx 891 } 892 893 mrows := []*row{} 894 rows: 895 for _, trow := range t.rows { 896 // Process the where clause, skipping non-match rows. This is lazy 897 // and just uses fmt.Sprintf("%v") to test equality. Good enough 898 // for test code. 899 for _, wcol := range s.whereCol { 900 idx := t.columnIndex(wcol.Column) 901 if idx == -1 { 902 t.mu.Unlock() 903 return nil, fmt.Errorf("db: invalid where clause column %q", wcol) 904 } 905 tcol := trow.cols[idx] 906 if bs, ok := tcol.([]byte); ok { 907 // lazy hack to avoid sprintf %v on a []byte 908 tcol = string(bs) 909 } 910 var argValue interface{} 911 if wcol.Placeholder == "?" { 912 argValue = args[wcol.Ordinal-1].Value 913 } else { 914 // Assign arg value from placeholder name. 915 for _, a := range args { 916 if a.Name == wcol.Placeholder[1:] { 917 argValue = a.Value 918 break 919 } 920 } 921 } 922 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 923 continue rows 924 } 925 } 926 mrow := &row{cols: make([]interface{}, len(s.colName))} 927 for seli, name := range s.colName { 928 mrow.cols[seli] = trow.cols[colIdx[name]] 929 } 930 mrows = append(mrows, mrow) 931 } 932 933 var colType []string 934 for _, column := range s.colName { 935 colType = append(colType, t.coltype[t.columnIndex(column)]) 936 } 937 938 t.mu.Unlock() 939 940 setMRows = append(setMRows, mrows) 941 setColumns = append(setColumns, s.colName) 942 setColType = append(setColType, colType) 943 944 if s.next == nil { 945 break 946 } 947 s = s.next 948 } 949 950 cursor := &rowsCursor{ 951 parentMem: s.c, 952 posRow: -1, 953 rows: setMRows, 954 cols: setColumns, 955 colType: setColType, 956 errPos: -1, 957 } 958 return cursor, nil 959 } 960 961 func (s *fakeStmt) NumInput() int { 962 if s.panic == "NumInput" { 963 panic(s.panic) 964 } 965 return s.placeholders 966 } 967 968 // hook to simulate broken connections 969 var hookCommitBadConn func() bool 970 971 func (tx *fakeTx) Commit() error { 972 tx.c.currTx = nil 973 if hookCommitBadConn != nil && hookCommitBadConn() { 974 return driver.ErrBadConn 975 } 976 tx.c.touchMem() 977 return nil 978 } 979 980 // hook to simulate broken connections 981 var hookRollbackBadConn func() bool 982 983 func (tx *fakeTx) Rollback() error { 984 tx.c.currTx = nil 985 if hookRollbackBadConn != nil && hookRollbackBadConn() { 986 return driver.ErrBadConn 987 } 988 tx.c.touchMem() 989 return nil 990 } 991 992 type rowsCursor struct { 993 parentMem memToucher 994 cols [][]string 995 colType [][]string 996 posSet int 997 posRow int 998 rows [][]*row 999 closed bool 1000 1001 // errPos and err are for making Next return early with error. 1002 errPos int 1003 err error 1004 1005 // a clone of slices to give out to clients, indexed by the 1006 // the original slice's first byte address. we clone them 1007 // just so we're able to corrupt them on close. 1008 bytesClone map[*byte][]byte 1009 1010 // Every operation writes to line to enable the race detector 1011 // check for data races. 1012 // This is separate from the fakeConn.line to allow for drivers that 1013 // can start multiple queries on the same transaction at the same time. 1014 line int64 1015 } 1016 1017 func (rc *rowsCursor) touchMem() { 1018 rc.parentMem.touchMem() 1019 rc.line++ 1020 } 1021 1022 func (rc *rowsCursor) Close() error { 1023 rc.touchMem() 1024 rc.parentMem.touchMem() 1025 rc.closed = true 1026 return nil 1027 } 1028 1029 func (rc *rowsCursor) Columns() []string { 1030 return rc.cols[rc.posSet] 1031 } 1032 1033 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1034 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1035 } 1036 1037 var rowsCursorNextHook func(dest []driver.Value) error 1038 1039 func (rc *rowsCursor) Next(dest []driver.Value) error { 1040 if rowsCursorNextHook != nil { 1041 return rowsCursorNextHook(dest) 1042 } 1043 1044 if rc.closed { 1045 return errors.New("fakedb: cursor is closed") 1046 } 1047 rc.touchMem() 1048 rc.posRow++ 1049 if rc.posRow == rc.errPos { 1050 return rc.err 1051 } 1052 if rc.posRow >= len(rc.rows[rc.posSet]) { 1053 return io.EOF // per interface spec 1054 } 1055 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1056 // TODO(bradfitz): convert to subset types? naah, I 1057 // think the subset types should only be input to 1058 // driver, but the sql package should be able to handle 1059 // a wider range of types coming out of drivers. all 1060 // for ease of drivers, and to prevent drivers from 1061 // messing up conversions or doing them differently. 1062 dest[i] = v 1063 1064 if bs, ok := v.([]byte); ok { 1065 if rc.bytesClone == nil { 1066 rc.bytesClone = make(map[*byte][]byte) 1067 } 1068 clone, ok := rc.bytesClone[&bs[0]] 1069 if !ok { 1070 clone = make([]byte, len(bs)) 1071 copy(clone, bs) 1072 rc.bytesClone[&bs[0]] = clone 1073 } 1074 dest[i] = clone 1075 } 1076 } 1077 return nil 1078 } 1079 1080 func (rc *rowsCursor) HasNextResultSet() bool { 1081 rc.touchMem() 1082 return rc.posSet < len(rc.rows)-1 1083 } 1084 1085 func (rc *rowsCursor) NextResultSet() error { 1086 rc.touchMem() 1087 if rc.HasNextResultSet() { 1088 rc.posSet++ 1089 rc.posRow = -1 1090 return nil 1091 } 1092 return io.EOF // Per interface spec. 1093 } 1094 1095 // fakeDriverString is like driver.String, but indirects pointers like 1096 // DefaultValueConverter. 1097 // 1098 // This could be surprising behavior to retroactively apply to 1099 // driver.String now that Go1 is out, but this is convenient for 1100 // our TestPointerParamsAndScans. 1101 // 1102 type fakeDriverString struct{} 1103 1104 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 1105 switch c := v.(type) { 1106 case string, []byte: 1107 return v, nil 1108 case *string: 1109 if c == nil { 1110 return nil, nil 1111 } 1112 return *c, nil 1113 } 1114 return fmt.Sprintf("%v", v), nil 1115 } 1116 1117 type anyTypeConverter struct{} 1118 1119 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { 1120 return v, nil 1121 } 1122 1123 func converterForType(typ string) driver.ValueConverter { 1124 switch typ { 1125 case "bool": 1126 return driver.Bool 1127 case "nullbool": 1128 return driver.Null{Converter: driver.Bool} 1129 case "int32": 1130 return driver.Int32 1131 case "string": 1132 return driver.NotNull{Converter: fakeDriverString{}} 1133 case "nullstring": 1134 return driver.Null{Converter: fakeDriverString{}} 1135 case "int64": 1136 // TODO(coopernurse): add type-specific converter 1137 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1138 case "nullint64": 1139 // TODO(coopernurse): add type-specific converter 1140 return driver.Null{Converter: driver.DefaultParameterConverter} 1141 case "float64": 1142 // TODO(coopernurse): add type-specific converter 1143 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1144 case "nullfloat64": 1145 // TODO(coopernurse): add type-specific converter 1146 return driver.Null{Converter: driver.DefaultParameterConverter} 1147 case "datetime": 1148 return driver.DefaultParameterConverter 1149 case "any": 1150 return anyTypeConverter{} 1151 } 1152 panic("invalid fakedb column type of " + typ) 1153 } 1154 1155 func colTypeToReflectType(typ string) reflect.Type { 1156 switch typ { 1157 case "bool": 1158 return reflect.TypeOf(false) 1159 case "nullbool": 1160 return reflect.TypeOf(NullBool{}) 1161 case "int32": 1162 return reflect.TypeOf(int32(0)) 1163 case "string": 1164 return reflect.TypeOf("") 1165 case "nullstring": 1166 return reflect.TypeOf(NullString{}) 1167 case "int64": 1168 return reflect.TypeOf(int64(0)) 1169 case "nullint64": 1170 return reflect.TypeOf(NullInt64{}) 1171 case "float64": 1172 return reflect.TypeOf(float64(0)) 1173 case "nullfloat64": 1174 return reflect.TypeOf(NullFloat64{}) 1175 case "datetime": 1176 return reflect.TypeOf(time.Time{}) 1177 case "any": 1178 return reflect.TypeOf(new(interface{})).Elem() 1179 } 1180 panic("invalid fakedb column type of " + typ) 1181 }