github.com/AESNooper/go/src@v0.0.0-20220218095104-b56a4ab1bbbb/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "reflect" 14 "sort" 15 "strconv" 16 "strings" 17 "sync" 18 "testing" 19 "time" 20 ) 21 22 // fakeDriver is a fake database that implements Go's driver.Driver 23 // interface, just for testing. 24 // 25 // It speaks a query language that's semantically similar to but 26 // syntactically different and simpler than SQL. The syntax is as 27 // follows: 28 // 29 // WIPE 30 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 31 // where types are: "string", [u]int{8,16,32,64}, "bool" 32 // INSERT|<tablename>|col=val,col2=val2,col3=? 33 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 34 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 35 // 36 // Any of these can be preceded by PANIC|<method>|, to cause the 37 // named method on fakeStmt to panic. 38 // 39 // Any of these can be proceeded by WAIT|<duration>|, to cause the 40 // named method on fakeStmt to sleep for the specified duration. 41 // 42 // Multiple of these can be combined when separated with a semicolon. 43 // 44 // When opening a fakeDriver's database, it starts empty with no 45 // tables. All tables and data are stored in memory only. 46 type fakeDriver struct { 47 mu sync.Mutex // guards 3 following fields 48 openCount int // conn opens 49 closeCount int // conn closes 50 waitCh chan struct{} 51 waitingCh chan struct{} 52 dbs map[string]*fakeDB 53 } 54 55 type fakeConnector struct { 56 name string 57 58 waiter func(context.Context) 59 closed bool 60 } 61 62 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 63 conn, err := fdriver.Open(c.name) 64 conn.(*fakeConn).waiter = c.waiter 65 return conn, err 66 } 67 68 func (c *fakeConnector) Driver() driver.Driver { 69 return fdriver 70 } 71 72 func (c *fakeConnector) Close() error { 73 if c.closed { 74 return errors.New("fakedb: connector is closed") 75 } 76 c.closed = true 77 return nil 78 } 79 80 type fakeDriverCtx struct { 81 fakeDriver 82 } 83 84 var _ driver.DriverContext = &fakeDriverCtx{} 85 86 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 87 return &fakeConnector{name: name}, nil 88 } 89 90 type fakeDB struct { 91 name string 92 93 mu sync.Mutex 94 tables map[string]*table 95 badConn bool 96 allowAny bool 97 } 98 99 type fakeError struct { 100 Message string 101 Wrapped error 102 } 103 104 func (err fakeError) Error() string { 105 return err.Message 106 } 107 108 func (err fakeError) Unwrap() error { 109 return err.Wrapped 110 } 111 112 type table struct { 113 mu sync.Mutex 114 colname []string 115 coltype []string 116 rows []*row 117 } 118 119 func (t *table) columnIndex(name string) int { 120 for n, nname := range t.colname { 121 if name == nname { 122 return n 123 } 124 } 125 return -1 126 } 127 128 type row struct { 129 cols []interface{} // must be same size as its table colname + coltype 130 } 131 132 type memToucher interface { 133 // touchMem reads & writes some memory, to help find data races. 134 touchMem() 135 } 136 137 type fakeConn struct { 138 db *fakeDB // where to return ourselves to 139 140 currTx *fakeTx 141 142 // Every operation writes to line to enable the race detector 143 // check for data races. 144 line int64 145 146 // Stats for tests: 147 mu sync.Mutex 148 stmtsMade int 149 stmtsClosed int 150 numPrepare int 151 152 // bad connection tests; see isBad() 153 bad bool 154 stickyBad bool 155 156 skipDirtySession bool // tests that use Conn should set this to true. 157 158 // dirtySession tests ResetSession, true if a query has executed 159 // until ResetSession is called. 160 dirtySession bool 161 162 // The waiter is called before each query. May be used in place of the "WAIT" 163 // directive. 164 waiter func(context.Context) 165 } 166 167 func (c *fakeConn) touchMem() { 168 c.line++ 169 } 170 171 func (c *fakeConn) incrStat(v *int) { 172 c.mu.Lock() 173 *v++ 174 c.mu.Unlock() 175 } 176 177 type fakeTx struct { 178 c *fakeConn 179 } 180 181 type boundCol struct { 182 Column string 183 Placeholder string 184 Ordinal int 185 } 186 187 type fakeStmt struct { 188 memToucher 189 c *fakeConn 190 q string // just for debugging 191 192 cmd string 193 table string 194 panic string 195 wait time.Duration 196 197 next *fakeStmt // used for returning multiple results. 198 199 closed bool 200 201 colName []string // used by CREATE, INSERT, SELECT (selected columns) 202 colType []string // used by CREATE 203 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 204 placeholders int // used by INSERT/SELECT: number of ? params 205 206 whereCol []boundCol // used by SELECT (all placeholders) 207 208 placeholderConverter []driver.ValueConverter // used by INSERT 209 } 210 211 var fdriver driver.Driver = &fakeDriver{} 212 213 func init() { 214 Register("test", fdriver) 215 } 216 217 func contains(list []string, y string) bool { 218 for _, x := range list { 219 if x == y { 220 return true 221 } 222 } 223 return false 224 } 225 226 type Dummy struct { 227 driver.Driver 228 } 229 230 func TestDrivers(t *testing.T) { 231 unregisterAllDrivers() 232 Register("test", fdriver) 233 Register("invalid", Dummy{}) 234 all := Drivers() 235 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 236 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 237 } 238 } 239 240 // hook to simulate connection failures 241 var hookOpenErr struct { 242 sync.Mutex 243 fn func() error 244 } 245 246 func setHookOpenErr(fn func() error) { 247 hookOpenErr.Lock() 248 defer hookOpenErr.Unlock() 249 hookOpenErr.fn = fn 250 } 251 252 // Supports dsn forms: 253 // <dbname> 254 // <dbname>;<opts> (only currently supported option is `badConn`, 255 // which causes driver.ErrBadConn to be returned on 256 // every other conn.Begin()) 257 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 258 hookOpenErr.Lock() 259 fn := hookOpenErr.fn 260 hookOpenErr.Unlock() 261 if fn != nil { 262 if err := fn(); err != nil { 263 return nil, err 264 } 265 } 266 parts := strings.Split(dsn, ";") 267 if len(parts) < 1 { 268 return nil, errors.New("fakedb: no database name") 269 } 270 name := parts[0] 271 272 db := d.getDB(name) 273 274 d.mu.Lock() 275 d.openCount++ 276 d.mu.Unlock() 277 conn := &fakeConn{db: db} 278 279 if len(parts) >= 2 && parts[1] == "badConn" { 280 conn.bad = true 281 } 282 if d.waitCh != nil { 283 d.waitingCh <- struct{}{} 284 <-d.waitCh 285 d.waitCh = nil 286 d.waitingCh = nil 287 } 288 return conn, nil 289 } 290 291 func (d *fakeDriver) getDB(name string) *fakeDB { 292 d.mu.Lock() 293 defer d.mu.Unlock() 294 if d.dbs == nil { 295 d.dbs = make(map[string]*fakeDB) 296 } 297 db, ok := d.dbs[name] 298 if !ok { 299 db = &fakeDB{name: name} 300 d.dbs[name] = db 301 } 302 return db 303 } 304 305 func (db *fakeDB) wipe() { 306 db.mu.Lock() 307 defer db.mu.Unlock() 308 db.tables = nil 309 } 310 311 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 312 db.mu.Lock() 313 defer db.mu.Unlock() 314 if db.tables == nil { 315 db.tables = make(map[string]*table) 316 } 317 if _, exist := db.tables[name]; exist { 318 return fmt.Errorf("fakedb: table %q already exists", name) 319 } 320 if len(columnNames) != len(columnTypes) { 321 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", 322 name, len(columnNames), len(columnTypes)) 323 } 324 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 325 return nil 326 } 327 328 // must be called with db.mu lock held 329 func (db *fakeDB) table(table string) (*table, bool) { 330 if db.tables == nil { 331 return nil, false 332 } 333 t, ok := db.tables[table] 334 return t, ok 335 } 336 337 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 338 db.mu.Lock() 339 defer db.mu.Unlock() 340 t, ok := db.table(table) 341 if !ok { 342 return 343 } 344 for n, cname := range t.colname { 345 if cname == column { 346 return t.coltype[n], true 347 } 348 } 349 return "", false 350 } 351 352 func (c *fakeConn) isBad() bool { 353 if c.stickyBad { 354 return true 355 } else if c.bad { 356 if c.db == nil { 357 return false 358 } 359 // alternate between bad conn and not bad conn 360 c.db.badConn = !c.db.badConn 361 return c.db.badConn 362 } else { 363 return false 364 } 365 } 366 367 func (c *fakeConn) isDirtyAndMark() bool { 368 if c.skipDirtySession { 369 return false 370 } 371 if c.currTx != nil { 372 c.dirtySession = true 373 return false 374 } 375 if c.dirtySession { 376 return true 377 } 378 c.dirtySession = true 379 return false 380 } 381 382 func (c *fakeConn) Begin() (driver.Tx, error) { 383 if c.isBad() { 384 return nil, fakeError{Wrapped: driver.ErrBadConn} 385 } 386 if c.currTx != nil { 387 return nil, errors.New("fakedb: already in a transaction") 388 } 389 c.touchMem() 390 c.currTx = &fakeTx{c: c} 391 return c.currTx, nil 392 } 393 394 var hookPostCloseConn struct { 395 sync.Mutex 396 fn func(*fakeConn, error) 397 } 398 399 func setHookpostCloseConn(fn func(*fakeConn, error)) { 400 hookPostCloseConn.Lock() 401 defer hookPostCloseConn.Unlock() 402 hookPostCloseConn.fn = fn 403 } 404 405 var testStrictClose *testing.T 406 407 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 408 // fails to close. If nil, the check is disabled. 409 func setStrictFakeConnClose(t *testing.T) { 410 testStrictClose = t 411 } 412 413 func (c *fakeConn) ResetSession(ctx context.Context) error { 414 c.dirtySession = false 415 c.currTx = nil 416 if c.isBad() { 417 return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} 418 } 419 return nil 420 } 421 422 var _ driver.Validator = (*fakeConn)(nil) 423 424 func (c *fakeConn) IsValid() bool { 425 return !c.isBad() 426 } 427 428 func (c *fakeConn) Close() (err error) { 429 drv := fdriver.(*fakeDriver) 430 defer func() { 431 if err != nil && testStrictClose != nil { 432 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 433 } 434 hookPostCloseConn.Lock() 435 fn := hookPostCloseConn.fn 436 hookPostCloseConn.Unlock() 437 if fn != nil { 438 fn(c, err) 439 } 440 if err == nil { 441 drv.mu.Lock() 442 drv.closeCount++ 443 drv.mu.Unlock() 444 } 445 }() 446 c.touchMem() 447 if c.currTx != nil { 448 return errors.New("fakedb: can't close fakeConn; in a Transaction") 449 } 450 if c.db == nil { 451 return errors.New("fakedb: can't close fakeConn; already closed") 452 } 453 if c.stmtsMade > c.stmtsClosed { 454 return errors.New("fakedb: can't close; dangling statement(s)") 455 } 456 c.db = nil 457 return nil 458 } 459 460 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 461 for _, arg := range args { 462 switch arg.Value.(type) { 463 case int64, float64, bool, nil, []byte, string, time.Time: 464 default: 465 if !allowAny { 466 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 467 } 468 } 469 } 470 return nil 471 } 472 473 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 474 // Ensure that ExecContext is called if available. 475 panic("ExecContext was not called.") 476 } 477 478 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 479 // This is an optional interface, but it's implemented here 480 // just to check that all the args are of the proper types. 481 // ErrSkip is returned so the caller acts as if we didn't 482 // implement this at all. 483 err := checkSubsetTypes(c.db.allowAny, args) 484 if err != nil { 485 return nil, err 486 } 487 return nil, driver.ErrSkip 488 } 489 490 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 491 // Ensure that ExecContext is called if available. 492 panic("QueryContext was not called.") 493 } 494 495 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 496 // This is an optional interface, but it's implemented here 497 // just to check that all the args are of the proper types. 498 // ErrSkip is returned so the caller acts as if we didn't 499 // implement this at all. 500 err := checkSubsetTypes(c.db.allowAny, args) 501 if err != nil { 502 return nil, err 503 } 504 return nil, driver.ErrSkip 505 } 506 507 func errf(msg string, args ...interface{}) error { 508 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 509 } 510 511 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 512 // (note that where columns must always contain ? marks, 513 // just a limitation for fakedb) 514 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 515 if len(parts) != 3 { 516 stmt.Close() 517 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 518 } 519 stmt.table = parts[0] 520 521 stmt.colName = strings.Split(parts[1], ",") 522 for n, colspec := range strings.Split(parts[2], ",") { 523 if colspec == "" { 524 continue 525 } 526 nameVal := strings.Split(colspec, "=") 527 if len(nameVal) != 2 { 528 stmt.Close() 529 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 530 } 531 column, value := nameVal[0], nameVal[1] 532 _, ok := c.db.columnType(stmt.table, column) 533 if !ok { 534 stmt.Close() 535 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 536 } 537 if !strings.HasPrefix(value, "?") { 538 stmt.Close() 539 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 540 stmt.table, column) 541 } 542 stmt.placeholders++ 543 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 544 } 545 return stmt, nil 546 } 547 548 // parts are table|col=type,col2=type2 549 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 550 if len(parts) != 2 { 551 stmt.Close() 552 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 553 } 554 stmt.table = parts[0] 555 for n, colspec := range strings.Split(parts[1], ",") { 556 nameType := strings.Split(colspec, "=") 557 if len(nameType) != 2 { 558 stmt.Close() 559 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 560 } 561 stmt.colName = append(stmt.colName, nameType[0]) 562 stmt.colType = append(stmt.colType, nameType[1]) 563 } 564 return stmt, nil 565 } 566 567 // parts are table|col=?,col2=val 568 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { 569 if len(parts) != 2 { 570 stmt.Close() 571 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 572 } 573 stmt.table = parts[0] 574 for n, colspec := range strings.Split(parts[1], ",") { 575 nameVal := strings.Split(colspec, "=") 576 if len(nameVal) != 2 { 577 stmt.Close() 578 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 579 } 580 column, value := nameVal[0], nameVal[1] 581 ctype, ok := c.db.columnType(stmt.table, column) 582 if !ok { 583 stmt.Close() 584 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 585 } 586 stmt.colName = append(stmt.colName, column) 587 588 if !strings.HasPrefix(value, "?") { 589 var subsetVal interface{} 590 // Convert to driver subset type 591 switch ctype { 592 case "string": 593 subsetVal = []byte(value) 594 case "blob": 595 subsetVal = []byte(value) 596 case "int32": 597 i, err := strconv.Atoi(value) 598 if err != nil { 599 stmt.Close() 600 return nil, errf("invalid conversion to int32 from %q", value) 601 } 602 subsetVal = int64(i) // int64 is a subset type, but not int32 603 case "table": // For testing cursor reads. 604 c.skipDirtySession = true 605 vparts := strings.Split(value, "!") 606 607 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) 608 if err != nil { 609 return nil, err 610 } 611 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) 612 substmt.Close() 613 if err != nil { 614 return nil, err 615 } 616 subsetVal = cursor 617 default: 618 stmt.Close() 619 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 620 } 621 stmt.colValue = append(stmt.colValue, subsetVal) 622 } else { 623 stmt.placeholders++ 624 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 625 stmt.colValue = append(stmt.colValue, value) 626 } 627 } 628 return stmt, nil 629 } 630 631 // hook to simulate broken connections 632 var hookPrepareBadConn func() bool 633 634 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 635 panic("use PrepareContext") 636 } 637 638 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 639 c.numPrepare++ 640 if c.db == nil { 641 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 642 } 643 644 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 645 return nil, fakeError{Message: "Preapre: Sticky Bad", Wrapped: driver.ErrBadConn} 646 } 647 648 c.touchMem() 649 var firstStmt, prev *fakeStmt 650 for _, query := range strings.Split(query, ";") { 651 parts := strings.Split(query, "|") 652 if len(parts) < 1 { 653 return nil, errf("empty query") 654 } 655 stmt := &fakeStmt{q: query, c: c, memToucher: c} 656 if firstStmt == nil { 657 firstStmt = stmt 658 } 659 if len(parts) >= 3 { 660 switch parts[0] { 661 case "PANIC": 662 stmt.panic = parts[1] 663 parts = parts[2:] 664 case "WAIT": 665 wait, err := time.ParseDuration(parts[1]) 666 if err != nil { 667 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 668 } 669 parts = parts[2:] 670 stmt.wait = wait 671 } 672 } 673 cmd := parts[0] 674 stmt.cmd = cmd 675 parts = parts[1:] 676 677 if c.waiter != nil { 678 c.waiter(ctx) 679 } 680 681 if stmt.wait > 0 { 682 wait := time.NewTimer(stmt.wait) 683 select { 684 case <-wait.C: 685 case <-ctx.Done(): 686 wait.Stop() 687 return nil, ctx.Err() 688 } 689 } 690 691 c.incrStat(&c.stmtsMade) 692 var err error 693 switch cmd { 694 case "WIPE": 695 // Nothing 696 case "SELECT": 697 stmt, err = c.prepareSelect(stmt, parts) 698 case "CREATE": 699 stmt, err = c.prepareCreate(stmt, parts) 700 case "INSERT": 701 stmt, err = c.prepareInsert(ctx, stmt, parts) 702 case "NOSERT": 703 // Do all the prep-work like for an INSERT but don't actually insert the row. 704 // Used for some of the concurrent tests. 705 stmt, err = c.prepareInsert(ctx, stmt, parts) 706 default: 707 stmt.Close() 708 return nil, errf("unsupported command type %q", cmd) 709 } 710 if err != nil { 711 return nil, err 712 } 713 if prev != nil { 714 prev.next = stmt 715 } 716 prev = stmt 717 } 718 return firstStmt, nil 719 } 720 721 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 722 if s.panic == "ColumnConverter" { 723 panic(s.panic) 724 } 725 if len(s.placeholderConverter) == 0 { 726 return driver.DefaultParameterConverter 727 } 728 return s.placeholderConverter[idx] 729 } 730 731 func (s *fakeStmt) Close() error { 732 if s.panic == "Close" { 733 panic(s.panic) 734 } 735 if s.c == nil { 736 panic("nil conn in fakeStmt.Close") 737 } 738 if s.c.db == nil { 739 panic("in fakeStmt.Close, conn's db is nil (already closed)") 740 } 741 s.touchMem() 742 if !s.closed { 743 s.c.incrStat(&s.c.stmtsClosed) 744 s.closed = true 745 } 746 if s.next != nil { 747 s.next.Close() 748 } 749 return nil 750 } 751 752 var errClosed = errors.New("fakedb: statement has been closed") 753 754 // hook to simulate broken connections 755 var hookExecBadConn func() bool 756 757 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 758 panic("Using ExecContext") 759 } 760 761 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") 762 763 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 764 if s.panic == "Exec" { 765 panic(s.panic) 766 } 767 if s.closed { 768 return nil, errClosed 769 } 770 771 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 772 return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} 773 } 774 if s.c.isDirtyAndMark() { 775 return nil, errFakeConnSessionDirty 776 } 777 778 err := checkSubsetTypes(s.c.db.allowAny, args) 779 if err != nil { 780 return nil, err 781 } 782 s.touchMem() 783 784 if s.wait > 0 { 785 time.Sleep(s.wait) 786 } 787 788 select { 789 default: 790 case <-ctx.Done(): 791 return nil, ctx.Err() 792 } 793 794 db := s.c.db 795 switch s.cmd { 796 case "WIPE": 797 db.wipe() 798 return driver.ResultNoRows, nil 799 case "CREATE": 800 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 801 return nil, err 802 } 803 return driver.ResultNoRows, nil 804 case "INSERT": 805 return s.execInsert(args, true) 806 case "NOSERT": 807 // Do all the prep-work like for an INSERT but don't actually insert the row. 808 // Used for some of the concurrent tests. 809 return s.execInsert(args, false) 810 } 811 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) 812 } 813 814 // When doInsert is true, add the row to the table. 815 // When doInsert is false do prep-work and error checking, but don't 816 // actually add the row to the table. 817 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 818 db := s.c.db 819 if len(args) != s.placeholders { 820 panic("error in pkg db; should only get here if size is correct") 821 } 822 db.mu.Lock() 823 t, ok := db.table(s.table) 824 db.mu.Unlock() 825 if !ok { 826 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 827 } 828 829 t.mu.Lock() 830 defer t.mu.Unlock() 831 832 var cols []interface{} 833 if doInsert { 834 cols = make([]interface{}, len(t.colname)) 835 } 836 argPos := 0 837 for n, colname := range s.colName { 838 colidx := t.columnIndex(colname) 839 if colidx == -1 { 840 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 841 } 842 var val interface{} 843 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 844 if strvalue == "?" { 845 val = args[argPos].Value 846 } else { 847 // Assign value from argument placeholder name. 848 for _, a := range args { 849 if a.Name == strvalue[1:] { 850 val = a.Value 851 break 852 } 853 } 854 } 855 argPos++ 856 } else { 857 val = s.colValue[n] 858 } 859 if doInsert { 860 cols[colidx] = val 861 } 862 } 863 864 if doInsert { 865 t.rows = append(t.rows, &row{cols: cols}) 866 } 867 return driver.RowsAffected(1), nil 868 } 869 870 // hook to simulate broken connections 871 var hookQueryBadConn func() bool 872 873 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 874 panic("Use QueryContext") 875 } 876 877 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 878 if s.panic == "Query" { 879 panic(s.panic) 880 } 881 if s.closed { 882 return nil, errClosed 883 } 884 885 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 886 return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} 887 } 888 if s.c.isDirtyAndMark() { 889 return nil, errFakeConnSessionDirty 890 } 891 892 err := checkSubsetTypes(s.c.db.allowAny, args) 893 if err != nil { 894 return nil, err 895 } 896 897 s.touchMem() 898 db := s.c.db 899 if len(args) != s.placeholders { 900 panic("error in pkg db; should only get here if size is correct") 901 } 902 903 setMRows := make([][]*row, 0, 1) 904 setColumns := make([][]string, 0, 1) 905 setColType := make([][]string, 0, 1) 906 907 for { 908 db.mu.Lock() 909 t, ok := db.table(s.table) 910 db.mu.Unlock() 911 if !ok { 912 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 913 } 914 915 if s.table == "magicquery" { 916 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 917 if args[0].Value == "sleep" { 918 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 919 } 920 } 921 } 922 if s.table == "tx_status" && s.colName[0] == "tx_status" { 923 txStatus := "autocommit" 924 if s.c.currTx != nil { 925 txStatus = "transaction" 926 } 927 cursor := &rowsCursor{ 928 parentMem: s.c, 929 posRow: -1, 930 rows: [][]*row{ 931 { 932 { 933 cols: []interface{}{ 934 txStatus, 935 }, 936 }, 937 }, 938 }, 939 cols: [][]string{ 940 { 941 "tx_status", 942 }, 943 }, 944 colType: [][]string{ 945 { 946 "string", 947 }, 948 }, 949 errPos: -1, 950 } 951 return cursor, nil 952 } 953 954 t.mu.Lock() 955 956 colIdx := make(map[string]int) // select column name -> column index in table 957 for _, name := range s.colName { 958 idx := t.columnIndex(name) 959 if idx == -1 { 960 t.mu.Unlock() 961 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 962 } 963 colIdx[name] = idx 964 } 965 966 mrows := []*row{} 967 rows: 968 for _, trow := range t.rows { 969 // Process the where clause, skipping non-match rows. This is lazy 970 // and just uses fmt.Sprintf("%v") to test equality. Good enough 971 // for test code. 972 for _, wcol := range s.whereCol { 973 idx := t.columnIndex(wcol.Column) 974 if idx == -1 { 975 t.mu.Unlock() 976 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) 977 } 978 tcol := trow.cols[idx] 979 if bs, ok := tcol.([]byte); ok { 980 // lazy hack to avoid sprintf %v on a []byte 981 tcol = string(bs) 982 } 983 var argValue interface{} 984 if wcol.Placeholder == "?" { 985 argValue = args[wcol.Ordinal-1].Value 986 } else { 987 // Assign arg value from placeholder name. 988 for _, a := range args { 989 if a.Name == wcol.Placeholder[1:] { 990 argValue = a.Value 991 break 992 } 993 } 994 } 995 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 996 continue rows 997 } 998 } 999 mrow := &row{cols: make([]interface{}, len(s.colName))} 1000 for seli, name := range s.colName { 1001 mrow.cols[seli] = trow.cols[colIdx[name]] 1002 } 1003 mrows = append(mrows, mrow) 1004 } 1005 1006 var colType []string 1007 for _, column := range s.colName { 1008 colType = append(colType, t.coltype[t.columnIndex(column)]) 1009 } 1010 1011 t.mu.Unlock() 1012 1013 setMRows = append(setMRows, mrows) 1014 setColumns = append(setColumns, s.colName) 1015 setColType = append(setColType, colType) 1016 1017 if s.next == nil { 1018 break 1019 } 1020 s = s.next 1021 } 1022 1023 cursor := &rowsCursor{ 1024 parentMem: s.c, 1025 posRow: -1, 1026 rows: setMRows, 1027 cols: setColumns, 1028 colType: setColType, 1029 errPos: -1, 1030 } 1031 return cursor, nil 1032 } 1033 1034 func (s *fakeStmt) NumInput() int { 1035 if s.panic == "NumInput" { 1036 panic(s.panic) 1037 } 1038 return s.placeholders 1039 } 1040 1041 // hook to simulate broken connections 1042 var hookCommitBadConn func() bool 1043 1044 func (tx *fakeTx) Commit() error { 1045 tx.c.currTx = nil 1046 if hookCommitBadConn != nil && hookCommitBadConn() { 1047 return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1048 } 1049 tx.c.touchMem() 1050 return nil 1051 } 1052 1053 // hook to simulate broken connections 1054 var hookRollbackBadConn func() bool 1055 1056 func (tx *fakeTx) Rollback() error { 1057 tx.c.currTx = nil 1058 if hookRollbackBadConn != nil && hookRollbackBadConn() { 1059 return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1060 } 1061 tx.c.touchMem() 1062 return nil 1063 } 1064 1065 type rowsCursor struct { 1066 parentMem memToucher 1067 cols [][]string 1068 colType [][]string 1069 posSet int 1070 posRow int 1071 rows [][]*row 1072 closed bool 1073 1074 // errPos and err are for making Next return early with error. 1075 errPos int 1076 err error 1077 1078 // a clone of slices to give out to clients, indexed by the 1079 // original slice's first byte address. we clone them 1080 // just so we're able to corrupt them on close. 1081 bytesClone map[*byte][]byte 1082 1083 // Every operation writes to line to enable the race detector 1084 // check for data races. 1085 // This is separate from the fakeConn.line to allow for drivers that 1086 // can start multiple queries on the same transaction at the same time. 1087 line int64 1088 } 1089 1090 func (rc *rowsCursor) touchMem() { 1091 rc.parentMem.touchMem() 1092 rc.line++ 1093 } 1094 1095 func (rc *rowsCursor) Close() error { 1096 rc.touchMem() 1097 rc.parentMem.touchMem() 1098 rc.closed = true 1099 return nil 1100 } 1101 1102 func (rc *rowsCursor) Columns() []string { 1103 return rc.cols[rc.posSet] 1104 } 1105 1106 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1107 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1108 } 1109 1110 var rowsCursorNextHook func(dest []driver.Value) error 1111 1112 func (rc *rowsCursor) Next(dest []driver.Value) error { 1113 if rowsCursorNextHook != nil { 1114 return rowsCursorNextHook(dest) 1115 } 1116 1117 if rc.closed { 1118 return errors.New("fakedb: cursor is closed") 1119 } 1120 rc.touchMem() 1121 rc.posRow++ 1122 if rc.posRow == rc.errPos { 1123 return rc.err 1124 } 1125 if rc.posRow >= len(rc.rows[rc.posSet]) { 1126 return io.EOF // per interface spec 1127 } 1128 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1129 // TODO(bradfitz): convert to subset types? naah, I 1130 // think the subset types should only be input to 1131 // driver, but the sql package should be able to handle 1132 // a wider range of types coming out of drivers. all 1133 // for ease of drivers, and to prevent drivers from 1134 // messing up conversions or doing them differently. 1135 dest[i] = v 1136 1137 if bs, ok := v.([]byte); ok { 1138 if rc.bytesClone == nil { 1139 rc.bytesClone = make(map[*byte][]byte) 1140 } 1141 clone, ok := rc.bytesClone[&bs[0]] 1142 if !ok { 1143 clone = make([]byte, len(bs)) 1144 copy(clone, bs) 1145 rc.bytesClone[&bs[0]] = clone 1146 } 1147 dest[i] = clone 1148 } 1149 } 1150 return nil 1151 } 1152 1153 func (rc *rowsCursor) HasNextResultSet() bool { 1154 rc.touchMem() 1155 return rc.posSet < len(rc.rows)-1 1156 } 1157 1158 func (rc *rowsCursor) NextResultSet() error { 1159 rc.touchMem() 1160 if rc.HasNextResultSet() { 1161 rc.posSet++ 1162 rc.posRow = -1 1163 return nil 1164 } 1165 return io.EOF // Per interface spec. 1166 } 1167 1168 // fakeDriverString is like driver.String, but indirects pointers like 1169 // DefaultValueConverter. 1170 // 1171 // This could be surprising behavior to retroactively apply to 1172 // driver.String now that Go1 is out, but this is convenient for 1173 // our TestPointerParamsAndScans. 1174 // 1175 type fakeDriverString struct{} 1176 1177 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 1178 switch c := v.(type) { 1179 case string, []byte: 1180 return v, nil 1181 case *string: 1182 if c == nil { 1183 return nil, nil 1184 } 1185 return *c, nil 1186 } 1187 return fmt.Sprintf("%v", v), nil 1188 } 1189 1190 type anyTypeConverter struct{} 1191 1192 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { 1193 return v, nil 1194 } 1195 1196 func converterForType(typ string) driver.ValueConverter { 1197 switch typ { 1198 case "bool": 1199 return driver.Bool 1200 case "nullbool": 1201 return driver.Null{Converter: driver.Bool} 1202 case "byte", "int16": 1203 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1204 case "int32": 1205 return driver.Int32 1206 case "nullbyte", "nullint32", "nullint16": 1207 return driver.Null{Converter: driver.DefaultParameterConverter} 1208 case "string": 1209 return driver.NotNull{Converter: fakeDriverString{}} 1210 case "nullstring": 1211 return driver.Null{Converter: fakeDriverString{}} 1212 case "int64": 1213 // TODO(coopernurse): add type-specific converter 1214 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1215 case "nullint64": 1216 // TODO(coopernurse): add type-specific converter 1217 return driver.Null{Converter: driver.DefaultParameterConverter} 1218 case "float64": 1219 // TODO(coopernurse): add type-specific converter 1220 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1221 case "nullfloat64": 1222 // TODO(coopernurse): add type-specific converter 1223 return driver.Null{Converter: driver.DefaultParameterConverter} 1224 case "datetime": 1225 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1226 case "nulldatetime": 1227 return driver.Null{Converter: driver.DefaultParameterConverter} 1228 case "any": 1229 return anyTypeConverter{} 1230 } 1231 panic("invalid fakedb column type of " + typ) 1232 } 1233 1234 func colTypeToReflectType(typ string) reflect.Type { 1235 switch typ { 1236 case "bool": 1237 return reflect.TypeOf(false) 1238 case "nullbool": 1239 return reflect.TypeOf(NullBool{}) 1240 case "int16": 1241 return reflect.TypeOf(int16(0)) 1242 case "nullint16": 1243 return reflect.TypeOf(NullInt16{}) 1244 case "int32": 1245 return reflect.TypeOf(int32(0)) 1246 case "nullint32": 1247 return reflect.TypeOf(NullInt32{}) 1248 case "string": 1249 return reflect.TypeOf("") 1250 case "nullstring": 1251 return reflect.TypeOf(NullString{}) 1252 case "int64": 1253 return reflect.TypeOf(int64(0)) 1254 case "nullint64": 1255 return reflect.TypeOf(NullInt64{}) 1256 case "float64": 1257 return reflect.TypeOf(float64(0)) 1258 case "nullfloat64": 1259 return reflect.TypeOf(NullFloat64{}) 1260 case "datetime": 1261 return reflect.TypeOf(time.Time{}) 1262 case "any": 1263 return reflect.TypeOf(new(interface{})).Elem() 1264 } 1265 panic("invalid fakedb column type of " + typ) 1266 }