github.com/geraldss/go/src@v0.0.0-20210511222824-ac7d0ebfc235/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "reflect" 14 "sort" 15 "strconv" 16 "strings" 17 "sync" 18 "testing" 19 "time" 20 ) 21 22 // fakeDriver is a fake database that implements Go's driver.Driver 23 // interface, just for testing. 24 // 25 // It speaks a query language that's semantically similar to but 26 // syntactically different and simpler than SQL. The syntax is as 27 // follows: 28 // 29 // WIPE 30 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 31 // where types are: "string", [u]int{8,16,32,64}, "bool" 32 // INSERT|<tablename>|col=val,col2=val2,col3=? 33 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 34 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 35 // 36 // Any of these can be preceded by PANIC|<method>|, to cause the 37 // named method on fakeStmt to panic. 38 // 39 // Any of these can be proceeded by WAIT|<duration>|, to cause the 40 // named method on fakeStmt to sleep for the specified duration. 41 // 42 // Multiple of these can be combined when separated with a semicolon. 43 // 44 // When opening a fakeDriver's database, it starts empty with no 45 // tables. All tables and data are stored in memory only. 46 type fakeDriver struct { 47 mu sync.Mutex // guards 3 following fields 48 openCount int // conn opens 49 closeCount int // conn closes 50 waitCh chan struct{} 51 waitingCh chan struct{} 52 dbs map[string]*fakeDB 53 } 54 55 type fakeConnector struct { 56 name string 57 58 waiter func(context.Context) 59 } 60 61 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 62 conn, err := fdriver.Open(c.name) 63 conn.(*fakeConn).waiter = c.waiter 64 return conn, err 65 } 66 67 func (c *fakeConnector) Driver() driver.Driver { 68 return fdriver 69 } 70 71 type fakeDriverCtx struct { 72 fakeDriver 73 } 74 75 var _ driver.DriverContext = &fakeDriverCtx{} 76 77 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 78 return &fakeConnector{name: name}, nil 79 } 80 81 type fakeDB struct { 82 name string 83 84 mu sync.Mutex 85 tables map[string]*table 86 badConn bool 87 allowAny bool 88 } 89 90 type table struct { 91 mu sync.Mutex 92 colname []string 93 coltype []string 94 rows []*row 95 } 96 97 func (t *table) columnIndex(name string) int { 98 for n, nname := range t.colname { 99 if name == nname { 100 return n 101 } 102 } 103 return -1 104 } 105 106 type row struct { 107 cols []interface{} // must be same size as its table colname + coltype 108 } 109 110 type memToucher interface { 111 // touchMem reads & writes some memory, to help find data races. 112 touchMem() 113 } 114 115 type fakeConn struct { 116 db *fakeDB // where to return ourselves to 117 118 currTx *fakeTx 119 120 // Every operation writes to line to enable the race detector 121 // check for data races. 122 line int64 123 124 // Stats for tests: 125 mu sync.Mutex 126 stmtsMade int 127 stmtsClosed int 128 numPrepare int 129 130 // bad connection tests; see isBad() 131 bad bool 132 stickyBad bool 133 134 skipDirtySession bool // tests that use Conn should set this to true. 135 136 // dirtySession tests ResetSession, true if a query has executed 137 // until ResetSession is called. 138 dirtySession bool 139 140 // The waiter is called before each query. May be used in place of the "WAIT" 141 // directive. 142 waiter func(context.Context) 143 } 144 145 func (c *fakeConn) touchMem() { 146 c.line++ 147 } 148 149 func (c *fakeConn) incrStat(v *int) { 150 c.mu.Lock() 151 *v++ 152 c.mu.Unlock() 153 } 154 155 type fakeTx struct { 156 c *fakeConn 157 } 158 159 type boundCol struct { 160 Column string 161 Placeholder string 162 Ordinal int 163 } 164 165 type fakeStmt struct { 166 memToucher 167 c *fakeConn 168 q string // just for debugging 169 170 cmd string 171 table string 172 panic string 173 wait time.Duration 174 175 next *fakeStmt // used for returning multiple results. 176 177 closed bool 178 179 colName []string // used by CREATE, INSERT, SELECT (selected columns) 180 colType []string // used by CREATE 181 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 182 placeholders int // used by INSERT/SELECT: number of ? params 183 184 whereCol []boundCol // used by SELECT (all placeholders) 185 186 placeholderConverter []driver.ValueConverter // used by INSERT 187 } 188 189 var fdriver driver.Driver = &fakeDriver{} 190 191 func init() { 192 Register("test", fdriver) 193 } 194 195 func contains(list []string, y string) bool { 196 for _, x := range list { 197 if x == y { 198 return true 199 } 200 } 201 return false 202 } 203 204 type Dummy struct { 205 driver.Driver 206 } 207 208 func TestDrivers(t *testing.T) { 209 unregisterAllDrivers() 210 Register("test", fdriver) 211 Register("invalid", Dummy{}) 212 all := Drivers() 213 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 214 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 215 } 216 } 217 218 // hook to simulate connection failures 219 var hookOpenErr struct { 220 sync.Mutex 221 fn func() error 222 } 223 224 func setHookOpenErr(fn func() error) { 225 hookOpenErr.Lock() 226 defer hookOpenErr.Unlock() 227 hookOpenErr.fn = fn 228 } 229 230 // Supports dsn forms: 231 // <dbname> 232 // <dbname>;<opts> (only currently supported option is `badConn`, 233 // which causes driver.ErrBadConn to be returned on 234 // every other conn.Begin()) 235 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 236 hookOpenErr.Lock() 237 fn := hookOpenErr.fn 238 hookOpenErr.Unlock() 239 if fn != nil { 240 if err := fn(); err != nil { 241 return nil, err 242 } 243 } 244 parts := strings.Split(dsn, ";") 245 if len(parts) < 1 { 246 return nil, errors.New("fakedb: no database name") 247 } 248 name := parts[0] 249 250 db := d.getDB(name) 251 252 d.mu.Lock() 253 d.openCount++ 254 d.mu.Unlock() 255 conn := &fakeConn{db: db} 256 257 if len(parts) >= 2 && parts[1] == "badConn" { 258 conn.bad = true 259 } 260 if d.waitCh != nil { 261 d.waitingCh <- struct{}{} 262 <-d.waitCh 263 d.waitCh = nil 264 d.waitingCh = nil 265 } 266 return conn, nil 267 } 268 269 func (d *fakeDriver) getDB(name string) *fakeDB { 270 d.mu.Lock() 271 defer d.mu.Unlock() 272 if d.dbs == nil { 273 d.dbs = make(map[string]*fakeDB) 274 } 275 db, ok := d.dbs[name] 276 if !ok { 277 db = &fakeDB{name: name} 278 d.dbs[name] = db 279 } 280 return db 281 } 282 283 func (db *fakeDB) wipe() { 284 db.mu.Lock() 285 defer db.mu.Unlock() 286 db.tables = nil 287 } 288 289 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 290 db.mu.Lock() 291 defer db.mu.Unlock() 292 if db.tables == nil { 293 db.tables = make(map[string]*table) 294 } 295 if _, exist := db.tables[name]; exist { 296 return fmt.Errorf("fakedb: table %q already exists", name) 297 } 298 if len(columnNames) != len(columnTypes) { 299 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", 300 name, len(columnNames), len(columnTypes)) 301 } 302 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 303 return nil 304 } 305 306 // must be called with db.mu lock held 307 func (db *fakeDB) table(table string) (*table, bool) { 308 if db.tables == nil { 309 return nil, false 310 } 311 t, ok := db.tables[table] 312 return t, ok 313 } 314 315 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 316 db.mu.Lock() 317 defer db.mu.Unlock() 318 t, ok := db.table(table) 319 if !ok { 320 return 321 } 322 for n, cname := range t.colname { 323 if cname == column { 324 return t.coltype[n], true 325 } 326 } 327 return "", false 328 } 329 330 func (c *fakeConn) isBad() bool { 331 if c.stickyBad { 332 return true 333 } else if c.bad { 334 if c.db == nil { 335 return false 336 } 337 // alternate between bad conn and not bad conn 338 c.db.badConn = !c.db.badConn 339 return c.db.badConn 340 } else { 341 return false 342 } 343 } 344 345 func (c *fakeConn) isDirtyAndMark() bool { 346 if c.skipDirtySession { 347 return false 348 } 349 if c.currTx != nil { 350 c.dirtySession = true 351 return false 352 } 353 if c.dirtySession { 354 return true 355 } 356 c.dirtySession = true 357 return false 358 } 359 360 func (c *fakeConn) Begin() (driver.Tx, error) { 361 if c.isBad() { 362 return nil, driver.ErrBadConn 363 } 364 if c.currTx != nil { 365 return nil, errors.New("fakedb: already in a transaction") 366 } 367 c.touchMem() 368 c.currTx = &fakeTx{c: c} 369 return c.currTx, nil 370 } 371 372 var hookPostCloseConn struct { 373 sync.Mutex 374 fn func(*fakeConn, error) 375 } 376 377 func setHookpostCloseConn(fn func(*fakeConn, error)) { 378 hookPostCloseConn.Lock() 379 defer hookPostCloseConn.Unlock() 380 hookPostCloseConn.fn = fn 381 } 382 383 var testStrictClose *testing.T 384 385 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 386 // fails to close. If nil, the check is disabled. 387 func setStrictFakeConnClose(t *testing.T) { 388 testStrictClose = t 389 } 390 391 func (c *fakeConn) ResetSession(ctx context.Context) error { 392 c.dirtySession = false 393 c.currTx = nil 394 if c.isBad() { 395 return driver.ErrBadConn 396 } 397 return nil 398 } 399 400 var _ driver.Validator = (*fakeConn)(nil) 401 402 func (c *fakeConn) IsValid() bool { 403 return !c.isBad() 404 } 405 406 func (c *fakeConn) Close() (err error) { 407 drv := fdriver.(*fakeDriver) 408 defer func() { 409 if err != nil && testStrictClose != nil { 410 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 411 } 412 hookPostCloseConn.Lock() 413 fn := hookPostCloseConn.fn 414 hookPostCloseConn.Unlock() 415 if fn != nil { 416 fn(c, err) 417 } 418 if err == nil { 419 drv.mu.Lock() 420 drv.closeCount++ 421 drv.mu.Unlock() 422 } 423 }() 424 c.touchMem() 425 if c.currTx != nil { 426 return errors.New("fakedb: can't close fakeConn; in a Transaction") 427 } 428 if c.db == nil { 429 return errors.New("fakedb: can't close fakeConn; already closed") 430 } 431 if c.stmtsMade > c.stmtsClosed { 432 return errors.New("fakedb: can't close; dangling statement(s)") 433 } 434 c.db = nil 435 return nil 436 } 437 438 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 439 for _, arg := range args { 440 switch arg.Value.(type) { 441 case int64, float64, bool, nil, []byte, string, time.Time: 442 default: 443 if !allowAny { 444 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 445 } 446 } 447 } 448 return nil 449 } 450 451 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 452 // Ensure that ExecContext is called if available. 453 panic("ExecContext was not called.") 454 } 455 456 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 457 // This is an optional interface, but it's implemented here 458 // just to check that all the args are of the proper types. 459 // ErrSkip is returned so the caller acts as if we didn't 460 // implement this at all. 461 err := checkSubsetTypes(c.db.allowAny, args) 462 if err != nil { 463 return nil, err 464 } 465 return nil, driver.ErrSkip 466 } 467 468 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 469 // Ensure that ExecContext is called if available. 470 panic("QueryContext was not called.") 471 } 472 473 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 474 // This is an optional interface, but it's implemented here 475 // just to check that all the args are of the proper types. 476 // ErrSkip is returned so the caller acts as if we didn't 477 // implement this at all. 478 err := checkSubsetTypes(c.db.allowAny, args) 479 if err != nil { 480 return nil, err 481 } 482 return nil, driver.ErrSkip 483 } 484 485 func errf(msg string, args ...interface{}) error { 486 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 487 } 488 489 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 490 // (note that where columns must always contain ? marks, 491 // just a limitation for fakedb) 492 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 493 if len(parts) != 3 { 494 stmt.Close() 495 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 496 } 497 stmt.table = parts[0] 498 499 stmt.colName = strings.Split(parts[1], ",") 500 for n, colspec := range strings.Split(parts[2], ",") { 501 if colspec == "" { 502 continue 503 } 504 nameVal := strings.Split(colspec, "=") 505 if len(nameVal) != 2 { 506 stmt.Close() 507 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 508 } 509 column, value := nameVal[0], nameVal[1] 510 _, ok := c.db.columnType(stmt.table, column) 511 if !ok { 512 stmt.Close() 513 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 514 } 515 if !strings.HasPrefix(value, "?") { 516 stmt.Close() 517 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 518 stmt.table, column) 519 } 520 stmt.placeholders++ 521 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 522 } 523 return stmt, nil 524 } 525 526 // parts are table|col=type,col2=type2 527 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 528 if len(parts) != 2 { 529 stmt.Close() 530 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 531 } 532 stmt.table = parts[0] 533 for n, colspec := range strings.Split(parts[1], ",") { 534 nameType := strings.Split(colspec, "=") 535 if len(nameType) != 2 { 536 stmt.Close() 537 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 538 } 539 stmt.colName = append(stmt.colName, nameType[0]) 540 stmt.colType = append(stmt.colType, nameType[1]) 541 } 542 return stmt, nil 543 } 544 545 // parts are table|col=?,col2=val 546 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { 547 if len(parts) != 2 { 548 stmt.Close() 549 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 550 } 551 stmt.table = parts[0] 552 for n, colspec := range strings.Split(parts[1], ",") { 553 nameVal := strings.Split(colspec, "=") 554 if len(nameVal) != 2 { 555 stmt.Close() 556 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 557 } 558 column, value := nameVal[0], nameVal[1] 559 ctype, ok := c.db.columnType(stmt.table, column) 560 if !ok { 561 stmt.Close() 562 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 563 } 564 stmt.colName = append(stmt.colName, column) 565 566 if !strings.HasPrefix(value, "?") { 567 var subsetVal interface{} 568 // Convert to driver subset type 569 switch ctype { 570 case "string": 571 subsetVal = []byte(value) 572 case "blob": 573 subsetVal = []byte(value) 574 case "int32": 575 i, err := strconv.Atoi(value) 576 if err != nil { 577 stmt.Close() 578 return nil, errf("invalid conversion to int32 from %q", value) 579 } 580 subsetVal = int64(i) // int64 is a subset type, but not int32 581 case "table": // For testing cursor reads. 582 c.skipDirtySession = true 583 vparts := strings.Split(value, "!") 584 585 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) 586 if err != nil { 587 return nil, err 588 } 589 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) 590 substmt.Close() 591 if err != nil { 592 return nil, err 593 } 594 subsetVal = cursor 595 default: 596 stmt.Close() 597 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 598 } 599 stmt.colValue = append(stmt.colValue, subsetVal) 600 } else { 601 stmt.placeholders++ 602 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 603 stmt.colValue = append(stmt.colValue, value) 604 } 605 } 606 return stmt, nil 607 } 608 609 // hook to simulate broken connections 610 var hookPrepareBadConn func() bool 611 612 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 613 panic("use PrepareContext") 614 } 615 616 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 617 c.numPrepare++ 618 if c.db == nil { 619 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 620 } 621 622 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 623 return nil, driver.ErrBadConn 624 } 625 626 c.touchMem() 627 var firstStmt, prev *fakeStmt 628 for _, query := range strings.Split(query, ";") { 629 parts := strings.Split(query, "|") 630 if len(parts) < 1 { 631 return nil, errf("empty query") 632 } 633 stmt := &fakeStmt{q: query, c: c, memToucher: c} 634 if firstStmt == nil { 635 firstStmt = stmt 636 } 637 if len(parts) >= 3 { 638 switch parts[0] { 639 case "PANIC": 640 stmt.panic = parts[1] 641 parts = parts[2:] 642 case "WAIT": 643 wait, err := time.ParseDuration(parts[1]) 644 if err != nil { 645 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 646 } 647 parts = parts[2:] 648 stmt.wait = wait 649 } 650 } 651 cmd := parts[0] 652 stmt.cmd = cmd 653 parts = parts[1:] 654 655 if c.waiter != nil { 656 c.waiter(ctx) 657 } 658 659 if stmt.wait > 0 { 660 wait := time.NewTimer(stmt.wait) 661 select { 662 case <-wait.C: 663 case <-ctx.Done(): 664 wait.Stop() 665 return nil, ctx.Err() 666 } 667 } 668 669 c.incrStat(&c.stmtsMade) 670 var err error 671 switch cmd { 672 case "WIPE": 673 // Nothing 674 case "SELECT": 675 stmt, err = c.prepareSelect(stmt, parts) 676 case "CREATE": 677 stmt, err = c.prepareCreate(stmt, parts) 678 case "INSERT": 679 stmt, err = c.prepareInsert(ctx, stmt, parts) 680 case "NOSERT": 681 // Do all the prep-work like for an INSERT but don't actually insert the row. 682 // Used for some of the concurrent tests. 683 stmt, err = c.prepareInsert(ctx, stmt, parts) 684 default: 685 stmt.Close() 686 return nil, errf("unsupported command type %q", cmd) 687 } 688 if err != nil { 689 return nil, err 690 } 691 if prev != nil { 692 prev.next = stmt 693 } 694 prev = stmt 695 } 696 return firstStmt, nil 697 } 698 699 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 700 if s.panic == "ColumnConverter" { 701 panic(s.panic) 702 } 703 if len(s.placeholderConverter) == 0 { 704 return driver.DefaultParameterConverter 705 } 706 return s.placeholderConverter[idx] 707 } 708 709 func (s *fakeStmt) Close() error { 710 if s.panic == "Close" { 711 panic(s.panic) 712 } 713 if s.c == nil { 714 panic("nil conn in fakeStmt.Close") 715 } 716 if s.c.db == nil { 717 panic("in fakeStmt.Close, conn's db is nil (already closed)") 718 } 719 s.touchMem() 720 if !s.closed { 721 s.c.incrStat(&s.c.stmtsClosed) 722 s.closed = true 723 } 724 if s.next != nil { 725 s.next.Close() 726 } 727 return nil 728 } 729 730 var errClosed = errors.New("fakedb: statement has been closed") 731 732 // hook to simulate broken connections 733 var hookExecBadConn func() bool 734 735 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 736 panic("Using ExecContext") 737 } 738 739 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") 740 741 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 742 if s.panic == "Exec" { 743 panic(s.panic) 744 } 745 if s.closed { 746 return nil, errClosed 747 } 748 749 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 750 return nil, driver.ErrBadConn 751 } 752 if s.c.isDirtyAndMark() { 753 return nil, errFakeConnSessionDirty 754 } 755 756 err := checkSubsetTypes(s.c.db.allowAny, args) 757 if err != nil { 758 return nil, err 759 } 760 s.touchMem() 761 762 if s.wait > 0 { 763 time.Sleep(s.wait) 764 } 765 766 select { 767 default: 768 case <-ctx.Done(): 769 return nil, ctx.Err() 770 } 771 772 db := s.c.db 773 switch s.cmd { 774 case "WIPE": 775 db.wipe() 776 return driver.ResultNoRows, nil 777 case "CREATE": 778 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 779 return nil, err 780 } 781 return driver.ResultNoRows, nil 782 case "INSERT": 783 return s.execInsert(args, true) 784 case "NOSERT": 785 // Do all the prep-work like for an INSERT but don't actually insert the row. 786 // Used for some of the concurrent tests. 787 return s.execInsert(args, false) 788 } 789 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) 790 } 791 792 // When doInsert is true, add the row to the table. 793 // When doInsert is false do prep-work and error checking, but don't 794 // actually add the row to the table. 795 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 796 db := s.c.db 797 if len(args) != s.placeholders { 798 panic("error in pkg db; should only get here if size is correct") 799 } 800 db.mu.Lock() 801 t, ok := db.table(s.table) 802 db.mu.Unlock() 803 if !ok { 804 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 805 } 806 807 t.mu.Lock() 808 defer t.mu.Unlock() 809 810 var cols []interface{} 811 if doInsert { 812 cols = make([]interface{}, len(t.colname)) 813 } 814 argPos := 0 815 for n, colname := range s.colName { 816 colidx := t.columnIndex(colname) 817 if colidx == -1 { 818 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 819 } 820 var val interface{} 821 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 822 if strvalue == "?" { 823 val = args[argPos].Value 824 } else { 825 // Assign value from argument placeholder name. 826 for _, a := range args { 827 if a.Name == strvalue[1:] { 828 val = a.Value 829 break 830 } 831 } 832 } 833 argPos++ 834 } else { 835 val = s.colValue[n] 836 } 837 if doInsert { 838 cols[colidx] = val 839 } 840 } 841 842 if doInsert { 843 t.rows = append(t.rows, &row{cols: cols}) 844 } 845 return driver.RowsAffected(1), nil 846 } 847 848 // hook to simulate broken connections 849 var hookQueryBadConn func() bool 850 851 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 852 panic("Use QueryContext") 853 } 854 855 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 856 if s.panic == "Query" { 857 panic(s.panic) 858 } 859 if s.closed { 860 return nil, errClosed 861 } 862 863 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 864 return nil, driver.ErrBadConn 865 } 866 if s.c.isDirtyAndMark() { 867 return nil, errFakeConnSessionDirty 868 } 869 870 err := checkSubsetTypes(s.c.db.allowAny, args) 871 if err != nil { 872 return nil, err 873 } 874 875 s.touchMem() 876 db := s.c.db 877 if len(args) != s.placeholders { 878 panic("error in pkg db; should only get here if size is correct") 879 } 880 881 setMRows := make([][]*row, 0, 1) 882 setColumns := make([][]string, 0, 1) 883 setColType := make([][]string, 0, 1) 884 885 for { 886 db.mu.Lock() 887 t, ok := db.table(s.table) 888 db.mu.Unlock() 889 if !ok { 890 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 891 } 892 893 if s.table == "magicquery" { 894 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 895 if args[0].Value == "sleep" { 896 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 897 } 898 } 899 } 900 if s.table == "tx_status" && s.colName[0] == "tx_status" { 901 txStatus := "autocommit" 902 if s.c.currTx != nil { 903 txStatus = "transaction" 904 } 905 cursor := &rowsCursor{ 906 parentMem: s.c, 907 posRow: -1, 908 rows: [][]*row{ 909 []*row{ 910 { 911 cols: []interface{}{ 912 txStatus, 913 }, 914 }, 915 }, 916 }, 917 cols: [][]string{ 918 []string{ 919 "tx_status", 920 }, 921 }, 922 colType: [][]string{ 923 []string{ 924 "string", 925 }, 926 }, 927 errPos: -1, 928 } 929 return cursor, nil 930 } 931 932 t.mu.Lock() 933 934 colIdx := make(map[string]int) // select column name -> column index in table 935 for _, name := range s.colName { 936 idx := t.columnIndex(name) 937 if idx == -1 { 938 t.mu.Unlock() 939 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 940 } 941 colIdx[name] = idx 942 } 943 944 mrows := []*row{} 945 rows: 946 for _, trow := range t.rows { 947 // Process the where clause, skipping non-match rows. This is lazy 948 // and just uses fmt.Sprintf("%v") to test equality. Good enough 949 // for test code. 950 for _, wcol := range s.whereCol { 951 idx := t.columnIndex(wcol.Column) 952 if idx == -1 { 953 t.mu.Unlock() 954 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) 955 } 956 tcol := trow.cols[idx] 957 if bs, ok := tcol.([]byte); ok { 958 // lazy hack to avoid sprintf %v on a []byte 959 tcol = string(bs) 960 } 961 var argValue interface{} 962 if wcol.Placeholder == "?" { 963 argValue = args[wcol.Ordinal-1].Value 964 } else { 965 // Assign arg value from placeholder name. 966 for _, a := range args { 967 if a.Name == wcol.Placeholder[1:] { 968 argValue = a.Value 969 break 970 } 971 } 972 } 973 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 974 continue rows 975 } 976 } 977 mrow := &row{cols: make([]interface{}, len(s.colName))} 978 for seli, name := range s.colName { 979 mrow.cols[seli] = trow.cols[colIdx[name]] 980 } 981 mrows = append(mrows, mrow) 982 } 983 984 var colType []string 985 for _, column := range s.colName { 986 colType = append(colType, t.coltype[t.columnIndex(column)]) 987 } 988 989 t.mu.Unlock() 990 991 setMRows = append(setMRows, mrows) 992 setColumns = append(setColumns, s.colName) 993 setColType = append(setColType, colType) 994 995 if s.next == nil { 996 break 997 } 998 s = s.next 999 } 1000 1001 cursor := &rowsCursor{ 1002 parentMem: s.c, 1003 posRow: -1, 1004 rows: setMRows, 1005 cols: setColumns, 1006 colType: setColType, 1007 errPos: -1, 1008 } 1009 return cursor, nil 1010 } 1011 1012 func (s *fakeStmt) NumInput() int { 1013 if s.panic == "NumInput" { 1014 panic(s.panic) 1015 } 1016 return s.placeholders 1017 } 1018 1019 // hook to simulate broken connections 1020 var hookCommitBadConn func() bool 1021 1022 func (tx *fakeTx) Commit() error { 1023 tx.c.currTx = nil 1024 if hookCommitBadConn != nil && hookCommitBadConn() { 1025 return driver.ErrBadConn 1026 } 1027 tx.c.touchMem() 1028 return nil 1029 } 1030 1031 // hook to simulate broken connections 1032 var hookRollbackBadConn func() bool 1033 1034 func (tx *fakeTx) Rollback() error { 1035 tx.c.currTx = nil 1036 if hookRollbackBadConn != nil && hookRollbackBadConn() { 1037 return driver.ErrBadConn 1038 } 1039 tx.c.touchMem() 1040 return nil 1041 } 1042 1043 type rowsCursor struct { 1044 parentMem memToucher 1045 cols [][]string 1046 colType [][]string 1047 posSet int 1048 posRow int 1049 rows [][]*row 1050 closed bool 1051 1052 // errPos and err are for making Next return early with error. 1053 errPos int 1054 err error 1055 1056 // a clone of slices to give out to clients, indexed by the 1057 // original slice's first byte address. we clone them 1058 // just so we're able to corrupt them on close. 1059 bytesClone map[*byte][]byte 1060 1061 // Every operation writes to line to enable the race detector 1062 // check for data races. 1063 // This is separate from the fakeConn.line to allow for drivers that 1064 // can start multiple queries on the same transaction at the same time. 1065 line int64 1066 } 1067 1068 func (rc *rowsCursor) touchMem() { 1069 rc.parentMem.touchMem() 1070 rc.line++ 1071 } 1072 1073 func (rc *rowsCursor) Close() error { 1074 rc.touchMem() 1075 rc.parentMem.touchMem() 1076 rc.closed = true 1077 return nil 1078 } 1079 1080 func (rc *rowsCursor) Columns() []string { 1081 return rc.cols[rc.posSet] 1082 } 1083 1084 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1085 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1086 } 1087 1088 var rowsCursorNextHook func(dest []driver.Value) error 1089 1090 func (rc *rowsCursor) Next(dest []driver.Value) error { 1091 if rowsCursorNextHook != nil { 1092 return rowsCursorNextHook(dest) 1093 } 1094 1095 if rc.closed { 1096 return errors.New("fakedb: cursor is closed") 1097 } 1098 rc.touchMem() 1099 rc.posRow++ 1100 if rc.posRow == rc.errPos { 1101 return rc.err 1102 } 1103 if rc.posRow >= len(rc.rows[rc.posSet]) { 1104 return io.EOF // per interface spec 1105 } 1106 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1107 // TODO(bradfitz): convert to subset types? naah, I 1108 // think the subset types should only be input to 1109 // driver, but the sql package should be able to handle 1110 // a wider range of types coming out of drivers. all 1111 // for ease of drivers, and to prevent drivers from 1112 // messing up conversions or doing them differently. 1113 dest[i] = v 1114 1115 if bs, ok := v.([]byte); ok { 1116 if rc.bytesClone == nil { 1117 rc.bytesClone = make(map[*byte][]byte) 1118 } 1119 clone, ok := rc.bytesClone[&bs[0]] 1120 if !ok { 1121 clone = make([]byte, len(bs)) 1122 copy(clone, bs) 1123 rc.bytesClone[&bs[0]] = clone 1124 } 1125 dest[i] = clone 1126 } 1127 } 1128 return nil 1129 } 1130 1131 func (rc *rowsCursor) HasNextResultSet() bool { 1132 rc.touchMem() 1133 return rc.posSet < len(rc.rows)-1 1134 } 1135 1136 func (rc *rowsCursor) NextResultSet() error { 1137 rc.touchMem() 1138 if rc.HasNextResultSet() { 1139 rc.posSet++ 1140 rc.posRow = -1 1141 return nil 1142 } 1143 return io.EOF // Per interface spec. 1144 } 1145 1146 // fakeDriverString is like driver.String, but indirects pointers like 1147 // DefaultValueConverter. 1148 // 1149 // This could be surprising behavior to retroactively apply to 1150 // driver.String now that Go1 is out, but this is convenient for 1151 // our TestPointerParamsAndScans. 1152 // 1153 type fakeDriverString struct{} 1154 1155 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 1156 switch c := v.(type) { 1157 case string, []byte: 1158 return v, nil 1159 case *string: 1160 if c == nil { 1161 return nil, nil 1162 } 1163 return *c, nil 1164 } 1165 return fmt.Sprintf("%v", v), nil 1166 } 1167 1168 type anyTypeConverter struct{} 1169 1170 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { 1171 return v, nil 1172 } 1173 1174 func converterForType(typ string) driver.ValueConverter { 1175 switch typ { 1176 case "bool": 1177 return driver.Bool 1178 case "nullbool": 1179 return driver.Null{Converter: driver.Bool} 1180 case "int32": 1181 return driver.Int32 1182 case "nullint32": 1183 return driver.Null{Converter: driver.DefaultParameterConverter} 1184 case "string": 1185 return driver.NotNull{Converter: fakeDriverString{}} 1186 case "nullstring": 1187 return driver.Null{Converter: fakeDriverString{}} 1188 case "int64": 1189 // TODO(coopernurse): add type-specific converter 1190 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1191 case "nullint64": 1192 // TODO(coopernurse): add type-specific converter 1193 return driver.Null{Converter: driver.DefaultParameterConverter} 1194 case "float64": 1195 // TODO(coopernurse): add type-specific converter 1196 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1197 case "nullfloat64": 1198 // TODO(coopernurse): add type-specific converter 1199 return driver.Null{Converter: driver.DefaultParameterConverter} 1200 case "datetime": 1201 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1202 case "nulldatetime": 1203 return driver.Null{Converter: driver.DefaultParameterConverter} 1204 case "any": 1205 return anyTypeConverter{} 1206 } 1207 panic("invalid fakedb column type of " + typ) 1208 } 1209 1210 func colTypeToReflectType(typ string) reflect.Type { 1211 switch typ { 1212 case "bool": 1213 return reflect.TypeOf(false) 1214 case "nullbool": 1215 return reflect.TypeOf(NullBool{}) 1216 case "int32": 1217 return reflect.TypeOf(int32(0)) 1218 case "nullint32": 1219 return reflect.TypeOf(NullInt32{}) 1220 case "string": 1221 return reflect.TypeOf("") 1222 case "nullstring": 1223 return reflect.TypeOf(NullString{}) 1224 case "int64": 1225 return reflect.TypeOf(int64(0)) 1226 case "nullint64": 1227 return reflect.TypeOf(NullInt64{}) 1228 case "float64": 1229 return reflect.TypeOf(float64(0)) 1230 case "nullfloat64": 1231 return reflect.TypeOf(NullFloat64{}) 1232 case "datetime": 1233 return reflect.TypeOf(time.Time{}) 1234 case "any": 1235 return reflect.TypeOf(new(interface{})).Elem() 1236 } 1237 panic("invalid fakedb column type of " + typ) 1238 }