github.com/sanprasirt/go@v0.0.0-20170607001320-a027466e4b6d/src/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "log" 14 "reflect" 15 "sort" 16 "strconv" 17 "strings" 18 "sync" 19 "testing" 20 "time" 21 ) 22 23 var _ = log.Printf 24 25 // fakeDriver is a fake database that implements Go's driver.Driver 26 // interface, just for testing. 27 // 28 // It speaks a query language that's semantically similar to but 29 // syntactically different and simpler than SQL. The syntax is as 30 // follows: 31 // 32 // WIPE 33 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 34 // where types are: "string", [u]int{8,16,32,64}, "bool" 35 // INSERT|<tablename>|col=val,col2=val2,col3=? 36 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 37 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 38 // 39 // Any of these can be preceded by PANIC|<method>|, to cause the 40 // named method on fakeStmt to panic. 41 // 42 // Any of these can be proceeded by WAIT|<duration>|, to cause the 43 // named method on fakeStmt to sleep for the specified duration. 44 // 45 // Multiple of these can be combined when separated with a semicolon. 46 // 47 // When opening a fakeDriver's database, it starts empty with no 48 // tables. All tables and data are stored in memory only. 49 type fakeDriver struct { 50 mu sync.Mutex // guards 3 following fields 51 openCount int // conn opens 52 closeCount int // conn closes 53 waitCh chan struct{} 54 waitingCh chan struct{} 55 dbs map[string]*fakeDB 56 } 57 58 type fakeDB struct { 59 name string 60 61 mu sync.Mutex 62 tables map[string]*table 63 badConn bool 64 allowAny bool 65 } 66 67 type table struct { 68 mu sync.Mutex 69 colname []string 70 coltype []string 71 rows []*row 72 } 73 74 func (t *table) columnIndex(name string) int { 75 for n, nname := range t.colname { 76 if name == nname { 77 return n 78 } 79 } 80 return -1 81 } 82 83 type row struct { 84 cols []interface{} // must be same size as its table colname + coltype 85 } 86 87 type fakeConn struct { 88 db *fakeDB // where to return ourselves to 89 90 currTx *fakeTx 91 92 // Stats for tests: 93 mu sync.Mutex 94 stmtsMade int 95 stmtsClosed int 96 numPrepare int 97 98 // bad connection tests; see isBad() 99 bad bool 100 stickyBad bool 101 } 102 103 func (c *fakeConn) incrStat(v *int) { 104 c.mu.Lock() 105 *v++ 106 c.mu.Unlock() 107 } 108 109 type fakeTx struct { 110 c *fakeConn 111 } 112 113 type boundCol struct { 114 Column string 115 Placeholder string 116 Ordinal int 117 } 118 119 type fakeStmt struct { 120 c *fakeConn 121 q string // just for debugging 122 123 cmd string 124 table string 125 panic string 126 wait time.Duration 127 128 next *fakeStmt // used for returning multiple results. 129 130 closed bool 131 132 colName []string // used by CREATE, INSERT, SELECT (selected columns) 133 colType []string // used by CREATE 134 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) 135 placeholders int // used by INSERT/SELECT: number of ? params 136 137 whereCol []boundCol // used by SELECT (all placeholders) 138 139 placeholderConverter []driver.ValueConverter // used by INSERT 140 } 141 142 var fdriver driver.Driver = &fakeDriver{} 143 144 func init() { 145 Register("test", fdriver) 146 } 147 148 func contains(list []string, y string) bool { 149 for _, x := range list { 150 if x == y { 151 return true 152 } 153 } 154 return false 155 } 156 157 type Dummy struct { 158 driver.Driver 159 } 160 161 func TestDrivers(t *testing.T) { 162 unregisterAllDrivers() 163 Register("test", fdriver) 164 Register("invalid", Dummy{}) 165 all := Drivers() 166 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 167 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 168 } 169 } 170 171 // hook to simulate connection failures 172 var hookOpenErr struct { 173 sync.Mutex 174 fn func() error 175 } 176 177 func setHookOpenErr(fn func() error) { 178 hookOpenErr.Lock() 179 defer hookOpenErr.Unlock() 180 hookOpenErr.fn = fn 181 } 182 183 // Supports dsn forms: 184 // <dbname> 185 // <dbname>;<opts> (only currently supported option is `badConn`, 186 // which causes driver.ErrBadConn to be returned on 187 // every other conn.Begin()) 188 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 189 hookOpenErr.Lock() 190 fn := hookOpenErr.fn 191 hookOpenErr.Unlock() 192 if fn != nil { 193 if err := fn(); err != nil { 194 return nil, err 195 } 196 } 197 parts := strings.Split(dsn, ";") 198 if len(parts) < 1 { 199 return nil, errors.New("fakedb: no database name") 200 } 201 name := parts[0] 202 203 db := d.getDB(name) 204 205 d.mu.Lock() 206 d.openCount++ 207 d.mu.Unlock() 208 conn := &fakeConn{db: db} 209 210 if len(parts) >= 2 && parts[1] == "badConn" { 211 conn.bad = true 212 } 213 if d.waitCh != nil { 214 d.waitingCh <- struct{}{} 215 <-d.waitCh 216 d.waitCh = nil 217 d.waitingCh = nil 218 } 219 return conn, nil 220 } 221 222 func (d *fakeDriver) getDB(name string) *fakeDB { 223 d.mu.Lock() 224 defer d.mu.Unlock() 225 if d.dbs == nil { 226 d.dbs = make(map[string]*fakeDB) 227 } 228 db, ok := d.dbs[name] 229 if !ok { 230 db = &fakeDB{name: name} 231 d.dbs[name] = db 232 } 233 return db 234 } 235 236 func (db *fakeDB) wipe() { 237 db.mu.Lock() 238 defer db.mu.Unlock() 239 db.tables = nil 240 } 241 242 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 243 db.mu.Lock() 244 defer db.mu.Unlock() 245 if db.tables == nil { 246 db.tables = make(map[string]*table) 247 } 248 if _, exist := db.tables[name]; exist { 249 return fmt.Errorf("table %q already exists", name) 250 } 251 if len(columnNames) != len(columnTypes) { 252 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", 253 name, len(columnNames), len(columnTypes)) 254 } 255 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 256 return nil 257 } 258 259 // must be called with db.mu lock held 260 func (db *fakeDB) table(table string) (*table, bool) { 261 if db.tables == nil { 262 return nil, false 263 } 264 t, ok := db.tables[table] 265 return t, ok 266 } 267 268 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 269 db.mu.Lock() 270 defer db.mu.Unlock() 271 t, ok := db.table(table) 272 if !ok { 273 return 274 } 275 for n, cname := range t.colname { 276 if cname == column { 277 return t.coltype[n], true 278 } 279 } 280 return "", false 281 } 282 283 func (c *fakeConn) isBad() bool { 284 if c.stickyBad { 285 return true 286 } else if c.bad { 287 // alternate between bad conn and not bad conn 288 c.db.badConn = !c.db.badConn 289 return c.db.badConn 290 } else { 291 return false 292 } 293 } 294 295 func (c *fakeConn) Begin() (driver.Tx, error) { 296 if c.isBad() { 297 return nil, driver.ErrBadConn 298 } 299 if c.currTx != nil { 300 return nil, errors.New("already in a transaction") 301 } 302 c.currTx = &fakeTx{c: c} 303 return c.currTx, nil 304 } 305 306 var hookPostCloseConn struct { 307 sync.Mutex 308 fn func(*fakeConn, error) 309 } 310 311 func setHookpostCloseConn(fn func(*fakeConn, error)) { 312 hookPostCloseConn.Lock() 313 defer hookPostCloseConn.Unlock() 314 hookPostCloseConn.fn = fn 315 } 316 317 var testStrictClose *testing.T 318 319 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 320 // fails to close. If nil, the check is disabled. 321 func setStrictFakeConnClose(t *testing.T) { 322 testStrictClose = t 323 } 324 325 func (c *fakeConn) Close() (err error) { 326 drv := fdriver.(*fakeDriver) 327 defer func() { 328 if err != nil && testStrictClose != nil { 329 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 330 } 331 hookPostCloseConn.Lock() 332 fn := hookPostCloseConn.fn 333 hookPostCloseConn.Unlock() 334 if fn != nil { 335 fn(c, err) 336 } 337 if err == nil { 338 drv.mu.Lock() 339 drv.closeCount++ 340 drv.mu.Unlock() 341 } 342 }() 343 if c.currTx != nil { 344 return errors.New("can't close fakeConn; in a Transaction") 345 } 346 if c.db == nil { 347 return errors.New("can't close fakeConn; already closed") 348 } 349 if c.stmtsMade > c.stmtsClosed { 350 return errors.New("can't close; dangling statement(s)") 351 } 352 c.db = nil 353 return nil 354 } 355 356 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 357 for _, arg := range args { 358 switch arg.Value.(type) { 359 case int64, float64, bool, nil, []byte, string, time.Time: 360 default: 361 if !allowAny { 362 return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 363 } 364 } 365 } 366 return nil 367 } 368 369 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 370 // Ensure that ExecContext is called if available. 371 panic("ExecContext was not called.") 372 } 373 374 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 375 // This is an optional interface, but it's implemented here 376 // just to check that all the args are of the proper types. 377 // ErrSkip is returned so the caller acts as if we didn't 378 // implement this at all. 379 err := checkSubsetTypes(c.db.allowAny, args) 380 if err != nil { 381 return nil, err 382 } 383 return nil, driver.ErrSkip 384 } 385 386 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 387 // Ensure that ExecContext is called if available. 388 panic("QueryContext was not called.") 389 } 390 391 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 392 // This is an optional interface, but it's implemented here 393 // just to check that all the args are of the proper types. 394 // ErrSkip is returned so the caller acts as if we didn't 395 // implement this at all. 396 err := checkSubsetTypes(c.db.allowAny, args) 397 if err != nil { 398 return nil, err 399 } 400 return nil, driver.ErrSkip 401 } 402 403 func errf(msg string, args ...interface{}) error { 404 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 405 } 406 407 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 408 // (note that where columns must always contain ? marks, 409 // just a limitation for fakedb) 410 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 411 if len(parts) != 3 { 412 stmt.Close() 413 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 414 } 415 stmt.table = parts[0] 416 417 stmt.colName = strings.Split(parts[1], ",") 418 for n, colspec := range strings.Split(parts[2], ",") { 419 if colspec == "" { 420 continue 421 } 422 nameVal := strings.Split(colspec, "=") 423 if len(nameVal) != 2 { 424 stmt.Close() 425 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 426 } 427 column, value := nameVal[0], nameVal[1] 428 _, ok := c.db.columnType(stmt.table, column) 429 if !ok { 430 stmt.Close() 431 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 432 } 433 if !strings.HasPrefix(value, "?") { 434 stmt.Close() 435 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 436 stmt.table, column) 437 } 438 stmt.placeholders++ 439 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 440 } 441 return stmt, nil 442 } 443 444 // parts are table|col=type,col2=type2 445 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 446 if len(parts) != 2 { 447 stmt.Close() 448 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 449 } 450 stmt.table = parts[0] 451 for n, colspec := range strings.Split(parts[1], ",") { 452 nameType := strings.Split(colspec, "=") 453 if len(nameType) != 2 { 454 stmt.Close() 455 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 456 } 457 stmt.colName = append(stmt.colName, nameType[0]) 458 stmt.colType = append(stmt.colType, nameType[1]) 459 } 460 return stmt, nil 461 } 462 463 // parts are table|col=?,col2=val 464 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 465 if len(parts) != 2 { 466 stmt.Close() 467 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 468 } 469 stmt.table = parts[0] 470 for n, colspec := range strings.Split(parts[1], ",") { 471 nameVal := strings.Split(colspec, "=") 472 if len(nameVal) != 2 { 473 stmt.Close() 474 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 475 } 476 column, value := nameVal[0], nameVal[1] 477 ctype, ok := c.db.columnType(stmt.table, column) 478 if !ok { 479 stmt.Close() 480 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 481 } 482 stmt.colName = append(stmt.colName, column) 483 484 if !strings.HasPrefix(value, "?") { 485 var subsetVal interface{} 486 // Convert to driver subset type 487 switch ctype { 488 case "string": 489 subsetVal = []byte(value) 490 case "blob": 491 subsetVal = []byte(value) 492 case "int32": 493 i, err := strconv.Atoi(value) 494 if err != nil { 495 stmt.Close() 496 return nil, errf("invalid conversion to int32 from %q", value) 497 } 498 subsetVal = int64(i) // int64 is a subset type, but not int32 499 default: 500 stmt.Close() 501 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 502 } 503 stmt.colValue = append(stmt.colValue, subsetVal) 504 } else { 505 stmt.placeholders++ 506 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 507 stmt.colValue = append(stmt.colValue, value) 508 } 509 } 510 return stmt, nil 511 } 512 513 // hook to simulate broken connections 514 var hookPrepareBadConn func() bool 515 516 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 517 panic("use PrepareContext") 518 } 519 520 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 521 c.numPrepare++ 522 if c.db == nil { 523 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 524 } 525 526 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 527 return nil, driver.ErrBadConn 528 } 529 530 var firstStmt, prev *fakeStmt 531 for _, query := range strings.Split(query, ";") { 532 parts := strings.Split(query, "|") 533 if len(parts) < 1 { 534 return nil, errf("empty query") 535 } 536 stmt := &fakeStmt{q: query, c: c} 537 if firstStmt == nil { 538 firstStmt = stmt 539 } 540 if len(parts) >= 3 { 541 switch parts[0] { 542 case "PANIC": 543 stmt.panic = parts[1] 544 parts = parts[2:] 545 case "WAIT": 546 wait, err := time.ParseDuration(parts[1]) 547 if err != nil { 548 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 549 } 550 parts = parts[2:] 551 stmt.wait = wait 552 } 553 } 554 cmd := parts[0] 555 stmt.cmd = cmd 556 parts = parts[1:] 557 558 if stmt.wait > 0 { 559 wait := time.NewTimer(stmt.wait) 560 select { 561 case <-wait.C: 562 case <-ctx.Done(): 563 wait.Stop() 564 return nil, ctx.Err() 565 } 566 } 567 568 c.incrStat(&c.stmtsMade) 569 var err error 570 switch cmd { 571 case "WIPE": 572 // Nothing 573 case "SELECT": 574 stmt, err = c.prepareSelect(stmt, parts) 575 case "CREATE": 576 stmt, err = c.prepareCreate(stmt, parts) 577 case "INSERT": 578 stmt, err = c.prepareInsert(stmt, parts) 579 case "NOSERT": 580 // Do all the prep-work like for an INSERT but don't actually insert the row. 581 // Used for some of the concurrent tests. 582 stmt, err = c.prepareInsert(stmt, parts) 583 default: 584 stmt.Close() 585 return nil, errf("unsupported command type %q", cmd) 586 } 587 if err != nil { 588 return nil, err 589 } 590 if prev != nil { 591 prev.next = stmt 592 } 593 prev = stmt 594 } 595 return firstStmt, nil 596 } 597 598 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 599 if s.panic == "ColumnConverter" { 600 panic(s.panic) 601 } 602 if len(s.placeholderConverter) == 0 { 603 return driver.DefaultParameterConverter 604 } 605 return s.placeholderConverter[idx] 606 } 607 608 func (s *fakeStmt) Close() error { 609 if s.panic == "Close" { 610 panic(s.panic) 611 } 612 if s.c == nil { 613 panic("nil conn in fakeStmt.Close") 614 } 615 if s.c.db == nil { 616 panic("in fakeStmt.Close, conn's db is nil (already closed)") 617 } 618 if !s.closed { 619 s.c.incrStat(&s.c.stmtsClosed) 620 s.closed = true 621 } 622 if s.next != nil { 623 s.next.Close() 624 } 625 return nil 626 } 627 628 var errClosed = errors.New("fakedb: statement has been closed") 629 630 // hook to simulate broken connections 631 var hookExecBadConn func() bool 632 633 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 634 panic("Using ExecContext") 635 } 636 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 637 if s.panic == "Exec" { 638 panic(s.panic) 639 } 640 if s.closed { 641 return nil, errClosed 642 } 643 644 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 645 return nil, driver.ErrBadConn 646 } 647 648 err := checkSubsetTypes(s.c.db.allowAny, args) 649 if err != nil { 650 return nil, err 651 } 652 653 if s.wait > 0 { 654 time.Sleep(s.wait) 655 } 656 657 select { 658 default: 659 case <-ctx.Done(): 660 return nil, ctx.Err() 661 } 662 663 db := s.c.db 664 switch s.cmd { 665 case "WIPE": 666 db.wipe() 667 return driver.ResultNoRows, nil 668 case "CREATE": 669 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 670 return nil, err 671 } 672 return driver.ResultNoRows, nil 673 case "INSERT": 674 return s.execInsert(args, true) 675 case "NOSERT": 676 // Do all the prep-work like for an INSERT but don't actually insert the row. 677 // Used for some of the concurrent tests. 678 return s.execInsert(args, false) 679 } 680 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) 681 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) 682 } 683 684 // When doInsert is true, add the row to the table. 685 // When doInsert is false do prep-work and error checking, but don't 686 // actually add the row to the table. 687 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 688 db := s.c.db 689 if len(args) != s.placeholders { 690 panic("error in pkg db; should only get here if size is correct") 691 } 692 db.mu.Lock() 693 t, ok := db.table(s.table) 694 db.mu.Unlock() 695 if !ok { 696 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 697 } 698 699 t.mu.Lock() 700 defer t.mu.Unlock() 701 702 var cols []interface{} 703 if doInsert { 704 cols = make([]interface{}, len(t.colname)) 705 } 706 argPos := 0 707 for n, colname := range s.colName { 708 colidx := t.columnIndex(colname) 709 if colidx == -1 { 710 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 711 } 712 var val interface{} 713 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 714 if strvalue == "?" { 715 val = args[argPos].Value 716 } else { 717 // Assign value from argument placeholder name. 718 for _, a := range args { 719 if a.Name == strvalue[1:] { 720 val = a.Value 721 break 722 } 723 } 724 } 725 argPos++ 726 } else { 727 val = s.colValue[n] 728 } 729 if doInsert { 730 cols[colidx] = val 731 } 732 } 733 734 if doInsert { 735 t.rows = append(t.rows, &row{cols: cols}) 736 } 737 return driver.RowsAffected(1), nil 738 } 739 740 // hook to simulate broken connections 741 var hookQueryBadConn func() bool 742 743 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 744 panic("Use QueryContext") 745 } 746 747 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 748 if s.panic == "Query" { 749 panic(s.panic) 750 } 751 if s.closed { 752 return nil, errClosed 753 } 754 755 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 756 return nil, driver.ErrBadConn 757 } 758 759 err := checkSubsetTypes(s.c.db.allowAny, args) 760 if err != nil { 761 return nil, err 762 } 763 764 db := s.c.db 765 if len(args) != s.placeholders { 766 panic("error in pkg db; should only get here if size is correct") 767 } 768 769 setMRows := make([][]*row, 0, 1) 770 setColumns := make([][]string, 0, 1) 771 setColType := make([][]string, 0, 1) 772 773 for { 774 db.mu.Lock() 775 t, ok := db.table(s.table) 776 db.mu.Unlock() 777 if !ok { 778 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 779 } 780 781 if s.table == "magicquery" { 782 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 783 if args[0].Value == "sleep" { 784 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 785 } 786 } 787 } 788 789 t.mu.Lock() 790 791 colIdx := make(map[string]int) // select column name -> column index in table 792 for _, name := range s.colName { 793 idx := t.columnIndex(name) 794 if idx == -1 { 795 t.mu.Unlock() 796 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 797 } 798 colIdx[name] = idx 799 } 800 801 mrows := []*row{} 802 rows: 803 for _, trow := range t.rows { 804 // Process the where clause, skipping non-match rows. This is lazy 805 // and just uses fmt.Sprintf("%v") to test equality. Good enough 806 // for test code. 807 for _, wcol := range s.whereCol { 808 idx := t.columnIndex(wcol.Column) 809 if idx == -1 { 810 t.mu.Unlock() 811 return nil, fmt.Errorf("db: invalid where clause column %q", wcol) 812 } 813 tcol := trow.cols[idx] 814 if bs, ok := tcol.([]byte); ok { 815 // lazy hack to avoid sprintf %v on a []byte 816 tcol = string(bs) 817 } 818 var argValue interface{} 819 if wcol.Placeholder == "?" { 820 argValue = args[wcol.Ordinal-1].Value 821 } else { 822 // Assign arg value from placeholder name. 823 for _, a := range args { 824 if a.Name == wcol.Placeholder[1:] { 825 argValue = a.Value 826 break 827 } 828 } 829 } 830 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 831 continue rows 832 } 833 } 834 mrow := &row{cols: make([]interface{}, len(s.colName))} 835 for seli, name := range s.colName { 836 mrow.cols[seli] = trow.cols[colIdx[name]] 837 } 838 mrows = append(mrows, mrow) 839 } 840 841 var colType []string 842 for _, column := range s.colName { 843 colType = append(colType, t.coltype[t.columnIndex(column)]) 844 } 845 846 t.mu.Unlock() 847 848 setMRows = append(setMRows, mrows) 849 setColumns = append(setColumns, s.colName) 850 setColType = append(setColType, colType) 851 852 if s.next == nil { 853 break 854 } 855 s = s.next 856 } 857 858 cursor := &rowsCursor{ 859 posRow: -1, 860 rows: setMRows, 861 cols: setColumns, 862 colType: setColType, 863 errPos: -1, 864 } 865 return cursor, nil 866 } 867 868 func (s *fakeStmt) NumInput() int { 869 if s.panic == "NumInput" { 870 panic(s.panic) 871 } 872 return s.placeholders 873 } 874 875 // hook to simulate broken connections 876 var hookCommitBadConn func() bool 877 878 func (tx *fakeTx) Commit() error { 879 tx.c.currTx = nil 880 if hookCommitBadConn != nil && hookCommitBadConn() { 881 return driver.ErrBadConn 882 } 883 return nil 884 } 885 886 // hook to simulate broken connections 887 var hookRollbackBadConn func() bool 888 889 func (tx *fakeTx) Rollback() error { 890 tx.c.currTx = nil 891 if hookRollbackBadConn != nil && hookRollbackBadConn() { 892 return driver.ErrBadConn 893 } 894 return nil 895 } 896 897 type rowsCursor struct { 898 cols [][]string 899 colType [][]string 900 posSet int 901 posRow int 902 rows [][]*row 903 closed bool 904 905 // errPos and err are for making Next return early with error. 906 errPos int 907 err error 908 909 // a clone of slices to give out to clients, indexed by the 910 // the original slice's first byte address. we clone them 911 // just so we're able to corrupt them on close. 912 bytesClone map[*byte][]byte 913 } 914 915 func (rc *rowsCursor) Close() error { 916 if !rc.closed { 917 for _, bs := range rc.bytesClone { 918 bs[0] = 255 // first byte corrupted 919 } 920 } 921 rc.closed = true 922 return nil 923 } 924 925 func (rc *rowsCursor) Columns() []string { 926 return rc.cols[rc.posSet] 927 } 928 929 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 930 return colTypeToReflectType(rc.colType[rc.posSet][index]) 931 } 932 933 var rowsCursorNextHook func(dest []driver.Value) error 934 935 func (rc *rowsCursor) Next(dest []driver.Value) error { 936 if rowsCursorNextHook != nil { 937 return rowsCursorNextHook(dest) 938 } 939 940 if rc.closed { 941 return errors.New("fakedb: cursor is closed") 942 } 943 rc.posRow++ 944 if rc.posRow == rc.errPos { 945 return rc.err 946 } 947 if rc.posRow >= len(rc.rows[rc.posSet]) { 948 return io.EOF // per interface spec 949 } 950 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 951 // TODO(bradfitz): convert to subset types? naah, I 952 // think the subset types should only be input to 953 // driver, but the sql package should be able to handle 954 // a wider range of types coming out of drivers. all 955 // for ease of drivers, and to prevent drivers from 956 // messing up conversions or doing them differently. 957 dest[i] = v 958 959 if bs, ok := v.([]byte); ok { 960 if rc.bytesClone == nil { 961 rc.bytesClone = make(map[*byte][]byte) 962 } 963 clone, ok := rc.bytesClone[&bs[0]] 964 if !ok { 965 clone = make([]byte, len(bs)) 966 copy(clone, bs) 967 rc.bytesClone[&bs[0]] = clone 968 } 969 dest[i] = clone 970 } 971 } 972 return nil 973 } 974 975 func (rc *rowsCursor) HasNextResultSet() bool { 976 return rc.posSet < len(rc.rows)-1 977 } 978 979 func (rc *rowsCursor) NextResultSet() error { 980 if rc.HasNextResultSet() { 981 rc.posSet++ 982 rc.posRow = -1 983 return nil 984 } 985 return io.EOF // Per interface spec. 986 } 987 988 // fakeDriverString is like driver.String, but indirects pointers like 989 // DefaultValueConverter. 990 // 991 // This could be surprising behavior to retroactively apply to 992 // driver.String now that Go1 is out, but this is convenient for 993 // our TestPointerParamsAndScans. 994 // 995 type fakeDriverString struct{} 996 997 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { 998 switch c := v.(type) { 999 case string, []byte: 1000 return v, nil 1001 case *string: 1002 if c == nil { 1003 return nil, nil 1004 } 1005 return *c, nil 1006 } 1007 return fmt.Sprintf("%v", v), nil 1008 } 1009 1010 type anyTypeConverter struct{} 1011 1012 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { 1013 return v, nil 1014 } 1015 1016 func converterForType(typ string) driver.ValueConverter { 1017 switch typ { 1018 case "bool": 1019 return driver.Bool 1020 case "nullbool": 1021 return driver.Null{Converter: driver.Bool} 1022 case "int32": 1023 return driver.Int32 1024 case "string": 1025 return driver.NotNull{Converter: fakeDriverString{}} 1026 case "nullstring": 1027 return driver.Null{Converter: fakeDriverString{}} 1028 case "int64": 1029 // TODO(coopernurse): add type-specific converter 1030 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1031 case "nullint64": 1032 // TODO(coopernurse): add type-specific converter 1033 return driver.Null{Converter: driver.DefaultParameterConverter} 1034 case "float64": 1035 // TODO(coopernurse): add type-specific converter 1036 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1037 case "nullfloat64": 1038 // TODO(coopernurse): add type-specific converter 1039 return driver.Null{Converter: driver.DefaultParameterConverter} 1040 case "datetime": 1041 return driver.DefaultParameterConverter 1042 case "any": 1043 return anyTypeConverter{} 1044 } 1045 panic("invalid fakedb column type of " + typ) 1046 } 1047 1048 func colTypeToReflectType(typ string) reflect.Type { 1049 switch typ { 1050 case "bool": 1051 return reflect.TypeOf(false) 1052 case "nullbool": 1053 return reflect.TypeOf(NullBool{}) 1054 case "int32": 1055 return reflect.TypeOf(int32(0)) 1056 case "string": 1057 return reflect.TypeOf("") 1058 case "nullstring": 1059 return reflect.TypeOf(NullString{}) 1060 case "int64": 1061 return reflect.TypeOf(int64(0)) 1062 case "nullint64": 1063 return reflect.TypeOf(NullInt64{}) 1064 case "float64": 1065 return reflect.TypeOf(float64(0)) 1066 case "nullfloat64": 1067 return reflect.TypeOf(NullFloat64{}) 1068 case "datetime": 1069 return reflect.TypeOf(time.Time{}) 1070 case "any": 1071 return reflect.TypeOf(new(interface{})).Elem() 1072 } 1073 panic("invalid fakedb column type of " + typ) 1074 }