github.com/guyezi/gofrontend@v0.0.0-20200228202240-7a62a49e62c0/libgo/go/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "reflect" 14 "sort" 15 "strconv" 16 "strings" 17 "sync" 18 "testing" 19 "time" 20 ) 21 22 // fakeDriver is a fake database that implements Go's driver.Driver 23 // interface, just for testing. 24 // 25 // It speaks a query language that's semantically similar to but 26 // syntactically different and simpler than SQL. The syntax is as 27 // follows: 28 // 29 // WIPE 30 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 31 // where types are: "string", [u]int{8,16,32,64}, "bool" 32 // INSERT|<tablename>|col=val,col2=val2,col3=? 33 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 34 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 35 // 36 // Any of these can be preceded by PANIC|<method>|, to cause the 37 // named method on fakeStmt to panic. 38 // 39 // Any of these can be proceeded by WAIT|<duration>|, to cause the 40 // named method on fakeStmt to sleep for the specified duration. 41 // 42 // Multiple of these can be combined when separated with a semicolon. 43 // 44 // When opening a fakeDriver's database, it starts empty with no 45 // tables. All tables and data are stored in memory only. 46 type fakeDriver struct { 47 mu sync.Mutex // guards 3 following fields 48 openCount int // conn opens 49 closeCount int // conn closes 50 waitCh chan struct{} 51 waitingCh chan struct{} 52 dbs map[string]*fakeDB 53 } 54 55 type fakeConnector struct { 56 name string 57 58 waiter func(context.Context) 59 } 60 61 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 62 conn, err := fdriver.Open(c.name) 63 conn.(*fakeConn).waiter = c.waiter 64 return conn, err 65 } 66 67 func (c *fakeConnector) Driver() driver.Driver { 68 return fdriver 69 } 70 71 type fakeDriverCtx struct { 72 fakeDriver 73 } 74 75 var _ driver.DriverContext = &fakeDriverCtx{} 76 77 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 78 return &fakeConnector{name: name}, nil 79 } 80 81 type fakeDB struct { 82 name string 83 84 mu sync.Mutex 85 tables map[string]*table 86 badConn bool 87 allowAny bool 88 } 89 90 type table struct { 91 mu sync.Mutex 92 colname []string 93 coltype []string 94 rows []*row 95 } 96 97 func (t *table) columnIndex(name string) int { 98 for n, nname := range t.colname { 99 if name == nname { 100 return n 101 } 102 } 103 return -1 104 } 105 106 type row struct { 107 cols []interface{} // must be same size as its table colname + coltype 108 } 109 110 type memToucher interface { 111 // touchMem reads & writes some memory, to help find data races. 112 touchMem() 113 } 114 115 type fakeConn struct { 116 db *fakeDB // where to return ourselves to 117 118 currTx *fakeTx 119 120 // Every operation writes to line to enable the race detector 121 // check for data races. 122 line int64 123 124 // Stats for tests: 125 mu sync.Mutex 126 stmtsMade int 127 stmtsClosed int 128 numPrepare int 129 130 // bad connection tests; see isBad() 131 bad bool 132 stickyBad bool 133 134 skipDirtySession bool // tests that use Conn should set this to true. 135 136 // dirtySession tests ResetSession, true if a query has executed 137 // until ResetSession is called. 138 dirtySession bool 139 140 // The waiter is called before each query. May be used in place of the "WAIT" 141 // directive. 142 waiter func(context.Context) 143 } 144 145 func (c *fakeConn) touchMem() { 146 c.line++ 147 } 148 149 func (c *fakeConn) incrStat(v *int) { 150 c.mu.Lock() 151 *v++ 152 c.mu.Unlock() 153 } 154 155 type fakeTx struct { 156 c *fakeConn 157 } 158 159 type boundCol struct { 160 Column string 161 Placeholder string 162 Ordinal int 163 } 164 165 type fakeStmt struct { 166 memToucher 167 c *fakeConn 168 q string // just for debugging 169 170 cmd string 171 table string 172 panic string 173 wait time.Duration 174 175 next *fakeStmt // used for returning multiple results. 176 177 closed bool 178 179 colName []string // used by CREATE, INSERT, SELECT (selected columns) 180 colType []string // used by CREATE 181 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 182 placeholders int // used by INSERT/SELECT: number of ? params 183 184 whereCol []boundCol // used by SELECT (all placeholders) 185 186 placeholderConverter []driver.ValueConverter // used by INSERT 187 } 188 189 var fdriver driver.Driver = &fakeDriver{} 190 191 func init() { 192 Register("test", fdriver) 193 } 194 195 func contains(list []string, y string) bool { 196 for _, x := range list { 197 if x == y { 198 return true 199 } 200 } 201 return false 202 } 203 204 type Dummy struct { 205 driver.Driver 206 } 207 208 func TestDrivers(t *testing.T) { 209 unregisterAllDrivers() 210 Register("test", fdriver) 211 Register("invalid", Dummy{}) 212 all := Drivers() 213 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 214 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 215 } 216 } 217 218 // hook to simulate connection failures 219 var hookOpenErr struct { 220 sync.Mutex 221 fn func() error 222 } 223 224 func setHookOpenErr(fn func() error) { 225 hookOpenErr.Lock() 226 defer hookOpenErr.Unlock() 227 hookOpenErr.fn = fn 228 } 229 230 // Supports dsn forms: 231 // <dbname> 232 // <dbname>;<opts> (only currently supported option is `badConn`, 233 // which causes driver.ErrBadConn to be returned on 234 // every other conn.Begin()) 235 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 236 hookOpenErr.Lock() 237 fn := hookOpenErr.fn 238 hookOpenErr.Unlock() 239 if fn != nil { 240 if err := fn(); err != nil { 241 return nil, err 242 } 243 } 244 parts := strings.Split(dsn, ";") 245 if len(parts) < 1 { 246 return nil, errors.New("fakedb: no database name") 247 } 248 name := parts[0] 249 250 db := d.getDB(name) 251 252 d.mu.Lock() 253 d.openCount++ 254 d.mu.Unlock() 255 conn := &fakeConn{db: db} 256 257 if len(parts) >= 2 && parts[1] == "badConn" { 258 conn.bad = true 259 } 260 if d.waitCh != nil { 261 d.waitingCh <- struct{}{} 262 <-d.waitCh 263 d.waitCh = nil 264 d.waitingCh = nil 265 } 266 return conn, nil 267 } 268 269 func (d *fakeDriver) getDB(name string) *fakeDB { 270 d.mu.Lock() 271 defer d.mu.Unlock() 272 if d.dbs == nil { 273 d.dbs = make(map[string]*fakeDB) 274 } 275 db, ok := d.dbs[name] 276 if !ok { 277 db = &fakeDB{name: name} 278 d.dbs[name] = db 279 } 280 return db 281 } 282 283 func (db *fakeDB) wipe() { 284 db.mu.Lock() 285 defer db.mu.Unlock() 286 db.tables = nil 287 } 288 289 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 290 db.mu.Lock() 291 defer db.mu.Unlock() 292 if db.tables == nil { 293 db.tables = make(map[string]*table) 294 } 295 if _, exist := db.tables[name]; exist { 296 return fmt.Errorf("fakedb: table %q already exists", name) 297 } 298 if len(columnNames) != len(columnTypes) { 299 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", 300 name, len(columnNames), len(columnTypes)) 301 } 302 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 303 return nil 304 } 305 306 // must be called with db.mu lock held 307 func (db *fakeDB) table(table string) (*table, bool) { 308 if db.tables == nil { 309 return nil, false 310 } 311 t, ok := db.tables[table] 312 return t, ok 313 } 314 315 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 316 db.mu.Lock() 317 defer db.mu.Unlock() 318 t, ok := db.table(table) 319 if !ok { 320 return 321 } 322 for n, cname := range t.colname { 323 if cname == column { 324 return t.coltype[n], true 325 } 326 } 327 return "", false 328 } 329 330 func (c *fakeConn) isBad() bool { 331 if c.stickyBad { 332 return true 333 } else if c.bad { 334 if c.db == nil { 335 return false 336 } 337 // alternate between bad conn and not bad conn 338 c.db.badConn = !c.db.badConn 339 return c.db.badConn 340 } else { 341 return false 342 } 343 } 344 345 func (c *fakeConn) isDirtyAndMark() bool { 346 if c.skipDirtySession { 347 return false 348 } 349 if c.currTx != nil { 350 c.dirtySession = true 351 return false 352 } 353 if c.dirtySession { 354 return true 355 } 356 c.dirtySession = true 357 return false 358 } 359 360 func (c *fakeConn) Begin() (driver.Tx, error) { 361 if c.isBad() { 362 return nil, driver.ErrBadConn 363 } 364 if c.currTx != nil { 365 return nil, errors.New("fakedb: already in a transaction") 366 } 367 c.touchMem() 368 c.currTx = &fakeTx{c: c} 369 return c.currTx, nil 370 } 371 372 var hookPostCloseConn struct { 373 sync.Mutex 374 fn func(*fakeConn, error) 375 } 376 377 func setHookpostCloseConn(fn func(*fakeConn, error)) { 378 hookPostCloseConn.Lock() 379 defer hookPostCloseConn.Unlock() 380 hookPostCloseConn.fn = fn 381 } 382 383 var testStrictClose *testing.T 384 385 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 386 // fails to close. If nil, the check is disabled. 387 func setStrictFakeConnClose(t *testing.T) { 388 testStrictClose = t 389 } 390 391 func (c *fakeConn) ResetSession(ctx context.Context) error { 392 c.dirtySession = false 393 if c.isBad() { 394 return driver.ErrBadConn 395 } 396 return nil 397 } 398 399 func (c *fakeConn) Close() (err error) { 400 drv := fdriver.(*fakeDriver) 401 defer func() { 402 if err != nil && testStrictClose != nil { 403 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 404 } 405 hookPostCloseConn.Lock() 406 fn := hookPostCloseConn.fn 407 hookPostCloseConn.Unlock() 408 if fn != nil { 409 fn(c, err) 410 } 411 if err == nil { 412 drv.mu.Lock() 413 drv.closeCount++ 414 drv.mu.Unlock() 415 } 416 }() 417 c.touchMem() 418 if c.currTx != nil { 419 return errors.New("fakedb: can't close fakeConn; in a Transaction") 420 } 421 if c.db == nil { 422 return errors.New("fakedb: can't close fakeConn; already closed") 423 } 424 if c.stmtsMade > c.stmtsClosed { 425 return errors.New("fakedb: can't close; dangling statement(s)") 426 } 427 c.db = nil 428 return nil 429 } 430 431 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 432 for _, arg := range args { 433 switch arg.Value.(type) { 434 case int64, float64, bool, nil, []byte, string, time.Time: 435 default: 436 if !allowAny { 437 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 438 } 439 } 440 } 441 return nil 442 } 443 444 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 445 // Ensure that ExecContext is called if available. 446 panic("ExecContext was not called.") 447 } 448 449 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 450 // This is an optional interface, but it's implemented here 451 // just to check that all the args are of the proper types. 452 // ErrSkip is returned so the caller acts as if we didn't 453 // implement this at all. 454 err := checkSubsetTypes(c.db.allowAny, args) 455 if err != nil { 456 return nil, err 457 } 458 return nil, driver.ErrSkip 459 } 460 461 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 462 // Ensure that ExecContext is called if available. 463 panic("QueryContext was not called.") 464 } 465 466 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 467 // This is an optional interface, but it's implemented here 468 // just to check that all the args are of the proper types. 469 // ErrSkip is returned so the caller acts as if we didn't 470 // implement this at all. 471 err := checkSubsetTypes(c.db.allowAny, args) 472 if err != nil { 473 return nil, err 474 } 475 return nil, driver.ErrSkip 476 } 477 478 func errf(msg string, args ...interface{}) error { 479 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 480 } 481 482 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 483 // (note that where columns must always contain ? marks, 484 // just a limitation for fakedb) 485 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 486 if len(parts) != 3 { 487 stmt.Close() 488 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 489 } 490 stmt.table = parts[0] 491 492 stmt.colName = strings.Split(parts[1], ",") 493 for n, colspec := range strings.Split(parts[2], ",") { 494 if colspec == "" { 495 continue 496 } 497 nameVal := strings.Split(colspec, "=") 498 if len(nameVal) != 2 { 499 stmt.Close() 500 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 501 } 502 column, value := nameVal[0], nameVal[1] 503 _, ok := c.db.columnType(stmt.table, column) 504 if !ok { 505 stmt.Close() 506 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 507 } 508 if !strings.HasPrefix(value, "?") { 509 stmt.Close() 510 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 511 stmt.table, column) 512 } 513 stmt.placeholders++ 514 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 515 } 516 return stmt, nil 517 } 518 519 // parts are table|col=type,col2=type2 520 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 521 if len(parts) != 2 { 522 stmt.Close() 523 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 524 } 525 stmt.table = parts[0] 526 for n, colspec := range strings.Split(parts[1], ",") { 527 nameType := strings.Split(colspec, "=") 528 if len(nameType) != 2 { 529 stmt.Close() 530 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 531 } 532 stmt.colName = append(stmt.colName, nameType[0]) 533 stmt.colType = append(stmt.colType, nameType[1]) 534 } 535 return stmt, nil 536 } 537 538 // parts are table|col=?,col2=val 539 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { 540 if len(parts) != 2 { 541 stmt.Close() 542 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 543 } 544 stmt.table = parts[0] 545 for n, colspec := range strings.Split(parts[1], ",") { 546 nameVal := strings.Split(colspec, "=") 547 if len(nameVal) != 2 { 548 stmt.Close() 549 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 550 } 551 column, value := nameVal[0], nameVal[1] 552 ctype, ok := c.db.columnType(stmt.table, column) 553 if !ok { 554 stmt.Close() 555 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 556 } 557 stmt.colName = append(stmt.colName, column) 558 559 if !strings.HasPrefix(value, "?") { 560 var subsetVal interface{} 561 // Convert to driver subset type 562 switch ctype { 563 case "string": 564 subsetVal = []byte(value) 565 case "blob": 566 subsetVal = []byte(value) 567 case "int32": 568 i, err := strconv.Atoi(value) 569 if err != nil { 570 stmt.Close() 571 return nil, errf("invalid conversion to int32 from %q", value) 572 } 573 subsetVal = int64(i) // int64 is a subset type, but not int32 574 case "table": // For testing cursor reads. 575 c.skipDirtySession = true 576 vparts := strings.Split(value, "!") 577 578 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) 579 if err != nil { 580 return nil, err 581 } 582 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) 583 substmt.Close() 584 if err != nil { 585 return nil, err 586 } 587 subsetVal = cursor 588 default: 589 stmt.Close() 590 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 591 } 592 stmt.colValue = append(stmt.colValue, subsetVal) 593 } else { 594 stmt.placeholders++ 595 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 596 stmt.colValue = append(stmt.colValue, value) 597 } 598 } 599 return stmt, nil 600 } 601 602 // hook to simulate broken connections 603 var hookPrepareBadConn func() bool 604 605 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 606 panic("use PrepareContext") 607 } 608 609 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 610 c.numPrepare++ 611 if c.db == nil { 612 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 613 } 614 615 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 616 return nil, driver.ErrBadConn 617 } 618 619 c.touchMem() 620 var firstStmt, prev *fakeStmt 621 for _, query := range strings.Split(query, ";") { 622 parts := strings.Split(query, "|") 623 if len(parts) < 1 { 624 return nil, errf("empty query") 625 } 626 stmt := &fakeStmt{q: query, c: c, memToucher: c} 627 if firstStmt == nil { 628 firstStmt = stmt 629 } 630 if len(parts) >= 3 { 631 switch parts[0] { 632 case "PANIC": 633 stmt.panic = parts[1] 634 parts = parts[2:] 635 case "WAIT": 636 wait, err := time.ParseDuration(parts[1]) 637 if err != nil { 638 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 639 } 640 parts = parts[2:] 641 stmt.wait = wait 642 } 643 } 644 cmd := parts[0] 645 stmt.cmd = cmd 646 parts = parts[1:] 647 648 if c.waiter != nil { 649 c.waiter(ctx) 650 } 651 652 if stmt.wait > 0 { 653 wait := time.NewTimer(stmt.wait) 654 select { 655 case <-wait.C: 656 case <-ctx.Done(): 657 wait.Stop() 658 return nil, ctx.Err() 659 } 660 } 661 662 c.incrStat(&c.stmtsMade) 663 var err error 664 switch cmd { 665 case "WIPE": 666 // Nothing 667 case "SELECT": 668 stmt, err = c.prepareSelect(stmt, parts) 669 case "CREATE": 670 stmt, err = c.prepareCreate(stmt, parts) 671 case "INSERT": 672 stmt, err = c.prepareInsert(ctx, stmt, parts) 673 case "NOSERT": 674 // Do all the prep-work like for an INSERT but don't actually insert the row. 675 // Used for some of the concurrent tests. 676 stmt, err = c.prepareInsert(ctx, stmt, parts) 677 default: 678 stmt.Close() 679 return nil, errf("unsupported command type %q", cmd) 680 } 681 if err != nil { 682 return nil, err 683 } 684 if prev != nil { 685 prev.next = stmt 686 } 687 prev = stmt 688 } 689 return firstStmt, nil 690 } 691 692 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 693 if s.panic == "ColumnConverter" { 694 panic(s.panic) 695 } 696 if len(s.placeholderConverter) == 0 { 697 return driver.DefaultParameterConverter 698 } 699 return s.placeholderConverter[idx] 700 } 701 702 func (s *fakeStmt) Close() error { 703 if s.panic == "Close" { 704 panic(s.panic) 705 } 706 if s.c == nil { 707 panic("nil conn in fakeStmt.Close") 708 } 709 if s.c.db == nil { 710 panic("in fakeStmt.Close, conn's db is nil (already closed)") 711 } 712 s.touchMem() 713 if !s.closed { 714 s.c.incrStat(&s.c.stmtsClosed) 715 s.closed = true 716 } 717 if s.next != nil { 718 s.next.Close() 719 } 720 return nil 721 } 722 723 var errClosed = errors.New("fakedb: statement has been closed") 724 725 // hook to simulate broken connections 726 var hookExecBadConn func() bool 727 728 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 729 panic("Using ExecContext") 730 } 731 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 732 if s.panic == "Exec" { 733 panic(s.panic) 734 } 735 if s.closed { 736 return nil, errClosed 737 } 738 739 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 740 return nil, driver.ErrBadConn 741 } 742 if s.c.isDirtyAndMark() { 743 return nil, errors.New("fakedb: session is dirty") 744 } 745 746 err := checkSubsetTypes(s.c.db.allowAny, args) 747 if err != nil { 748 return nil, err 749 } 750 s.touchMem() 751 752 if s.wait > 0 { 753 time.Sleep(s.wait) 754 } 755 756 select { 757 default: 758 case <-ctx.Done(): 759 return nil, ctx.Err() 760 } 761 762 db := s.c.db 763 switch s.cmd { 764 case "WIPE": 765 db.wipe() 766 return driver.ResultNoRows, nil 767 case "CREATE": 768 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 769 return nil, err 770 } 771 return driver.ResultNoRows, nil 772 case "INSERT": 773 return s.execInsert(args, true) 774 case "NOSERT": 775 // Do all the prep-work like for an INSERT but don't actually insert the row. 776 // Used for some of the concurrent tests. 777 return s.execInsert(args, false) 778 } 779 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) 780 } 781 782 // When doInsert is true, add the row to the table. 783 // When doInsert is false do prep-work and error checking, but don't 784 // actually add the row to the table. 785 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 786 db := s.c.db 787 if len(args) != s.placeholders { 788 panic("error in pkg db; should only get here if size is correct") 789 } 790 db.mu.Lock() 791 t, ok := db.table(s.table) 792 db.mu.Unlock() 793 if !ok { 794 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 795 } 796 797 t.mu.Lock() 798 defer t.mu.Unlock() 799 800 var cols []interface{} 801 if doInsert { 802 cols = make([]interface{}, len(t.colname)) 803 } 804 argPos := 0 805 for n, colname := range s.colName { 806 colidx := t.columnIndex(colname) 807 if colidx == -1 { 808 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 809 } 810 var val interface{} 811 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 812 if strvalue == "?" { 813 val = args[argPos].Value 814 } else { 815 // Assign value from argument placeholder name. 816 for _, a := range args { 817 if a.Name == strvalue[1:] { 818 val = a.Value 819 break 820 } 821 } 822 } 823 argPos++ 824 } else { 825 val = s.colValue[n] 826 } 827 if doInsert { 828 cols[colidx] = val 829 } 830 } 831 832 if doInsert { 833 t.rows = append(t.rows, &row{cols: cols}) 834 } 835 return driver.RowsAffected(1), nil 836 } 837 838 // hook to simulate broken connections 839 var hookQueryBadConn func() bool 840 841 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 842 panic("Use QueryContext") 843 } 844 845 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 846 if s.panic == "Query" { 847 panic(s.panic) 848 } 849 if s.closed { 850 return nil, errClosed 851 } 852 853 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 854 return nil, driver.ErrBadConn 855 } 856 if s.c.isDirtyAndMark() { 857 return nil, errors.New("fakedb: session is dirty") 858 } 859 860 err := checkSubsetTypes(s.c.db.allowAny, args) 861 if err != nil { 862 return nil, err 863 } 864 865 s.touchMem() 866 db := s.c.db 867 if len(args) != s.placeholders { 868 panic("error in pkg db; should only get here if size is correct") 869 } 870 871 setMRows := make([][]*row, 0, 1) 872 setColumns := make([][]string, 0, 1) 873 setColType := make([][]string, 0, 1) 874 875 for { 876 db.mu.Lock() 877 t, ok := db.table(s.table) 878 db.mu.Unlock() 879 if !ok { 880 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 881 } 882 883 if s.table == "magicquery" { 884 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 885 if args[0].Value == "sleep" { 886 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 887 } 888 } 889 } 890 891 t.mu.Lock() 892 893 colIdx := make(map[string]int) // select column name -> column index in table 894 for _, name := range s.colName { 895 idx := t.columnIndex(name) 896 if idx == -1 { 897 t.mu.Unlock() 898 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 899 } 900 colIdx[name] = idx 901 } 902 903 mrows := []*row{} 904 rows: 905 for _, trow := range t.rows { 906 // Process the where clause, skipping non-match rows. This is lazy 907 // and just uses fmt.Sprintf("%v") to test equality. Good enough 908 // for test code. 909 for _, wcol := range s.whereCol { 910 idx := t.columnIndex(wcol.Column) 911 if idx == -1 { 912 t.mu.Unlock() 913 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) 914 } 915 tcol := trow.cols[idx] 916 if bs, ok := tcol.([]byte); ok { 917 // lazy hack to avoid sprintf %v on a []byte 918 tcol = string(bs) 919 } 920 var argValue interface{} 921 if wcol.Placeholder == "?" { 922 argValue = args[wcol.Ordinal-1].Value 923 } else { 924 // Assign arg value from placeholder name. 925 for _, a := range args { 926 if a.Name == wcol.Placeholder[1:] { 927 argValue = a.Value 928 break 929 } 930 } 931 } 932 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 933 continue rows 934 } 935 } 936 mrow := &row{cols: make([]interface{}, len(s.colName))} 937 for seli, name := range s.colName { 938 mrow.cols[seli] = trow.cols[colIdx[name]] 939 } 940 mrows = append(mrows, mrow) 941 } 942 943 var colType []string 944 for _, column := range s.colName { 945 colType = append(colType, t.coltype[t.columnIndex(column)]) 946 } 947 948 t.mu.Unlock() 949 950 setMRows = append(setMRows, mrows) 951 setColumns = append(setColumns, s.colName) 952 setColType = append(setColType, colType) 953 954 if s.next == nil { 955 break 956 } 957 s = s.next 958 } 959 960 cursor := &rowsCursor{ 961 parentMem: s.c, 962 posRow: -1, 963 rows: setMRows, 964 cols: setColumns, 965 colType: setColType, 966 errPos: -1, 967 } 968 return cursor, nil 969 } 970 971 func (s *fakeStmt) NumInput() int { 972 if s.panic == "NumInput" { 973 panic(s.panic) 974 } 975 return s.placeholders 976 } 977 978 // hook to simulate broken connections 979 var hookCommitBadConn func() bool 980 981 func (tx *fakeTx) Commit() error { 982 tx.c.currTx = nil 983 if hookCommitBadConn != nil && hookCommitBadConn() { 984 return driver.ErrBadConn 985 } 986 tx.c.touchMem() 987 return nil 988 } 989 990 // hook to simulate broken connections 991 var hookRollbackBadConn func() bool 992 993 func (tx *fakeTx) Rollback() error { 994 tx.c.currTx = nil 995 if hookRollbackBadConn != nil && hookRollbackBadConn() { 996 return driver.ErrBadConn 997 } 998 tx.c.touchMem() 999 return nil 1000 } 1001 1002 type rowsCursor struct { 1003 parentMem memToucher 1004 cols [][]string 1005 colType [][]string 1006 posSet int 1007 posRow int 1008 rows [][]*row 1009 closed bool 1010 1011 // errPos and err are for making Next return early with error. 1012 errPos int 1013 err error 1014 1015 // a clone of slices to give out to clients, indexed by the 1016 // original slice's first byte address. we clone them 1017 // just so we're able to corrupt them on close. 1018 bytesClone map[*byte][]byte 1019 1020 // Every operation writes to line to enable the race detector 1021 // check for data races. 1022 // This is separate from the fakeConn.line to allow for drivers that 1023 // can start multiple queries on the same transaction at the same time. 1024 line int64 1025 } 1026 1027 func (rc *rowsCursor) touchMem() { 1028 rc.parentMem.touchMem() 1029 rc.line++ 1030 } 1031 1032 func (rc *rowsCursor) Close() error { 1033 rc.touchMem() 1034 rc.parentMem.touchMem() 1035 rc.closed = true 1036 return nil 1037 } 1038 1039 func (rc *rowsCursor) Columns() []string { 1040 return rc.cols[rc.posSet] 1041 } 1042 1043 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1044 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1045 } 1046 1047 var rowsCursorNextHook func(dest []driver.Value) error 1048 1049 func (rc *rowsCursor) Next(dest []driver.Value) error { 1050 if rowsCursorNextHook != nil { 1051 return rowsCursorNextHook(dest) 1052 } 1053 1054 if rc.closed { 1055 return errors.New("fakedb: cursor is closed") 1056 } 1057 rc.touchMem() 1058 rc.posRow++ 1059 if rc.posRow == rc.errPos { 1060 return rc.err 1061 } 1062 if rc.posRow >= len(rc.rows[rc.posSet]) { 1063 return io.EOF // per interface spec 1064 } 1065 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1066 // TODO(bradfitz): convert to subset types? naah, I 1067 // think the subset types should only be input to 1068 // driver, but the sql package should be able to handle 1069 // a wider range of types coming out of drivers. all 1070 // for ease of drivers, and to prevent drivers from 1071 // messing up conversions or doing them differently. 1072 dest[i] = v 1073 1074 if bs, ok := v.([]byte); ok { 1075 if rc.bytesClone == nil { 1076 rc.bytesClone = make(map[*byte][]byte) 1077 } 1078 clone, ok := rc.bytesClone[&bs[0]] 1079 if !ok { 1080 clone = make([]byte, len(bs)) 1081 copy(clone, bs) 1082 rc.bytesClone[&bs[0]] = clone 1083 } 1084 dest[i] = clone 1085 } 1086 } 1087 return nil 1088 } 1089 1090 func (rc *rowsCursor) HasNextResultSet() bool { 1091 rc.touchMem() 1092 return rc.posSet < len(rc.rows)-1 1093 } 1094 1095 func (rc *rowsCursor) NextResultSet() error { 1096 rc.touchMem() 1097 if rc.HasNextResultSet() { 1098 rc.posSet++ 1099 rc.posRow = -1 1100 return nil 1101 } 1102 return io.EOF // Per interface spec. 1103 } 1104 1105 // fakeDriverString is like driver.String, but indirects pointers like 1106 // DefaultValueConverter. 1107 // 1108 // This could be surprising behavior to retroactively apply to 1109 // driver.String now that Go1 is out, but this is convenient for 1110 // our TestPointerParamsAndScans. 1111 // 1112 type fakeDriverString struct{} 1113 1114 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 1115 switch c := v.(type) { 1116 case string, []byte: 1117 return v, nil 1118 case *string: 1119 if c == nil { 1120 return nil, nil 1121 } 1122 return *c, nil 1123 } 1124 return fmt.Sprintf("%v", v), nil 1125 } 1126 1127 type anyTypeConverter struct{} 1128 1129 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { 1130 return v, nil 1131 } 1132 1133 func converterForType(typ string) driver.ValueConverter { 1134 switch typ { 1135 case "bool": 1136 return driver.Bool 1137 case "nullbool": 1138 return driver.Null{Converter: driver.Bool} 1139 case "int32": 1140 return driver.Int32 1141 case "nullint32": 1142 return driver.Null{Converter: driver.DefaultParameterConverter} 1143 case "string": 1144 return driver.NotNull{Converter: fakeDriverString{}} 1145 case "nullstring": 1146 return driver.Null{Converter: fakeDriverString{}} 1147 case "int64": 1148 // TODO(coopernurse): add type-specific converter 1149 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1150 case "nullint64": 1151 // TODO(coopernurse): add type-specific converter 1152 return driver.Null{Converter: driver.DefaultParameterConverter} 1153 case "float64": 1154 // TODO(coopernurse): add type-specific converter 1155 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1156 case "nullfloat64": 1157 // TODO(coopernurse): add type-specific converter 1158 return driver.Null{Converter: driver.DefaultParameterConverter} 1159 case "datetime": 1160 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1161 case "nulldatetime": 1162 return driver.Null{Converter: driver.DefaultParameterConverter} 1163 case "any": 1164 return anyTypeConverter{} 1165 } 1166 panic("invalid fakedb column type of " + typ) 1167 } 1168 1169 func colTypeToReflectType(typ string) reflect.Type { 1170 switch typ { 1171 case "bool": 1172 return reflect.TypeOf(false) 1173 case "nullbool": 1174 return reflect.TypeOf(NullBool{}) 1175 case "int32": 1176 return reflect.TypeOf(int32(0)) 1177 case "nullint32": 1178 return reflect.TypeOf(NullInt32{}) 1179 case "string": 1180 return reflect.TypeOf("") 1181 case "nullstring": 1182 return reflect.TypeOf(NullString{}) 1183 case "int64": 1184 return reflect.TypeOf(int64(0)) 1185 case "nullint64": 1186 return reflect.TypeOf(NullInt64{}) 1187 case "float64": 1188 return reflect.TypeOf(float64(0)) 1189 case "nullfloat64": 1190 return reflect.TypeOf(NullFloat64{}) 1191 case "datetime": 1192 return reflect.TypeOf(time.Time{}) 1193 case "any": 1194 return reflect.TypeOf(new(interface{})).Elem() 1195 } 1196 panic("invalid fakedb column type of " + typ) 1197 }