github.com/geraldss/go/src@v0.0.0-20210511222824-ac7d0ebfc235/database/sql/fakedb_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"reflect"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  // fakeDriver is a fake database that implements Go's driver.Driver
    23  // interface, just for testing.
    24  //
    25  // It speaks a query language that's semantically similar to but
    26  // syntactically different and simpler than SQL.  The syntax is as
    27  // follows:
    28  //
    29  //   WIPE
    30  //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    31  //     where types are: "string", [u]int{8,16,32,64}, "bool"
    32  //   INSERT|<tablename>|col=val,col2=val2,col3=?
    33  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    34  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    35  //
    36  // Any of these can be preceded by PANIC|<method>|, to cause the
    37  // named method on fakeStmt to panic.
    38  //
    39  // Any of these can be proceeded by WAIT|<duration>|, to cause the
    40  // named method on fakeStmt to sleep for the specified duration.
    41  //
    42  // Multiple of these can be combined when separated with a semicolon.
    43  //
    44  // When opening a fakeDriver's database, it starts empty with no
    45  // tables. All tables and data are stored in memory only.
    46  type fakeDriver struct {
    47  	mu         sync.Mutex // guards 3 following fields
    48  	openCount  int        // conn opens
    49  	closeCount int        // conn closes
    50  	waitCh     chan struct{}
    51  	waitingCh  chan struct{}
    52  	dbs        map[string]*fakeDB
    53  }
    54  
    55  type fakeConnector struct {
    56  	name string
    57  
    58  	waiter func(context.Context)
    59  }
    60  
    61  func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
    62  	conn, err := fdriver.Open(c.name)
    63  	conn.(*fakeConn).waiter = c.waiter
    64  	return conn, err
    65  }
    66  
    67  func (c *fakeConnector) Driver() driver.Driver {
    68  	return fdriver
    69  }
    70  
    71  type fakeDriverCtx struct {
    72  	fakeDriver
    73  }
    74  
    75  var _ driver.DriverContext = &fakeDriverCtx{}
    76  
    77  func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
    78  	return &fakeConnector{name: name}, nil
    79  }
    80  
    81  type fakeDB struct {
    82  	name string
    83  
    84  	mu       sync.Mutex
    85  	tables   map[string]*table
    86  	badConn  bool
    87  	allowAny bool
    88  }
    89  
    90  type table struct {
    91  	mu      sync.Mutex
    92  	colname []string
    93  	coltype []string
    94  	rows    []*row
    95  }
    96  
    97  func (t *table) columnIndex(name string) int {
    98  	for n, nname := range t.colname {
    99  		if name == nname {
   100  			return n
   101  		}
   102  	}
   103  	return -1
   104  }
   105  
   106  type row struct {
   107  	cols []interface{} // must be same size as its table colname + coltype
   108  }
   109  
   110  type memToucher interface {
   111  	// touchMem reads & writes some memory, to help find data races.
   112  	touchMem()
   113  }
   114  
   115  type fakeConn struct {
   116  	db *fakeDB // where to return ourselves to
   117  
   118  	currTx *fakeTx
   119  
   120  	// Every operation writes to line to enable the race detector
   121  	// check for data races.
   122  	line int64
   123  
   124  	// Stats for tests:
   125  	mu          sync.Mutex
   126  	stmtsMade   int
   127  	stmtsClosed int
   128  	numPrepare  int
   129  
   130  	// bad connection tests; see isBad()
   131  	bad       bool
   132  	stickyBad bool
   133  
   134  	skipDirtySession bool // tests that use Conn should set this to true.
   135  
   136  	// dirtySession tests ResetSession, true if a query has executed
   137  	// until ResetSession is called.
   138  	dirtySession bool
   139  
   140  	// The waiter is called before each query. May be used in place of the "WAIT"
   141  	// directive.
   142  	waiter func(context.Context)
   143  }
   144  
   145  func (c *fakeConn) touchMem() {
   146  	c.line++
   147  }
   148  
   149  func (c *fakeConn) incrStat(v *int) {
   150  	c.mu.Lock()
   151  	*v++
   152  	c.mu.Unlock()
   153  }
   154  
   155  type fakeTx struct {
   156  	c *fakeConn
   157  }
   158  
   159  type boundCol struct {
   160  	Column      string
   161  	Placeholder string
   162  	Ordinal     int
   163  }
   164  
   165  type fakeStmt struct {
   166  	memToucher
   167  	c *fakeConn
   168  	q string // just for debugging
   169  
   170  	cmd   string
   171  	table string
   172  	panic string
   173  	wait  time.Duration
   174  
   175  	next *fakeStmt // used for returning multiple results.
   176  
   177  	closed bool
   178  
   179  	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
   180  	colType      []string      // used by CREATE
   181  	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
   182  	placeholders int           // used by INSERT/SELECT: number of ? params
   183  
   184  	whereCol []boundCol // used by SELECT (all placeholders)
   185  
   186  	placeholderConverter []driver.ValueConverter // used by INSERT
   187  }
   188  
   189  var fdriver driver.Driver = &fakeDriver{}
   190  
   191  func init() {
   192  	Register("test", fdriver)
   193  }
   194  
   195  func contains(list []string, y string) bool {
   196  	for _, x := range list {
   197  		if x == y {
   198  			return true
   199  		}
   200  	}
   201  	return false
   202  }
   203  
   204  type Dummy struct {
   205  	driver.Driver
   206  }
   207  
   208  func TestDrivers(t *testing.T) {
   209  	unregisterAllDrivers()
   210  	Register("test", fdriver)
   211  	Register("invalid", Dummy{})
   212  	all := Drivers()
   213  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   214  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   215  	}
   216  }
   217  
   218  // hook to simulate connection failures
   219  var hookOpenErr struct {
   220  	sync.Mutex
   221  	fn func() error
   222  }
   223  
   224  func setHookOpenErr(fn func() error) {
   225  	hookOpenErr.Lock()
   226  	defer hookOpenErr.Unlock()
   227  	hookOpenErr.fn = fn
   228  }
   229  
   230  // Supports dsn forms:
   231  //    <dbname>
   232  //    <dbname>;<opts>  (only currently supported option is `badConn`,
   233  //                      which causes driver.ErrBadConn to be returned on
   234  //                      every other conn.Begin())
   235  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   236  	hookOpenErr.Lock()
   237  	fn := hookOpenErr.fn
   238  	hookOpenErr.Unlock()
   239  	if fn != nil {
   240  		if err := fn(); err != nil {
   241  			return nil, err
   242  		}
   243  	}
   244  	parts := strings.Split(dsn, ";")
   245  	if len(parts) < 1 {
   246  		return nil, errors.New("fakedb: no database name")
   247  	}
   248  	name := parts[0]
   249  
   250  	db := d.getDB(name)
   251  
   252  	d.mu.Lock()
   253  	d.openCount++
   254  	d.mu.Unlock()
   255  	conn := &fakeConn{db: db}
   256  
   257  	if len(parts) >= 2 && parts[1] == "badConn" {
   258  		conn.bad = true
   259  	}
   260  	if d.waitCh != nil {
   261  		d.waitingCh <- struct{}{}
   262  		<-d.waitCh
   263  		d.waitCh = nil
   264  		d.waitingCh = nil
   265  	}
   266  	return conn, nil
   267  }
   268  
   269  func (d *fakeDriver) getDB(name string) *fakeDB {
   270  	d.mu.Lock()
   271  	defer d.mu.Unlock()
   272  	if d.dbs == nil {
   273  		d.dbs = make(map[string]*fakeDB)
   274  	}
   275  	db, ok := d.dbs[name]
   276  	if !ok {
   277  		db = &fakeDB{name: name}
   278  		d.dbs[name] = db
   279  	}
   280  	return db
   281  }
   282  
   283  func (db *fakeDB) wipe() {
   284  	db.mu.Lock()
   285  	defer db.mu.Unlock()
   286  	db.tables = nil
   287  }
   288  
   289  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   290  	db.mu.Lock()
   291  	defer db.mu.Unlock()
   292  	if db.tables == nil {
   293  		db.tables = make(map[string]*table)
   294  	}
   295  	if _, exist := db.tables[name]; exist {
   296  		return fmt.Errorf("fakedb: table %q already exists", name)
   297  	}
   298  	if len(columnNames) != len(columnTypes) {
   299  		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
   300  			name, len(columnNames), len(columnTypes))
   301  	}
   302  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   303  	return nil
   304  }
   305  
   306  // must be called with db.mu lock held
   307  func (db *fakeDB) table(table string) (*table, bool) {
   308  	if db.tables == nil {
   309  		return nil, false
   310  	}
   311  	t, ok := db.tables[table]
   312  	return t, ok
   313  }
   314  
   315  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   316  	db.mu.Lock()
   317  	defer db.mu.Unlock()
   318  	t, ok := db.table(table)
   319  	if !ok {
   320  		return
   321  	}
   322  	for n, cname := range t.colname {
   323  		if cname == column {
   324  			return t.coltype[n], true
   325  		}
   326  	}
   327  	return "", false
   328  }
   329  
   330  func (c *fakeConn) isBad() bool {
   331  	if c.stickyBad {
   332  		return true
   333  	} else if c.bad {
   334  		if c.db == nil {
   335  			return false
   336  		}
   337  		// alternate between bad conn and not bad conn
   338  		c.db.badConn = !c.db.badConn
   339  		return c.db.badConn
   340  	} else {
   341  		return false
   342  	}
   343  }
   344  
   345  func (c *fakeConn) isDirtyAndMark() bool {
   346  	if c.skipDirtySession {
   347  		return false
   348  	}
   349  	if c.currTx != nil {
   350  		c.dirtySession = true
   351  		return false
   352  	}
   353  	if c.dirtySession {
   354  		return true
   355  	}
   356  	c.dirtySession = true
   357  	return false
   358  }
   359  
   360  func (c *fakeConn) Begin() (driver.Tx, error) {
   361  	if c.isBad() {
   362  		return nil, driver.ErrBadConn
   363  	}
   364  	if c.currTx != nil {
   365  		return nil, errors.New("fakedb: already in a transaction")
   366  	}
   367  	c.touchMem()
   368  	c.currTx = &fakeTx{c: c}
   369  	return c.currTx, nil
   370  }
   371  
   372  var hookPostCloseConn struct {
   373  	sync.Mutex
   374  	fn func(*fakeConn, error)
   375  }
   376  
   377  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   378  	hookPostCloseConn.Lock()
   379  	defer hookPostCloseConn.Unlock()
   380  	hookPostCloseConn.fn = fn
   381  }
   382  
   383  var testStrictClose *testing.T
   384  
   385  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   386  // fails to close. If nil, the check is disabled.
   387  func setStrictFakeConnClose(t *testing.T) {
   388  	testStrictClose = t
   389  }
   390  
   391  func (c *fakeConn) ResetSession(ctx context.Context) error {
   392  	c.dirtySession = false
   393  	c.currTx = nil
   394  	if c.isBad() {
   395  		return driver.ErrBadConn
   396  	}
   397  	return nil
   398  }
   399  
   400  var _ driver.Validator = (*fakeConn)(nil)
   401  
   402  func (c *fakeConn) IsValid() bool {
   403  	return !c.isBad()
   404  }
   405  
   406  func (c *fakeConn) Close() (err error) {
   407  	drv := fdriver.(*fakeDriver)
   408  	defer func() {
   409  		if err != nil && testStrictClose != nil {
   410  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   411  		}
   412  		hookPostCloseConn.Lock()
   413  		fn := hookPostCloseConn.fn
   414  		hookPostCloseConn.Unlock()
   415  		if fn != nil {
   416  			fn(c, err)
   417  		}
   418  		if err == nil {
   419  			drv.mu.Lock()
   420  			drv.closeCount++
   421  			drv.mu.Unlock()
   422  		}
   423  	}()
   424  	c.touchMem()
   425  	if c.currTx != nil {
   426  		return errors.New("fakedb: can't close fakeConn; in a Transaction")
   427  	}
   428  	if c.db == nil {
   429  		return errors.New("fakedb: can't close fakeConn; already closed")
   430  	}
   431  	if c.stmtsMade > c.stmtsClosed {
   432  		return errors.New("fakedb: can't close; dangling statement(s)")
   433  	}
   434  	c.db = nil
   435  	return nil
   436  }
   437  
   438  func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
   439  	for _, arg := range args {
   440  		switch arg.Value.(type) {
   441  		case int64, float64, bool, nil, []byte, string, time.Time:
   442  		default:
   443  			if !allowAny {
   444  				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   445  			}
   446  		}
   447  	}
   448  	return nil
   449  }
   450  
   451  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   452  	// Ensure that ExecContext is called if available.
   453  	panic("ExecContext was not called.")
   454  }
   455  
   456  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   457  	// This is an optional interface, but it's implemented here
   458  	// just to check that all the args are of the proper types.
   459  	// ErrSkip is returned so the caller acts as if we didn't
   460  	// implement this at all.
   461  	err := checkSubsetTypes(c.db.allowAny, args)
   462  	if err != nil {
   463  		return nil, err
   464  	}
   465  	return nil, driver.ErrSkip
   466  }
   467  
   468  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   469  	// Ensure that ExecContext is called if available.
   470  	panic("QueryContext was not called.")
   471  }
   472  
   473  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   474  	// This is an optional interface, but it's implemented here
   475  	// just to check that all the args are of the proper types.
   476  	// ErrSkip is returned so the caller acts as if we didn't
   477  	// implement this at all.
   478  	err := checkSubsetTypes(c.db.allowAny, args)
   479  	if err != nil {
   480  		return nil, err
   481  	}
   482  	return nil, driver.ErrSkip
   483  }
   484  
   485  func errf(msg string, args ...interface{}) error {
   486  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   487  }
   488  
   489  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   490  // (note that where columns must always contain ? marks,
   491  //  just a limitation for fakedb)
   492  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   493  	if len(parts) != 3 {
   494  		stmt.Close()
   495  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   496  	}
   497  	stmt.table = parts[0]
   498  
   499  	stmt.colName = strings.Split(parts[1], ",")
   500  	for n, colspec := range strings.Split(parts[2], ",") {
   501  		if colspec == "" {
   502  			continue
   503  		}
   504  		nameVal := strings.Split(colspec, "=")
   505  		if len(nameVal) != 2 {
   506  			stmt.Close()
   507  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   508  		}
   509  		column, value := nameVal[0], nameVal[1]
   510  		_, ok := c.db.columnType(stmt.table, column)
   511  		if !ok {
   512  			stmt.Close()
   513  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   514  		}
   515  		if !strings.HasPrefix(value, "?") {
   516  			stmt.Close()
   517  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   518  				stmt.table, column)
   519  		}
   520  		stmt.placeholders++
   521  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   522  	}
   523  	return stmt, nil
   524  }
   525  
   526  // parts are table|col=type,col2=type2
   527  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   528  	if len(parts) != 2 {
   529  		stmt.Close()
   530  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   531  	}
   532  	stmt.table = parts[0]
   533  	for n, colspec := range strings.Split(parts[1], ",") {
   534  		nameType := strings.Split(colspec, "=")
   535  		if len(nameType) != 2 {
   536  			stmt.Close()
   537  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   538  		}
   539  		stmt.colName = append(stmt.colName, nameType[0])
   540  		stmt.colType = append(stmt.colType, nameType[1])
   541  	}
   542  	return stmt, nil
   543  }
   544  
   545  // parts are table|col=?,col2=val
   546  func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   547  	if len(parts) != 2 {
   548  		stmt.Close()
   549  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   550  	}
   551  	stmt.table = parts[0]
   552  	for n, colspec := range strings.Split(parts[1], ",") {
   553  		nameVal := strings.Split(colspec, "=")
   554  		if len(nameVal) != 2 {
   555  			stmt.Close()
   556  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   557  		}
   558  		column, value := nameVal[0], nameVal[1]
   559  		ctype, ok := c.db.columnType(stmt.table, column)
   560  		if !ok {
   561  			stmt.Close()
   562  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   563  		}
   564  		stmt.colName = append(stmt.colName, column)
   565  
   566  		if !strings.HasPrefix(value, "?") {
   567  			var subsetVal interface{}
   568  			// Convert to driver subset type
   569  			switch ctype {
   570  			case "string":
   571  				subsetVal = []byte(value)
   572  			case "blob":
   573  				subsetVal = []byte(value)
   574  			case "int32":
   575  				i, err := strconv.Atoi(value)
   576  				if err != nil {
   577  					stmt.Close()
   578  					return nil, errf("invalid conversion to int32 from %q", value)
   579  				}
   580  				subsetVal = int64(i) // int64 is a subset type, but not int32
   581  			case "table": // For testing cursor reads.
   582  				c.skipDirtySession = true
   583  				vparts := strings.Split(value, "!")
   584  
   585  				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
   586  				if err != nil {
   587  					return nil, err
   588  				}
   589  				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
   590  				substmt.Close()
   591  				if err != nil {
   592  					return nil, err
   593  				}
   594  				subsetVal = cursor
   595  			default:
   596  				stmt.Close()
   597  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   598  			}
   599  			stmt.colValue = append(stmt.colValue, subsetVal)
   600  		} else {
   601  			stmt.placeholders++
   602  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   603  			stmt.colValue = append(stmt.colValue, value)
   604  		}
   605  	}
   606  	return stmt, nil
   607  }
   608  
   609  // hook to simulate broken connections
   610  var hookPrepareBadConn func() bool
   611  
   612  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   613  	panic("use PrepareContext")
   614  }
   615  
   616  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   617  	c.numPrepare++
   618  	if c.db == nil {
   619  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   620  	}
   621  
   622  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   623  		return nil, driver.ErrBadConn
   624  	}
   625  
   626  	c.touchMem()
   627  	var firstStmt, prev *fakeStmt
   628  	for _, query := range strings.Split(query, ";") {
   629  		parts := strings.Split(query, "|")
   630  		if len(parts) < 1 {
   631  			return nil, errf("empty query")
   632  		}
   633  		stmt := &fakeStmt{q: query, c: c, memToucher: c}
   634  		if firstStmt == nil {
   635  			firstStmt = stmt
   636  		}
   637  		if len(parts) >= 3 {
   638  			switch parts[0] {
   639  			case "PANIC":
   640  				stmt.panic = parts[1]
   641  				parts = parts[2:]
   642  			case "WAIT":
   643  				wait, err := time.ParseDuration(parts[1])
   644  				if err != nil {
   645  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
   646  				}
   647  				parts = parts[2:]
   648  				stmt.wait = wait
   649  			}
   650  		}
   651  		cmd := parts[0]
   652  		stmt.cmd = cmd
   653  		parts = parts[1:]
   654  
   655  		if c.waiter != nil {
   656  			c.waiter(ctx)
   657  		}
   658  
   659  		if stmt.wait > 0 {
   660  			wait := time.NewTimer(stmt.wait)
   661  			select {
   662  			case <-wait.C:
   663  			case <-ctx.Done():
   664  				wait.Stop()
   665  				return nil, ctx.Err()
   666  			}
   667  		}
   668  
   669  		c.incrStat(&c.stmtsMade)
   670  		var err error
   671  		switch cmd {
   672  		case "WIPE":
   673  			// Nothing
   674  		case "SELECT":
   675  			stmt, err = c.prepareSelect(stmt, parts)
   676  		case "CREATE":
   677  			stmt, err = c.prepareCreate(stmt, parts)
   678  		case "INSERT":
   679  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   680  		case "NOSERT":
   681  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   682  			// Used for some of the concurrent tests.
   683  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   684  		default:
   685  			stmt.Close()
   686  			return nil, errf("unsupported command type %q", cmd)
   687  		}
   688  		if err != nil {
   689  			return nil, err
   690  		}
   691  		if prev != nil {
   692  			prev.next = stmt
   693  		}
   694  		prev = stmt
   695  	}
   696  	return firstStmt, nil
   697  }
   698  
   699  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   700  	if s.panic == "ColumnConverter" {
   701  		panic(s.panic)
   702  	}
   703  	if len(s.placeholderConverter) == 0 {
   704  		return driver.DefaultParameterConverter
   705  	}
   706  	return s.placeholderConverter[idx]
   707  }
   708  
   709  func (s *fakeStmt) Close() error {
   710  	if s.panic == "Close" {
   711  		panic(s.panic)
   712  	}
   713  	if s.c == nil {
   714  		panic("nil conn in fakeStmt.Close")
   715  	}
   716  	if s.c.db == nil {
   717  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   718  	}
   719  	s.touchMem()
   720  	if !s.closed {
   721  		s.c.incrStat(&s.c.stmtsClosed)
   722  		s.closed = true
   723  	}
   724  	if s.next != nil {
   725  		s.next.Close()
   726  	}
   727  	return nil
   728  }
   729  
   730  var errClosed = errors.New("fakedb: statement has been closed")
   731  
   732  // hook to simulate broken connections
   733  var hookExecBadConn func() bool
   734  
   735  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   736  	panic("Using ExecContext")
   737  }
   738  
   739  var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
   740  
   741  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   742  	if s.panic == "Exec" {
   743  		panic(s.panic)
   744  	}
   745  	if s.closed {
   746  		return nil, errClosed
   747  	}
   748  
   749  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   750  		return nil, driver.ErrBadConn
   751  	}
   752  	if s.c.isDirtyAndMark() {
   753  		return nil, errFakeConnSessionDirty
   754  	}
   755  
   756  	err := checkSubsetTypes(s.c.db.allowAny, args)
   757  	if err != nil {
   758  		return nil, err
   759  	}
   760  	s.touchMem()
   761  
   762  	if s.wait > 0 {
   763  		time.Sleep(s.wait)
   764  	}
   765  
   766  	select {
   767  	default:
   768  	case <-ctx.Done():
   769  		return nil, ctx.Err()
   770  	}
   771  
   772  	db := s.c.db
   773  	switch s.cmd {
   774  	case "WIPE":
   775  		db.wipe()
   776  		return driver.ResultNoRows, nil
   777  	case "CREATE":
   778  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   779  			return nil, err
   780  		}
   781  		return driver.ResultNoRows, nil
   782  	case "INSERT":
   783  		return s.execInsert(args, true)
   784  	case "NOSERT":
   785  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   786  		// Used for some of the concurrent tests.
   787  		return s.execInsert(args, false)
   788  	}
   789  	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
   790  }
   791  
   792  // When doInsert is true, add the row to the table.
   793  // When doInsert is false do prep-work and error checking, but don't
   794  // actually add the row to the table.
   795  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   796  	db := s.c.db
   797  	if len(args) != s.placeholders {
   798  		panic("error in pkg db; should only get here if size is correct")
   799  	}
   800  	db.mu.Lock()
   801  	t, ok := db.table(s.table)
   802  	db.mu.Unlock()
   803  	if !ok {
   804  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   805  	}
   806  
   807  	t.mu.Lock()
   808  	defer t.mu.Unlock()
   809  
   810  	var cols []interface{}
   811  	if doInsert {
   812  		cols = make([]interface{}, len(t.colname))
   813  	}
   814  	argPos := 0
   815  	for n, colname := range s.colName {
   816  		colidx := t.columnIndex(colname)
   817  		if colidx == -1 {
   818  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   819  		}
   820  		var val interface{}
   821  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   822  			if strvalue == "?" {
   823  				val = args[argPos].Value
   824  			} else {
   825  				// Assign value from argument placeholder name.
   826  				for _, a := range args {
   827  					if a.Name == strvalue[1:] {
   828  						val = a.Value
   829  						break
   830  					}
   831  				}
   832  			}
   833  			argPos++
   834  		} else {
   835  			val = s.colValue[n]
   836  		}
   837  		if doInsert {
   838  			cols[colidx] = val
   839  		}
   840  	}
   841  
   842  	if doInsert {
   843  		t.rows = append(t.rows, &row{cols: cols})
   844  	}
   845  	return driver.RowsAffected(1), nil
   846  }
   847  
   848  // hook to simulate broken connections
   849  var hookQueryBadConn func() bool
   850  
   851  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   852  	panic("Use QueryContext")
   853  }
   854  
   855  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   856  	if s.panic == "Query" {
   857  		panic(s.panic)
   858  	}
   859  	if s.closed {
   860  		return nil, errClosed
   861  	}
   862  
   863  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   864  		return nil, driver.ErrBadConn
   865  	}
   866  	if s.c.isDirtyAndMark() {
   867  		return nil, errFakeConnSessionDirty
   868  	}
   869  
   870  	err := checkSubsetTypes(s.c.db.allowAny, args)
   871  	if err != nil {
   872  		return nil, err
   873  	}
   874  
   875  	s.touchMem()
   876  	db := s.c.db
   877  	if len(args) != s.placeholders {
   878  		panic("error in pkg db; should only get here if size is correct")
   879  	}
   880  
   881  	setMRows := make([][]*row, 0, 1)
   882  	setColumns := make([][]string, 0, 1)
   883  	setColType := make([][]string, 0, 1)
   884  
   885  	for {
   886  		db.mu.Lock()
   887  		t, ok := db.table(s.table)
   888  		db.mu.Unlock()
   889  		if !ok {
   890  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   891  		}
   892  
   893  		if s.table == "magicquery" {
   894  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   895  				if args[0].Value == "sleep" {
   896  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   897  				}
   898  			}
   899  		}
   900  		if s.table == "tx_status" && s.colName[0] == "tx_status" {
   901  			txStatus := "autocommit"
   902  			if s.c.currTx != nil {
   903  				txStatus = "transaction"
   904  			}
   905  			cursor := &rowsCursor{
   906  				parentMem: s.c,
   907  				posRow:    -1,
   908  				rows: [][]*row{
   909  					[]*row{
   910  						{
   911  							cols: []interface{}{
   912  								txStatus,
   913  							},
   914  						},
   915  					},
   916  				},
   917  				cols: [][]string{
   918  					[]string{
   919  						"tx_status",
   920  					},
   921  				},
   922  				colType: [][]string{
   923  					[]string{
   924  						"string",
   925  					},
   926  				},
   927  				errPos: -1,
   928  			}
   929  			return cursor, nil
   930  		}
   931  
   932  		t.mu.Lock()
   933  
   934  		colIdx := make(map[string]int) // select column name -> column index in table
   935  		for _, name := range s.colName {
   936  			idx := t.columnIndex(name)
   937  			if idx == -1 {
   938  				t.mu.Unlock()
   939  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   940  			}
   941  			colIdx[name] = idx
   942  		}
   943  
   944  		mrows := []*row{}
   945  	rows:
   946  		for _, trow := range t.rows {
   947  			// Process the where clause, skipping non-match rows. This is lazy
   948  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   949  			// for test code.
   950  			for _, wcol := range s.whereCol {
   951  				idx := t.columnIndex(wcol.Column)
   952  				if idx == -1 {
   953  					t.mu.Unlock()
   954  					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
   955  				}
   956  				tcol := trow.cols[idx]
   957  				if bs, ok := tcol.([]byte); ok {
   958  					// lazy hack to avoid sprintf %v on a []byte
   959  					tcol = string(bs)
   960  				}
   961  				var argValue interface{}
   962  				if wcol.Placeholder == "?" {
   963  					argValue = args[wcol.Ordinal-1].Value
   964  				} else {
   965  					// Assign arg value from placeholder name.
   966  					for _, a := range args {
   967  						if a.Name == wcol.Placeholder[1:] {
   968  							argValue = a.Value
   969  							break
   970  						}
   971  					}
   972  				}
   973  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
   974  					continue rows
   975  				}
   976  			}
   977  			mrow := &row{cols: make([]interface{}, len(s.colName))}
   978  			for seli, name := range s.colName {
   979  				mrow.cols[seli] = trow.cols[colIdx[name]]
   980  			}
   981  			mrows = append(mrows, mrow)
   982  		}
   983  
   984  		var colType []string
   985  		for _, column := range s.colName {
   986  			colType = append(colType, t.coltype[t.columnIndex(column)])
   987  		}
   988  
   989  		t.mu.Unlock()
   990  
   991  		setMRows = append(setMRows, mrows)
   992  		setColumns = append(setColumns, s.colName)
   993  		setColType = append(setColType, colType)
   994  
   995  		if s.next == nil {
   996  			break
   997  		}
   998  		s = s.next
   999  	}
  1000  
  1001  	cursor := &rowsCursor{
  1002  		parentMem: s.c,
  1003  		posRow:    -1,
  1004  		rows:      setMRows,
  1005  		cols:      setColumns,
  1006  		colType:   setColType,
  1007  		errPos:    -1,
  1008  	}
  1009  	return cursor, nil
  1010  }
  1011  
  1012  func (s *fakeStmt) NumInput() int {
  1013  	if s.panic == "NumInput" {
  1014  		panic(s.panic)
  1015  	}
  1016  	return s.placeholders
  1017  }
  1018  
  1019  // hook to simulate broken connections
  1020  var hookCommitBadConn func() bool
  1021  
  1022  func (tx *fakeTx) Commit() error {
  1023  	tx.c.currTx = nil
  1024  	if hookCommitBadConn != nil && hookCommitBadConn() {
  1025  		return driver.ErrBadConn
  1026  	}
  1027  	tx.c.touchMem()
  1028  	return nil
  1029  }
  1030  
  1031  // hook to simulate broken connections
  1032  var hookRollbackBadConn func() bool
  1033  
  1034  func (tx *fakeTx) Rollback() error {
  1035  	tx.c.currTx = nil
  1036  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
  1037  		return driver.ErrBadConn
  1038  	}
  1039  	tx.c.touchMem()
  1040  	return nil
  1041  }
  1042  
  1043  type rowsCursor struct {
  1044  	parentMem memToucher
  1045  	cols      [][]string
  1046  	colType   [][]string
  1047  	posSet    int
  1048  	posRow    int
  1049  	rows      [][]*row
  1050  	closed    bool
  1051  
  1052  	// errPos and err are for making Next return early with error.
  1053  	errPos int
  1054  	err    error
  1055  
  1056  	// a clone of slices to give out to clients, indexed by the
  1057  	// original slice's first byte address.  we clone them
  1058  	// just so we're able to corrupt them on close.
  1059  	bytesClone map[*byte][]byte
  1060  
  1061  	// Every operation writes to line to enable the race detector
  1062  	// check for data races.
  1063  	// This is separate from the fakeConn.line to allow for drivers that
  1064  	// can start multiple queries on the same transaction at the same time.
  1065  	line int64
  1066  }
  1067  
  1068  func (rc *rowsCursor) touchMem() {
  1069  	rc.parentMem.touchMem()
  1070  	rc.line++
  1071  }
  1072  
  1073  func (rc *rowsCursor) Close() error {
  1074  	rc.touchMem()
  1075  	rc.parentMem.touchMem()
  1076  	rc.closed = true
  1077  	return nil
  1078  }
  1079  
  1080  func (rc *rowsCursor) Columns() []string {
  1081  	return rc.cols[rc.posSet]
  1082  }
  1083  
  1084  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
  1085  	return colTypeToReflectType(rc.colType[rc.posSet][index])
  1086  }
  1087  
  1088  var rowsCursorNextHook func(dest []driver.Value) error
  1089  
  1090  func (rc *rowsCursor) Next(dest []driver.Value) error {
  1091  	if rowsCursorNextHook != nil {
  1092  		return rowsCursorNextHook(dest)
  1093  	}
  1094  
  1095  	if rc.closed {
  1096  		return errors.New("fakedb: cursor is closed")
  1097  	}
  1098  	rc.touchMem()
  1099  	rc.posRow++
  1100  	if rc.posRow == rc.errPos {
  1101  		return rc.err
  1102  	}
  1103  	if rc.posRow >= len(rc.rows[rc.posSet]) {
  1104  		return io.EOF // per interface spec
  1105  	}
  1106  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
  1107  		// TODO(bradfitz): convert to subset types? naah, I
  1108  		// think the subset types should only be input to
  1109  		// driver, but the sql package should be able to handle
  1110  		// a wider range of types coming out of drivers. all
  1111  		// for ease of drivers, and to prevent drivers from
  1112  		// messing up conversions or doing them differently.
  1113  		dest[i] = v
  1114  
  1115  		if bs, ok := v.([]byte); ok {
  1116  			if rc.bytesClone == nil {
  1117  				rc.bytesClone = make(map[*byte][]byte)
  1118  			}
  1119  			clone, ok := rc.bytesClone[&bs[0]]
  1120  			if !ok {
  1121  				clone = make([]byte, len(bs))
  1122  				copy(clone, bs)
  1123  				rc.bytesClone[&bs[0]] = clone
  1124  			}
  1125  			dest[i] = clone
  1126  		}
  1127  	}
  1128  	return nil
  1129  }
  1130  
  1131  func (rc *rowsCursor) HasNextResultSet() bool {
  1132  	rc.touchMem()
  1133  	return rc.posSet < len(rc.rows)-1
  1134  }
  1135  
  1136  func (rc *rowsCursor) NextResultSet() error {
  1137  	rc.touchMem()
  1138  	if rc.HasNextResultSet() {
  1139  		rc.posSet++
  1140  		rc.posRow = -1
  1141  		return nil
  1142  	}
  1143  	return io.EOF // Per interface spec.
  1144  }
  1145  
  1146  // fakeDriverString is like driver.String, but indirects pointers like
  1147  // DefaultValueConverter.
  1148  //
  1149  // This could be surprising behavior to retroactively apply to
  1150  // driver.String now that Go1 is out, but this is convenient for
  1151  // our TestPointerParamsAndScans.
  1152  //
  1153  type fakeDriverString struct{}
  1154  
  1155  func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
  1156  	switch c := v.(type) {
  1157  	case string, []byte:
  1158  		return v, nil
  1159  	case *string:
  1160  		if c == nil {
  1161  			return nil, nil
  1162  		}
  1163  		return *c, nil
  1164  	}
  1165  	return fmt.Sprintf("%v", v), nil
  1166  }
  1167  
  1168  type anyTypeConverter struct{}
  1169  
  1170  func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
  1171  	return v, nil
  1172  }
  1173  
  1174  func converterForType(typ string) driver.ValueConverter {
  1175  	switch typ {
  1176  	case "bool":
  1177  		return driver.Bool
  1178  	case "nullbool":
  1179  		return driver.Null{Converter: driver.Bool}
  1180  	case "int32":
  1181  		return driver.Int32
  1182  	case "nullint32":
  1183  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1184  	case "string":
  1185  		return driver.NotNull{Converter: fakeDriverString{}}
  1186  	case "nullstring":
  1187  		return driver.Null{Converter: fakeDriverString{}}
  1188  	case "int64":
  1189  		// TODO(coopernurse): add type-specific converter
  1190  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1191  	case "nullint64":
  1192  		// TODO(coopernurse): add type-specific converter
  1193  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1194  	case "float64":
  1195  		// TODO(coopernurse): add type-specific converter
  1196  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1197  	case "nullfloat64":
  1198  		// TODO(coopernurse): add type-specific converter
  1199  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1200  	case "datetime":
  1201  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1202  	case "nulldatetime":
  1203  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1204  	case "any":
  1205  		return anyTypeConverter{}
  1206  	}
  1207  	panic("invalid fakedb column type of " + typ)
  1208  }
  1209  
  1210  func colTypeToReflectType(typ string) reflect.Type {
  1211  	switch typ {
  1212  	case "bool":
  1213  		return reflect.TypeOf(false)
  1214  	case "nullbool":
  1215  		return reflect.TypeOf(NullBool{})
  1216  	case "int32":
  1217  		return reflect.TypeOf(int32(0))
  1218  	case "nullint32":
  1219  		return reflect.TypeOf(NullInt32{})
  1220  	case "string":
  1221  		return reflect.TypeOf("")
  1222  	case "nullstring":
  1223  		return reflect.TypeOf(NullString{})
  1224  	case "int64":
  1225  		return reflect.TypeOf(int64(0))
  1226  	case "nullint64":
  1227  		return reflect.TypeOf(NullInt64{})
  1228  	case "float64":
  1229  		return reflect.TypeOf(float64(0))
  1230  	case "nullfloat64":
  1231  		return reflect.TypeOf(NullFloat64{})
  1232  	case "datetime":
  1233  		return reflect.TypeOf(time.Time{})
  1234  	case "any":
  1235  		return reflect.TypeOf(new(interface{})).Elem()
  1236  	}
  1237  	panic("invalid fakedb column type of " + typ)
  1238  }