github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "reflect" 14 "sort" 15 "strconv" 16 "strings" 17 "sync" 18 "testing" 19 "time" 20 ) 21 22 // fakeDriver is a fake database that implements Go's driver.Driver 23 // interface, just for testing. 24 // 25 // It speaks a query language that's semantically similar to but 26 // syntactically different and simpler than SQL. The syntax is as 27 // follows: 28 // 29 // WIPE 30 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 31 // where types are: "string", [u]int{8,16,32,64}, "bool" 32 // INSERT|<tablename>|col=val,col2=val2,col3=? 33 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 34 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 35 // 36 // Any of these can be preceded by PANIC|<method>|, to cause the 37 // named method on fakeStmt to panic. 38 // 39 // Any of these can be proceeded by WAIT|<duration>|, to cause the 40 // named method on fakeStmt to sleep for the specified duration. 41 // 42 // Multiple of these can be combined when separated with a semicolon. 43 // 44 // When opening a fakeDriver's database, it starts empty with no 45 // tables. All tables and data are stored in memory only. 46 type fakeDriver struct { 47 mu sync.Mutex // guards 3 following fields 48 openCount int // conn opens 49 closeCount int // conn closes 50 waitCh chan struct{} 51 waitingCh chan struct{} 52 dbs map[string]*fakeDB 53 } 54 55 type fakeConnector struct { 56 name string 57 58 waiter func(context.Context) 59 closed bool 60 } 61 62 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 63 conn, err := fdriver.Open(c.name) 64 conn.(*fakeConn).waiter = c.waiter 65 return conn, err 66 } 67 68 func (c *fakeConnector) Driver() driver.Driver { 69 return fdriver 70 } 71 72 func (c *fakeConnector) Close() error { 73 if c.closed { 74 return errors.New("fakedb: connector is closed") 75 } 76 c.closed = true 77 return nil 78 } 79 80 type fakeDriverCtx struct { 81 fakeDriver 82 } 83 84 var _ driver.DriverContext = &fakeDriverCtx{} 85 86 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 87 return &fakeConnector{name: name}, nil 88 } 89 90 type fakeDB struct { 91 name string 92 93 mu sync.Mutex 94 tables map[string]*table 95 badConn bool 96 allowAny bool 97 } 98 99 type fakeError struct { 100 Message string 101 Wrapped error 102 } 103 104 func (err fakeError) Error() string { 105 return err.Message 106 } 107 108 func (err fakeError) Unwrap() error { 109 return err.Wrapped 110 } 111 112 type table struct { 113 mu sync.Mutex 114 colname []string 115 coltype []string 116 rows []*row 117 } 118 119 func (t *table) columnIndex(name string) int { 120 for n, nname := range t.colname { 121 if name == nname { 122 return n 123 } 124 } 125 return -1 126 } 127 128 type row struct { 129 cols []any // must be same size as its table colname + coltype 130 } 131 132 type memToucher interface { 133 // touchMem reads & writes some memory, to help find data races. 134 touchMem() 135 } 136 137 type fakeConn struct { 138 db *fakeDB // where to return ourselves to 139 140 currTx *fakeTx 141 142 // Every operation writes to line to enable the race detector 143 // check for data races. 144 line int64 145 146 // Stats for tests: 147 mu sync.Mutex 148 stmtsMade int 149 stmtsClosed int 150 numPrepare int 151 152 // bad connection tests; see isBad() 153 bad bool 154 stickyBad bool 155 156 skipDirtySession bool // tests that use Conn should set this to true. 157 158 // dirtySession tests ResetSession, true if a query has executed 159 // until ResetSession is called. 160 dirtySession bool 161 162 // The waiter is called before each query. May be used in place of the "WAIT" 163 // directive. 164 waiter func(context.Context) 165 } 166 167 func (c *fakeConn) touchMem() { 168 c.line++ 169 } 170 171 func (c *fakeConn) incrStat(v *int) { 172 c.mu.Lock() 173 *v++ 174 c.mu.Unlock() 175 } 176 177 type fakeTx struct { 178 c *fakeConn 179 } 180 181 type boundCol struct { 182 Column string 183 Placeholder string 184 Ordinal int 185 } 186 187 type fakeStmt struct { 188 memToucher 189 c *fakeConn 190 q string // just for debugging 191 192 cmd string 193 table string 194 panic string 195 wait time.Duration 196 197 next *fakeStmt // used for returning multiple results. 198 199 closed bool 200 201 colName []string // used by CREATE, INSERT, SELECT (selected columns) 202 colType []string // used by CREATE 203 colValue []any // used by INSERT (mix of strings and "?" for bound params) 204 placeholders int // used by INSERT/SELECT: number of ? params 205 206 whereCol []boundCol // used by SELECT (all placeholders) 207 208 placeholderConverter []driver.ValueConverter // used by INSERT 209 } 210 211 var fdriver driver.Driver = &fakeDriver{} 212 213 func init() { 214 Register("test", fdriver) 215 } 216 217 func contains(list []string, y string) bool { 218 for _, x := range list { 219 if x == y { 220 return true 221 } 222 } 223 return false 224 } 225 226 type Dummy struct { 227 driver.Driver 228 } 229 230 func TestDrivers(t *testing.T) { 231 unregisterAllDrivers() 232 Register("test", fdriver) 233 Register("invalid", Dummy{}) 234 all := Drivers() 235 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 236 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 237 } 238 } 239 240 // hook to simulate connection failures 241 var hookOpenErr struct { 242 sync.Mutex 243 fn func() error 244 } 245 246 func setHookOpenErr(fn func() error) { 247 hookOpenErr.Lock() 248 defer hookOpenErr.Unlock() 249 hookOpenErr.fn = fn 250 } 251 252 // Supports dsn forms: 253 // 254 // <dbname> 255 // <dbname>;<opts> (only currently supported option is `badConn`, 256 // which causes driver.ErrBadConn to be returned on 257 // every other conn.Begin()) 258 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 259 hookOpenErr.Lock() 260 fn := hookOpenErr.fn 261 hookOpenErr.Unlock() 262 if fn != nil { 263 if err := fn(); err != nil { 264 return nil, err 265 } 266 } 267 parts := strings.Split(dsn, ";") 268 if len(parts) < 1 { 269 return nil, errors.New("fakedb: no database name") 270 } 271 name := parts[0] 272 273 db := d.getDB(name) 274 275 d.mu.Lock() 276 d.openCount++ 277 d.mu.Unlock() 278 conn := &fakeConn{db: db} 279 280 if len(parts) >= 2 && parts[1] == "badConn" { 281 conn.bad = true 282 } 283 if d.waitCh != nil { 284 d.waitingCh <- struct{}{} 285 <-d.waitCh 286 d.waitCh = nil 287 d.waitingCh = nil 288 } 289 return conn, nil 290 } 291 292 func (d *fakeDriver) getDB(name string) *fakeDB { 293 d.mu.Lock() 294 defer d.mu.Unlock() 295 if d.dbs == nil { 296 d.dbs = make(map[string]*fakeDB) 297 } 298 db, ok := d.dbs[name] 299 if !ok { 300 db = &fakeDB{name: name} 301 d.dbs[name] = db 302 } 303 return db 304 } 305 306 func (db *fakeDB) wipe() { 307 db.mu.Lock() 308 defer db.mu.Unlock() 309 db.tables = nil 310 } 311 312 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 313 db.mu.Lock() 314 defer db.mu.Unlock() 315 if db.tables == nil { 316 db.tables = make(map[string]*table) 317 } 318 if _, exist := db.tables[name]; exist { 319 return fmt.Errorf("fakedb: table %q already exists", name) 320 } 321 if len(columnNames) != len(columnTypes) { 322 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", 323 name, len(columnNames), len(columnTypes)) 324 } 325 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 326 return nil 327 } 328 329 // must be called with db.mu lock held 330 func (db *fakeDB) table(table string) (*table, bool) { 331 if db.tables == nil { 332 return nil, false 333 } 334 t, ok := db.tables[table] 335 return t, ok 336 } 337 338 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 339 db.mu.Lock() 340 defer db.mu.Unlock() 341 t, ok := db.table(table) 342 if !ok { 343 return 344 } 345 for n, cname := range t.colname { 346 if cname == column { 347 return t.coltype[n], true 348 } 349 } 350 return "", false 351 } 352 353 func (c *fakeConn) isBad() bool { 354 if c.stickyBad { 355 return true 356 } else if c.bad { 357 if c.db == nil { 358 return false 359 } 360 // alternate between bad conn and not bad conn 361 c.db.badConn = !c.db.badConn 362 return c.db.badConn 363 } else { 364 return false 365 } 366 } 367 368 func (c *fakeConn) isDirtyAndMark() bool { 369 if c.skipDirtySession { 370 return false 371 } 372 if c.currTx != nil { 373 c.dirtySession = true 374 return false 375 } 376 if c.dirtySession { 377 return true 378 } 379 c.dirtySession = true 380 return false 381 } 382 383 func (c *fakeConn) Begin() (driver.Tx, error) { 384 if c.isBad() { 385 return nil, fakeError{Wrapped: driver.ErrBadConn} 386 } 387 if c.currTx != nil { 388 return nil, errors.New("fakedb: already in a transaction") 389 } 390 c.touchMem() 391 c.currTx = &fakeTx{c: c} 392 return c.currTx, nil 393 } 394 395 var hookPostCloseConn struct { 396 sync.Mutex 397 fn func(*fakeConn, error) 398 } 399 400 func setHookpostCloseConn(fn func(*fakeConn, error)) { 401 hookPostCloseConn.Lock() 402 defer hookPostCloseConn.Unlock() 403 hookPostCloseConn.fn = fn 404 } 405 406 var testStrictClose *testing.T 407 408 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 409 // fails to close. If nil, the check is disabled. 410 func setStrictFakeConnClose(t *testing.T) { 411 testStrictClose = t 412 } 413 414 func (c *fakeConn) ResetSession(ctx context.Context) error { 415 c.dirtySession = false 416 c.currTx = nil 417 if c.isBad() { 418 return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} 419 } 420 return nil 421 } 422 423 var _ driver.Validator = (*fakeConn)(nil) 424 425 func (c *fakeConn) IsValid() bool { 426 return !c.isBad() 427 } 428 429 func (c *fakeConn) Close() (err error) { 430 drv := fdriver.(*fakeDriver) 431 defer func() { 432 if err != nil && testStrictClose != nil { 433 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 434 } 435 hookPostCloseConn.Lock() 436 fn := hookPostCloseConn.fn 437 hookPostCloseConn.Unlock() 438 if fn != nil { 439 fn(c, err) 440 } 441 if err == nil { 442 drv.mu.Lock() 443 drv.closeCount++ 444 drv.mu.Unlock() 445 } 446 }() 447 c.touchMem() 448 if c.currTx != nil { 449 return errors.New("fakedb: can't close fakeConn; in a Transaction") 450 } 451 if c.db == nil { 452 return errors.New("fakedb: can't close fakeConn; already closed") 453 } 454 if c.stmtsMade > c.stmtsClosed { 455 return errors.New("fakedb: can't close; dangling statement(s)") 456 } 457 c.db = nil 458 return nil 459 } 460 461 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 462 for _, arg := range args { 463 switch arg.Value.(type) { 464 case int64, float64, bool, nil, []byte, string, time.Time: 465 default: 466 if !allowAny { 467 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 468 } 469 } 470 } 471 return nil 472 } 473 474 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 475 // Ensure that ExecContext is called if available. 476 panic("ExecContext was not called.") 477 } 478 479 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 480 // This is an optional interface, but it's implemented here 481 // just to check that all the args are of the proper types. 482 // ErrSkip is returned so the caller acts as if we didn't 483 // implement this at all. 484 err := checkSubsetTypes(c.db.allowAny, args) 485 if err != nil { 486 return nil, err 487 } 488 return nil, driver.ErrSkip 489 } 490 491 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 492 // Ensure that ExecContext is called if available. 493 panic("QueryContext was not called.") 494 } 495 496 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 497 // This is an optional interface, but it's implemented here 498 // just to check that all the args are of the proper types. 499 // ErrSkip is returned so the caller acts as if we didn't 500 // implement this at all. 501 err := checkSubsetTypes(c.db.allowAny, args) 502 if err != nil { 503 return nil, err 504 } 505 return nil, driver.ErrSkip 506 } 507 508 func errf(msg string, args ...any) error { 509 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 510 } 511 512 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 513 // (note that where columns must always contain ? marks, 514 // just a limitation for fakedb) 515 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 516 if len(parts) != 3 { 517 stmt.Close() 518 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 519 } 520 stmt.table = parts[0] 521 522 stmt.colName = strings.Split(parts[1], ",") 523 for n, colspec := range strings.Split(parts[2], ",") { 524 if colspec == "" { 525 continue 526 } 527 nameVal := strings.Split(colspec, "=") 528 if len(nameVal) != 2 { 529 stmt.Close() 530 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 531 } 532 column, value := nameVal[0], nameVal[1] 533 _, ok := c.db.columnType(stmt.table, column) 534 if !ok { 535 stmt.Close() 536 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 537 } 538 if !strings.HasPrefix(value, "?") { 539 stmt.Close() 540 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 541 stmt.table, column) 542 } 543 stmt.placeholders++ 544 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 545 } 546 return stmt, nil 547 } 548 549 // parts are table|col=type,col2=type2 550 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 551 if len(parts) != 2 { 552 stmt.Close() 553 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 554 } 555 stmt.table = parts[0] 556 for n, colspec := range strings.Split(parts[1], ",") { 557 nameType := strings.Split(colspec, "=") 558 if len(nameType) != 2 { 559 stmt.Close() 560 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 561 } 562 stmt.colName = append(stmt.colName, nameType[0]) 563 stmt.colType = append(stmt.colType, nameType[1]) 564 } 565 return stmt, nil 566 } 567 568 // parts are table|col=?,col2=val 569 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { 570 if len(parts) != 2 { 571 stmt.Close() 572 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 573 } 574 stmt.table = parts[0] 575 for n, colspec := range strings.Split(parts[1], ",") { 576 nameVal := strings.Split(colspec, "=") 577 if len(nameVal) != 2 { 578 stmt.Close() 579 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 580 } 581 column, value := nameVal[0], nameVal[1] 582 ctype, ok := c.db.columnType(stmt.table, column) 583 if !ok { 584 stmt.Close() 585 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 586 } 587 stmt.colName = append(stmt.colName, column) 588 589 if !strings.HasPrefix(value, "?") { 590 var subsetVal any 591 // Convert to driver subset type 592 switch ctype { 593 case "string": 594 subsetVal = []byte(value) 595 case "blob": 596 subsetVal = []byte(value) 597 case "int32": 598 i, err := strconv.Atoi(value) 599 if err != nil { 600 stmt.Close() 601 return nil, errf("invalid conversion to int32 from %q", value) 602 } 603 subsetVal = int64(i) // int64 is a subset type, but not int32 604 case "table": // For testing cursor reads. 605 c.skipDirtySession = true 606 vparts := strings.Split(value, "!") 607 608 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) 609 if err != nil { 610 return nil, err 611 } 612 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) 613 substmt.Close() 614 if err != nil { 615 return nil, err 616 } 617 subsetVal = cursor 618 default: 619 stmt.Close() 620 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 621 } 622 stmt.colValue = append(stmt.colValue, subsetVal) 623 } else { 624 stmt.placeholders++ 625 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 626 stmt.colValue = append(stmt.colValue, value) 627 } 628 } 629 return stmt, nil 630 } 631 632 // hook to simulate broken connections 633 var hookPrepareBadConn func() bool 634 635 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 636 panic("use PrepareContext") 637 } 638 639 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 640 c.numPrepare++ 641 if c.db == nil { 642 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 643 } 644 645 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 646 return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn} 647 } 648 649 c.touchMem() 650 var firstStmt, prev *fakeStmt 651 for _, query := range strings.Split(query, ";") { 652 parts := strings.Split(query, "|") 653 if len(parts) < 1 { 654 return nil, errf("empty query") 655 } 656 stmt := &fakeStmt{q: query, c: c, memToucher: c} 657 if firstStmt == nil { 658 firstStmt = stmt 659 } 660 if len(parts) >= 3 { 661 switch parts[0] { 662 case "PANIC": 663 stmt.panic = parts[1] 664 parts = parts[2:] 665 case "WAIT": 666 wait, err := time.ParseDuration(parts[1]) 667 if err != nil { 668 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 669 } 670 parts = parts[2:] 671 stmt.wait = wait 672 } 673 } 674 cmd := parts[0] 675 stmt.cmd = cmd 676 parts = parts[1:] 677 678 if c.waiter != nil { 679 c.waiter(ctx) 680 if err := ctx.Err(); err != nil { 681 return nil, err 682 } 683 } 684 685 if stmt.wait > 0 { 686 wait := time.NewTimer(stmt.wait) 687 select { 688 case <-wait.C: 689 case <-ctx.Done(): 690 wait.Stop() 691 return nil, ctx.Err() 692 } 693 } 694 695 c.incrStat(&c.stmtsMade) 696 var err error 697 switch cmd { 698 case "WIPE": 699 // Nothing 700 case "SELECT": 701 stmt, err = c.prepareSelect(stmt, parts) 702 case "CREATE": 703 stmt, err = c.prepareCreate(stmt, parts) 704 case "INSERT": 705 stmt, err = c.prepareInsert(ctx, stmt, parts) 706 case "NOSERT": 707 // Do all the prep-work like for an INSERT but don't actually insert the row. 708 // Used for some of the concurrent tests. 709 stmt, err = c.prepareInsert(ctx, stmt, parts) 710 default: 711 stmt.Close() 712 return nil, errf("unsupported command type %q", cmd) 713 } 714 if err != nil { 715 return nil, err 716 } 717 if prev != nil { 718 prev.next = stmt 719 } 720 prev = stmt 721 } 722 return firstStmt, nil 723 } 724 725 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 726 if s.panic == "ColumnConverter" { 727 panic(s.panic) 728 } 729 if len(s.placeholderConverter) == 0 { 730 return driver.DefaultParameterConverter 731 } 732 return s.placeholderConverter[idx] 733 } 734 735 func (s *fakeStmt) Close() error { 736 if s.panic == "Close" { 737 panic(s.panic) 738 } 739 if s.c == nil { 740 panic("nil conn in fakeStmt.Close") 741 } 742 if s.c.db == nil { 743 panic("in fakeStmt.Close, conn's db is nil (already closed)") 744 } 745 s.touchMem() 746 if !s.closed { 747 s.c.incrStat(&s.c.stmtsClosed) 748 s.closed = true 749 } 750 if s.next != nil { 751 s.next.Close() 752 } 753 return nil 754 } 755 756 var errClosed = errors.New("fakedb: statement has been closed") 757 758 // hook to simulate broken connections 759 var hookExecBadConn func() bool 760 761 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 762 panic("Using ExecContext") 763 } 764 765 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") 766 767 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 768 if s.panic == "Exec" { 769 panic(s.panic) 770 } 771 if s.closed { 772 return nil, errClosed 773 } 774 775 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 776 return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} 777 } 778 if s.c.isDirtyAndMark() { 779 return nil, errFakeConnSessionDirty 780 } 781 782 err := checkSubsetTypes(s.c.db.allowAny, args) 783 if err != nil { 784 return nil, err 785 } 786 s.touchMem() 787 788 if s.wait > 0 { 789 time.Sleep(s.wait) 790 } 791 792 select { 793 default: 794 case <-ctx.Done(): 795 return nil, ctx.Err() 796 } 797 798 db := s.c.db 799 switch s.cmd { 800 case "WIPE": 801 db.wipe() 802 return driver.ResultNoRows, nil 803 case "CREATE": 804 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 805 return nil, err 806 } 807 return driver.ResultNoRows, nil 808 case "INSERT": 809 return s.execInsert(args, true) 810 case "NOSERT": 811 // Do all the prep-work like for an INSERT but don't actually insert the row. 812 // Used for some of the concurrent tests. 813 return s.execInsert(args, false) 814 } 815 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) 816 } 817 818 // When doInsert is true, add the row to the table. 819 // When doInsert is false do prep-work and error checking, but don't 820 // actually add the row to the table. 821 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 822 db := s.c.db 823 if len(args) != s.placeholders { 824 panic("error in pkg db; should only get here if size is correct") 825 } 826 db.mu.Lock() 827 t, ok := db.table(s.table) 828 db.mu.Unlock() 829 if !ok { 830 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 831 } 832 833 t.mu.Lock() 834 defer t.mu.Unlock() 835 836 var cols []any 837 if doInsert { 838 cols = make([]any, len(t.colname)) 839 } 840 argPos := 0 841 for n, colname := range s.colName { 842 colidx := t.columnIndex(colname) 843 if colidx == -1 { 844 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 845 } 846 var val any 847 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 848 if strvalue == "?" { 849 val = args[argPos].Value 850 } else { 851 // Assign value from argument placeholder name. 852 for _, a := range args { 853 if a.Name == strvalue[1:] { 854 val = a.Value 855 break 856 } 857 } 858 } 859 argPos++ 860 } else { 861 val = s.colValue[n] 862 } 863 if doInsert { 864 cols[colidx] = val 865 } 866 } 867 868 if doInsert { 869 t.rows = append(t.rows, &row{cols: cols}) 870 } 871 return driver.RowsAffected(1), nil 872 } 873 874 // hook to simulate broken connections 875 var hookQueryBadConn func() bool 876 877 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 878 panic("Use QueryContext") 879 } 880 881 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 882 if s.panic == "Query" { 883 panic(s.panic) 884 } 885 if s.closed { 886 return nil, errClosed 887 } 888 889 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 890 return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} 891 } 892 if s.c.isDirtyAndMark() { 893 return nil, errFakeConnSessionDirty 894 } 895 896 err := checkSubsetTypes(s.c.db.allowAny, args) 897 if err != nil { 898 return nil, err 899 } 900 901 s.touchMem() 902 db := s.c.db 903 if len(args) != s.placeholders { 904 panic("error in pkg db; should only get here if size is correct") 905 } 906 907 setMRows := make([][]*row, 0, 1) 908 setColumns := make([][]string, 0, 1) 909 setColType := make([][]string, 0, 1) 910 911 for { 912 db.mu.Lock() 913 t, ok := db.table(s.table) 914 db.mu.Unlock() 915 if !ok { 916 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 917 } 918 919 if s.table == "magicquery" { 920 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 921 if args[0].Value == "sleep" { 922 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 923 } 924 } 925 } 926 if s.table == "tx_status" && s.colName[0] == "tx_status" { 927 txStatus := "autocommit" 928 if s.c.currTx != nil { 929 txStatus = "transaction" 930 } 931 cursor := &rowsCursor{ 932 parentMem: s.c, 933 posRow: -1, 934 rows: [][]*row{ 935 { 936 { 937 cols: []any{ 938 txStatus, 939 }, 940 }, 941 }, 942 }, 943 cols: [][]string{ 944 { 945 "tx_status", 946 }, 947 }, 948 colType: [][]string{ 949 { 950 "string", 951 }, 952 }, 953 errPos: -1, 954 } 955 return cursor, nil 956 } 957 958 t.mu.Lock() 959 960 colIdx := make(map[string]int) // select column name -> column index in table 961 for _, name := range s.colName { 962 idx := t.columnIndex(name) 963 if idx == -1 { 964 t.mu.Unlock() 965 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 966 } 967 colIdx[name] = idx 968 } 969 970 mrows := []*row{} 971 rows: 972 for _, trow := range t.rows { 973 // Process the where clause, skipping non-match rows. This is lazy 974 // and just uses fmt.Sprintf("%v") to test equality. Good enough 975 // for test code. 976 for _, wcol := range s.whereCol { 977 idx := t.columnIndex(wcol.Column) 978 if idx == -1 { 979 t.mu.Unlock() 980 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) 981 } 982 tcol := trow.cols[idx] 983 if bs, ok := tcol.([]byte); ok { 984 // lazy hack to avoid sprintf %v on a []byte 985 tcol = string(bs) 986 } 987 var argValue any 988 if wcol.Placeholder == "?" { 989 argValue = args[wcol.Ordinal-1].Value 990 } else { 991 // Assign arg value from placeholder name. 992 for _, a := range args { 993 if a.Name == wcol.Placeholder[1:] { 994 argValue = a.Value 995 break 996 } 997 } 998 } 999 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 1000 continue rows 1001 } 1002 } 1003 mrow := &row{cols: make([]any, len(s.colName))} 1004 for seli, name := range s.colName { 1005 mrow.cols[seli] = trow.cols[colIdx[name]] 1006 } 1007 mrows = append(mrows, mrow) 1008 } 1009 1010 var colType []string 1011 for _, column := range s.colName { 1012 colType = append(colType, t.coltype[t.columnIndex(column)]) 1013 } 1014 1015 t.mu.Unlock() 1016 1017 setMRows = append(setMRows, mrows) 1018 setColumns = append(setColumns, s.colName) 1019 setColType = append(setColType, colType) 1020 1021 if s.next == nil { 1022 break 1023 } 1024 s = s.next 1025 } 1026 1027 cursor := &rowsCursor{ 1028 parentMem: s.c, 1029 posRow: -1, 1030 rows: setMRows, 1031 cols: setColumns, 1032 colType: setColType, 1033 errPos: -1, 1034 } 1035 return cursor, nil 1036 } 1037 1038 func (s *fakeStmt) NumInput() int { 1039 if s.panic == "NumInput" { 1040 panic(s.panic) 1041 } 1042 return s.placeholders 1043 } 1044 1045 // hook to simulate broken connections 1046 var hookCommitBadConn func() bool 1047 1048 func (tx *fakeTx) Commit() error { 1049 tx.c.currTx = nil 1050 if hookCommitBadConn != nil && hookCommitBadConn() { 1051 return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1052 } 1053 tx.c.touchMem() 1054 return nil 1055 } 1056 1057 // hook to simulate broken connections 1058 var hookRollbackBadConn func() bool 1059 1060 func (tx *fakeTx) Rollback() error { 1061 tx.c.currTx = nil 1062 if hookRollbackBadConn != nil && hookRollbackBadConn() { 1063 return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1064 } 1065 tx.c.touchMem() 1066 return nil 1067 } 1068 1069 type rowsCursor struct { 1070 parentMem memToucher 1071 cols [][]string 1072 colType [][]string 1073 posSet int 1074 posRow int 1075 rows [][]*row 1076 closed bool 1077 1078 // errPos and err are for making Next return early with error. 1079 errPos int 1080 err error 1081 1082 // a clone of slices to give out to clients, indexed by the 1083 // original slice's first byte address. we clone them 1084 // just so we're able to corrupt them on close. 1085 bytesClone map[*byte][]byte 1086 1087 // Every operation writes to line to enable the race detector 1088 // check for data races. 1089 // This is separate from the fakeConn.line to allow for drivers that 1090 // can start multiple queries on the same transaction at the same time. 1091 line int64 1092 1093 // closeErr is returned when rowsCursor.Close 1094 closeErr error 1095 } 1096 1097 func (rc *rowsCursor) touchMem() { 1098 rc.parentMem.touchMem() 1099 rc.line++ 1100 } 1101 1102 func (rc *rowsCursor) Close() error { 1103 rc.touchMem() 1104 rc.parentMem.touchMem() 1105 rc.closed = true 1106 return rc.closeErr 1107 } 1108 1109 func (rc *rowsCursor) Columns() []string { 1110 return rc.cols[rc.posSet] 1111 } 1112 1113 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1114 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1115 } 1116 1117 var rowsCursorNextHook func(dest []driver.Value) error 1118 1119 func (rc *rowsCursor) Next(dest []driver.Value) error { 1120 if rowsCursorNextHook != nil { 1121 return rowsCursorNextHook(dest) 1122 } 1123 1124 if rc.closed { 1125 return errors.New("fakedb: cursor is closed") 1126 } 1127 rc.touchMem() 1128 rc.posRow++ 1129 if rc.posRow == rc.errPos { 1130 return rc.err 1131 } 1132 if rc.posRow >= len(rc.rows[rc.posSet]) { 1133 return io.EOF // per interface spec 1134 } 1135 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1136 // TODO(bradfitz): convert to subset types? naah, I 1137 // think the subset types should only be input to 1138 // driver, but the sql package should be able to handle 1139 // a wider range of types coming out of drivers. all 1140 // for ease of drivers, and to prevent drivers from 1141 // messing up conversions or doing them differently. 1142 dest[i] = v 1143 1144 if bs, ok := v.([]byte); ok { 1145 if rc.bytesClone == nil { 1146 rc.bytesClone = make(map[*byte][]byte) 1147 } 1148 clone, ok := rc.bytesClone[&bs[0]] 1149 if !ok { 1150 clone = make([]byte, len(bs)) 1151 copy(clone, bs) 1152 rc.bytesClone[&bs[0]] = clone 1153 } 1154 dest[i] = clone 1155 } 1156 } 1157 return nil 1158 } 1159 1160 func (rc *rowsCursor) HasNextResultSet() bool { 1161 rc.touchMem() 1162 return rc.posSet < len(rc.rows)-1 1163 } 1164 1165 func (rc *rowsCursor) NextResultSet() error { 1166 rc.touchMem() 1167 if rc.HasNextResultSet() { 1168 rc.posSet++ 1169 rc.posRow = -1 1170 return nil 1171 } 1172 return io.EOF // Per interface spec. 1173 } 1174 1175 // fakeDriverString is like driver.String, but indirects pointers like 1176 // DefaultValueConverter. 1177 // 1178 // This could be surprising behavior to retroactively apply to 1179 // driver.String now that Go1 is out, but this is convenient for 1180 // our TestPointerParamsAndScans. 1181 type fakeDriverString struct{} 1182 1183 func (fakeDriverString) ConvertValue(v any) (driver.Value, error) { 1184 switch c := v.(type) { 1185 case string, []byte: 1186 return v, nil 1187 case *string: 1188 if c == nil { 1189 return nil, nil 1190 } 1191 return *c, nil 1192 } 1193 return fmt.Sprintf("%v", v), nil 1194 } 1195 1196 type anyTypeConverter struct{} 1197 1198 func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) { 1199 return v, nil 1200 } 1201 1202 func converterForType(typ string) driver.ValueConverter { 1203 switch typ { 1204 case "bool": 1205 return driver.Bool 1206 case "nullbool": 1207 return driver.Null{Converter: driver.Bool} 1208 case "byte", "int16": 1209 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1210 case "int32": 1211 return driver.Int32 1212 case "nullbyte", "nullint32", "nullint16": 1213 return driver.Null{Converter: driver.DefaultParameterConverter} 1214 case "string": 1215 return driver.NotNull{Converter: fakeDriverString{}} 1216 case "nullstring": 1217 return driver.Null{Converter: fakeDriverString{}} 1218 case "int64": 1219 // TODO(coopernurse): add type-specific converter 1220 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1221 case "nullint64": 1222 // TODO(coopernurse): add type-specific converter 1223 return driver.Null{Converter: driver.DefaultParameterConverter} 1224 case "float64": 1225 // TODO(coopernurse): add type-specific converter 1226 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1227 case "nullfloat64": 1228 // TODO(coopernurse): add type-specific converter 1229 return driver.Null{Converter: driver.DefaultParameterConverter} 1230 case "datetime": 1231 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1232 case "nulldatetime": 1233 return driver.Null{Converter: driver.DefaultParameterConverter} 1234 case "any": 1235 return anyTypeConverter{} 1236 } 1237 panic("invalid fakedb column type of " + typ) 1238 } 1239 1240 func colTypeToReflectType(typ string) reflect.Type { 1241 switch typ { 1242 case "bool": 1243 return reflect.TypeOf(false) 1244 case "nullbool": 1245 return reflect.TypeOf(NullBool{}) 1246 case "int16": 1247 return reflect.TypeOf(int16(0)) 1248 case "nullint16": 1249 return reflect.TypeOf(NullInt16{}) 1250 case "int32": 1251 return reflect.TypeOf(int32(0)) 1252 case "nullint32": 1253 return reflect.TypeOf(NullInt32{}) 1254 case "string": 1255 return reflect.TypeOf("") 1256 case "nullstring": 1257 return reflect.TypeOf(NullString{}) 1258 case "int64": 1259 return reflect.TypeOf(int64(0)) 1260 case "nullint64": 1261 return reflect.TypeOf(NullInt64{}) 1262 case "float64": 1263 return reflect.TypeOf(float64(0)) 1264 case "nullfloat64": 1265 return reflect.TypeOf(NullFloat64{}) 1266 case "datetime": 1267 return reflect.TypeOf(time.Time{}) 1268 case "any": 1269 return reflect.TypeOf(new(any)).Elem() 1270 } 1271 panic("invalid fakedb column type of " + typ) 1272 }