github.com/dannin/go@v0.0.0-20161031215817-d35dfd405eaa/src/database/sql/fakedb_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"reflect"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  var _ = log.Printf
    24  
    25  // fakeDriver is a fake database that implements Go's driver.Driver
    26  // interface, just for testing.
    27  //
    28  // It speaks a query language that's semantically similar to but
    29  // syntactically different and simpler than SQL.  The syntax is as
    30  // follows:
    31  //
    32  //   WIPE
    33  //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    34  //     where types are: "string", [u]int{8,16,32,64}, "bool"
    35  //   INSERT|<tablename>|col=val,col2=val2,col3=?
    36  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    37  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    38  //
    39  // Any of these can be preceded by PANIC|<method>|, to cause the
    40  // named method on fakeStmt to panic.
    41  //
    42  // Multiple of these can be combined when separated with a semicolon.
    43  //
    44  // When opening a fakeDriver's database, it starts empty with no
    45  // tables. All tables and data are stored in memory only.
    46  type fakeDriver struct {
    47  	mu         sync.Mutex // guards 3 following fields
    48  	openCount  int        // conn opens
    49  	closeCount int        // conn closes
    50  	waitCh     chan struct{}
    51  	waitingCh  chan struct{}
    52  	dbs        map[string]*fakeDB
    53  }
    54  
    55  type fakeDB struct {
    56  	name string
    57  
    58  	mu      sync.Mutex
    59  	tables  map[string]*table
    60  	badConn bool
    61  }
    62  
    63  type table struct {
    64  	mu      sync.Mutex
    65  	colname []string
    66  	coltype []string
    67  	rows    []*row
    68  }
    69  
    70  func (t *table) columnIndex(name string) int {
    71  	for n, nname := range t.colname {
    72  		if name == nname {
    73  			return n
    74  		}
    75  	}
    76  	return -1
    77  }
    78  
    79  type row struct {
    80  	cols []interface{} // must be same size as its table colname + coltype
    81  }
    82  
    83  type fakeConn struct {
    84  	db *fakeDB // where to return ourselves to
    85  
    86  	currTx *fakeTx
    87  
    88  	// Stats for tests:
    89  	mu          sync.Mutex
    90  	stmtsMade   int
    91  	stmtsClosed int
    92  	numPrepare  int
    93  
    94  	// bad connection tests; see isBad()
    95  	bad       bool
    96  	stickyBad bool
    97  }
    98  
    99  func (c *fakeConn) incrStat(v *int) {
   100  	c.mu.Lock()
   101  	*v++
   102  	c.mu.Unlock()
   103  }
   104  
   105  type fakeTx struct {
   106  	c *fakeConn
   107  }
   108  
   109  type boundCol struct {
   110  	Column      string
   111  	Placeholder string
   112  	Ordinal     int
   113  }
   114  
   115  type fakeStmt struct {
   116  	c *fakeConn
   117  	q string // just for debugging
   118  
   119  	cmd   string
   120  	table string
   121  	panic string
   122  
   123  	next *fakeStmt // used for returning multiple results.
   124  
   125  	closed bool
   126  
   127  	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
   128  	colType      []string      // used by CREATE
   129  	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
   130  	placeholders int           // used by INSERT/SELECT: number of ? params
   131  
   132  	whereCol []boundCol // used by SELECT (all placeholders)
   133  
   134  	placeholderConverter []driver.ValueConverter // used by INSERT
   135  }
   136  
   137  var fdriver driver.Driver = &fakeDriver{}
   138  
   139  func init() {
   140  	Register("test", fdriver)
   141  }
   142  
   143  func contains(list []string, y string) bool {
   144  	for _, x := range list {
   145  		if x == y {
   146  			return true
   147  		}
   148  	}
   149  	return false
   150  }
   151  
   152  type Dummy struct {
   153  	driver.Driver
   154  }
   155  
   156  func TestDrivers(t *testing.T) {
   157  	unregisterAllDrivers()
   158  	Register("test", fdriver)
   159  	Register("invalid", Dummy{})
   160  	all := Drivers()
   161  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   162  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   163  	}
   164  }
   165  
   166  // hook to simulate connection failures
   167  var hookOpenErr struct {
   168  	sync.Mutex
   169  	fn func() error
   170  }
   171  
   172  func setHookOpenErr(fn func() error) {
   173  	hookOpenErr.Lock()
   174  	defer hookOpenErr.Unlock()
   175  	hookOpenErr.fn = fn
   176  }
   177  
   178  // Supports dsn forms:
   179  //    <dbname>
   180  //    <dbname>;<opts>  (only currently supported option is `badConn`,
   181  //                      which causes driver.ErrBadConn to be returned on
   182  //                      every other conn.Begin())
   183  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   184  	hookOpenErr.Lock()
   185  	fn := hookOpenErr.fn
   186  	hookOpenErr.Unlock()
   187  	if fn != nil {
   188  		if err := fn(); err != nil {
   189  			return nil, err
   190  		}
   191  	}
   192  	parts := strings.Split(dsn, ";")
   193  	if len(parts) < 1 {
   194  		return nil, errors.New("fakedb: no database name")
   195  	}
   196  	name := parts[0]
   197  
   198  	db := d.getDB(name)
   199  
   200  	d.mu.Lock()
   201  	d.openCount++
   202  	d.mu.Unlock()
   203  	conn := &fakeConn{db: db}
   204  
   205  	if len(parts) >= 2 && parts[1] == "badConn" {
   206  		conn.bad = true
   207  	}
   208  	if d.waitCh != nil {
   209  		d.waitingCh <- struct{}{}
   210  		<-d.waitCh
   211  		d.waitCh = nil
   212  		d.waitingCh = nil
   213  	}
   214  	return conn, nil
   215  }
   216  
   217  func (d *fakeDriver) getDB(name string) *fakeDB {
   218  	d.mu.Lock()
   219  	defer d.mu.Unlock()
   220  	if d.dbs == nil {
   221  		d.dbs = make(map[string]*fakeDB)
   222  	}
   223  	db, ok := d.dbs[name]
   224  	if !ok {
   225  		db = &fakeDB{name: name}
   226  		d.dbs[name] = db
   227  	}
   228  	return db
   229  }
   230  
   231  func (db *fakeDB) wipe() {
   232  	db.mu.Lock()
   233  	defer db.mu.Unlock()
   234  	db.tables = nil
   235  }
   236  
   237  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   238  	db.mu.Lock()
   239  	defer db.mu.Unlock()
   240  	if db.tables == nil {
   241  		db.tables = make(map[string]*table)
   242  	}
   243  	if _, exist := db.tables[name]; exist {
   244  		return fmt.Errorf("table %q already exists", name)
   245  	}
   246  	if len(columnNames) != len(columnTypes) {
   247  		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
   248  			name, len(columnNames), len(columnTypes))
   249  	}
   250  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   251  	return nil
   252  }
   253  
   254  // must be called with db.mu lock held
   255  func (db *fakeDB) table(table string) (*table, bool) {
   256  	if db.tables == nil {
   257  		return nil, false
   258  	}
   259  	t, ok := db.tables[table]
   260  	return t, ok
   261  }
   262  
   263  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   264  	db.mu.Lock()
   265  	defer db.mu.Unlock()
   266  	t, ok := db.table(table)
   267  	if !ok {
   268  		return
   269  	}
   270  	for n, cname := range t.colname {
   271  		if cname == column {
   272  			return t.coltype[n], true
   273  		}
   274  	}
   275  	return "", false
   276  }
   277  
   278  func (c *fakeConn) isBad() bool {
   279  	if c.stickyBad {
   280  		return true
   281  	} else if c.bad {
   282  		// alternate between bad conn and not bad conn
   283  		c.db.badConn = !c.db.badConn
   284  		return c.db.badConn
   285  	} else {
   286  		return false
   287  	}
   288  }
   289  
   290  func (c *fakeConn) Begin() (driver.Tx, error) {
   291  	if c.isBad() {
   292  		return nil, driver.ErrBadConn
   293  	}
   294  	if c.currTx != nil {
   295  		return nil, errors.New("already in a transaction")
   296  	}
   297  	c.currTx = &fakeTx{c: c}
   298  	return c.currTx, nil
   299  }
   300  
   301  var hookPostCloseConn struct {
   302  	sync.Mutex
   303  	fn func(*fakeConn, error)
   304  }
   305  
   306  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   307  	hookPostCloseConn.Lock()
   308  	defer hookPostCloseConn.Unlock()
   309  	hookPostCloseConn.fn = fn
   310  }
   311  
   312  var testStrictClose *testing.T
   313  
   314  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   315  // fails to close. If nil, the check is disabled.
   316  func setStrictFakeConnClose(t *testing.T) {
   317  	testStrictClose = t
   318  }
   319  
   320  func (c *fakeConn) Close() (err error) {
   321  	drv := fdriver.(*fakeDriver)
   322  	defer func() {
   323  		if err != nil && testStrictClose != nil {
   324  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   325  		}
   326  		hookPostCloseConn.Lock()
   327  		fn := hookPostCloseConn.fn
   328  		hookPostCloseConn.Unlock()
   329  		if fn != nil {
   330  			fn(c, err)
   331  		}
   332  		if err == nil {
   333  			drv.mu.Lock()
   334  			drv.closeCount++
   335  			drv.mu.Unlock()
   336  		}
   337  	}()
   338  	if c.currTx != nil {
   339  		return errors.New("can't close fakeConn; in a Transaction")
   340  	}
   341  	if c.db == nil {
   342  		return errors.New("can't close fakeConn; already closed")
   343  	}
   344  	if c.stmtsMade > c.stmtsClosed {
   345  		return errors.New("can't close; dangling statement(s)")
   346  	}
   347  	c.db = nil
   348  	return nil
   349  }
   350  
   351  func checkSubsetTypes(args []driver.NamedValue) error {
   352  	for _, arg := range args {
   353  		switch arg.Value.(type) {
   354  		case int64, float64, bool, nil, []byte, string, time.Time:
   355  		default:
   356  			return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   357  		}
   358  	}
   359  	return nil
   360  }
   361  
   362  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   363  	// Ensure that ExecContext is called if available.
   364  	panic("ExecContext was not called.")
   365  }
   366  
   367  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   368  	// This is an optional interface, but it's implemented here
   369  	// just to check that all the args are of the proper types.
   370  	// ErrSkip is returned so the caller acts as if we didn't
   371  	// implement this at all.
   372  	err := checkSubsetTypes(args)
   373  	if err != nil {
   374  		return nil, err
   375  	}
   376  	return nil, driver.ErrSkip
   377  }
   378  
   379  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   380  	// Ensure that ExecContext is called if available.
   381  	panic("QueryContext was not called.")
   382  }
   383  
   384  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   385  	// This is an optional interface, but it's implemented here
   386  	// just to check that all the args are of the proper types.
   387  	// ErrSkip is returned so the caller acts as if we didn't
   388  	// implement this at all.
   389  	err := checkSubsetTypes(args)
   390  	if err != nil {
   391  		return nil, err
   392  	}
   393  	return nil, driver.ErrSkip
   394  }
   395  
   396  func errf(msg string, args ...interface{}) error {
   397  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   398  }
   399  
   400  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   401  // (note that where columns must always contain ? marks,
   402  //  just a limitation for fakedb)
   403  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   404  	if len(parts) != 3 {
   405  		stmt.Close()
   406  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   407  	}
   408  	stmt.table = parts[0]
   409  
   410  	stmt.colName = strings.Split(parts[1], ",")
   411  	for n, colspec := range strings.Split(parts[2], ",") {
   412  		if colspec == "" {
   413  			continue
   414  		}
   415  		nameVal := strings.Split(colspec, "=")
   416  		if len(nameVal) != 2 {
   417  			stmt.Close()
   418  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   419  		}
   420  		column, value := nameVal[0], nameVal[1]
   421  		_, ok := c.db.columnType(stmt.table, column)
   422  		if !ok {
   423  			stmt.Close()
   424  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   425  		}
   426  		if !strings.HasPrefix(value, "?") {
   427  			stmt.Close()
   428  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   429  				stmt.table, column)
   430  		}
   431  		stmt.placeholders++
   432  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   433  	}
   434  	return stmt, nil
   435  }
   436  
   437  // parts are table|col=type,col2=type2
   438  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   439  	if len(parts) != 2 {
   440  		stmt.Close()
   441  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   442  	}
   443  	stmt.table = parts[0]
   444  	for n, colspec := range strings.Split(parts[1], ",") {
   445  		nameType := strings.Split(colspec, "=")
   446  		if len(nameType) != 2 {
   447  			stmt.Close()
   448  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   449  		}
   450  		stmt.colName = append(stmt.colName, nameType[0])
   451  		stmt.colType = append(stmt.colType, nameType[1])
   452  	}
   453  	return stmt, nil
   454  }
   455  
   456  // parts are table|col=?,col2=val
   457  func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   458  	if len(parts) != 2 {
   459  		stmt.Close()
   460  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   461  	}
   462  	stmt.table = parts[0]
   463  	for n, colspec := range strings.Split(parts[1], ",") {
   464  		nameVal := strings.Split(colspec, "=")
   465  		if len(nameVal) != 2 {
   466  			stmt.Close()
   467  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   468  		}
   469  		column, value := nameVal[0], nameVal[1]
   470  		ctype, ok := c.db.columnType(stmt.table, column)
   471  		if !ok {
   472  			stmt.Close()
   473  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   474  		}
   475  		stmt.colName = append(stmt.colName, column)
   476  
   477  		if !strings.HasPrefix(value, "?") {
   478  			var subsetVal interface{}
   479  			// Convert to driver subset type
   480  			switch ctype {
   481  			case "string":
   482  				subsetVal = []byte(value)
   483  			case "blob":
   484  				subsetVal = []byte(value)
   485  			case "int32":
   486  				i, err := strconv.Atoi(value)
   487  				if err != nil {
   488  					stmt.Close()
   489  					return nil, errf("invalid conversion to int32 from %q", value)
   490  				}
   491  				subsetVal = int64(i) // int64 is a subset type, but not int32
   492  			default:
   493  				stmt.Close()
   494  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   495  			}
   496  			stmt.colValue = append(stmt.colValue, subsetVal)
   497  		} else {
   498  			stmt.placeholders++
   499  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   500  			stmt.colValue = append(stmt.colValue, value)
   501  		}
   502  	}
   503  	return stmt, nil
   504  }
   505  
   506  // hook to simulate broken connections
   507  var hookPrepareBadConn func() bool
   508  
   509  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   510  	c.numPrepare++
   511  	if c.db == nil {
   512  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   513  	}
   514  
   515  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   516  		return nil, driver.ErrBadConn
   517  	}
   518  
   519  	var firstStmt, prev *fakeStmt
   520  	for _, query := range strings.Split(query, ";") {
   521  		parts := strings.Split(query, "|")
   522  		if len(parts) < 1 {
   523  			return nil, errf("empty query")
   524  		}
   525  		stmt := &fakeStmt{q: query, c: c}
   526  		if firstStmt == nil {
   527  			firstStmt = stmt
   528  		}
   529  		if len(parts) >= 3 && parts[0] == "PANIC" {
   530  			stmt.panic = parts[1]
   531  			parts = parts[2:]
   532  		}
   533  		cmd := parts[0]
   534  		stmt.cmd = cmd
   535  		parts = parts[1:]
   536  
   537  		c.incrStat(&c.stmtsMade)
   538  		var err error
   539  		switch cmd {
   540  		case "WIPE":
   541  			// Nothing
   542  		case "SELECT":
   543  			stmt, err = c.prepareSelect(stmt, parts)
   544  		case "CREATE":
   545  			stmt, err = c.prepareCreate(stmt, parts)
   546  		case "INSERT":
   547  			stmt, err = c.prepareInsert(stmt, parts)
   548  		case "NOSERT":
   549  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   550  			// Used for some of the concurrent tests.
   551  			stmt, err = c.prepareInsert(stmt, parts)
   552  		default:
   553  			stmt.Close()
   554  			return nil, errf("unsupported command type %q", cmd)
   555  		}
   556  		if err != nil {
   557  			return nil, err
   558  		}
   559  		if prev != nil {
   560  			prev.next = stmt
   561  		}
   562  		prev = stmt
   563  	}
   564  	return firstStmt, nil
   565  }
   566  
   567  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   568  	if s.panic == "ColumnConverter" {
   569  		panic(s.panic)
   570  	}
   571  	if len(s.placeholderConverter) == 0 {
   572  		return driver.DefaultParameterConverter
   573  	}
   574  	return s.placeholderConverter[idx]
   575  }
   576  
   577  func (s *fakeStmt) Close() error {
   578  	if s.panic == "Close" {
   579  		panic(s.panic)
   580  	}
   581  	if s.c == nil {
   582  		panic("nil conn in fakeStmt.Close")
   583  	}
   584  	if s.c.db == nil {
   585  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   586  	}
   587  	if !s.closed {
   588  		s.c.incrStat(&s.c.stmtsClosed)
   589  		s.closed = true
   590  	}
   591  	if s.next != nil {
   592  		s.next.Close()
   593  	}
   594  	return nil
   595  }
   596  
   597  var errClosed = errors.New("fakedb: statement has been closed")
   598  
   599  // hook to simulate broken connections
   600  var hookExecBadConn func() bool
   601  
   602  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   603  	panic("Using ExecContext")
   604  }
   605  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   606  	if s.panic == "Exec" {
   607  		panic(s.panic)
   608  	}
   609  	if s.closed {
   610  		return nil, errClosed
   611  	}
   612  
   613  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   614  		return nil, driver.ErrBadConn
   615  	}
   616  
   617  	err := checkSubsetTypes(args)
   618  	if err != nil {
   619  		return nil, err
   620  	}
   621  
   622  	db := s.c.db
   623  	switch s.cmd {
   624  	case "WIPE":
   625  		db.wipe()
   626  		return driver.ResultNoRows, nil
   627  	case "CREATE":
   628  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   629  			return nil, err
   630  		}
   631  		return driver.ResultNoRows, nil
   632  	case "INSERT":
   633  		return s.execInsert(args, true)
   634  	case "NOSERT":
   635  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   636  		// Used for some of the concurrent tests.
   637  		return s.execInsert(args, false)
   638  	}
   639  	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
   640  	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
   641  }
   642  
   643  // When doInsert is true, add the row to the table.
   644  // When doInsert is false do prep-work and error checking, but don't
   645  // actually add the row to the table.
   646  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   647  	db := s.c.db
   648  	if len(args) != s.placeholders {
   649  		panic("error in pkg db; should only get here if size is correct")
   650  	}
   651  	db.mu.Lock()
   652  	t, ok := db.table(s.table)
   653  	db.mu.Unlock()
   654  	if !ok {
   655  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   656  	}
   657  
   658  	t.mu.Lock()
   659  	defer t.mu.Unlock()
   660  
   661  	var cols []interface{}
   662  	if doInsert {
   663  		cols = make([]interface{}, len(t.colname))
   664  	}
   665  	argPos := 0
   666  	for n, colname := range s.colName {
   667  		colidx := t.columnIndex(colname)
   668  		if colidx == -1 {
   669  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   670  		}
   671  		var val interface{}
   672  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   673  			if strvalue == "?" {
   674  				val = args[argPos].Value
   675  			} else {
   676  				// Assign value from argument placeholder name.
   677  				for _, a := range args {
   678  					if a.Name == strvalue {
   679  						val = a.Value
   680  						break
   681  					}
   682  				}
   683  			}
   684  			argPos++
   685  		} else {
   686  			val = s.colValue[n]
   687  		}
   688  		if doInsert {
   689  			cols[colidx] = val
   690  		}
   691  	}
   692  
   693  	if doInsert {
   694  		t.rows = append(t.rows, &row{cols: cols})
   695  	}
   696  	return driver.RowsAffected(1), nil
   697  }
   698  
   699  // hook to simulate broken connections
   700  var hookQueryBadConn func() bool
   701  
   702  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   703  	panic("Use QueryContext")
   704  }
   705  
   706  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   707  	if s.panic == "Query" {
   708  		panic(s.panic)
   709  	}
   710  	if s.closed {
   711  		return nil, errClosed
   712  	}
   713  
   714  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   715  		return nil, driver.ErrBadConn
   716  	}
   717  
   718  	err := checkSubsetTypes(args)
   719  	if err != nil {
   720  		return nil, err
   721  	}
   722  
   723  	db := s.c.db
   724  	if len(args) != s.placeholders {
   725  		panic("error in pkg db; should only get here if size is correct")
   726  	}
   727  
   728  	setMRows := make([][]*row, 0, 1)
   729  	setColumns := make([][]string, 0, 1)
   730  	setColType := make([][]string, 0, 1)
   731  
   732  	for {
   733  		db.mu.Lock()
   734  		t, ok := db.table(s.table)
   735  		db.mu.Unlock()
   736  		if !ok {
   737  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   738  		}
   739  
   740  		if s.table == "magicquery" {
   741  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   742  				if args[0].Value == "sleep" {
   743  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   744  				}
   745  			}
   746  		}
   747  
   748  		t.mu.Lock()
   749  
   750  		colIdx := make(map[string]int) // select column name -> column index in table
   751  		for _, name := range s.colName {
   752  			idx := t.columnIndex(name)
   753  			if idx == -1 {
   754  				t.mu.Unlock()
   755  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   756  			}
   757  			colIdx[name] = idx
   758  		}
   759  
   760  		mrows := []*row{}
   761  	rows:
   762  		for _, trow := range t.rows {
   763  			// Process the where clause, skipping non-match rows. This is lazy
   764  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   765  			// for test code.
   766  			for _, wcol := range s.whereCol {
   767  				idx := t.columnIndex(wcol.Column)
   768  				if idx == -1 {
   769  					t.mu.Unlock()
   770  					return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
   771  				}
   772  				tcol := trow.cols[idx]
   773  				if bs, ok := tcol.([]byte); ok {
   774  					// lazy hack to avoid sprintf %v on a []byte
   775  					tcol = string(bs)
   776  				}
   777  				var argValue interface{}
   778  				if wcol.Placeholder == "?" {
   779  					argValue = args[wcol.Ordinal-1].Value
   780  				} else {
   781  					// Assign arg value from placeholder name.
   782  					for _, a := range args {
   783  						if a.Name == wcol.Placeholder {
   784  							argValue = a.Value
   785  							break
   786  						}
   787  					}
   788  				}
   789  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
   790  					continue rows
   791  				}
   792  			}
   793  			mrow := &row{cols: make([]interface{}, len(s.colName))}
   794  			for seli, name := range s.colName {
   795  				mrow.cols[seli] = trow.cols[colIdx[name]]
   796  			}
   797  			mrows = append(mrows, mrow)
   798  		}
   799  
   800  		var colType []string
   801  		for _, column := range s.colName {
   802  			colType = append(colType, t.coltype[t.columnIndex(column)])
   803  		}
   804  
   805  		t.mu.Unlock()
   806  
   807  		setMRows = append(setMRows, mrows)
   808  		setColumns = append(setColumns, s.colName)
   809  		setColType = append(setColType, colType)
   810  
   811  		if s.next == nil {
   812  			break
   813  		}
   814  		s = s.next
   815  	}
   816  
   817  	cursor := &rowsCursor{
   818  		posRow:  -1,
   819  		rows:    setMRows,
   820  		cols:    setColumns,
   821  		colType: setColType,
   822  		errPos:  -1,
   823  	}
   824  	return cursor, nil
   825  }
   826  
   827  func (s *fakeStmt) NumInput() int {
   828  	if s.panic == "NumInput" {
   829  		panic(s.panic)
   830  	}
   831  	return s.placeholders
   832  }
   833  
   834  // hook to simulate broken connections
   835  var hookCommitBadConn func() bool
   836  
   837  func (tx *fakeTx) Commit() error {
   838  	tx.c.currTx = nil
   839  	if hookCommitBadConn != nil && hookCommitBadConn() {
   840  		return driver.ErrBadConn
   841  	}
   842  	return nil
   843  }
   844  
   845  // hook to simulate broken connections
   846  var hookRollbackBadConn func() bool
   847  
   848  func (tx *fakeTx) Rollback() error {
   849  	tx.c.currTx = nil
   850  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
   851  		return driver.ErrBadConn
   852  	}
   853  	return nil
   854  }
   855  
   856  type rowsCursor struct {
   857  	cols    [][]string
   858  	colType [][]string
   859  	posSet  int
   860  	posRow  int
   861  	rows    [][]*row
   862  	closed  bool
   863  
   864  	// errPos and err are for making Next return early with error.
   865  	errPos int
   866  	err    error
   867  
   868  	// a clone of slices to give out to clients, indexed by the
   869  	// the original slice's first byte address.  we clone them
   870  	// just so we're able to corrupt them on close.
   871  	bytesClone map[*byte][]byte
   872  }
   873  
   874  func (rc *rowsCursor) Close() error {
   875  	if !rc.closed {
   876  		for _, bs := range rc.bytesClone {
   877  			bs[0] = 255 // first byte corrupted
   878  		}
   879  	}
   880  	rc.closed = true
   881  	return nil
   882  }
   883  
   884  func (rc *rowsCursor) Columns() []string {
   885  	return rc.cols[rc.posSet]
   886  }
   887  
   888  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
   889  	return colTypeToReflectType(rc.colType[rc.posSet][index])
   890  }
   891  
   892  var rowsCursorNextHook func(dest []driver.Value) error
   893  
   894  func (rc *rowsCursor) Next(dest []driver.Value) error {
   895  	if rowsCursorNextHook != nil {
   896  		return rowsCursorNextHook(dest)
   897  	}
   898  
   899  	if rc.closed {
   900  		return errors.New("fakedb: cursor is closed")
   901  	}
   902  	rc.posRow++
   903  	if rc.posRow == rc.errPos {
   904  		return rc.err
   905  	}
   906  	if rc.posRow >= len(rc.rows[rc.posSet]) {
   907  		return io.EOF // per interface spec
   908  	}
   909  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
   910  		// TODO(bradfitz): convert to subset types? naah, I
   911  		// think the subset types should only be input to
   912  		// driver, but the sql package should be able to handle
   913  		// a wider range of types coming out of drivers. all
   914  		// for ease of drivers, and to prevent drivers from
   915  		// messing up conversions or doing them differently.
   916  		dest[i] = v
   917  
   918  		if bs, ok := v.([]byte); ok {
   919  			if rc.bytesClone == nil {
   920  				rc.bytesClone = make(map[*byte][]byte)
   921  			}
   922  			clone, ok := rc.bytesClone[&bs[0]]
   923  			if !ok {
   924  				clone = make([]byte, len(bs))
   925  				copy(clone, bs)
   926  				rc.bytesClone[&bs[0]] = clone
   927  			}
   928  			dest[i] = clone
   929  		}
   930  	}
   931  	return nil
   932  }
   933  
   934  func (rc *rowsCursor) HasNextResultSet() bool {
   935  	return rc.posSet < len(rc.rows)-1
   936  }
   937  
   938  func (rc *rowsCursor) NextResultSet() error {
   939  	if rc.HasNextResultSet() {
   940  		rc.posSet++
   941  		rc.posRow = -1
   942  		return nil
   943  	}
   944  	return io.EOF // Per interface spec.
   945  }
   946  
   947  // fakeDriverString is like driver.String, but indirects pointers like
   948  // DefaultValueConverter.
   949  //
   950  // This could be surprising behavior to retroactively apply to
   951  // driver.String now that Go1 is out, but this is convenient for
   952  // our TestPointerParamsAndScans.
   953  //
   954  type fakeDriverString struct{}
   955  
   956  func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
   957  	switch c := v.(type) {
   958  	case string, []byte:
   959  		return v, nil
   960  	case *string:
   961  		if c == nil {
   962  			return nil, nil
   963  		}
   964  		return *c, nil
   965  	}
   966  	return fmt.Sprintf("%v", v), nil
   967  }
   968  
   969  func converterForType(typ string) driver.ValueConverter {
   970  	switch typ {
   971  	case "bool":
   972  		return driver.Bool
   973  	case "nullbool":
   974  		return driver.Null{Converter: driver.Bool}
   975  	case "int32":
   976  		return driver.Int32
   977  	case "string":
   978  		return driver.NotNull{Converter: fakeDriverString{}}
   979  	case "nullstring":
   980  		return driver.Null{Converter: fakeDriverString{}}
   981  	case "int64":
   982  		// TODO(coopernurse): add type-specific converter
   983  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
   984  	case "nullint64":
   985  		// TODO(coopernurse): add type-specific converter
   986  		return driver.Null{Converter: driver.DefaultParameterConverter}
   987  	case "float64":
   988  		// TODO(coopernurse): add type-specific converter
   989  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
   990  	case "nullfloat64":
   991  		// TODO(coopernurse): add type-specific converter
   992  		return driver.Null{Converter: driver.DefaultParameterConverter}
   993  	case "datetime":
   994  		return driver.DefaultParameterConverter
   995  	}
   996  	panic("invalid fakedb column type of " + typ)
   997  }
   998  
   999  func colTypeToReflectType(typ string) reflect.Type {
  1000  	switch typ {
  1001  	case "bool":
  1002  		return reflect.TypeOf(false)
  1003  	case "nullbool":
  1004  		return reflect.TypeOf(NullBool{})
  1005  	case "int32":
  1006  		return reflect.TypeOf(int32(0))
  1007  	case "string":
  1008  		return reflect.TypeOf("")
  1009  	case "nullstring":
  1010  		return reflect.TypeOf(NullString{})
  1011  	case "int64":
  1012  		return reflect.TypeOf(int64(0))
  1013  	case "nullint64":
  1014  		return reflect.TypeOf(NullInt64{})
  1015  	case "float64":
  1016  		return reflect.TypeOf(float64(0))
  1017  	case "nullfloat64":
  1018  		return reflect.TypeOf(NullFloat64{})
  1019  	case "datetime":
  1020  		return reflect.TypeOf(time.Time{})
  1021  	}
  1022  	panic("invalid fakedb column type of " + typ)
  1023  }