github.com/rakyll/go@v0.0.0-20170216000551-64c02460d703/src/database/sql/fakedb_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"reflect"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  var _ = log.Printf
    24  
    25  // fakeDriver is a fake database that implements Go's driver.Driver
    26  // interface, just for testing.
    27  //
    28  // It speaks a query language that's semantically similar to but
    29  // syntactically different and simpler than SQL.  The syntax is as
    30  // follows:
    31  //
    32  //   WIPE
    33  //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    34  //     where types are: "string", [u]int{8,16,32,64}, "bool"
    35  //   INSERT|<tablename>|col=val,col2=val2,col3=?
    36  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    37  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    38  //
    39  // Any of these can be preceded by PANIC|<method>|, to cause the
    40  // named method on fakeStmt to panic.
    41  //
    42  // Any of these can be proceeded by WAIT|<duration>|, to cause the
    43  // named method on fakeStmt to sleep for the specified duration.
    44  //
    45  // Multiple of these can be combined when separated with a semicolon.
    46  //
    47  // When opening a fakeDriver's database, it starts empty with no
    48  // tables. All tables and data are stored in memory only.
    49  type fakeDriver struct {
    50  	mu         sync.Mutex // guards 3 following fields
    51  	openCount  int        // conn opens
    52  	closeCount int        // conn closes
    53  	waitCh     chan struct{}
    54  	waitingCh  chan struct{}
    55  	dbs        map[string]*fakeDB
    56  }
    57  
    58  type fakeDB struct {
    59  	name string
    60  
    61  	mu      sync.Mutex
    62  	tables  map[string]*table
    63  	badConn bool
    64  }
    65  
    66  type table struct {
    67  	mu      sync.Mutex
    68  	colname []string
    69  	coltype []string
    70  	rows    []*row
    71  }
    72  
    73  func (t *table) columnIndex(name string) int {
    74  	for n, nname := range t.colname {
    75  		if name == nname {
    76  			return n
    77  		}
    78  	}
    79  	return -1
    80  }
    81  
    82  type row struct {
    83  	cols []interface{} // must be same size as its table colname + coltype
    84  }
    85  
    86  type fakeConn struct {
    87  	db *fakeDB // where to return ourselves to
    88  
    89  	currTx *fakeTx
    90  
    91  	// Stats for tests:
    92  	mu          sync.Mutex
    93  	stmtsMade   int
    94  	stmtsClosed int
    95  	numPrepare  int
    96  
    97  	// bad connection tests; see isBad()
    98  	bad       bool
    99  	stickyBad bool
   100  }
   101  
   102  func (c *fakeConn) incrStat(v *int) {
   103  	c.mu.Lock()
   104  	*v++
   105  	c.mu.Unlock()
   106  }
   107  
   108  type fakeTx struct {
   109  	c *fakeConn
   110  }
   111  
   112  type boundCol struct {
   113  	Column      string
   114  	Placeholder string
   115  	Ordinal     int
   116  }
   117  
   118  type fakeStmt struct {
   119  	c *fakeConn
   120  	q string // just for debugging
   121  
   122  	cmd   string
   123  	table string
   124  	panic string
   125  	wait  time.Duration
   126  
   127  	next *fakeStmt // used for returning multiple results.
   128  
   129  	closed bool
   130  
   131  	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
   132  	colType      []string      // used by CREATE
   133  	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
   134  	placeholders int           // used by INSERT/SELECT: number of ? params
   135  
   136  	whereCol []boundCol // used by SELECT (all placeholders)
   137  
   138  	placeholderConverter []driver.ValueConverter // used by INSERT
   139  }
   140  
   141  var fdriver driver.Driver = &fakeDriver{}
   142  
   143  func init() {
   144  	Register("test", fdriver)
   145  }
   146  
   147  func contains(list []string, y string) bool {
   148  	for _, x := range list {
   149  		if x == y {
   150  			return true
   151  		}
   152  	}
   153  	return false
   154  }
   155  
   156  type Dummy struct {
   157  	driver.Driver
   158  }
   159  
   160  func TestDrivers(t *testing.T) {
   161  	unregisterAllDrivers()
   162  	Register("test", fdriver)
   163  	Register("invalid", Dummy{})
   164  	all := Drivers()
   165  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   166  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   167  	}
   168  }
   169  
   170  // hook to simulate connection failures
   171  var hookOpenErr struct {
   172  	sync.Mutex
   173  	fn func() error
   174  }
   175  
   176  func setHookOpenErr(fn func() error) {
   177  	hookOpenErr.Lock()
   178  	defer hookOpenErr.Unlock()
   179  	hookOpenErr.fn = fn
   180  }
   181  
   182  // Supports dsn forms:
   183  //    <dbname>
   184  //    <dbname>;<opts>  (only currently supported option is `badConn`,
   185  //                      which causes driver.ErrBadConn to be returned on
   186  //                      every other conn.Begin())
   187  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   188  	hookOpenErr.Lock()
   189  	fn := hookOpenErr.fn
   190  	hookOpenErr.Unlock()
   191  	if fn != nil {
   192  		if err := fn(); err != nil {
   193  			return nil, err
   194  		}
   195  	}
   196  	parts := strings.Split(dsn, ";")
   197  	if len(parts) < 1 {
   198  		return nil, errors.New("fakedb: no database name")
   199  	}
   200  	name := parts[0]
   201  
   202  	db := d.getDB(name)
   203  
   204  	d.mu.Lock()
   205  	d.openCount++
   206  	d.mu.Unlock()
   207  	conn := &fakeConn{db: db}
   208  
   209  	if len(parts) >= 2 && parts[1] == "badConn" {
   210  		conn.bad = true
   211  	}
   212  	if d.waitCh != nil {
   213  		d.waitingCh <- struct{}{}
   214  		<-d.waitCh
   215  		d.waitCh = nil
   216  		d.waitingCh = nil
   217  	}
   218  	return conn, nil
   219  }
   220  
   221  func (d *fakeDriver) getDB(name string) *fakeDB {
   222  	d.mu.Lock()
   223  	defer d.mu.Unlock()
   224  	if d.dbs == nil {
   225  		d.dbs = make(map[string]*fakeDB)
   226  	}
   227  	db, ok := d.dbs[name]
   228  	if !ok {
   229  		db = &fakeDB{name: name}
   230  		d.dbs[name] = db
   231  	}
   232  	return db
   233  }
   234  
   235  func (db *fakeDB) wipe() {
   236  	db.mu.Lock()
   237  	defer db.mu.Unlock()
   238  	db.tables = nil
   239  }
   240  
   241  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   242  	db.mu.Lock()
   243  	defer db.mu.Unlock()
   244  	if db.tables == nil {
   245  		db.tables = make(map[string]*table)
   246  	}
   247  	if _, exist := db.tables[name]; exist {
   248  		return fmt.Errorf("table %q already exists", name)
   249  	}
   250  	if len(columnNames) != len(columnTypes) {
   251  		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
   252  			name, len(columnNames), len(columnTypes))
   253  	}
   254  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   255  	return nil
   256  }
   257  
   258  // must be called with db.mu lock held
   259  func (db *fakeDB) table(table string) (*table, bool) {
   260  	if db.tables == nil {
   261  		return nil, false
   262  	}
   263  	t, ok := db.tables[table]
   264  	return t, ok
   265  }
   266  
   267  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   268  	db.mu.Lock()
   269  	defer db.mu.Unlock()
   270  	t, ok := db.table(table)
   271  	if !ok {
   272  		return
   273  	}
   274  	for n, cname := range t.colname {
   275  		if cname == column {
   276  			return t.coltype[n], true
   277  		}
   278  	}
   279  	return "", false
   280  }
   281  
   282  func (c *fakeConn) isBad() bool {
   283  	if c.stickyBad {
   284  		return true
   285  	} else if c.bad {
   286  		// alternate between bad conn and not bad conn
   287  		c.db.badConn = !c.db.badConn
   288  		return c.db.badConn
   289  	} else {
   290  		return false
   291  	}
   292  }
   293  
   294  func (c *fakeConn) Begin() (driver.Tx, error) {
   295  	if c.isBad() {
   296  		return nil, driver.ErrBadConn
   297  	}
   298  	if c.currTx != nil {
   299  		return nil, errors.New("already in a transaction")
   300  	}
   301  	c.currTx = &fakeTx{c: c}
   302  	return c.currTx, nil
   303  }
   304  
   305  var hookPostCloseConn struct {
   306  	sync.Mutex
   307  	fn func(*fakeConn, error)
   308  }
   309  
   310  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   311  	hookPostCloseConn.Lock()
   312  	defer hookPostCloseConn.Unlock()
   313  	hookPostCloseConn.fn = fn
   314  }
   315  
   316  var testStrictClose *testing.T
   317  
   318  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   319  // fails to close. If nil, the check is disabled.
   320  func setStrictFakeConnClose(t *testing.T) {
   321  	testStrictClose = t
   322  }
   323  
   324  func (c *fakeConn) Close() (err error) {
   325  	drv := fdriver.(*fakeDriver)
   326  	defer func() {
   327  		if err != nil && testStrictClose != nil {
   328  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   329  		}
   330  		hookPostCloseConn.Lock()
   331  		fn := hookPostCloseConn.fn
   332  		hookPostCloseConn.Unlock()
   333  		if fn != nil {
   334  			fn(c, err)
   335  		}
   336  		if err == nil {
   337  			drv.mu.Lock()
   338  			drv.closeCount++
   339  			drv.mu.Unlock()
   340  		}
   341  	}()
   342  	if c.currTx != nil {
   343  		return errors.New("can't close fakeConn; in a Transaction")
   344  	}
   345  	if c.db == nil {
   346  		return errors.New("can't close fakeConn; already closed")
   347  	}
   348  	if c.stmtsMade > c.stmtsClosed {
   349  		return errors.New("can't close; dangling statement(s)")
   350  	}
   351  	c.db = nil
   352  	return nil
   353  }
   354  
   355  func checkSubsetTypes(args []driver.NamedValue) error {
   356  	for _, arg := range args {
   357  		switch arg.Value.(type) {
   358  		case int64, float64, bool, nil, []byte, string, time.Time:
   359  		default:
   360  			return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   361  		}
   362  	}
   363  	return nil
   364  }
   365  
   366  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   367  	// Ensure that ExecContext is called if available.
   368  	panic("ExecContext was not called.")
   369  }
   370  
   371  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   372  	// This is an optional interface, but it's implemented here
   373  	// just to check that all the args are of the proper types.
   374  	// ErrSkip is returned so the caller acts as if we didn't
   375  	// implement this at all.
   376  	err := checkSubsetTypes(args)
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  	return nil, driver.ErrSkip
   381  }
   382  
   383  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   384  	// Ensure that ExecContext is called if available.
   385  	panic("QueryContext was not called.")
   386  }
   387  
   388  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   389  	// This is an optional interface, but it's implemented here
   390  	// just to check that all the args are of the proper types.
   391  	// ErrSkip is returned so the caller acts as if we didn't
   392  	// implement this at all.
   393  	err := checkSubsetTypes(args)
   394  	if err != nil {
   395  		return nil, err
   396  	}
   397  	return nil, driver.ErrSkip
   398  }
   399  
   400  func errf(msg string, args ...interface{}) error {
   401  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   402  }
   403  
   404  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   405  // (note that where columns must always contain ? marks,
   406  //  just a limitation for fakedb)
   407  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   408  	if len(parts) != 3 {
   409  		stmt.Close()
   410  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   411  	}
   412  	stmt.table = parts[0]
   413  
   414  	stmt.colName = strings.Split(parts[1], ",")
   415  	for n, colspec := range strings.Split(parts[2], ",") {
   416  		if colspec == "" {
   417  			continue
   418  		}
   419  		nameVal := strings.Split(colspec, "=")
   420  		if len(nameVal) != 2 {
   421  			stmt.Close()
   422  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   423  		}
   424  		column, value := nameVal[0], nameVal[1]
   425  		_, ok := c.db.columnType(stmt.table, column)
   426  		if !ok {
   427  			stmt.Close()
   428  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   429  		}
   430  		if !strings.HasPrefix(value, "?") {
   431  			stmt.Close()
   432  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   433  				stmt.table, column)
   434  		}
   435  		stmt.placeholders++
   436  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   437  	}
   438  	return stmt, nil
   439  }
   440  
   441  // parts are table|col=type,col2=type2
   442  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   443  	if len(parts) != 2 {
   444  		stmt.Close()
   445  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   446  	}
   447  	stmt.table = parts[0]
   448  	for n, colspec := range strings.Split(parts[1], ",") {
   449  		nameType := strings.Split(colspec, "=")
   450  		if len(nameType) != 2 {
   451  			stmt.Close()
   452  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   453  		}
   454  		stmt.colName = append(stmt.colName, nameType[0])
   455  		stmt.colType = append(stmt.colType, nameType[1])
   456  	}
   457  	return stmt, nil
   458  }
   459  
   460  // parts are table|col=?,col2=val
   461  func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   462  	if len(parts) != 2 {
   463  		stmt.Close()
   464  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   465  	}
   466  	stmt.table = parts[0]
   467  	for n, colspec := range strings.Split(parts[1], ",") {
   468  		nameVal := strings.Split(colspec, "=")
   469  		if len(nameVal) != 2 {
   470  			stmt.Close()
   471  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   472  		}
   473  		column, value := nameVal[0], nameVal[1]
   474  		ctype, ok := c.db.columnType(stmt.table, column)
   475  		if !ok {
   476  			stmt.Close()
   477  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   478  		}
   479  		stmt.colName = append(stmt.colName, column)
   480  
   481  		if !strings.HasPrefix(value, "?") {
   482  			var subsetVal interface{}
   483  			// Convert to driver subset type
   484  			switch ctype {
   485  			case "string":
   486  				subsetVal = []byte(value)
   487  			case "blob":
   488  				subsetVal = []byte(value)
   489  			case "int32":
   490  				i, err := strconv.Atoi(value)
   491  				if err != nil {
   492  					stmt.Close()
   493  					return nil, errf("invalid conversion to int32 from %q", value)
   494  				}
   495  				subsetVal = int64(i) // int64 is a subset type, but not int32
   496  			default:
   497  				stmt.Close()
   498  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   499  			}
   500  			stmt.colValue = append(stmt.colValue, subsetVal)
   501  		} else {
   502  			stmt.placeholders++
   503  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   504  			stmt.colValue = append(stmt.colValue, value)
   505  		}
   506  	}
   507  	return stmt, nil
   508  }
   509  
   510  // hook to simulate broken connections
   511  var hookPrepareBadConn func() bool
   512  
   513  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   514  	panic("use PrepareContext")
   515  }
   516  
   517  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   518  	c.numPrepare++
   519  	if c.db == nil {
   520  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   521  	}
   522  
   523  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   524  		return nil, driver.ErrBadConn
   525  	}
   526  
   527  	var firstStmt, prev *fakeStmt
   528  	for _, query := range strings.Split(query, ";") {
   529  		parts := strings.Split(query, "|")
   530  		if len(parts) < 1 {
   531  			return nil, errf("empty query")
   532  		}
   533  		stmt := &fakeStmt{q: query, c: c}
   534  		if firstStmt == nil {
   535  			firstStmt = stmt
   536  		}
   537  		if len(parts) >= 3 {
   538  			switch parts[0] {
   539  			case "PANIC":
   540  				stmt.panic = parts[1]
   541  				parts = parts[2:]
   542  			case "WAIT":
   543  				wait, err := time.ParseDuration(parts[1])
   544  				if err != nil {
   545  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
   546  				}
   547  				parts = parts[2:]
   548  				stmt.wait = wait
   549  			}
   550  		}
   551  		cmd := parts[0]
   552  		stmt.cmd = cmd
   553  		parts = parts[1:]
   554  
   555  		if stmt.wait > 0 {
   556  			wait := time.NewTimer(stmt.wait)
   557  			select {
   558  			case <-wait.C:
   559  			case <-ctx.Done():
   560  				wait.Stop()
   561  				return nil, ctx.Err()
   562  			}
   563  		}
   564  
   565  		c.incrStat(&c.stmtsMade)
   566  		var err error
   567  		switch cmd {
   568  		case "WIPE":
   569  			// Nothing
   570  		case "SELECT":
   571  			stmt, err = c.prepareSelect(stmt, parts)
   572  		case "CREATE":
   573  			stmt, err = c.prepareCreate(stmt, parts)
   574  		case "INSERT":
   575  			stmt, err = c.prepareInsert(stmt, parts)
   576  		case "NOSERT":
   577  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   578  			// Used for some of the concurrent tests.
   579  			stmt, err = c.prepareInsert(stmt, parts)
   580  		default:
   581  			stmt.Close()
   582  			return nil, errf("unsupported command type %q", cmd)
   583  		}
   584  		if err != nil {
   585  			return nil, err
   586  		}
   587  		if prev != nil {
   588  			prev.next = stmt
   589  		}
   590  		prev = stmt
   591  	}
   592  	return firstStmt, nil
   593  }
   594  
   595  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   596  	if s.panic == "ColumnConverter" {
   597  		panic(s.panic)
   598  	}
   599  	if len(s.placeholderConverter) == 0 {
   600  		return driver.DefaultParameterConverter
   601  	}
   602  	return s.placeholderConverter[idx]
   603  }
   604  
   605  func (s *fakeStmt) Close() error {
   606  	if s.panic == "Close" {
   607  		panic(s.panic)
   608  	}
   609  	if s.c == nil {
   610  		panic("nil conn in fakeStmt.Close")
   611  	}
   612  	if s.c.db == nil {
   613  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   614  	}
   615  	if !s.closed {
   616  		s.c.incrStat(&s.c.stmtsClosed)
   617  		s.closed = true
   618  	}
   619  	if s.next != nil {
   620  		s.next.Close()
   621  	}
   622  	return nil
   623  }
   624  
   625  var errClosed = errors.New("fakedb: statement has been closed")
   626  
   627  // hook to simulate broken connections
   628  var hookExecBadConn func() bool
   629  
   630  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   631  	panic("Using ExecContext")
   632  }
   633  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   634  	if s.panic == "Exec" {
   635  		panic(s.panic)
   636  	}
   637  	if s.closed {
   638  		return nil, errClosed
   639  	}
   640  
   641  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   642  		return nil, driver.ErrBadConn
   643  	}
   644  
   645  	err := checkSubsetTypes(args)
   646  	if err != nil {
   647  		return nil, err
   648  	}
   649  
   650  	if s.wait > 0 {
   651  		time.Sleep(s.wait)
   652  	}
   653  
   654  	select {
   655  	default:
   656  	case <-ctx.Done():
   657  		return nil, ctx.Err()
   658  	}
   659  
   660  	db := s.c.db
   661  	switch s.cmd {
   662  	case "WIPE":
   663  		db.wipe()
   664  		return driver.ResultNoRows, nil
   665  	case "CREATE":
   666  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   667  			return nil, err
   668  		}
   669  		return driver.ResultNoRows, nil
   670  	case "INSERT":
   671  		return s.execInsert(args, true)
   672  	case "NOSERT":
   673  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   674  		// Used for some of the concurrent tests.
   675  		return s.execInsert(args, false)
   676  	}
   677  	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
   678  	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
   679  }
   680  
   681  // When doInsert is true, add the row to the table.
   682  // When doInsert is false do prep-work and error checking, but don't
   683  // actually add the row to the table.
   684  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   685  	db := s.c.db
   686  	if len(args) != s.placeholders {
   687  		panic("error in pkg db; should only get here if size is correct")
   688  	}
   689  	db.mu.Lock()
   690  	t, ok := db.table(s.table)
   691  	db.mu.Unlock()
   692  	if !ok {
   693  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   694  	}
   695  
   696  	t.mu.Lock()
   697  	defer t.mu.Unlock()
   698  
   699  	var cols []interface{}
   700  	if doInsert {
   701  		cols = make([]interface{}, len(t.colname))
   702  	}
   703  	argPos := 0
   704  	for n, colname := range s.colName {
   705  		colidx := t.columnIndex(colname)
   706  		if colidx == -1 {
   707  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   708  		}
   709  		var val interface{}
   710  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   711  			if strvalue == "?" {
   712  				val = args[argPos].Value
   713  			} else {
   714  				// Assign value from argument placeholder name.
   715  				for _, a := range args {
   716  					if a.Name == strvalue[1:] {
   717  						val = a.Value
   718  						break
   719  					}
   720  				}
   721  			}
   722  			argPos++
   723  		} else {
   724  			val = s.colValue[n]
   725  		}
   726  		if doInsert {
   727  			cols[colidx] = val
   728  		}
   729  	}
   730  
   731  	if doInsert {
   732  		t.rows = append(t.rows, &row{cols: cols})
   733  	}
   734  	return driver.RowsAffected(1), nil
   735  }
   736  
   737  // hook to simulate broken connections
   738  var hookQueryBadConn func() bool
   739  
   740  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   741  	panic("Use QueryContext")
   742  }
   743  
   744  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   745  	if s.panic == "Query" {
   746  		panic(s.panic)
   747  	}
   748  	if s.closed {
   749  		return nil, errClosed
   750  	}
   751  
   752  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   753  		return nil, driver.ErrBadConn
   754  	}
   755  
   756  	err := checkSubsetTypes(args)
   757  	if err != nil {
   758  		return nil, err
   759  	}
   760  
   761  	db := s.c.db
   762  	if len(args) != s.placeholders {
   763  		panic("error in pkg db; should only get here if size is correct")
   764  	}
   765  
   766  	setMRows := make([][]*row, 0, 1)
   767  	setColumns := make([][]string, 0, 1)
   768  	setColType := make([][]string, 0, 1)
   769  
   770  	for {
   771  		db.mu.Lock()
   772  		t, ok := db.table(s.table)
   773  		db.mu.Unlock()
   774  		if !ok {
   775  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   776  		}
   777  
   778  		if s.table == "magicquery" {
   779  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   780  				if args[0].Value == "sleep" {
   781  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   782  				}
   783  			}
   784  		}
   785  
   786  		t.mu.Lock()
   787  
   788  		colIdx := make(map[string]int) // select column name -> column index in table
   789  		for _, name := range s.colName {
   790  			idx := t.columnIndex(name)
   791  			if idx == -1 {
   792  				t.mu.Unlock()
   793  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   794  			}
   795  			colIdx[name] = idx
   796  		}
   797  
   798  		mrows := []*row{}
   799  	rows:
   800  		for _, trow := range t.rows {
   801  			// Process the where clause, skipping non-match rows. This is lazy
   802  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   803  			// for test code.
   804  			for _, wcol := range s.whereCol {
   805  				idx := t.columnIndex(wcol.Column)
   806  				if idx == -1 {
   807  					t.mu.Unlock()
   808  					return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
   809  				}
   810  				tcol := trow.cols[idx]
   811  				if bs, ok := tcol.([]byte); ok {
   812  					// lazy hack to avoid sprintf %v on a []byte
   813  					tcol = string(bs)
   814  				}
   815  				var argValue interface{}
   816  				if wcol.Placeholder == "?" {
   817  					argValue = args[wcol.Ordinal-1].Value
   818  				} else {
   819  					// Assign arg value from placeholder name.
   820  					for _, a := range args {
   821  						if a.Name == wcol.Placeholder[1:] {
   822  							argValue = a.Value
   823  							break
   824  						}
   825  					}
   826  				}
   827  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
   828  					continue rows
   829  				}
   830  			}
   831  			mrow := &row{cols: make([]interface{}, len(s.colName))}
   832  			for seli, name := range s.colName {
   833  				mrow.cols[seli] = trow.cols[colIdx[name]]
   834  			}
   835  			mrows = append(mrows, mrow)
   836  		}
   837  
   838  		var colType []string
   839  		for _, column := range s.colName {
   840  			colType = append(colType, t.coltype[t.columnIndex(column)])
   841  		}
   842  
   843  		t.mu.Unlock()
   844  
   845  		setMRows = append(setMRows, mrows)
   846  		setColumns = append(setColumns, s.colName)
   847  		setColType = append(setColType, colType)
   848  
   849  		if s.next == nil {
   850  			break
   851  		}
   852  		s = s.next
   853  	}
   854  
   855  	cursor := &rowsCursor{
   856  		posRow:  -1,
   857  		rows:    setMRows,
   858  		cols:    setColumns,
   859  		colType: setColType,
   860  		errPos:  -1,
   861  	}
   862  	return cursor, nil
   863  }
   864  
   865  func (s *fakeStmt) NumInput() int {
   866  	if s.panic == "NumInput" {
   867  		panic(s.panic)
   868  	}
   869  	return s.placeholders
   870  }
   871  
   872  // hook to simulate broken connections
   873  var hookCommitBadConn func() bool
   874  
   875  func (tx *fakeTx) Commit() error {
   876  	tx.c.currTx = nil
   877  	if hookCommitBadConn != nil && hookCommitBadConn() {
   878  		return driver.ErrBadConn
   879  	}
   880  	return nil
   881  }
   882  
   883  // hook to simulate broken connections
   884  var hookRollbackBadConn func() bool
   885  
   886  func (tx *fakeTx) Rollback() error {
   887  	tx.c.currTx = nil
   888  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
   889  		return driver.ErrBadConn
   890  	}
   891  	return nil
   892  }
   893  
   894  type rowsCursor struct {
   895  	cols    [][]string
   896  	colType [][]string
   897  	posSet  int
   898  	posRow  int
   899  	rows    [][]*row
   900  	closed  bool
   901  
   902  	// errPos and err are for making Next return early with error.
   903  	errPos int
   904  	err    error
   905  
   906  	// a clone of slices to give out to clients, indexed by the
   907  	// the original slice's first byte address.  we clone them
   908  	// just so we're able to corrupt them on close.
   909  	bytesClone map[*byte][]byte
   910  }
   911  
   912  func (rc *rowsCursor) Close() error {
   913  	if !rc.closed {
   914  		for _, bs := range rc.bytesClone {
   915  			bs[0] = 255 // first byte corrupted
   916  		}
   917  	}
   918  	rc.closed = true
   919  	return nil
   920  }
   921  
   922  func (rc *rowsCursor) Columns() []string {
   923  	return rc.cols[rc.posSet]
   924  }
   925  
   926  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
   927  	return colTypeToReflectType(rc.colType[rc.posSet][index])
   928  }
   929  
   930  var rowsCursorNextHook func(dest []driver.Value) error
   931  
   932  func (rc *rowsCursor) Next(dest []driver.Value) error {
   933  	if rowsCursorNextHook != nil {
   934  		return rowsCursorNextHook(dest)
   935  	}
   936  
   937  	if rc.closed {
   938  		return errors.New("fakedb: cursor is closed")
   939  	}
   940  	rc.posRow++
   941  	if rc.posRow == rc.errPos {
   942  		return rc.err
   943  	}
   944  	if rc.posRow >= len(rc.rows[rc.posSet]) {
   945  		return io.EOF // per interface spec
   946  	}
   947  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
   948  		// TODO(bradfitz): convert to subset types? naah, I
   949  		// think the subset types should only be input to
   950  		// driver, but the sql package should be able to handle
   951  		// a wider range of types coming out of drivers. all
   952  		// for ease of drivers, and to prevent drivers from
   953  		// messing up conversions or doing them differently.
   954  		dest[i] = v
   955  
   956  		if bs, ok := v.([]byte); ok {
   957  			if rc.bytesClone == nil {
   958  				rc.bytesClone = make(map[*byte][]byte)
   959  			}
   960  			clone, ok := rc.bytesClone[&bs[0]]
   961  			if !ok {
   962  				clone = make([]byte, len(bs))
   963  				copy(clone, bs)
   964  				rc.bytesClone[&bs[0]] = clone
   965  			}
   966  			dest[i] = clone
   967  		}
   968  	}
   969  	return nil
   970  }
   971  
   972  func (rc *rowsCursor) HasNextResultSet() bool {
   973  	return rc.posSet < len(rc.rows)-1
   974  }
   975  
   976  func (rc *rowsCursor) NextResultSet() error {
   977  	if rc.HasNextResultSet() {
   978  		rc.posSet++
   979  		rc.posRow = -1
   980  		return nil
   981  	}
   982  	return io.EOF // Per interface spec.
   983  }
   984  
   985  // fakeDriverString is like driver.String, but indirects pointers like
   986  // DefaultValueConverter.
   987  //
   988  // This could be surprising behavior to retroactively apply to
   989  // driver.String now that Go1 is out, but this is convenient for
   990  // our TestPointerParamsAndScans.
   991  //
   992  type fakeDriverString struct{}
   993  
   994  func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
   995  	switch c := v.(type) {
   996  	case string, []byte:
   997  		return v, nil
   998  	case *string:
   999  		if c == nil {
  1000  			return nil, nil
  1001  		}
  1002  		return *c, nil
  1003  	}
  1004  	return fmt.Sprintf("%v", v), nil
  1005  }
  1006  
  1007  func converterForType(typ string) driver.ValueConverter {
  1008  	switch typ {
  1009  	case "bool":
  1010  		return driver.Bool
  1011  	case "nullbool":
  1012  		return driver.Null{Converter: driver.Bool}
  1013  	case "int32":
  1014  		return driver.Int32
  1015  	case "string":
  1016  		return driver.NotNull{Converter: fakeDriverString{}}
  1017  	case "nullstring":
  1018  		return driver.Null{Converter: fakeDriverString{}}
  1019  	case "int64":
  1020  		// TODO(coopernurse): add type-specific converter
  1021  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1022  	case "nullint64":
  1023  		// TODO(coopernurse): add type-specific converter
  1024  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1025  	case "float64":
  1026  		// TODO(coopernurse): add type-specific converter
  1027  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1028  	case "nullfloat64":
  1029  		// TODO(coopernurse): add type-specific converter
  1030  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1031  	case "datetime":
  1032  		return driver.DefaultParameterConverter
  1033  	}
  1034  	panic("invalid fakedb column type of " + typ)
  1035  }
  1036  
  1037  func colTypeToReflectType(typ string) reflect.Type {
  1038  	switch typ {
  1039  	case "bool":
  1040  		return reflect.TypeOf(false)
  1041  	case "nullbool":
  1042  		return reflect.TypeOf(NullBool{})
  1043  	case "int32":
  1044  		return reflect.TypeOf(int32(0))
  1045  	case "string":
  1046  		return reflect.TypeOf("")
  1047  	case "nullstring":
  1048  		return reflect.TypeOf(NullString{})
  1049  	case "int64":
  1050  		return reflect.TypeOf(int64(0))
  1051  	case "nullint64":
  1052  		return reflect.TypeOf(NullInt64{})
  1053  	case "float64":
  1054  		return reflect.TypeOf(float64(0))
  1055  	case "nullfloat64":
  1056  		return reflect.TypeOf(NullFloat64{})
  1057  	case "datetime":
  1058  		return reflect.TypeOf(time.Time{})
  1059  	}
  1060  	panic("invalid fakedb column type of " + typ)
  1061  }