github.com/rakyll/go@v0.0.0-20170216000551-64c02460d703/src/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "log" 14 "reflect" 15 "sort" 16 "strconv" 17 "strings" 18 "sync" 19 "testing" 20 "time" 21 ) 22 23 var _ = log.Printf 24 25 // fakeDriver is a fake database that implements Go's driver.Driver 26 // interface, just for testing. 27 // 28 // It speaks a query language that's semantically similar to but 29 // syntactically different and simpler than SQL. The syntax is as 30 // follows: 31 // 32 // WIPE 33 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 34 // where types are: "string", [u]int{8,16,32,64}, "bool" 35 // INSERT|<tablename>|col=val,col2=val2,col3=? 36 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 37 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 38 // 39 // Any of these can be preceded by PANIC|<method>|, to cause the 40 // named method on fakeStmt to panic. 41 // 42 // Any of these can be proceeded by WAIT|<duration>|, to cause the 43 // named method on fakeStmt to sleep for the specified duration. 44 // 45 // Multiple of these can be combined when separated with a semicolon. 46 // 47 // When opening a fakeDriver's database, it starts empty with no 48 // tables. All tables and data are stored in memory only. 49 type fakeDriver struct { 50 mu sync.Mutex // guards 3 following fields 51 openCount int // conn opens 52 closeCount int // conn closes 53 waitCh chan struct{} 54 waitingCh chan struct{} 55 dbs map[string]*fakeDB 56 } 57 58 type fakeDB struct { 59 name string 60 61 mu sync.Mutex 62 tables map[string]*table 63 badConn bool 64 } 65 66 type table struct { 67 mu sync.Mutex 68 colname []string 69 coltype []string 70 rows []*row 71 } 72 73 func (t *table) columnIndex(name string) int { 74 for n, nname := range t.colname { 75 if name == nname { 76 return n 77 } 78 } 79 return -1 80 } 81 82 type row struct { 83 cols []interface{} // must be same size as its table colname + coltype 84 } 85 86 type fakeConn struct { 87 db *fakeDB // where to return ourselves to 88 89 currTx *fakeTx 90 91 // Stats for tests: 92 mu sync.Mutex 93 stmtsMade int 94 stmtsClosed int 95 numPrepare int 96 97 // bad connection tests; see isBad() 98 bad bool 99 stickyBad bool 100 } 101 102 func (c *fakeConn) incrStat(v *int) { 103 c.mu.Lock() 104 *v++ 105 c.mu.Unlock() 106 } 107 108 type fakeTx struct { 109 c *fakeConn 110 } 111 112 type boundCol struct { 113 Column string 114 Placeholder string 115 Ordinal int 116 } 117 118 type fakeStmt struct { 119 c *fakeConn 120 q string // just for debugging 121 122 cmd string 123 table string 124 panic string 125 wait time.Duration 126 127 next *fakeStmt // used for returning multiple results. 128 129 closed bool 130 131 colName []string // used by CREATE, INSERT, SELECT (selected columns) 132 colType []string // used by CREATE 133 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 134 placeholders int // used by INSERT/SELECT: number of ? params 135 136 whereCol []boundCol // used by SELECT (all placeholders) 137 138 placeholderConverter []driver.ValueConverter // used by INSERT 139 } 140 141 var fdriver driver.Driver = &fakeDriver{} 142 143 func init() { 144 Register("test", fdriver) 145 } 146 147 func contains(list []string, y string) bool { 148 for _, x := range list { 149 if x == y { 150 return true 151 } 152 } 153 return false 154 } 155 156 type Dummy struct { 157 driver.Driver 158 } 159 160 func TestDrivers(t *testing.T) { 161 unregisterAllDrivers() 162 Register("test", fdriver) 163 Register("invalid", Dummy{}) 164 all := Drivers() 165 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 166 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 167 } 168 } 169 170 // hook to simulate connection failures 171 var hookOpenErr struct { 172 sync.Mutex 173 fn func() error 174 } 175 176 func setHookOpenErr(fn func() error) { 177 hookOpenErr.Lock() 178 defer hookOpenErr.Unlock() 179 hookOpenErr.fn = fn 180 } 181 182 // Supports dsn forms: 183 // <dbname> 184 // <dbname>;<opts> (only currently supported option is `badConn`, 185 // which causes driver.ErrBadConn to be returned on 186 // every other conn.Begin()) 187 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 188 hookOpenErr.Lock() 189 fn := hookOpenErr.fn 190 hookOpenErr.Unlock() 191 if fn != nil { 192 if err := fn(); err != nil { 193 return nil, err 194 } 195 } 196 parts := strings.Split(dsn, ";") 197 if len(parts) < 1 { 198 return nil, errors.New("fakedb: no database name") 199 } 200 name := parts[0] 201 202 db := d.getDB(name) 203 204 d.mu.Lock() 205 d.openCount++ 206 d.mu.Unlock() 207 conn := &fakeConn{db: db} 208 209 if len(parts) >= 2 && parts[1] == "badConn" { 210 conn.bad = true 211 } 212 if d.waitCh != nil { 213 d.waitingCh <- struct{}{} 214 <-d.waitCh 215 d.waitCh = nil 216 d.waitingCh = nil 217 } 218 return conn, nil 219 } 220 221 func (d *fakeDriver) getDB(name string) *fakeDB { 222 d.mu.Lock() 223 defer d.mu.Unlock() 224 if d.dbs == nil { 225 d.dbs = make(map[string]*fakeDB) 226 } 227 db, ok := d.dbs[name] 228 if !ok { 229 db = &fakeDB{name: name} 230 d.dbs[name] = db 231 } 232 return db 233 } 234 235 func (db *fakeDB) wipe() { 236 db.mu.Lock() 237 defer db.mu.Unlock() 238 db.tables = nil 239 } 240 241 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 242 db.mu.Lock() 243 defer db.mu.Unlock() 244 if db.tables == nil { 245 db.tables = make(map[string]*table) 246 } 247 if _, exist := db.tables[name]; exist { 248 return fmt.Errorf("table %q already exists", name) 249 } 250 if len(columnNames) != len(columnTypes) { 251 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", 252 name, len(columnNames), len(columnTypes)) 253 } 254 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 255 return nil 256 } 257 258 // must be called with db.mu lock held 259 func (db *fakeDB) table(table string) (*table, bool) { 260 if db.tables == nil { 261 return nil, false 262 } 263 t, ok := db.tables[table] 264 return t, ok 265 } 266 267 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 268 db.mu.Lock() 269 defer db.mu.Unlock() 270 t, ok := db.table(table) 271 if !ok { 272 return 273 } 274 for n, cname := range t.colname { 275 if cname == column { 276 return t.coltype[n], true 277 } 278 } 279 return "", false 280 } 281 282 func (c *fakeConn) isBad() bool { 283 if c.stickyBad { 284 return true 285 } else if c.bad { 286 // alternate between bad conn and not bad conn 287 c.db.badConn = !c.db.badConn 288 return c.db.badConn 289 } else { 290 return false 291 } 292 } 293 294 func (c *fakeConn) Begin() (driver.Tx, error) { 295 if c.isBad() { 296 return nil, driver.ErrBadConn 297 } 298 if c.currTx != nil { 299 return nil, errors.New("already in a transaction") 300 } 301 c.currTx = &fakeTx{c: c} 302 return c.currTx, nil 303 } 304 305 var hookPostCloseConn struct { 306 sync.Mutex 307 fn func(*fakeConn, error) 308 } 309 310 func setHookpostCloseConn(fn func(*fakeConn, error)) { 311 hookPostCloseConn.Lock() 312 defer hookPostCloseConn.Unlock() 313 hookPostCloseConn.fn = fn 314 } 315 316 var testStrictClose *testing.T 317 318 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 319 // fails to close. If nil, the check is disabled. 320 func setStrictFakeConnClose(t *testing.T) { 321 testStrictClose = t 322 } 323 324 func (c *fakeConn) Close() (err error) { 325 drv := fdriver.(*fakeDriver) 326 defer func() { 327 if err != nil && testStrictClose != nil { 328 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 329 } 330 hookPostCloseConn.Lock() 331 fn := hookPostCloseConn.fn 332 hookPostCloseConn.Unlock() 333 if fn != nil { 334 fn(c, err) 335 } 336 if err == nil { 337 drv.mu.Lock() 338 drv.closeCount++ 339 drv.mu.Unlock() 340 } 341 }() 342 if c.currTx != nil { 343 return errors.New("can't close fakeConn; in a Transaction") 344 } 345 if c.db == nil { 346 return errors.New("can't close fakeConn; already closed") 347 } 348 if c.stmtsMade > c.stmtsClosed { 349 return errors.New("can't close; dangling statement(s)") 350 } 351 c.db = nil 352 return nil 353 } 354 355 func checkSubsetTypes(args []driver.NamedValue) error { 356 for _, arg := range args { 357 switch arg.Value.(type) { 358 case int64, float64, bool, nil, []byte, string, time.Time: 359 default: 360 return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 361 } 362 } 363 return nil 364 } 365 366 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 367 // Ensure that ExecContext is called if available. 368 panic("ExecContext was not called.") 369 } 370 371 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 372 // This is an optional interface, but it's implemented here 373 // just to check that all the args are of the proper types. 374 // ErrSkip is returned so the caller acts as if we didn't 375 // implement this at all. 376 err := checkSubsetTypes(args) 377 if err != nil { 378 return nil, err 379 } 380 return nil, driver.ErrSkip 381 } 382 383 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 384 // Ensure that ExecContext is called if available. 385 panic("QueryContext was not called.") 386 } 387 388 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 389 // This is an optional interface, but it's implemented here 390 // just to check that all the args are of the proper types. 391 // ErrSkip is returned so the caller acts as if we didn't 392 // implement this at all. 393 err := checkSubsetTypes(args) 394 if err != nil { 395 return nil, err 396 } 397 return nil, driver.ErrSkip 398 } 399 400 func errf(msg string, args ...interface{}) error { 401 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 402 } 403 404 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 405 // (note that where columns must always contain ? marks, 406 // just a limitation for fakedb) 407 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 408 if len(parts) != 3 { 409 stmt.Close() 410 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 411 } 412 stmt.table = parts[0] 413 414 stmt.colName = strings.Split(parts[1], ",") 415 for n, colspec := range strings.Split(parts[2], ",") { 416 if colspec == "" { 417 continue 418 } 419 nameVal := strings.Split(colspec, "=") 420 if len(nameVal) != 2 { 421 stmt.Close() 422 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 423 } 424 column, value := nameVal[0], nameVal[1] 425 _, ok := c.db.columnType(stmt.table, column) 426 if !ok { 427 stmt.Close() 428 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 429 } 430 if !strings.HasPrefix(value, "?") { 431 stmt.Close() 432 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 433 stmt.table, column) 434 } 435 stmt.placeholders++ 436 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 437 } 438 return stmt, nil 439 } 440 441 // parts are table|col=type,col2=type2 442 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 443 if len(parts) != 2 { 444 stmt.Close() 445 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 446 } 447 stmt.table = parts[0] 448 for n, colspec := range strings.Split(parts[1], ",") { 449 nameType := strings.Split(colspec, "=") 450 if len(nameType) != 2 { 451 stmt.Close() 452 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 453 } 454 stmt.colName = append(stmt.colName, nameType[0]) 455 stmt.colType = append(stmt.colType, nameType[1]) 456 } 457 return stmt, nil 458 } 459 460 // parts are table|col=?,col2=val 461 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 462 if len(parts) != 2 { 463 stmt.Close() 464 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 465 } 466 stmt.table = parts[0] 467 for n, colspec := range strings.Split(parts[1], ",") { 468 nameVal := strings.Split(colspec, "=") 469 if len(nameVal) != 2 { 470 stmt.Close() 471 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 472 } 473 column, value := nameVal[0], nameVal[1] 474 ctype, ok := c.db.columnType(stmt.table, column) 475 if !ok { 476 stmt.Close() 477 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 478 } 479 stmt.colName = append(stmt.colName, column) 480 481 if !strings.HasPrefix(value, "?") { 482 var subsetVal interface{} 483 // Convert to driver subset type 484 switch ctype { 485 case "string": 486 subsetVal = []byte(value) 487 case "blob": 488 subsetVal = []byte(value) 489 case "int32": 490 i, err := strconv.Atoi(value) 491 if err != nil { 492 stmt.Close() 493 return nil, errf("invalid conversion to int32 from %q", value) 494 } 495 subsetVal = int64(i) // int64 is a subset type, but not int32 496 default: 497 stmt.Close() 498 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 499 } 500 stmt.colValue = append(stmt.colValue, subsetVal) 501 } else { 502 stmt.placeholders++ 503 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 504 stmt.colValue = append(stmt.colValue, value) 505 } 506 } 507 return stmt, nil 508 } 509 510 // hook to simulate broken connections 511 var hookPrepareBadConn func() bool 512 513 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 514 panic("use PrepareContext") 515 } 516 517 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 518 c.numPrepare++ 519 if c.db == nil { 520 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 521 } 522 523 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 524 return nil, driver.ErrBadConn 525 } 526 527 var firstStmt, prev *fakeStmt 528 for _, query := range strings.Split(query, ";") { 529 parts := strings.Split(query, "|") 530 if len(parts) < 1 { 531 return nil, errf("empty query") 532 } 533 stmt := &fakeStmt{q: query, c: c} 534 if firstStmt == nil { 535 firstStmt = stmt 536 } 537 if len(parts) >= 3 { 538 switch parts[0] { 539 case "PANIC": 540 stmt.panic = parts[1] 541 parts = parts[2:] 542 case "WAIT": 543 wait, err := time.ParseDuration(parts[1]) 544 if err != nil { 545 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 546 } 547 parts = parts[2:] 548 stmt.wait = wait 549 } 550 } 551 cmd := parts[0] 552 stmt.cmd = cmd 553 parts = parts[1:] 554 555 if stmt.wait > 0 { 556 wait := time.NewTimer(stmt.wait) 557 select { 558 case <-wait.C: 559 case <-ctx.Done(): 560 wait.Stop() 561 return nil, ctx.Err() 562 } 563 } 564 565 c.incrStat(&c.stmtsMade) 566 var err error 567 switch cmd { 568 case "WIPE": 569 // Nothing 570 case "SELECT": 571 stmt, err = c.prepareSelect(stmt, parts) 572 case "CREATE": 573 stmt, err = c.prepareCreate(stmt, parts) 574 case "INSERT": 575 stmt, err = c.prepareInsert(stmt, parts) 576 case "NOSERT": 577 // Do all the prep-work like for an INSERT but don't actually insert the row. 578 // Used for some of the concurrent tests. 579 stmt, err = c.prepareInsert(stmt, parts) 580 default: 581 stmt.Close() 582 return nil, errf("unsupported command type %q", cmd) 583 } 584 if err != nil { 585 return nil, err 586 } 587 if prev != nil { 588 prev.next = stmt 589 } 590 prev = stmt 591 } 592 return firstStmt, nil 593 } 594 595 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 596 if s.panic == "ColumnConverter" { 597 panic(s.panic) 598 } 599 if len(s.placeholderConverter) == 0 { 600 return driver.DefaultParameterConverter 601 } 602 return s.placeholderConverter[idx] 603 } 604 605 func (s *fakeStmt) Close() error { 606 if s.panic == "Close" { 607 panic(s.panic) 608 } 609 if s.c == nil { 610 panic("nil conn in fakeStmt.Close") 611 } 612 if s.c.db == nil { 613 panic("in fakeStmt.Close, conn's db is nil (already closed)") 614 } 615 if !s.closed { 616 s.c.incrStat(&s.c.stmtsClosed) 617 s.closed = true 618 } 619 if s.next != nil { 620 s.next.Close() 621 } 622 return nil 623 } 624 625 var errClosed = errors.New("fakedb: statement has been closed") 626 627 // hook to simulate broken connections 628 var hookExecBadConn func() bool 629 630 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 631 panic("Using ExecContext") 632 } 633 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 634 if s.panic == "Exec" { 635 panic(s.panic) 636 } 637 if s.closed { 638 return nil, errClosed 639 } 640 641 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 642 return nil, driver.ErrBadConn 643 } 644 645 err := checkSubsetTypes(args) 646 if err != nil { 647 return nil, err 648 } 649 650 if s.wait > 0 { 651 time.Sleep(s.wait) 652 } 653 654 select { 655 default: 656 case <-ctx.Done(): 657 return nil, ctx.Err() 658 } 659 660 db := s.c.db 661 switch s.cmd { 662 case "WIPE": 663 db.wipe() 664 return driver.ResultNoRows, nil 665 case "CREATE": 666 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 667 return nil, err 668 } 669 return driver.ResultNoRows, nil 670 case "INSERT": 671 return s.execInsert(args, true) 672 case "NOSERT": 673 // Do all the prep-work like for an INSERT but don't actually insert the row. 674 // Used for some of the concurrent tests. 675 return s.execInsert(args, false) 676 } 677 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) 678 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) 679 } 680 681 // When doInsert is true, add the row to the table. 682 // When doInsert is false do prep-work and error checking, but don't 683 // actually add the row to the table. 684 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 685 db := s.c.db 686 if len(args) != s.placeholders { 687 panic("error in pkg db; should only get here if size is correct") 688 } 689 db.mu.Lock() 690 t, ok := db.table(s.table) 691 db.mu.Unlock() 692 if !ok { 693 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 694 } 695 696 t.mu.Lock() 697 defer t.mu.Unlock() 698 699 var cols []interface{} 700 if doInsert { 701 cols = make([]interface{}, len(t.colname)) 702 } 703 argPos := 0 704 for n, colname := range s.colName { 705 colidx := t.columnIndex(colname) 706 if colidx == -1 { 707 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 708 } 709 var val interface{} 710 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 711 if strvalue == "?" { 712 val = args[argPos].Value 713 } else { 714 // Assign value from argument placeholder name. 715 for _, a := range args { 716 if a.Name == strvalue[1:] { 717 val = a.Value 718 break 719 } 720 } 721 } 722 argPos++ 723 } else { 724 val = s.colValue[n] 725 } 726 if doInsert { 727 cols[colidx] = val 728 } 729 } 730 731 if doInsert { 732 t.rows = append(t.rows, &row{cols: cols}) 733 } 734 return driver.RowsAffected(1), nil 735 } 736 737 // hook to simulate broken connections 738 var hookQueryBadConn func() bool 739 740 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 741 panic("Use QueryContext") 742 } 743 744 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 745 if s.panic == "Query" { 746 panic(s.panic) 747 } 748 if s.closed { 749 return nil, errClosed 750 } 751 752 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 753 return nil, driver.ErrBadConn 754 } 755 756 err := checkSubsetTypes(args) 757 if err != nil { 758 return nil, err 759 } 760 761 db := s.c.db 762 if len(args) != s.placeholders { 763 panic("error in pkg db; should only get here if size is correct") 764 } 765 766 setMRows := make([][]*row, 0, 1) 767 setColumns := make([][]string, 0, 1) 768 setColType := make([][]string, 0, 1) 769 770 for { 771 db.mu.Lock() 772 t, ok := db.table(s.table) 773 db.mu.Unlock() 774 if !ok { 775 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 776 } 777 778 if s.table == "magicquery" { 779 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 780 if args[0].Value == "sleep" { 781 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 782 } 783 } 784 } 785 786 t.mu.Lock() 787 788 colIdx := make(map[string]int) // select column name -> column index in table 789 for _, name := range s.colName { 790 idx := t.columnIndex(name) 791 if idx == -1 { 792 t.mu.Unlock() 793 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 794 } 795 colIdx[name] = idx 796 } 797 798 mrows := []*row{} 799 rows: 800 for _, trow := range t.rows { 801 // Process the where clause, skipping non-match rows. This is lazy 802 // and just uses fmt.Sprintf("%v") to test equality. Good enough 803 // for test code. 804 for _, wcol := range s.whereCol { 805 idx := t.columnIndex(wcol.Column) 806 if idx == -1 { 807 t.mu.Unlock() 808 return nil, fmt.Errorf("db: invalid where clause column %q", wcol) 809 } 810 tcol := trow.cols[idx] 811 if bs, ok := tcol.([]byte); ok { 812 // lazy hack to avoid sprintf %v on a []byte 813 tcol = string(bs) 814 } 815 var argValue interface{} 816 if wcol.Placeholder == "?" { 817 argValue = args[wcol.Ordinal-1].Value 818 } else { 819 // Assign arg value from placeholder name. 820 for _, a := range args { 821 if a.Name == wcol.Placeholder[1:] { 822 argValue = a.Value 823 break 824 } 825 } 826 } 827 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 828 continue rows 829 } 830 } 831 mrow := &row{cols: make([]interface{}, len(s.colName))} 832 for seli, name := range s.colName { 833 mrow.cols[seli] = trow.cols[colIdx[name]] 834 } 835 mrows = append(mrows, mrow) 836 } 837 838 var colType []string 839 for _, column := range s.colName { 840 colType = append(colType, t.coltype[t.columnIndex(column)]) 841 } 842 843 t.mu.Unlock() 844 845 setMRows = append(setMRows, mrows) 846 setColumns = append(setColumns, s.colName) 847 setColType = append(setColType, colType) 848 849 if s.next == nil { 850 break 851 } 852 s = s.next 853 } 854 855 cursor := &rowsCursor{ 856 posRow: -1, 857 rows: setMRows, 858 cols: setColumns, 859 colType: setColType, 860 errPos: -1, 861 } 862 return cursor, nil 863 } 864 865 func (s *fakeStmt) NumInput() int { 866 if s.panic == "NumInput" { 867 panic(s.panic) 868 } 869 return s.placeholders 870 } 871 872 // hook to simulate broken connections 873 var hookCommitBadConn func() bool 874 875 func (tx *fakeTx) Commit() error { 876 tx.c.currTx = nil 877 if hookCommitBadConn != nil && hookCommitBadConn() { 878 return driver.ErrBadConn 879 } 880 return nil 881 } 882 883 // hook to simulate broken connections 884 var hookRollbackBadConn func() bool 885 886 func (tx *fakeTx) Rollback() error { 887 tx.c.currTx = nil 888 if hookRollbackBadConn != nil && hookRollbackBadConn() { 889 return driver.ErrBadConn 890 } 891 return nil 892 } 893 894 type rowsCursor struct { 895 cols [][]string 896 colType [][]string 897 posSet int 898 posRow int 899 rows [][]*row 900 closed bool 901 902 // errPos and err are for making Next return early with error. 903 errPos int 904 err error 905 906 // a clone of slices to give out to clients, indexed by the 907 // the original slice's first byte address. we clone them 908 // just so we're able to corrupt them on close. 909 bytesClone map[*byte][]byte 910 } 911 912 func (rc *rowsCursor) Close() error { 913 if !rc.closed { 914 for _, bs := range rc.bytesClone { 915 bs[0] = 255 // first byte corrupted 916 } 917 } 918 rc.closed = true 919 return nil 920 } 921 922 func (rc *rowsCursor) Columns() []string { 923 return rc.cols[rc.posSet] 924 } 925 926 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 927 return colTypeToReflectType(rc.colType[rc.posSet][index]) 928 } 929 930 var rowsCursorNextHook func(dest []driver.Value) error 931 932 func (rc *rowsCursor) Next(dest []driver.Value) error { 933 if rowsCursorNextHook != nil { 934 return rowsCursorNextHook(dest) 935 } 936 937 if rc.closed { 938 return errors.New("fakedb: cursor is closed") 939 } 940 rc.posRow++ 941 if rc.posRow == rc.errPos { 942 return rc.err 943 } 944 if rc.posRow >= len(rc.rows[rc.posSet]) { 945 return io.EOF // per interface spec 946 } 947 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 948 // TODO(bradfitz): convert to subset types? naah, I 949 // think the subset types should only be input to 950 // driver, but the sql package should be able to handle 951 // a wider range of types coming out of drivers. all 952 // for ease of drivers, and to prevent drivers from 953 // messing up conversions or doing them differently. 954 dest[i] = v 955 956 if bs, ok := v.([]byte); ok { 957 if rc.bytesClone == nil { 958 rc.bytesClone = make(map[*byte][]byte) 959 } 960 clone, ok := rc.bytesClone[&bs[0]] 961 if !ok { 962 clone = make([]byte, len(bs)) 963 copy(clone, bs) 964 rc.bytesClone[&bs[0]] = clone 965 } 966 dest[i] = clone 967 } 968 } 969 return nil 970 } 971 972 func (rc *rowsCursor) HasNextResultSet() bool { 973 return rc.posSet < len(rc.rows)-1 974 } 975 976 func (rc *rowsCursor) NextResultSet() error { 977 if rc.HasNextResultSet() { 978 rc.posSet++ 979 rc.posRow = -1 980 return nil 981 } 982 return io.EOF // Per interface spec. 983 } 984 985 // fakeDriverString is like driver.String, but indirects pointers like 986 // DefaultValueConverter. 987 // 988 // This could be surprising behavior to retroactively apply to 989 // driver.String now that Go1 is out, but this is convenient for 990 // our TestPointerParamsAndScans. 991 // 992 type fakeDriverString struct{} 993 994 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 995 switch c := v.(type) { 996 case string, []byte: 997 return v, nil 998 case *string: 999 if c == nil { 1000 return nil, nil 1001 } 1002 return *c, nil 1003 } 1004 return fmt.Sprintf("%v", v), nil 1005 } 1006 1007 func converterForType(typ string) driver.ValueConverter { 1008 switch typ { 1009 case "bool": 1010 return driver.Bool 1011 case "nullbool": 1012 return driver.Null{Converter: driver.Bool} 1013 case "int32": 1014 return driver.Int32 1015 case "string": 1016 return driver.NotNull{Converter: fakeDriverString{}} 1017 case "nullstring": 1018 return driver.Null{Converter: fakeDriverString{}} 1019 case "int64": 1020 // TODO(coopernurse): add type-specific converter 1021 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1022 case "nullint64": 1023 // TODO(coopernurse): add type-specific converter 1024 return driver.Null{Converter: driver.DefaultParameterConverter} 1025 case "float64": 1026 // TODO(coopernurse): add type-specific converter 1027 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1028 case "nullfloat64": 1029 // TODO(coopernurse): add type-specific converter 1030 return driver.Null{Converter: driver.DefaultParameterConverter} 1031 case "datetime": 1032 return driver.DefaultParameterConverter 1033 } 1034 panic("invalid fakedb column type of " + typ) 1035 } 1036 1037 func colTypeToReflectType(typ string) reflect.Type { 1038 switch typ { 1039 case "bool": 1040 return reflect.TypeOf(false) 1041 case "nullbool": 1042 return reflect.TypeOf(NullBool{}) 1043 case "int32": 1044 return reflect.TypeOf(int32(0)) 1045 case "string": 1046 return reflect.TypeOf("") 1047 case "nullstring": 1048 return reflect.TypeOf(NullString{}) 1049 case "int64": 1050 return reflect.TypeOf(int64(0)) 1051 case "nullint64": 1052 return reflect.TypeOf(NullInt64{}) 1053 case "float64": 1054 return reflect.TypeOf(float64(0)) 1055 case "nullfloat64": 1056 return reflect.TypeOf(NullFloat64{}) 1057 case "datetime": 1058 return reflect.TypeOf(time.Time{}) 1059 } 1060 panic("invalid fakedb column type of " + typ) 1061 }