github.com/code-reading/golang@v0.0.0-20220303082512-ba5bc0e589a3/go/src/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "reflect" 14 "sort" 15 "strconv" 16 "strings" 17 "sync" 18 "testing" 19 "time" 20 ) 21 22 // fakeDriver is a fake database that implements Go's driver.Driver 23 // interface, just for testing. 24 // 25 // It speaks a query language that's semantically similar to but 26 // syntactically different and simpler than SQL. The syntax is as 27 // follows: 28 // 29 // WIPE 30 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 31 // where types are: "string", [u]int{8,16,32,64}, "bool" 32 // INSERT|<tablename>|col=val,col2=val2,col3=? 33 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 34 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 35 // 36 // Any of these can be preceded by PANIC|<method>|, to cause the 37 // named method on fakeStmt to panic. 38 // 39 // Any of these can be proceeded by WAIT|<duration>|, to cause the 40 // named method on fakeStmt to sleep for the specified duration. 41 // 42 // Multiple of these can be combined when separated with a semicolon. 43 // 44 // When opening a fakeDriver's database, it starts empty with no 45 // tables. All tables and data are stored in memory only. 46 type fakeDriver struct { 47 mu sync.Mutex // guards 3 following fields 48 openCount int // conn opens 49 closeCount int // conn closes 50 waitCh chan struct{} 51 waitingCh chan struct{} 52 dbs map[string]*fakeDB 53 } 54 55 type fakeConnector struct { 56 name string 57 58 waiter func(context.Context) 59 closed bool 60 } 61 62 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 63 conn, err := fdriver.Open(c.name) 64 conn.(*fakeConn).waiter = c.waiter 65 return conn, err 66 } 67 68 func (c *fakeConnector) Driver() driver.Driver { 69 return fdriver 70 } 71 72 func (c *fakeConnector) Close() error { 73 if c.closed { 74 return errors.New("fakedb: connector is closed") 75 } 76 c.closed = true 77 return nil 78 } 79 80 type fakeDriverCtx struct { 81 fakeDriver 82 } 83 84 var _ driver.DriverContext = &fakeDriverCtx{} 85 86 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 87 return &fakeConnector{name: name}, nil 88 } 89 90 type fakeDB struct { 91 name string 92 93 mu sync.Mutex 94 tables map[string]*table 95 badConn bool 96 allowAny bool 97 } 98 99 type table struct { 100 mu sync.Mutex 101 colname []string 102 coltype []string 103 rows []*row 104 } 105 106 func (t *table) columnIndex(name string) int { 107 for n, nname := range t.colname { 108 if name == nname { 109 return n 110 } 111 } 112 return -1 113 } 114 115 type row struct { 116 cols []interface{} // must be same size as its table colname + coltype 117 } 118 119 type memToucher interface { 120 // touchMem reads & writes some memory, to help find data races. 121 touchMem() 122 } 123 124 type fakeConn struct { 125 db *fakeDB // where to return ourselves to 126 127 currTx *fakeTx 128 129 // Every operation writes to line to enable the race detector 130 // check for data races. 131 line int64 132 133 // Stats for tests: 134 mu sync.Mutex 135 stmtsMade int 136 stmtsClosed int 137 numPrepare int 138 139 // bad connection tests; see isBad() 140 bad bool 141 stickyBad bool 142 143 skipDirtySession bool // tests that use Conn should set this to true. 144 145 // dirtySession tests ResetSession, true if a query has executed 146 // until ResetSession is called. 147 dirtySession bool 148 149 // The waiter is called before each query. May be used in place of the "WAIT" 150 // directive. 151 waiter func(context.Context) 152 } 153 154 func (c *fakeConn) touchMem() { 155 c.line++ 156 } 157 158 func (c *fakeConn) incrStat(v *int) { 159 c.mu.Lock() 160 *v++ 161 c.mu.Unlock() 162 } 163 164 type fakeTx struct { 165 c *fakeConn 166 } 167 168 type boundCol struct { 169 Column string 170 Placeholder string 171 Ordinal int 172 } 173 174 type fakeStmt struct { 175 memToucher 176 c *fakeConn 177 q string // just for debugging 178 179 cmd string 180 table string 181 panic string 182 wait time.Duration 183 184 next *fakeStmt // used for returning multiple results. 185 186 closed bool 187 188 colName []string // used by CREATE, INSERT, SELECT (selected columns) 189 colType []string // used by CREATE 190 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 191 placeholders int // used by INSERT/SELECT: number of ? params 192 193 whereCol []boundCol // used by SELECT (all placeholders) 194 195 placeholderConverter []driver.ValueConverter // used by INSERT 196 } 197 198 var fdriver driver.Driver = &fakeDriver{} 199 200 func init() { 201 Register("test", fdriver) 202 } 203 204 func contains(list []string, y string) bool { 205 for _, x := range list { 206 if x == y { 207 return true 208 } 209 } 210 return false 211 } 212 213 type Dummy struct { 214 driver.Driver 215 } 216 217 func TestDrivers(t *testing.T) { 218 unregisterAllDrivers() 219 Register("test", fdriver) 220 Register("invalid", Dummy{}) 221 all := Drivers() 222 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 223 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 224 } 225 } 226 227 // hook to simulate connection failures 228 var hookOpenErr struct { 229 sync.Mutex 230 fn func() error 231 } 232 233 func setHookOpenErr(fn func() error) { 234 hookOpenErr.Lock() 235 defer hookOpenErr.Unlock() 236 hookOpenErr.fn = fn 237 } 238 239 // Supports dsn forms: 240 // <dbname> 241 // <dbname>;<opts> (only currently supported option is `badConn`, 242 // which causes driver.ErrBadConn to be returned on 243 // every other conn.Begin()) 244 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 245 hookOpenErr.Lock() 246 fn := hookOpenErr.fn 247 hookOpenErr.Unlock() 248 if fn != nil { 249 if err := fn(); err != nil { 250 return nil, err 251 } 252 } 253 parts := strings.Split(dsn, ";") 254 if len(parts) < 1 { 255 return nil, errors.New("fakedb: no database name") 256 } 257 name := parts[0] 258 259 db := d.getDB(name) 260 261 d.mu.Lock() 262 d.openCount++ 263 d.mu.Unlock() 264 conn := &fakeConn{db: db} 265 266 if len(parts) >= 2 && parts[1] == "badConn" { 267 conn.bad = true 268 } 269 if d.waitCh != nil { 270 d.waitingCh <- struct{}{} 271 <-d.waitCh 272 d.waitCh = nil 273 d.waitingCh = nil 274 } 275 return conn, nil 276 } 277 278 func (d *fakeDriver) getDB(name string) *fakeDB { 279 d.mu.Lock() 280 defer d.mu.Unlock() 281 if d.dbs == nil { 282 d.dbs = make(map[string]*fakeDB) 283 } 284 db, ok := d.dbs[name] 285 if !ok { 286 db = &fakeDB{name: name} 287 d.dbs[name] = db 288 } 289 return db 290 } 291 292 func (db *fakeDB) wipe() { 293 db.mu.Lock() 294 defer db.mu.Unlock() 295 db.tables = nil 296 } 297 298 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 299 db.mu.Lock() 300 defer db.mu.Unlock() 301 if db.tables == nil { 302 db.tables = make(map[string]*table) 303 } 304 if _, exist := db.tables[name]; exist { 305 return fmt.Errorf("fakedb: table %q already exists", name) 306 } 307 if len(columnNames) != len(columnTypes) { 308 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", 309 name, len(columnNames), len(columnTypes)) 310 } 311 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 312 return nil 313 } 314 315 // must be called with db.mu lock held 316 func (db *fakeDB) table(table string) (*table, bool) { 317 if db.tables == nil { 318 return nil, false 319 } 320 t, ok := db.tables[table] 321 return t, ok 322 } 323 324 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 325 db.mu.Lock() 326 defer db.mu.Unlock() 327 t, ok := db.table(table) 328 if !ok { 329 return 330 } 331 for n, cname := range t.colname { 332 if cname == column { 333 return t.coltype[n], true 334 } 335 } 336 return "", false 337 } 338 339 func (c *fakeConn) isBad() bool { 340 if c.stickyBad { 341 return true 342 } else if c.bad { 343 if c.db == nil { 344 return false 345 } 346 // alternate between bad conn and not bad conn 347 c.db.badConn = !c.db.badConn 348 return c.db.badConn 349 } else { 350 return false 351 } 352 } 353 354 func (c *fakeConn) isDirtyAndMark() bool { 355 if c.skipDirtySession { 356 return false 357 } 358 if c.currTx != nil { 359 c.dirtySession = true 360 return false 361 } 362 if c.dirtySession { 363 return true 364 } 365 c.dirtySession = true 366 return false 367 } 368 369 func (c *fakeConn) Begin() (driver.Tx, error) { 370 if c.isBad() { 371 return nil, driver.ErrBadConn 372 } 373 if c.currTx != nil { 374 return nil, errors.New("fakedb: already in a transaction") 375 } 376 c.touchMem() 377 c.currTx = &fakeTx{c: c} 378 return c.currTx, nil 379 } 380 381 var hookPostCloseConn struct { 382 sync.Mutex 383 fn func(*fakeConn, error) 384 } 385 386 func setHookpostCloseConn(fn func(*fakeConn, error)) { 387 hookPostCloseConn.Lock() 388 defer hookPostCloseConn.Unlock() 389 hookPostCloseConn.fn = fn 390 } 391 392 var testStrictClose *testing.T 393 394 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 395 // fails to close. If nil, the check is disabled. 396 func setStrictFakeConnClose(t *testing.T) { 397 testStrictClose = t 398 } 399 400 func (c *fakeConn) ResetSession(ctx context.Context) error { 401 c.dirtySession = false 402 c.currTx = nil 403 if c.isBad() { 404 return driver.ErrBadConn 405 } 406 return nil 407 } 408 409 var _ driver.Validator = (*fakeConn)(nil) 410 411 func (c *fakeConn) IsValid() bool { 412 return !c.isBad() 413 } 414 415 func (c *fakeConn) Close() (err error) { 416 drv := fdriver.(*fakeDriver) 417 defer func() { 418 if err != nil && testStrictClose != nil { 419 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 420 } 421 hookPostCloseConn.Lock() 422 fn := hookPostCloseConn.fn 423 hookPostCloseConn.Unlock() 424 if fn != nil { 425 fn(c, err) 426 } 427 if err == nil { 428 drv.mu.Lock() 429 drv.closeCount++ 430 drv.mu.Unlock() 431 } 432 }() 433 c.touchMem() 434 if c.currTx != nil { 435 return errors.New("fakedb: can't close fakeConn; in a Transaction") 436 } 437 if c.db == nil { 438 return errors.New("fakedb: can't close fakeConn; already closed") 439 } 440 if c.stmtsMade > c.stmtsClosed { 441 return errors.New("fakedb: can't close; dangling statement(s)") 442 } 443 c.db = nil 444 return nil 445 } 446 447 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 448 for _, arg := range args { 449 switch arg.Value.(type) { 450 case int64, float64, bool, nil, []byte, string, time.Time: 451 default: 452 if !allowAny { 453 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 454 } 455 } 456 } 457 return nil 458 } 459 460 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 461 // Ensure that ExecContext is called if available. 462 panic("ExecContext was not called.") 463 } 464 465 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 466 // This is an optional interface, but it's implemented here 467 // just to check that all the args are of the proper types. 468 // ErrSkip is returned so the caller acts as if we didn't 469 // implement this at all. 470 err := checkSubsetTypes(c.db.allowAny, args) 471 if err != nil { 472 return nil, err 473 } 474 return nil, driver.ErrSkip 475 } 476 477 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 478 // Ensure that ExecContext is called if available. 479 panic("QueryContext was not called.") 480 } 481 482 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 483 // This is an optional interface, but it's implemented here 484 // just to check that all the args are of the proper types. 485 // ErrSkip is returned so the caller acts as if we didn't 486 // implement this at all. 487 err := checkSubsetTypes(c.db.allowAny, args) 488 if err != nil { 489 return nil, err 490 } 491 return nil, driver.ErrSkip 492 } 493 494 func errf(msg string, args ...interface{}) error { 495 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 496 } 497 498 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 499 // (note that where columns must always contain ? marks, 500 // just a limitation for fakedb) 501 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 502 if len(parts) != 3 { 503 stmt.Close() 504 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 505 } 506 stmt.table = parts[0] 507 508 stmt.colName = strings.Split(parts[1], ",") 509 for n, colspec := range strings.Split(parts[2], ",") { 510 if colspec == "" { 511 continue 512 } 513 nameVal := strings.Split(colspec, "=") 514 if len(nameVal) != 2 { 515 stmt.Close() 516 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 517 } 518 column, value := nameVal[0], nameVal[1] 519 _, ok := c.db.columnType(stmt.table, column) 520 if !ok { 521 stmt.Close() 522 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 523 } 524 if !strings.HasPrefix(value, "?") { 525 stmt.Close() 526 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 527 stmt.table, column) 528 } 529 stmt.placeholders++ 530 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 531 } 532 return stmt, nil 533 } 534 535 // parts are table|col=type,col2=type2 536 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 537 if len(parts) != 2 { 538 stmt.Close() 539 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 540 } 541 stmt.table = parts[0] 542 for n, colspec := range strings.Split(parts[1], ",") { 543 nameType := strings.Split(colspec, "=") 544 if len(nameType) != 2 { 545 stmt.Close() 546 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 547 } 548 stmt.colName = append(stmt.colName, nameType[0]) 549 stmt.colType = append(stmt.colType, nameType[1]) 550 } 551 return stmt, nil 552 } 553 554 // parts are table|col=?,col2=val 555 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { 556 if len(parts) != 2 { 557 stmt.Close() 558 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 559 } 560 stmt.table = parts[0] 561 for n, colspec := range strings.Split(parts[1], ",") { 562 nameVal := strings.Split(colspec, "=") 563 if len(nameVal) != 2 { 564 stmt.Close() 565 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 566 } 567 column, value := nameVal[0], nameVal[1] 568 ctype, ok := c.db.columnType(stmt.table, column) 569 if !ok { 570 stmt.Close() 571 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 572 } 573 stmt.colName = append(stmt.colName, column) 574 575 if !strings.HasPrefix(value, "?") { 576 var subsetVal interface{} 577 // Convert to driver subset type 578 switch ctype { 579 case "string": 580 subsetVal = []byte(value) 581 case "blob": 582 subsetVal = []byte(value) 583 case "int32": 584 i, err := strconv.Atoi(value) 585 if err != nil { 586 stmt.Close() 587 return nil, errf("invalid conversion to int32 from %q", value) 588 } 589 subsetVal = int64(i) // int64 is a subset type, but not int32 590 case "table": // For testing cursor reads. 591 c.skipDirtySession = true 592 vparts := strings.Split(value, "!") 593 594 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) 595 if err != nil { 596 return nil, err 597 } 598 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) 599 substmt.Close() 600 if err != nil { 601 return nil, err 602 } 603 subsetVal = cursor 604 default: 605 stmt.Close() 606 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 607 } 608 stmt.colValue = append(stmt.colValue, subsetVal) 609 } else { 610 stmt.placeholders++ 611 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 612 stmt.colValue = append(stmt.colValue, value) 613 } 614 } 615 return stmt, nil 616 } 617 618 // hook to simulate broken connections 619 var hookPrepareBadConn func() bool 620 621 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 622 panic("use PrepareContext") 623 } 624 625 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 626 c.numPrepare++ 627 if c.db == nil { 628 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 629 } 630 631 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 632 return nil, driver.ErrBadConn 633 } 634 635 c.touchMem() 636 var firstStmt, prev *fakeStmt 637 for _, query := range strings.Split(query, ";") { 638 parts := strings.Split(query, "|") 639 if len(parts) < 1 { 640 return nil, errf("empty query") 641 } 642 stmt := &fakeStmt{q: query, c: c, memToucher: c} 643 if firstStmt == nil { 644 firstStmt = stmt 645 } 646 if len(parts) >= 3 { 647 switch parts[0] { 648 case "PANIC": 649 stmt.panic = parts[1] 650 parts = parts[2:] 651 case "WAIT": 652 wait, err := time.ParseDuration(parts[1]) 653 if err != nil { 654 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 655 } 656 parts = parts[2:] 657 stmt.wait = wait 658 } 659 } 660 cmd := parts[0] 661 stmt.cmd = cmd 662 parts = parts[1:] 663 664 if c.waiter != nil { 665 c.waiter(ctx) 666 } 667 668 if stmt.wait > 0 { 669 wait := time.NewTimer(stmt.wait) 670 select { 671 case <-wait.C: 672 case <-ctx.Done(): 673 wait.Stop() 674 return nil, ctx.Err() 675 } 676 } 677 678 c.incrStat(&c.stmtsMade) 679 var err error 680 switch cmd { 681 case "WIPE": 682 // Nothing 683 case "SELECT": 684 stmt, err = c.prepareSelect(stmt, parts) 685 case "CREATE": 686 stmt, err = c.prepareCreate(stmt, parts) 687 case "INSERT": 688 stmt, err = c.prepareInsert(ctx, stmt, parts) 689 case "NOSERT": 690 // Do all the prep-work like for an INSERT but don't actually insert the row. 691 // Used for some of the concurrent tests. 692 stmt, err = c.prepareInsert(ctx, stmt, parts) 693 default: 694 stmt.Close() 695 return nil, errf("unsupported command type %q", cmd) 696 } 697 if err != nil { 698 return nil, err 699 } 700 if prev != nil { 701 prev.next = stmt 702 } 703 prev = stmt 704 } 705 return firstStmt, nil 706 } 707 708 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 709 if s.panic == "ColumnConverter" { 710 panic(s.panic) 711 } 712 if len(s.placeholderConverter) == 0 { 713 return driver.DefaultParameterConverter 714 } 715 return s.placeholderConverter[idx] 716 } 717 718 func (s *fakeStmt) Close() error { 719 if s.panic == "Close" { 720 panic(s.panic) 721 } 722 if s.c == nil { 723 panic("nil conn in fakeStmt.Close") 724 } 725 if s.c.db == nil { 726 panic("in fakeStmt.Close, conn's db is nil (already closed)") 727 } 728 s.touchMem() 729 if !s.closed { 730 s.c.incrStat(&s.c.stmtsClosed) 731 s.closed = true 732 } 733 if s.next != nil { 734 s.next.Close() 735 } 736 return nil 737 } 738 739 var errClosed = errors.New("fakedb: statement has been closed") 740 741 // hook to simulate broken connections 742 var hookExecBadConn func() bool 743 744 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 745 panic("Using ExecContext") 746 } 747 748 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") 749 750 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 751 if s.panic == "Exec" { 752 panic(s.panic) 753 } 754 if s.closed { 755 return nil, errClosed 756 } 757 758 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 759 return nil, driver.ErrBadConn 760 } 761 if s.c.isDirtyAndMark() { 762 return nil, errFakeConnSessionDirty 763 } 764 765 err := checkSubsetTypes(s.c.db.allowAny, args) 766 if err != nil { 767 return nil, err 768 } 769 s.touchMem() 770 771 if s.wait > 0 { 772 time.Sleep(s.wait) 773 } 774 775 select { 776 default: 777 case <-ctx.Done(): 778 return nil, ctx.Err() 779 } 780 781 db := s.c.db 782 switch s.cmd { 783 case "WIPE": 784 db.wipe() 785 return driver.ResultNoRows, nil 786 case "CREATE": 787 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 788 return nil, err 789 } 790 return driver.ResultNoRows, nil 791 case "INSERT": 792 return s.execInsert(args, true) 793 case "NOSERT": 794 // Do all the prep-work like for an INSERT but don't actually insert the row. 795 // Used for some of the concurrent tests. 796 return s.execInsert(args, false) 797 } 798 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) 799 } 800 801 // When doInsert is true, add the row to the table. 802 // When doInsert is false do prep-work and error checking, but don't 803 // actually add the row to the table. 804 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 805 db := s.c.db 806 if len(args) != s.placeholders { 807 panic("error in pkg db; should only get here if size is correct") 808 } 809 db.mu.Lock() 810 t, ok := db.table(s.table) 811 db.mu.Unlock() 812 if !ok { 813 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 814 } 815 816 t.mu.Lock() 817 defer t.mu.Unlock() 818 819 var cols []interface{} 820 if doInsert { 821 cols = make([]interface{}, len(t.colname)) 822 } 823 argPos := 0 824 for n, colname := range s.colName { 825 colidx := t.columnIndex(colname) 826 if colidx == -1 { 827 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 828 } 829 var val interface{} 830 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 831 if strvalue == "?" { 832 val = args[argPos].Value 833 } else { 834 // Assign value from argument placeholder name. 835 for _, a := range args { 836 if a.Name == strvalue[1:] { 837 val = a.Value 838 break 839 } 840 } 841 } 842 argPos++ 843 } else { 844 val = s.colValue[n] 845 } 846 if doInsert { 847 cols[colidx] = val 848 } 849 } 850 851 if doInsert { 852 t.rows = append(t.rows, &row{cols: cols}) 853 } 854 return driver.RowsAffected(1), nil 855 } 856 857 // hook to simulate broken connections 858 var hookQueryBadConn func() bool 859 860 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 861 panic("Use QueryContext") 862 } 863 864 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 865 if s.panic == "Query" { 866 panic(s.panic) 867 } 868 if s.closed { 869 return nil, errClosed 870 } 871 872 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 873 return nil, driver.ErrBadConn 874 } 875 if s.c.isDirtyAndMark() { 876 return nil, errFakeConnSessionDirty 877 } 878 879 err := checkSubsetTypes(s.c.db.allowAny, args) 880 if err != nil { 881 return nil, err 882 } 883 884 s.touchMem() 885 db := s.c.db 886 if len(args) != s.placeholders { 887 panic("error in pkg db; should only get here if size is correct") 888 } 889 890 setMRows := make([][]*row, 0, 1) 891 setColumns := make([][]string, 0, 1) 892 setColType := make([][]string, 0, 1) 893 894 for { 895 db.mu.Lock() 896 t, ok := db.table(s.table) 897 db.mu.Unlock() 898 if !ok { 899 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 900 } 901 902 if s.table == "magicquery" { 903 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 904 if args[0].Value == "sleep" { 905 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 906 } 907 } 908 } 909 if s.table == "tx_status" && s.colName[0] == "tx_status" { 910 txStatus := "autocommit" 911 if s.c.currTx != nil { 912 txStatus = "transaction" 913 } 914 cursor := &rowsCursor{ 915 parentMem: s.c, 916 posRow: -1, 917 rows: [][]*row{ 918 { 919 { 920 cols: []interface{}{ 921 txStatus, 922 }, 923 }, 924 }, 925 }, 926 cols: [][]string{ 927 { 928 "tx_status", 929 }, 930 }, 931 colType: [][]string{ 932 { 933 "string", 934 }, 935 }, 936 errPos: -1, 937 } 938 return cursor, nil 939 } 940 941 t.mu.Lock() 942 943 colIdx := make(map[string]int) // select column name -> column index in table 944 for _, name := range s.colName { 945 idx := t.columnIndex(name) 946 if idx == -1 { 947 t.mu.Unlock() 948 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 949 } 950 colIdx[name] = idx 951 } 952 953 mrows := []*row{} 954 rows: 955 for _, trow := range t.rows { 956 // Process the where clause, skipping non-match rows. This is lazy 957 // and just uses fmt.Sprintf("%v") to test equality. Good enough 958 // for test code. 959 for _, wcol := range s.whereCol { 960 idx := t.columnIndex(wcol.Column) 961 if idx == -1 { 962 t.mu.Unlock() 963 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) 964 } 965 tcol := trow.cols[idx] 966 if bs, ok := tcol.([]byte); ok { 967 // lazy hack to avoid sprintf %v on a []byte 968 tcol = string(bs) 969 } 970 var argValue interface{} 971 if wcol.Placeholder == "?" { 972 argValue = args[wcol.Ordinal-1].Value 973 } else { 974 // Assign arg value from placeholder name. 975 for _, a := range args { 976 if a.Name == wcol.Placeholder[1:] { 977 argValue = a.Value 978 break 979 } 980 } 981 } 982 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 983 continue rows 984 } 985 } 986 mrow := &row{cols: make([]interface{}, len(s.colName))} 987 for seli, name := range s.colName { 988 mrow.cols[seli] = trow.cols[colIdx[name]] 989 } 990 mrows = append(mrows, mrow) 991 } 992 993 var colType []string 994 for _, column := range s.colName { 995 colType = append(colType, t.coltype[t.columnIndex(column)]) 996 } 997 998 t.mu.Unlock() 999 1000 setMRows = append(setMRows, mrows) 1001 setColumns = append(setColumns, s.colName) 1002 setColType = append(setColType, colType) 1003 1004 if s.next == nil { 1005 break 1006 } 1007 s = s.next 1008 } 1009 1010 cursor := &rowsCursor{ 1011 parentMem: s.c, 1012 posRow: -1, 1013 rows: setMRows, 1014 cols: setColumns, 1015 colType: setColType, 1016 errPos: -1, 1017 } 1018 return cursor, nil 1019 } 1020 1021 func (s *fakeStmt) NumInput() int { 1022 if s.panic == "NumInput" { 1023 panic(s.panic) 1024 } 1025 return s.placeholders 1026 } 1027 1028 // hook to simulate broken connections 1029 var hookCommitBadConn func() bool 1030 1031 func (tx *fakeTx) Commit() error { 1032 tx.c.currTx = nil 1033 if hookCommitBadConn != nil && hookCommitBadConn() { 1034 return driver.ErrBadConn 1035 } 1036 tx.c.touchMem() 1037 return nil 1038 } 1039 1040 // hook to simulate broken connections 1041 var hookRollbackBadConn func() bool 1042 1043 func (tx *fakeTx) Rollback() error { 1044 tx.c.currTx = nil 1045 if hookRollbackBadConn != nil && hookRollbackBadConn() { 1046 return driver.ErrBadConn 1047 } 1048 tx.c.touchMem() 1049 return nil 1050 } 1051 1052 type rowsCursor struct { 1053 parentMem memToucher 1054 cols [][]string 1055 colType [][]string 1056 posSet int 1057 posRow int 1058 rows [][]*row 1059 closed bool 1060 1061 // errPos and err are for making Next return early with error. 1062 errPos int 1063 err error 1064 1065 // a clone of slices to give out to clients, indexed by the 1066 // original slice's first byte address. we clone them 1067 // just so we're able to corrupt them on close. 1068 bytesClone map[*byte][]byte 1069 1070 // Every operation writes to line to enable the race detector 1071 // check for data races. 1072 // This is separate from the fakeConn.line to allow for drivers that 1073 // can start multiple queries on the same transaction at the same time. 1074 line int64 1075 } 1076 1077 func (rc *rowsCursor) touchMem() { 1078 rc.parentMem.touchMem() 1079 rc.line++ 1080 } 1081 1082 func (rc *rowsCursor) Close() error { 1083 rc.touchMem() 1084 rc.parentMem.touchMem() 1085 rc.closed = true 1086 return nil 1087 } 1088 1089 func (rc *rowsCursor) Columns() []string { 1090 return rc.cols[rc.posSet] 1091 } 1092 1093 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1094 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1095 } 1096 1097 var rowsCursorNextHook func(dest []driver.Value) error 1098 1099 func (rc *rowsCursor) Next(dest []driver.Value) error { 1100 if rowsCursorNextHook != nil { 1101 return rowsCursorNextHook(dest) 1102 } 1103 1104 if rc.closed { 1105 return errors.New("fakedb: cursor is closed") 1106 } 1107 rc.touchMem() 1108 rc.posRow++ 1109 if rc.posRow == rc.errPos { 1110 return rc.err 1111 } 1112 if rc.posRow >= len(rc.rows[rc.posSet]) { 1113 return io.EOF // per interface spec 1114 } 1115 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1116 // TODO(bradfitz): convert to subset types? naah, I 1117 // think the subset types should only be input to 1118 // driver, but the sql package should be able to handle 1119 // a wider range of types coming out of drivers. all 1120 // for ease of drivers, and to prevent drivers from 1121 // messing up conversions or doing them differently. 1122 dest[i] = v 1123 1124 if bs, ok := v.([]byte); ok { 1125 if rc.bytesClone == nil { 1126 rc.bytesClone = make(map[*byte][]byte) 1127 } 1128 clone, ok := rc.bytesClone[&bs[0]] 1129 if !ok { 1130 clone = make([]byte, len(bs)) 1131 copy(clone, bs) 1132 rc.bytesClone[&bs[0]] = clone 1133 } 1134 dest[i] = clone 1135 } 1136 } 1137 return nil 1138 } 1139 1140 func (rc *rowsCursor) HasNextResultSet() bool { 1141 rc.touchMem() 1142 return rc.posSet < len(rc.rows)-1 1143 } 1144 1145 func (rc *rowsCursor) NextResultSet() error { 1146 rc.touchMem() 1147 if rc.HasNextResultSet() { 1148 rc.posSet++ 1149 rc.posRow = -1 1150 return nil 1151 } 1152 return io.EOF // Per interface spec. 1153 } 1154 1155 // fakeDriverString is like driver.String, but indirects pointers like 1156 // DefaultValueConverter. 1157 // 1158 // This could be surprising behavior to retroactively apply to 1159 // driver.String now that Go1 is out, but this is convenient for 1160 // our TestPointerParamsAndScans. 1161 // 1162 type fakeDriverString struct{} 1163 1164 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 1165 switch c := v.(type) { 1166 case string, []byte: 1167 return v, nil 1168 case *string: 1169 if c == nil { 1170 return nil, nil 1171 } 1172 return *c, nil 1173 } 1174 return fmt.Sprintf("%v", v), nil 1175 } 1176 1177 type anyTypeConverter struct{} 1178 1179 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { 1180 return v, nil 1181 } 1182 1183 func converterForType(typ string) driver.ValueConverter { 1184 switch typ { 1185 case "bool": 1186 return driver.Bool 1187 case "nullbool": 1188 return driver.Null{Converter: driver.Bool} 1189 case "byte", "int16": 1190 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1191 case "int32": 1192 return driver.Int32 1193 case "nullbyte", "nullint32", "nullint16": 1194 return driver.Null{Converter: driver.DefaultParameterConverter} 1195 case "string": 1196 return driver.NotNull{Converter: fakeDriverString{}} 1197 case "nullstring": 1198 return driver.Null{Converter: fakeDriverString{}} 1199 case "int64": 1200 // TODO(coopernurse): add type-specific converter 1201 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1202 case "nullint64": 1203 // TODO(coopernurse): add type-specific converter 1204 return driver.Null{Converter: driver.DefaultParameterConverter} 1205 case "float64": 1206 // TODO(coopernurse): add type-specific converter 1207 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1208 case "nullfloat64": 1209 // TODO(coopernurse): add type-specific converter 1210 return driver.Null{Converter: driver.DefaultParameterConverter} 1211 case "datetime": 1212 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1213 case "nulldatetime": 1214 return driver.Null{Converter: driver.DefaultParameterConverter} 1215 case "any": 1216 return anyTypeConverter{} 1217 } 1218 panic("invalid fakedb column type of " + typ) 1219 } 1220 1221 func colTypeToReflectType(typ string) reflect.Type { 1222 switch typ { 1223 case "bool": 1224 return reflect.TypeOf(false) 1225 case "nullbool": 1226 return reflect.TypeOf(NullBool{}) 1227 case "int16": 1228 return reflect.TypeOf(int16(0)) 1229 case "nullint16": 1230 return reflect.TypeOf(NullInt16{}) 1231 case "int32": 1232 return reflect.TypeOf(int32(0)) 1233 case "nullint32": 1234 return reflect.TypeOf(NullInt32{}) 1235 case "string": 1236 return reflect.TypeOf("") 1237 case "nullstring": 1238 return reflect.TypeOf(NullString{}) 1239 case "int64": 1240 return reflect.TypeOf(int64(0)) 1241 case "nullint64": 1242 return reflect.TypeOf(NullInt64{}) 1243 case "float64": 1244 return reflect.TypeOf(float64(0)) 1245 case "nullfloat64": 1246 return reflect.TypeOf(NullFloat64{}) 1247 case "datetime": 1248 return reflect.TypeOf(time.Time{}) 1249 case "any": 1250 return reflect.TypeOf(new(interface{})).Elem() 1251 } 1252 panic("invalid fakedb column type of " + typ) 1253 }