github.com/ice-blockchain/go/src@v0.0.0-20240403114104-1564d284e521/database/sql/fakedb_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sql 6 7 import ( 8 "context" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "io" 13 "reflect" 14 "sort" 15 "strconv" 16 "strings" 17 "sync" 18 "sync/atomic" 19 "testing" 20 "time" 21 ) 22 23 // fakeDriver is a fake database that implements Go's driver.Driver 24 // interface, just for testing. 25 // 26 // It speaks a query language that's semantically similar to but 27 // syntactically different and simpler than SQL. The syntax is as 28 // follows: 29 // 30 // WIPE 31 // CREATE|<tablename>|<col>=<type>,<col>=<type>,... 32 // where types are: "string", [u]int{8,16,32,64}, "bool" 33 // INSERT|<tablename>|col=val,col2=val2,col3=? 34 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? 35 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 36 // 37 // Any of these can be preceded by PANIC|<method>|, to cause the 38 // named method on fakeStmt to panic. 39 // 40 // Any of these can be proceeded by WAIT|<duration>|, to cause the 41 // named method on fakeStmt to sleep for the specified duration. 42 // 43 // Multiple of these can be combined when separated with a semicolon. 44 // 45 // When opening a fakeDriver's database, it starts empty with no 46 // tables. All tables and data are stored in memory only. 47 type fakeDriver struct { 48 mu sync.Mutex // guards 3 following fields 49 openCount int // conn opens 50 closeCount int // conn closes 51 waitCh chan struct{} 52 waitingCh chan struct{} 53 dbs map[string]*fakeDB 54 } 55 56 type fakeConnector struct { 57 name string 58 59 waiter func(context.Context) 60 closed bool 61 } 62 63 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 64 conn, err := fdriver.Open(c.name) 65 conn.(*fakeConn).waiter = c.waiter 66 return conn, err 67 } 68 69 func (c *fakeConnector) Driver() driver.Driver { 70 return fdriver 71 } 72 73 func (c *fakeConnector) Close() error { 74 if c.closed { 75 return errors.New("fakedb: connector is closed") 76 } 77 c.closed = true 78 return nil 79 } 80 81 type fakeDriverCtx struct { 82 fakeDriver 83 } 84 85 var _ driver.DriverContext = &fakeDriverCtx{} 86 87 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 88 return &fakeConnector{name: name}, nil 89 } 90 91 type fakeDB struct { 92 name string 93 94 useRawBytes atomic.Bool 95 96 mu sync.Mutex 97 tables map[string]*table 98 badConn bool 99 allowAny bool 100 } 101 102 type fakeError struct { 103 Message string 104 Wrapped error 105 } 106 107 func (err fakeError) Error() string { 108 return err.Message 109 } 110 111 func (err fakeError) Unwrap() error { 112 return err.Wrapped 113 } 114 115 type table struct { 116 mu sync.Mutex 117 colname []string 118 coltype []string 119 rows []*row 120 } 121 122 func (t *table) columnIndex(name string) int { 123 for n, nname := range t.colname { 124 if name == nname { 125 return n 126 } 127 } 128 return -1 129 } 130 131 type row struct { 132 cols []any // must be same size as its table colname + coltype 133 } 134 135 type memToucher interface { 136 // touchMem reads & writes some memory, to help find data races. 137 touchMem() 138 } 139 140 type fakeConn struct { 141 db *fakeDB // where to return ourselves to 142 143 currTx *fakeTx 144 145 // Every operation writes to line to enable the race detector 146 // check for data races. 147 line int64 148 149 // Stats for tests: 150 mu sync.Mutex 151 stmtsMade int 152 stmtsClosed int 153 numPrepare int 154 155 // bad connection tests; see isBad() 156 bad bool 157 stickyBad bool 158 159 skipDirtySession bool // tests that use Conn should set this to true. 160 161 // dirtySession tests ResetSession, true if a query has executed 162 // until ResetSession is called. 163 dirtySession bool 164 165 // The waiter is called before each query. May be used in place of the "WAIT" 166 // directive. 167 waiter func(context.Context) 168 } 169 170 func (c *fakeConn) touchMem() { 171 c.line++ 172 } 173 174 func (c *fakeConn) incrStat(v *int) { 175 c.mu.Lock() 176 *v++ 177 c.mu.Unlock() 178 } 179 180 type fakeTx struct { 181 c *fakeConn 182 } 183 184 type boundCol struct { 185 Column string 186 Placeholder string 187 Ordinal int 188 } 189 190 type fakeStmt struct { 191 memToucher 192 c *fakeConn 193 q string // just for debugging 194 195 cmd string 196 table string 197 panic string 198 wait time.Duration 199 200 next *fakeStmt // used for returning multiple results. 201 202 closed bool 203 204 colName []string // used by CREATE, INSERT, SELECT (selected columns) 205 colType []string // used by CREATE 206 colValue []any // used by INSERT (mix of strings and "?" for bound params) 207 placeholders int // used by INSERT/SELECT: number of ? params 208 209 whereCol []boundCol // used by SELECT (all placeholders) 210 211 placeholderConverter []driver.ValueConverter // used by INSERT 212 } 213 214 var fdriver driver.Driver = &fakeDriver{} 215 216 func init() { 217 Register("test", fdriver) 218 } 219 220 func contains(list []string, y string) bool { 221 for _, x := range list { 222 if x == y { 223 return true 224 } 225 } 226 return false 227 } 228 229 type Dummy struct { 230 driver.Driver 231 } 232 233 func TestDrivers(t *testing.T) { 234 unregisterAllDrivers() 235 Register("test", fdriver) 236 Register("invalid", Dummy{}) 237 all := Drivers() 238 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { 239 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) 240 } 241 } 242 243 // hook to simulate connection failures 244 var hookOpenErr struct { 245 sync.Mutex 246 fn func() error 247 } 248 249 func setHookOpenErr(fn func() error) { 250 hookOpenErr.Lock() 251 defer hookOpenErr.Unlock() 252 hookOpenErr.fn = fn 253 } 254 255 // Supports dsn forms: 256 // 257 // <dbname> 258 // <dbname>;<opts> (only currently supported option is `badConn`, 259 // which causes driver.ErrBadConn to be returned on 260 // every other conn.Begin()) 261 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 262 hookOpenErr.Lock() 263 fn := hookOpenErr.fn 264 hookOpenErr.Unlock() 265 if fn != nil { 266 if err := fn(); err != nil { 267 return nil, err 268 } 269 } 270 parts := strings.Split(dsn, ";") 271 if len(parts) < 1 { 272 return nil, errors.New("fakedb: no database name") 273 } 274 name := parts[0] 275 276 db := d.getDB(name) 277 278 d.mu.Lock() 279 d.openCount++ 280 d.mu.Unlock() 281 conn := &fakeConn{db: db} 282 283 if len(parts) >= 2 && parts[1] == "badConn" { 284 conn.bad = true 285 } 286 if d.waitCh != nil { 287 d.waitingCh <- struct{}{} 288 <-d.waitCh 289 d.waitCh = nil 290 d.waitingCh = nil 291 } 292 return conn, nil 293 } 294 295 func (d *fakeDriver) getDB(name string) *fakeDB { 296 d.mu.Lock() 297 defer d.mu.Unlock() 298 if d.dbs == nil { 299 d.dbs = make(map[string]*fakeDB) 300 } 301 db, ok := d.dbs[name] 302 if !ok { 303 db = &fakeDB{name: name} 304 d.dbs[name] = db 305 } 306 return db 307 } 308 309 func (db *fakeDB) wipe() { 310 db.mu.Lock() 311 defer db.mu.Unlock() 312 db.tables = nil 313 } 314 315 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 316 db.mu.Lock() 317 defer db.mu.Unlock() 318 if db.tables == nil { 319 db.tables = make(map[string]*table) 320 } 321 if _, exist := db.tables[name]; exist { 322 return fmt.Errorf("fakedb: table %q already exists", name) 323 } 324 if len(columnNames) != len(columnTypes) { 325 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", 326 name, len(columnNames), len(columnTypes)) 327 } 328 db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 329 return nil 330 } 331 332 // must be called with db.mu lock held 333 func (db *fakeDB) table(table string) (*table, bool) { 334 if db.tables == nil { 335 return nil, false 336 } 337 t, ok := db.tables[table] 338 return t, ok 339 } 340 341 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 342 db.mu.Lock() 343 defer db.mu.Unlock() 344 t, ok := db.table(table) 345 if !ok { 346 return 347 } 348 for n, cname := range t.colname { 349 if cname == column { 350 return t.coltype[n], true 351 } 352 } 353 return "", false 354 } 355 356 func (c *fakeConn) isBad() bool { 357 if c.stickyBad { 358 return true 359 } else if c.bad { 360 if c.db == nil { 361 return false 362 } 363 // alternate between bad conn and not bad conn 364 c.db.badConn = !c.db.badConn 365 return c.db.badConn 366 } else { 367 return false 368 } 369 } 370 371 func (c *fakeConn) isDirtyAndMark() bool { 372 if c.skipDirtySession { 373 return false 374 } 375 if c.currTx != nil { 376 c.dirtySession = true 377 return false 378 } 379 if c.dirtySession { 380 return true 381 } 382 c.dirtySession = true 383 return false 384 } 385 386 func (c *fakeConn) Begin() (driver.Tx, error) { 387 if c.isBad() { 388 return nil, fakeError{Wrapped: driver.ErrBadConn} 389 } 390 if c.currTx != nil { 391 return nil, errors.New("fakedb: already in a transaction") 392 } 393 c.touchMem() 394 c.currTx = &fakeTx{c: c} 395 return c.currTx, nil 396 } 397 398 var hookPostCloseConn struct { 399 sync.Mutex 400 fn func(*fakeConn, error) 401 } 402 403 func setHookpostCloseConn(fn func(*fakeConn, error)) { 404 hookPostCloseConn.Lock() 405 defer hookPostCloseConn.Unlock() 406 hookPostCloseConn.fn = fn 407 } 408 409 var testStrictClose *testing.T 410 411 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 412 // fails to close. If nil, the check is disabled. 413 func setStrictFakeConnClose(t *testing.T) { 414 testStrictClose = t 415 } 416 417 func (c *fakeConn) ResetSession(ctx context.Context) error { 418 c.dirtySession = false 419 c.currTx = nil 420 if c.isBad() { 421 return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} 422 } 423 return nil 424 } 425 426 var _ driver.Validator = (*fakeConn)(nil) 427 428 func (c *fakeConn) IsValid() bool { 429 return !c.isBad() 430 } 431 432 func (c *fakeConn) Close() (err error) { 433 drv := fdriver.(*fakeDriver) 434 defer func() { 435 if err != nil && testStrictClose != nil { 436 testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 437 } 438 hookPostCloseConn.Lock() 439 fn := hookPostCloseConn.fn 440 hookPostCloseConn.Unlock() 441 if fn != nil { 442 fn(c, err) 443 } 444 if err == nil { 445 drv.mu.Lock() 446 drv.closeCount++ 447 drv.mu.Unlock() 448 } 449 }() 450 c.touchMem() 451 if c.currTx != nil { 452 return errors.New("fakedb: can't close fakeConn; in a Transaction") 453 } 454 if c.db == nil { 455 return errors.New("fakedb: can't close fakeConn; already closed") 456 } 457 if c.stmtsMade > c.stmtsClosed { 458 return errors.New("fakedb: can't close; dangling statement(s)") 459 } 460 c.db = nil 461 return nil 462 } 463 464 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 465 for _, arg := range args { 466 switch arg.Value.(type) { 467 case int64, float64, bool, nil, []byte, string, time.Time: 468 default: 469 if !allowAny { 470 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 471 } 472 } 473 } 474 return nil 475 } 476 477 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 478 // Ensure that ExecContext is called if available. 479 panic("ExecContext was not called.") 480 } 481 482 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 483 // This is an optional interface, but it's implemented here 484 // just to check that all the args are of the proper types. 485 // ErrSkip is returned so the caller acts as if we didn't 486 // implement this at all. 487 err := checkSubsetTypes(c.db.allowAny, args) 488 if err != nil { 489 return nil, err 490 } 491 return nil, driver.ErrSkip 492 } 493 494 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 495 // Ensure that ExecContext is called if available. 496 panic("QueryContext was not called.") 497 } 498 499 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 500 // This is an optional interface, but it's implemented here 501 // just to check that all the args are of the proper types. 502 // ErrSkip is returned so the caller acts as if we didn't 503 // implement this at all. 504 err := checkSubsetTypes(c.db.allowAny, args) 505 if err != nil { 506 return nil, err 507 } 508 return nil, driver.ErrSkip 509 } 510 511 func errf(msg string, args ...any) error { 512 return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 513 } 514 515 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 516 // (note that where columns must always contain ? marks, 517 // just a limitation for fakedb) 518 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 519 if len(parts) != 3 { 520 stmt.Close() 521 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 522 } 523 stmt.table = parts[0] 524 525 stmt.colName = strings.Split(parts[1], ",") 526 for n, colspec := range strings.Split(parts[2], ",") { 527 if colspec == "" { 528 continue 529 } 530 nameVal := strings.Split(colspec, "=") 531 if len(nameVal) != 2 { 532 stmt.Close() 533 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 534 } 535 column, value := nameVal[0], nameVal[1] 536 _, ok := c.db.columnType(stmt.table, column) 537 if !ok { 538 stmt.Close() 539 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 540 } 541 if !strings.HasPrefix(value, "?") { 542 stmt.Close() 543 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 544 stmt.table, column) 545 } 546 stmt.placeholders++ 547 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 548 } 549 return stmt, nil 550 } 551 552 // parts are table|col=type,col2=type2 553 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 554 if len(parts) != 2 { 555 stmt.Close() 556 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 557 } 558 stmt.table = parts[0] 559 for n, colspec := range strings.Split(parts[1], ",") { 560 nameType := strings.Split(colspec, "=") 561 if len(nameType) != 2 { 562 stmt.Close() 563 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 564 } 565 stmt.colName = append(stmt.colName, nameType[0]) 566 stmt.colType = append(stmt.colType, nameType[1]) 567 } 568 return stmt, nil 569 } 570 571 // parts are table|col=?,col2=val 572 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { 573 if len(parts) != 2 { 574 stmt.Close() 575 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 576 } 577 stmt.table = parts[0] 578 for n, colspec := range strings.Split(parts[1], ",") { 579 nameVal := strings.Split(colspec, "=") 580 if len(nameVal) != 2 { 581 stmt.Close() 582 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 583 } 584 column, value := nameVal[0], nameVal[1] 585 ctype, ok := c.db.columnType(stmt.table, column) 586 if !ok { 587 stmt.Close() 588 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 589 } 590 stmt.colName = append(stmt.colName, column) 591 592 if !strings.HasPrefix(value, "?") { 593 var subsetVal any 594 // Convert to driver subset type 595 switch ctype { 596 case "string": 597 subsetVal = []byte(value) 598 case "blob": 599 subsetVal = []byte(value) 600 case "int32": 601 i, err := strconv.Atoi(value) 602 if err != nil { 603 stmt.Close() 604 return nil, errf("invalid conversion to int32 from %q", value) 605 } 606 subsetVal = int64(i) // int64 is a subset type, but not int32 607 case "table": // For testing cursor reads. 608 c.skipDirtySession = true 609 vparts := strings.Split(value, "!") 610 611 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) 612 if err != nil { 613 return nil, err 614 } 615 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) 616 substmt.Close() 617 if err != nil { 618 return nil, err 619 } 620 subsetVal = cursor 621 default: 622 stmt.Close() 623 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 624 } 625 stmt.colValue = append(stmt.colValue, subsetVal) 626 } else { 627 stmt.placeholders++ 628 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 629 stmt.colValue = append(stmt.colValue, value) 630 } 631 } 632 return stmt, nil 633 } 634 635 // hook to simulate broken connections 636 var hookPrepareBadConn func() bool 637 638 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 639 panic("use PrepareContext") 640 } 641 642 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 643 c.numPrepare++ 644 if c.db == nil { 645 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 646 } 647 648 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 649 return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn} 650 } 651 652 c.touchMem() 653 var firstStmt, prev *fakeStmt 654 for _, query := range strings.Split(query, ";") { 655 parts := strings.Split(query, "|") 656 if len(parts) < 1 { 657 return nil, errf("empty query") 658 } 659 stmt := &fakeStmt{q: query, c: c, memToucher: c} 660 if firstStmt == nil { 661 firstStmt = stmt 662 } 663 if len(parts) >= 3 { 664 switch parts[0] { 665 case "PANIC": 666 stmt.panic = parts[1] 667 parts = parts[2:] 668 case "WAIT": 669 wait, err := time.ParseDuration(parts[1]) 670 if err != nil { 671 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 672 } 673 parts = parts[2:] 674 stmt.wait = wait 675 } 676 } 677 cmd := parts[0] 678 stmt.cmd = cmd 679 parts = parts[1:] 680 681 if c.waiter != nil { 682 c.waiter(ctx) 683 if err := ctx.Err(); err != nil { 684 return nil, err 685 } 686 } 687 688 if stmt.wait > 0 { 689 wait := time.NewTimer(stmt.wait) 690 select { 691 case <-wait.C: 692 case <-ctx.Done(): 693 wait.Stop() 694 return nil, ctx.Err() 695 } 696 } 697 698 c.incrStat(&c.stmtsMade) 699 var err error 700 switch cmd { 701 case "WIPE": 702 // Nothing 703 case "USE_RAWBYTES": 704 c.db.useRawBytes.Store(true) 705 case "SELECT": 706 stmt, err = c.prepareSelect(stmt, parts) 707 case "CREATE": 708 stmt, err = c.prepareCreate(stmt, parts) 709 case "INSERT": 710 stmt, err = c.prepareInsert(ctx, stmt, parts) 711 case "NOSERT": 712 // Do all the prep-work like for an INSERT but don't actually insert the row. 713 // Used for some of the concurrent tests. 714 stmt, err = c.prepareInsert(ctx, stmt, parts) 715 default: 716 stmt.Close() 717 return nil, errf("unsupported command type %q", cmd) 718 } 719 if err != nil { 720 return nil, err 721 } 722 if prev != nil { 723 prev.next = stmt 724 } 725 prev = stmt 726 } 727 return firstStmt, nil 728 } 729 730 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 731 if s.panic == "ColumnConverter" { 732 panic(s.panic) 733 } 734 if len(s.placeholderConverter) == 0 { 735 return driver.DefaultParameterConverter 736 } 737 return s.placeholderConverter[idx] 738 } 739 740 func (s *fakeStmt) Close() error { 741 if s.panic == "Close" { 742 panic(s.panic) 743 } 744 if s.c == nil { 745 panic("nil conn in fakeStmt.Close") 746 } 747 if s.c.db == nil { 748 panic("in fakeStmt.Close, conn's db is nil (already closed)") 749 } 750 s.touchMem() 751 if !s.closed { 752 s.c.incrStat(&s.c.stmtsClosed) 753 s.closed = true 754 } 755 if s.next != nil { 756 s.next.Close() 757 } 758 return nil 759 } 760 761 var errClosed = errors.New("fakedb: statement has been closed") 762 763 // hook to simulate broken connections 764 var hookExecBadConn func() bool 765 766 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 767 panic("Using ExecContext") 768 } 769 770 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") 771 772 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 773 if s.panic == "Exec" { 774 panic(s.panic) 775 } 776 if s.closed { 777 return nil, errClosed 778 } 779 780 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 781 return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} 782 } 783 if s.c.isDirtyAndMark() { 784 return nil, errFakeConnSessionDirty 785 } 786 787 err := checkSubsetTypes(s.c.db.allowAny, args) 788 if err != nil { 789 return nil, err 790 } 791 s.touchMem() 792 793 if s.wait > 0 { 794 time.Sleep(s.wait) 795 } 796 797 select { 798 default: 799 case <-ctx.Done(): 800 return nil, ctx.Err() 801 } 802 803 db := s.c.db 804 switch s.cmd { 805 case "WIPE": 806 db.wipe() 807 return driver.ResultNoRows, nil 808 case "USE_RAWBYTES": 809 s.c.db.useRawBytes.Store(true) 810 return driver.ResultNoRows, nil 811 case "CREATE": 812 if err := db.createTable(s.table, s.colName, s.colType); err != nil { 813 return nil, err 814 } 815 return driver.ResultNoRows, nil 816 case "INSERT": 817 return s.execInsert(args, true) 818 case "NOSERT": 819 // Do all the prep-work like for an INSERT but don't actually insert the row. 820 // Used for some of the concurrent tests. 821 return s.execInsert(args, false) 822 } 823 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) 824 } 825 826 // When doInsert is true, add the row to the table. 827 // When doInsert is false do prep-work and error checking, but don't 828 // actually add the row to the table. 829 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 830 db := s.c.db 831 if len(args) != s.placeholders { 832 panic("error in pkg db; should only get here if size is correct") 833 } 834 db.mu.Lock() 835 t, ok := db.table(s.table) 836 db.mu.Unlock() 837 if !ok { 838 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 839 } 840 841 t.mu.Lock() 842 defer t.mu.Unlock() 843 844 var cols []any 845 if doInsert { 846 cols = make([]any, len(t.colname)) 847 } 848 argPos := 0 849 for n, colname := range s.colName { 850 colidx := t.columnIndex(colname) 851 if colidx == -1 { 852 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 853 } 854 var val any 855 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 856 if strvalue == "?" { 857 val = args[argPos].Value 858 } else { 859 // Assign value from argument placeholder name. 860 for _, a := range args { 861 if a.Name == strvalue[1:] { 862 val = a.Value 863 break 864 } 865 } 866 } 867 argPos++ 868 } else { 869 val = s.colValue[n] 870 } 871 if doInsert { 872 cols[colidx] = val 873 } 874 } 875 876 if doInsert { 877 t.rows = append(t.rows, &row{cols: cols}) 878 } 879 return driver.RowsAffected(1), nil 880 } 881 882 // hook to simulate broken connections 883 var hookQueryBadConn func() bool 884 885 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 886 panic("Use QueryContext") 887 } 888 889 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 890 if s.panic == "Query" { 891 panic(s.panic) 892 } 893 if s.closed { 894 return nil, errClosed 895 } 896 897 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 898 return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} 899 } 900 if s.c.isDirtyAndMark() { 901 return nil, errFakeConnSessionDirty 902 } 903 904 err := checkSubsetTypes(s.c.db.allowAny, args) 905 if err != nil { 906 return nil, err 907 } 908 909 s.touchMem() 910 db := s.c.db 911 if len(args) != s.placeholders { 912 panic("error in pkg db; should only get here if size is correct") 913 } 914 915 setMRows := make([][]*row, 0, 1) 916 setColumns := make([][]string, 0, 1) 917 setColType := make([][]string, 0, 1) 918 919 for { 920 db.mu.Lock() 921 t, ok := db.table(s.table) 922 db.mu.Unlock() 923 if !ok { 924 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 925 } 926 927 if s.table == "magicquery" { 928 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 929 if args[0].Value == "sleep" { 930 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 931 } 932 } 933 } 934 if s.table == "tx_status" && s.colName[0] == "tx_status" { 935 txStatus := "autocommit" 936 if s.c.currTx != nil { 937 txStatus = "transaction" 938 } 939 cursor := &rowsCursor{ 940 db: s.c.db, 941 parentMem: s.c, 942 posRow: -1, 943 rows: [][]*row{ 944 { 945 { 946 cols: []any{ 947 txStatus, 948 }, 949 }, 950 }, 951 }, 952 cols: [][]string{ 953 { 954 "tx_status", 955 }, 956 }, 957 colType: [][]string{ 958 { 959 "string", 960 }, 961 }, 962 errPos: -1, 963 } 964 return cursor, nil 965 } 966 967 t.mu.Lock() 968 969 colIdx := make(map[string]int) // select column name -> column index in table 970 for _, name := range s.colName { 971 idx := t.columnIndex(name) 972 if idx == -1 { 973 t.mu.Unlock() 974 return nil, fmt.Errorf("fakedb: unknown column name %q", name) 975 } 976 colIdx[name] = idx 977 } 978 979 mrows := []*row{} 980 rows: 981 for _, trow := range t.rows { 982 // Process the where clause, skipping non-match rows. This is lazy 983 // and just uses fmt.Sprintf("%v") to test equality. Good enough 984 // for test code. 985 for _, wcol := range s.whereCol { 986 idx := t.columnIndex(wcol.Column) 987 if idx == -1 { 988 t.mu.Unlock() 989 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) 990 } 991 tcol := trow.cols[idx] 992 if bs, ok := tcol.([]byte); ok { 993 // lazy hack to avoid sprintf %v on a []byte 994 tcol = string(bs) 995 } 996 var argValue any 997 if wcol.Placeholder == "?" { 998 argValue = args[wcol.Ordinal-1].Value 999 } else { 1000 // Assign arg value from placeholder name. 1001 for _, a := range args { 1002 if a.Name == wcol.Placeholder[1:] { 1003 argValue = a.Value 1004 break 1005 } 1006 } 1007 } 1008 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 1009 continue rows 1010 } 1011 } 1012 mrow := &row{cols: make([]any, len(s.colName))} 1013 for seli, name := range s.colName { 1014 mrow.cols[seli] = trow.cols[colIdx[name]] 1015 } 1016 mrows = append(mrows, mrow) 1017 } 1018 1019 var colType []string 1020 for _, column := range s.colName { 1021 colType = append(colType, t.coltype[t.columnIndex(column)]) 1022 } 1023 1024 t.mu.Unlock() 1025 1026 setMRows = append(setMRows, mrows) 1027 setColumns = append(setColumns, s.colName) 1028 setColType = append(setColType, colType) 1029 1030 if s.next == nil { 1031 break 1032 } 1033 s = s.next 1034 } 1035 1036 cursor := &rowsCursor{ 1037 db: s.c.db, 1038 parentMem: s.c, 1039 posRow: -1, 1040 rows: setMRows, 1041 cols: setColumns, 1042 colType: setColType, 1043 errPos: -1, 1044 } 1045 return cursor, nil 1046 } 1047 1048 func (s *fakeStmt) NumInput() int { 1049 if s.panic == "NumInput" { 1050 panic(s.panic) 1051 } 1052 return s.placeholders 1053 } 1054 1055 // hook to simulate broken connections 1056 var hookCommitBadConn func() bool 1057 1058 func (tx *fakeTx) Commit() error { 1059 tx.c.currTx = nil 1060 if hookCommitBadConn != nil && hookCommitBadConn() { 1061 return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1062 } 1063 tx.c.touchMem() 1064 return nil 1065 } 1066 1067 // hook to simulate broken connections 1068 var hookRollbackBadConn func() bool 1069 1070 func (tx *fakeTx) Rollback() error { 1071 tx.c.currTx = nil 1072 if hookRollbackBadConn != nil && hookRollbackBadConn() { 1073 return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1074 } 1075 tx.c.touchMem() 1076 return nil 1077 } 1078 1079 type rowsCursor struct { 1080 db *fakeDB 1081 parentMem memToucher 1082 cols [][]string 1083 colType [][]string 1084 posSet int 1085 posRow int 1086 rows [][]*row 1087 closed bool 1088 1089 // errPos and err are for making Next return early with error. 1090 errPos int 1091 err error 1092 1093 // a clone of slices to give out to clients, indexed by the 1094 // original slice's first byte address. we clone them 1095 // just so we're able to corrupt them on close. 1096 bytesClone map[*byte][]byte 1097 1098 // Every operation writes to line to enable the race detector 1099 // check for data races. 1100 // This is separate from the fakeConn.line to allow for drivers that 1101 // can start multiple queries on the same transaction at the same time. 1102 line int64 1103 1104 // closeErr is returned when rowsCursor.Close 1105 closeErr error 1106 } 1107 1108 func (rc *rowsCursor) touchMem() { 1109 rc.parentMem.touchMem() 1110 rc.line++ 1111 } 1112 1113 func (rc *rowsCursor) Close() error { 1114 rc.touchMem() 1115 rc.parentMem.touchMem() 1116 rc.closed = true 1117 return rc.closeErr 1118 } 1119 1120 func (rc *rowsCursor) Columns() []string { 1121 return rc.cols[rc.posSet] 1122 } 1123 1124 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1125 return colTypeToReflectType(rc.colType[rc.posSet][index]) 1126 } 1127 1128 var rowsCursorNextHook func(dest []driver.Value) error 1129 1130 func (rc *rowsCursor) Next(dest []driver.Value) error { 1131 if rowsCursorNextHook != nil { 1132 return rowsCursorNextHook(dest) 1133 } 1134 1135 if rc.closed { 1136 return errors.New("fakedb: cursor is closed") 1137 } 1138 rc.touchMem() 1139 rc.posRow++ 1140 if rc.posRow == rc.errPos { 1141 return rc.err 1142 } 1143 if rc.posRow >= len(rc.rows[rc.posSet]) { 1144 return io.EOF // per interface spec 1145 } 1146 for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1147 // TODO(bradfitz): convert to subset types? naah, I 1148 // think the subset types should only be input to 1149 // driver, but the sql package should be able to handle 1150 // a wider range of types coming out of drivers. all 1151 // for ease of drivers, and to prevent drivers from 1152 // messing up conversions or doing them differently. 1153 dest[i] = v 1154 1155 if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() { 1156 if rc.bytesClone == nil { 1157 rc.bytesClone = make(map[*byte][]byte) 1158 } 1159 clone, ok := rc.bytesClone[&bs[0]] 1160 if !ok { 1161 clone = make([]byte, len(bs)) 1162 copy(clone, bs) 1163 rc.bytesClone[&bs[0]] = clone 1164 } 1165 dest[i] = clone 1166 } 1167 } 1168 return nil 1169 } 1170 1171 func (rc *rowsCursor) HasNextResultSet() bool { 1172 rc.touchMem() 1173 return rc.posSet < len(rc.rows)-1 1174 } 1175 1176 func (rc *rowsCursor) NextResultSet() error { 1177 rc.touchMem() 1178 if rc.HasNextResultSet() { 1179 rc.posSet++ 1180 rc.posRow = -1 1181 return nil 1182 } 1183 return io.EOF // Per interface spec. 1184 } 1185 1186 // fakeDriverString is like driver.String, but indirects pointers like 1187 // DefaultValueConverter. 1188 // 1189 // This could be surprising behavior to retroactively apply to 1190 // driver.String now that Go1 is out, but this is convenient for 1191 // our TestPointerParamsAndScans. 1192 type fakeDriverString struct{} 1193 1194 func (fakeDriverString) ConvertValue(v any) (driver.Value, error) { 1195 switch c := v.(type) { 1196 case string, []byte: 1197 return v, nil 1198 case *string: 1199 if c == nil { 1200 return nil, nil 1201 } 1202 return *c, nil 1203 } 1204 return fmt.Sprintf("%v", v), nil 1205 } 1206 1207 type anyTypeConverter struct{} 1208 1209 func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) { 1210 return v, nil 1211 } 1212 1213 func converterForType(typ string) driver.ValueConverter { 1214 switch typ { 1215 case "bool": 1216 return driver.Bool 1217 case "nullbool": 1218 return driver.Null{Converter: driver.Bool} 1219 case "byte", "int16": 1220 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1221 case "int32": 1222 return driver.Int32 1223 case "nullbyte", "nullint32", "nullint16": 1224 return driver.Null{Converter: driver.DefaultParameterConverter} 1225 case "string": 1226 return driver.NotNull{Converter: fakeDriverString{}} 1227 case "nullstring": 1228 return driver.Null{Converter: fakeDriverString{}} 1229 case "int64": 1230 // TODO(coopernurse): add type-specific converter 1231 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1232 case "nullint64": 1233 // TODO(coopernurse): add type-specific converter 1234 return driver.Null{Converter: driver.DefaultParameterConverter} 1235 case "float64": 1236 // TODO(coopernurse): add type-specific converter 1237 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1238 case "nullfloat64": 1239 // TODO(coopernurse): add type-specific converter 1240 return driver.Null{Converter: driver.DefaultParameterConverter} 1241 case "datetime": 1242 return driver.NotNull{Converter: driver.DefaultParameterConverter} 1243 case "nulldatetime": 1244 return driver.Null{Converter: driver.DefaultParameterConverter} 1245 case "any": 1246 return anyTypeConverter{} 1247 } 1248 panic("invalid fakedb column type of " + typ) 1249 } 1250 1251 func colTypeToReflectType(typ string) reflect.Type { 1252 switch typ { 1253 case "bool": 1254 return reflect.TypeFor[bool]() 1255 case "nullbool": 1256 return reflect.TypeFor[NullBool]() 1257 case "int16": 1258 return reflect.TypeFor[int16]() 1259 case "nullint16": 1260 return reflect.TypeFor[NullInt16]() 1261 case "int32": 1262 return reflect.TypeFor[int32]() 1263 case "nullint32": 1264 return reflect.TypeFor[NullInt32]() 1265 case "string": 1266 return reflect.TypeFor[string]() 1267 case "nullstring": 1268 return reflect.TypeFor[NullString]() 1269 case "int64": 1270 return reflect.TypeFor[int64]() 1271 case "nullint64": 1272 return reflect.TypeFor[NullInt64]() 1273 case "float64": 1274 return reflect.TypeFor[float64]() 1275 case "nullfloat64": 1276 return reflect.TypeFor[NullFloat64]() 1277 case "datetime": 1278 return reflect.TypeFor[time.Time]() 1279 case "any": 1280 return reflect.TypeFor[any]() 1281 } 1282 panic("invalid fakedb column type of " + typ) 1283 }