github.com/mattn/go@v0.0.0-20171011075504-07f7db3ea99f/src/database/sql/fakedb_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"reflect"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  var _ = log.Printf
    24  
    25  // fakeDriver is a fake database that implements Go's driver.Driver
    26  // interface, just for testing.
    27  //
    28  // It speaks a query language that's semantically similar to but
    29  // syntactically different and simpler than SQL.  The syntax is as
    30  // follows:
    31  //
    32  //   WIPE
    33  //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    34  //     where types are: "string", [u]int{8,16,32,64}, "bool"
    35  //   INSERT|<tablename>|col=val,col2=val2,col3=?
    36  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    37  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    38  //
    39  // Any of these can be preceded by PANIC|<method>|, to cause the
    40  // named method on fakeStmt to panic.
    41  //
    42  // Any of these can be proceeded by WAIT|<duration>|, to cause the
    43  // named method on fakeStmt to sleep for the specified duration.
    44  //
    45  // Multiple of these can be combined when separated with a semicolon.
    46  //
    47  // When opening a fakeDriver's database, it starts empty with no
    48  // tables. All tables and data are stored in memory only.
    49  type fakeDriver struct {
    50  	mu         sync.Mutex // guards 3 following fields
    51  	openCount  int        // conn opens
    52  	closeCount int        // conn closes
    53  	waitCh     chan struct{}
    54  	waitingCh  chan struct{}
    55  	dbs        map[string]*fakeDB
    56  }
    57  
    58  type fakeDB struct {
    59  	name string
    60  
    61  	mu       sync.Mutex
    62  	tables   map[string]*table
    63  	badConn  bool
    64  	allowAny bool
    65  }
    66  
    67  type table struct {
    68  	mu      sync.Mutex
    69  	colname []string
    70  	coltype []string
    71  	rows    []*row
    72  }
    73  
    74  func (t *table) columnIndex(name string) int {
    75  	for n, nname := range t.colname {
    76  		if name == nname {
    77  			return n
    78  		}
    79  	}
    80  	return -1
    81  }
    82  
    83  type row struct {
    84  	cols []interface{} // must be same size as its table colname + coltype
    85  }
    86  
    87  type memToucher interface {
    88  	// touchMem reads & writes some memory, to help find data races.
    89  	touchMem()
    90  }
    91  
    92  type fakeConn struct {
    93  	db *fakeDB // where to return ourselves to
    94  
    95  	currTx *fakeTx
    96  
    97  	// Every operation writes to line to enable the race detector
    98  	// check for data races.
    99  	line int64
   100  
   101  	// Stats for tests:
   102  	mu          sync.Mutex
   103  	stmtsMade   int
   104  	stmtsClosed int
   105  	numPrepare  int
   106  
   107  	// bad connection tests; see isBad()
   108  	bad       bool
   109  	stickyBad bool
   110  }
   111  
   112  func (c *fakeConn) touchMem() {
   113  	c.line++
   114  }
   115  
   116  func (c *fakeConn) incrStat(v *int) {
   117  	c.mu.Lock()
   118  	*v++
   119  	c.mu.Unlock()
   120  }
   121  
   122  type fakeTx struct {
   123  	c *fakeConn
   124  }
   125  
   126  type boundCol struct {
   127  	Column      string
   128  	Placeholder string
   129  	Ordinal     int
   130  }
   131  
   132  type fakeStmt struct {
   133  	memToucher
   134  	c *fakeConn
   135  	q string // just for debugging
   136  
   137  	cmd   string
   138  	table string
   139  	panic string
   140  	wait  time.Duration
   141  
   142  	next *fakeStmt // used for returning multiple results.
   143  
   144  	closed bool
   145  
   146  	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
   147  	colType      []string      // used by CREATE
   148  	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
   149  	placeholders int           // used by INSERT/SELECT: number of ? params
   150  
   151  	whereCol []boundCol // used by SELECT (all placeholders)
   152  
   153  	placeholderConverter []driver.ValueConverter // used by INSERT
   154  }
   155  
   156  var fdriver driver.Driver = &fakeDriver{}
   157  
   158  func init() {
   159  	Register("test", fdriver)
   160  }
   161  
   162  func contains(list []string, y string) bool {
   163  	for _, x := range list {
   164  		if x == y {
   165  			return true
   166  		}
   167  	}
   168  	return false
   169  }
   170  
   171  type Dummy struct {
   172  	driver.Driver
   173  }
   174  
   175  func TestDrivers(t *testing.T) {
   176  	unregisterAllDrivers()
   177  	Register("test", fdriver)
   178  	Register("invalid", Dummy{})
   179  	all := Drivers()
   180  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   181  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   182  	}
   183  }
   184  
   185  // hook to simulate connection failures
   186  var hookOpenErr struct {
   187  	sync.Mutex
   188  	fn func() error
   189  }
   190  
   191  func setHookOpenErr(fn func() error) {
   192  	hookOpenErr.Lock()
   193  	defer hookOpenErr.Unlock()
   194  	hookOpenErr.fn = fn
   195  }
   196  
   197  // Supports dsn forms:
   198  //    <dbname>
   199  //    <dbname>;<opts>  (only currently supported option is `badConn`,
   200  //                      which causes driver.ErrBadConn to be returned on
   201  //                      every other conn.Begin())
   202  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   203  	hookOpenErr.Lock()
   204  	fn := hookOpenErr.fn
   205  	hookOpenErr.Unlock()
   206  	if fn != nil {
   207  		if err := fn(); err != nil {
   208  			return nil, err
   209  		}
   210  	}
   211  	parts := strings.Split(dsn, ";")
   212  	if len(parts) < 1 {
   213  		return nil, errors.New("fakedb: no database name")
   214  	}
   215  	name := parts[0]
   216  
   217  	db := d.getDB(name)
   218  
   219  	d.mu.Lock()
   220  	d.openCount++
   221  	d.mu.Unlock()
   222  	conn := &fakeConn{db: db}
   223  
   224  	if len(parts) >= 2 && parts[1] == "badConn" {
   225  		conn.bad = true
   226  	}
   227  	if d.waitCh != nil {
   228  		d.waitingCh <- struct{}{}
   229  		<-d.waitCh
   230  		d.waitCh = nil
   231  		d.waitingCh = nil
   232  	}
   233  	return conn, nil
   234  }
   235  
   236  func (d *fakeDriver) getDB(name string) *fakeDB {
   237  	d.mu.Lock()
   238  	defer d.mu.Unlock()
   239  	if d.dbs == nil {
   240  		d.dbs = make(map[string]*fakeDB)
   241  	}
   242  	db, ok := d.dbs[name]
   243  	if !ok {
   244  		db = &fakeDB{name: name}
   245  		d.dbs[name] = db
   246  	}
   247  	return db
   248  }
   249  
   250  func (db *fakeDB) wipe() {
   251  	db.mu.Lock()
   252  	defer db.mu.Unlock()
   253  	db.tables = nil
   254  }
   255  
   256  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   257  	db.mu.Lock()
   258  	defer db.mu.Unlock()
   259  	if db.tables == nil {
   260  		db.tables = make(map[string]*table)
   261  	}
   262  	if _, exist := db.tables[name]; exist {
   263  		return fmt.Errorf("table %q already exists", name)
   264  	}
   265  	if len(columnNames) != len(columnTypes) {
   266  		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
   267  			name, len(columnNames), len(columnTypes))
   268  	}
   269  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   270  	return nil
   271  }
   272  
   273  // must be called with db.mu lock held
   274  func (db *fakeDB) table(table string) (*table, bool) {
   275  	if db.tables == nil {
   276  		return nil, false
   277  	}
   278  	t, ok := db.tables[table]
   279  	return t, ok
   280  }
   281  
   282  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   283  	db.mu.Lock()
   284  	defer db.mu.Unlock()
   285  	t, ok := db.table(table)
   286  	if !ok {
   287  		return
   288  	}
   289  	for n, cname := range t.colname {
   290  		if cname == column {
   291  			return t.coltype[n], true
   292  		}
   293  	}
   294  	return "", false
   295  }
   296  
   297  func (c *fakeConn) isBad() bool {
   298  	if c.stickyBad {
   299  		return true
   300  	} else if c.bad {
   301  		// alternate between bad conn and not bad conn
   302  		c.db.badConn = !c.db.badConn
   303  		return c.db.badConn
   304  	} else {
   305  		return false
   306  	}
   307  }
   308  
   309  func (c *fakeConn) Begin() (driver.Tx, error) {
   310  	if c.isBad() {
   311  		return nil, driver.ErrBadConn
   312  	}
   313  	if c.currTx != nil {
   314  		return nil, errors.New("already in a transaction")
   315  	}
   316  	c.touchMem()
   317  	c.currTx = &fakeTx{c: c}
   318  	return c.currTx, nil
   319  }
   320  
   321  var hookPostCloseConn struct {
   322  	sync.Mutex
   323  	fn func(*fakeConn, error)
   324  }
   325  
   326  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   327  	hookPostCloseConn.Lock()
   328  	defer hookPostCloseConn.Unlock()
   329  	hookPostCloseConn.fn = fn
   330  }
   331  
   332  var testStrictClose *testing.T
   333  
   334  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   335  // fails to close. If nil, the check is disabled.
   336  func setStrictFakeConnClose(t *testing.T) {
   337  	testStrictClose = t
   338  }
   339  
   340  func (c *fakeConn) Close() (err error) {
   341  	drv := fdriver.(*fakeDriver)
   342  	defer func() {
   343  		if err != nil && testStrictClose != nil {
   344  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   345  		}
   346  		hookPostCloseConn.Lock()
   347  		fn := hookPostCloseConn.fn
   348  		hookPostCloseConn.Unlock()
   349  		if fn != nil {
   350  			fn(c, err)
   351  		}
   352  		if err == nil {
   353  			drv.mu.Lock()
   354  			drv.closeCount++
   355  			drv.mu.Unlock()
   356  		}
   357  	}()
   358  	c.touchMem()
   359  	if c.currTx != nil {
   360  		return errors.New("can't close fakeConn; in a Transaction")
   361  	}
   362  	if c.db == nil {
   363  		return errors.New("can't close fakeConn; already closed")
   364  	}
   365  	if c.stmtsMade > c.stmtsClosed {
   366  		return errors.New("can't close; dangling statement(s)")
   367  	}
   368  	c.db = nil
   369  	return nil
   370  }
   371  
   372  func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
   373  	for _, arg := range args {
   374  		switch arg.Value.(type) {
   375  		case int64, float64, bool, nil, []byte, string, time.Time:
   376  		default:
   377  			if !allowAny {
   378  				return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   379  			}
   380  		}
   381  	}
   382  	return nil
   383  }
   384  
   385  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   386  	// Ensure that ExecContext is called if available.
   387  	panic("ExecContext was not called.")
   388  }
   389  
   390  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   391  	// This is an optional interface, but it's implemented here
   392  	// just to check that all the args are of the proper types.
   393  	// ErrSkip is returned so the caller acts as if we didn't
   394  	// implement this at all.
   395  	err := checkSubsetTypes(c.db.allowAny, args)
   396  	if err != nil {
   397  		return nil, err
   398  	}
   399  	return nil, driver.ErrSkip
   400  }
   401  
   402  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   403  	// Ensure that ExecContext is called if available.
   404  	panic("QueryContext was not called.")
   405  }
   406  
   407  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   408  	// This is an optional interface, but it's implemented here
   409  	// just to check that all the args are of the proper types.
   410  	// ErrSkip is returned so the caller acts as if we didn't
   411  	// implement this at all.
   412  	err := checkSubsetTypes(c.db.allowAny, args)
   413  	if err != nil {
   414  		return nil, err
   415  	}
   416  	return nil, driver.ErrSkip
   417  }
   418  
   419  func errf(msg string, args ...interface{}) error {
   420  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   421  }
   422  
   423  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   424  // (note that where columns must always contain ? marks,
   425  //  just a limitation for fakedb)
   426  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   427  	if len(parts) != 3 {
   428  		stmt.Close()
   429  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   430  	}
   431  	stmt.table = parts[0]
   432  
   433  	stmt.colName = strings.Split(parts[1], ",")
   434  	for n, colspec := range strings.Split(parts[2], ",") {
   435  		if colspec == "" {
   436  			continue
   437  		}
   438  		nameVal := strings.Split(colspec, "=")
   439  		if len(nameVal) != 2 {
   440  			stmt.Close()
   441  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   442  		}
   443  		column, value := nameVal[0], nameVal[1]
   444  		_, ok := c.db.columnType(stmt.table, column)
   445  		if !ok {
   446  			stmt.Close()
   447  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   448  		}
   449  		if !strings.HasPrefix(value, "?") {
   450  			stmt.Close()
   451  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   452  				stmt.table, column)
   453  		}
   454  		stmt.placeholders++
   455  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   456  	}
   457  	return stmt, nil
   458  }
   459  
   460  // parts are table|col=type,col2=type2
   461  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   462  	if len(parts) != 2 {
   463  		stmt.Close()
   464  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   465  	}
   466  	stmt.table = parts[0]
   467  	for n, colspec := range strings.Split(parts[1], ",") {
   468  		nameType := strings.Split(colspec, "=")
   469  		if len(nameType) != 2 {
   470  			stmt.Close()
   471  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   472  		}
   473  		stmt.colName = append(stmt.colName, nameType[0])
   474  		stmt.colType = append(stmt.colType, nameType[1])
   475  	}
   476  	return stmt, nil
   477  }
   478  
   479  // parts are table|col=?,col2=val
   480  func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   481  	if len(parts) != 2 {
   482  		stmt.Close()
   483  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   484  	}
   485  	stmt.table = parts[0]
   486  	for n, colspec := range strings.Split(parts[1], ",") {
   487  		nameVal := strings.Split(colspec, "=")
   488  		if len(nameVal) != 2 {
   489  			stmt.Close()
   490  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   491  		}
   492  		column, value := nameVal[0], nameVal[1]
   493  		ctype, ok := c.db.columnType(stmt.table, column)
   494  		if !ok {
   495  			stmt.Close()
   496  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   497  		}
   498  		stmt.colName = append(stmt.colName, column)
   499  
   500  		if !strings.HasPrefix(value, "?") {
   501  			var subsetVal interface{}
   502  			// Convert to driver subset type
   503  			switch ctype {
   504  			case "string":
   505  				subsetVal = []byte(value)
   506  			case "blob":
   507  				subsetVal = []byte(value)
   508  			case "int32":
   509  				i, err := strconv.Atoi(value)
   510  				if err != nil {
   511  					stmt.Close()
   512  					return nil, errf("invalid conversion to int32 from %q", value)
   513  				}
   514  				subsetVal = int64(i) // int64 is a subset type, but not int32
   515  			default:
   516  				stmt.Close()
   517  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   518  			}
   519  			stmt.colValue = append(stmt.colValue, subsetVal)
   520  		} else {
   521  			stmt.placeholders++
   522  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   523  			stmt.colValue = append(stmt.colValue, value)
   524  		}
   525  	}
   526  	return stmt, nil
   527  }
   528  
   529  // hook to simulate broken connections
   530  var hookPrepareBadConn func() bool
   531  
   532  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   533  	panic("use PrepareContext")
   534  }
   535  
   536  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   537  	c.numPrepare++
   538  	if c.db == nil {
   539  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   540  	}
   541  
   542  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   543  		return nil, driver.ErrBadConn
   544  	}
   545  
   546  	c.touchMem()
   547  	var firstStmt, prev *fakeStmt
   548  	for _, query := range strings.Split(query, ";") {
   549  		parts := strings.Split(query, "|")
   550  		if len(parts) < 1 {
   551  			return nil, errf("empty query")
   552  		}
   553  		stmt := &fakeStmt{q: query, c: c, memToucher: c}
   554  		if firstStmt == nil {
   555  			firstStmt = stmt
   556  		}
   557  		if len(parts) >= 3 {
   558  			switch parts[0] {
   559  			case "PANIC":
   560  				stmt.panic = parts[1]
   561  				parts = parts[2:]
   562  			case "WAIT":
   563  				wait, err := time.ParseDuration(parts[1])
   564  				if err != nil {
   565  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
   566  				}
   567  				parts = parts[2:]
   568  				stmt.wait = wait
   569  			}
   570  		}
   571  		cmd := parts[0]
   572  		stmt.cmd = cmd
   573  		parts = parts[1:]
   574  
   575  		if stmt.wait > 0 {
   576  			wait := time.NewTimer(stmt.wait)
   577  			select {
   578  			case <-wait.C:
   579  			case <-ctx.Done():
   580  				wait.Stop()
   581  				return nil, ctx.Err()
   582  			}
   583  		}
   584  
   585  		c.incrStat(&c.stmtsMade)
   586  		var err error
   587  		switch cmd {
   588  		case "WIPE":
   589  			// Nothing
   590  		case "SELECT":
   591  			stmt, err = c.prepareSelect(stmt, parts)
   592  		case "CREATE":
   593  			stmt, err = c.prepareCreate(stmt, parts)
   594  		case "INSERT":
   595  			stmt, err = c.prepareInsert(stmt, parts)
   596  		case "NOSERT":
   597  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   598  			// Used for some of the concurrent tests.
   599  			stmt, err = c.prepareInsert(stmt, parts)
   600  		default:
   601  			stmt.Close()
   602  			return nil, errf("unsupported command type %q", cmd)
   603  		}
   604  		if err != nil {
   605  			return nil, err
   606  		}
   607  		if prev != nil {
   608  			prev.next = stmt
   609  		}
   610  		prev = stmt
   611  	}
   612  	return firstStmt, nil
   613  }
   614  
   615  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   616  	if s.panic == "ColumnConverter" {
   617  		panic(s.panic)
   618  	}
   619  	if len(s.placeholderConverter) == 0 {
   620  		return driver.DefaultParameterConverter
   621  	}
   622  	return s.placeholderConverter[idx]
   623  }
   624  
   625  func (s *fakeStmt) Close() error {
   626  	if s.panic == "Close" {
   627  		panic(s.panic)
   628  	}
   629  	if s.c == nil {
   630  		panic("nil conn in fakeStmt.Close")
   631  	}
   632  	if s.c.db == nil {
   633  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   634  	}
   635  	s.touchMem()
   636  	if !s.closed {
   637  		s.c.incrStat(&s.c.stmtsClosed)
   638  		s.closed = true
   639  	}
   640  	if s.next != nil {
   641  		s.next.Close()
   642  	}
   643  	return nil
   644  }
   645  
   646  var errClosed = errors.New("fakedb: statement has been closed")
   647  
   648  // hook to simulate broken connections
   649  var hookExecBadConn func() bool
   650  
   651  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   652  	panic("Using ExecContext")
   653  }
   654  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   655  	if s.panic == "Exec" {
   656  		panic(s.panic)
   657  	}
   658  	if s.closed {
   659  		return nil, errClosed
   660  	}
   661  
   662  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   663  		return nil, driver.ErrBadConn
   664  	}
   665  
   666  	err := checkSubsetTypes(s.c.db.allowAny, args)
   667  	if err != nil {
   668  		return nil, err
   669  	}
   670  	s.touchMem()
   671  
   672  	if s.wait > 0 {
   673  		time.Sleep(s.wait)
   674  	}
   675  
   676  	select {
   677  	default:
   678  	case <-ctx.Done():
   679  		return nil, ctx.Err()
   680  	}
   681  
   682  	db := s.c.db
   683  	switch s.cmd {
   684  	case "WIPE":
   685  		db.wipe()
   686  		return driver.ResultNoRows, nil
   687  	case "CREATE":
   688  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   689  			return nil, err
   690  		}
   691  		return driver.ResultNoRows, nil
   692  	case "INSERT":
   693  		return s.execInsert(args, true)
   694  	case "NOSERT":
   695  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   696  		// Used for some of the concurrent tests.
   697  		return s.execInsert(args, false)
   698  	}
   699  	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
   700  	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
   701  }
   702  
   703  // When doInsert is true, add the row to the table.
   704  // When doInsert is false do prep-work and error checking, but don't
   705  // actually add the row to the table.
   706  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   707  	db := s.c.db
   708  	if len(args) != s.placeholders {
   709  		panic("error in pkg db; should only get here if size is correct")
   710  	}
   711  	db.mu.Lock()
   712  	t, ok := db.table(s.table)
   713  	db.mu.Unlock()
   714  	if !ok {
   715  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   716  	}
   717  
   718  	t.mu.Lock()
   719  	defer t.mu.Unlock()
   720  
   721  	var cols []interface{}
   722  	if doInsert {
   723  		cols = make([]interface{}, len(t.colname))
   724  	}
   725  	argPos := 0
   726  	for n, colname := range s.colName {
   727  		colidx := t.columnIndex(colname)
   728  		if colidx == -1 {
   729  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   730  		}
   731  		var val interface{}
   732  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   733  			if strvalue == "?" {
   734  				val = args[argPos].Value
   735  			} else {
   736  				// Assign value from argument placeholder name.
   737  				for _, a := range args {
   738  					if a.Name == strvalue[1:] {
   739  						val = a.Value
   740  						break
   741  					}
   742  				}
   743  			}
   744  			argPos++
   745  		} else {
   746  			val = s.colValue[n]
   747  		}
   748  		if doInsert {
   749  			cols[colidx] = val
   750  		}
   751  	}
   752  
   753  	if doInsert {
   754  		t.rows = append(t.rows, &row{cols: cols})
   755  	}
   756  	return driver.RowsAffected(1), nil
   757  }
   758  
   759  // hook to simulate broken connections
   760  var hookQueryBadConn func() bool
   761  
   762  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   763  	panic("Use QueryContext")
   764  }
   765  
   766  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   767  	if s.panic == "Query" {
   768  		panic(s.panic)
   769  	}
   770  	if s.closed {
   771  		return nil, errClosed
   772  	}
   773  
   774  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   775  		return nil, driver.ErrBadConn
   776  	}
   777  
   778  	err := checkSubsetTypes(s.c.db.allowAny, args)
   779  	if err != nil {
   780  		return nil, err
   781  	}
   782  
   783  	s.touchMem()
   784  	db := s.c.db
   785  	if len(args) != s.placeholders {
   786  		panic("error in pkg db; should only get here if size is correct")
   787  	}
   788  
   789  	setMRows := make([][]*row, 0, 1)
   790  	setColumns := make([][]string, 0, 1)
   791  	setColType := make([][]string, 0, 1)
   792  
   793  	for {
   794  		db.mu.Lock()
   795  		t, ok := db.table(s.table)
   796  		db.mu.Unlock()
   797  		if !ok {
   798  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   799  		}
   800  
   801  		if s.table == "magicquery" {
   802  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   803  				if args[0].Value == "sleep" {
   804  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   805  				}
   806  			}
   807  		}
   808  
   809  		t.mu.Lock()
   810  
   811  		colIdx := make(map[string]int) // select column name -> column index in table
   812  		for _, name := range s.colName {
   813  			idx := t.columnIndex(name)
   814  			if idx == -1 {
   815  				t.mu.Unlock()
   816  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   817  			}
   818  			colIdx[name] = idx
   819  		}
   820  
   821  		mrows := []*row{}
   822  	rows:
   823  		for _, trow := range t.rows {
   824  			// Process the where clause, skipping non-match rows. This is lazy
   825  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   826  			// for test code.
   827  			for _, wcol := range s.whereCol {
   828  				idx := t.columnIndex(wcol.Column)
   829  				if idx == -1 {
   830  					t.mu.Unlock()
   831  					return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
   832  				}
   833  				tcol := trow.cols[idx]
   834  				if bs, ok := tcol.([]byte); ok {
   835  					// lazy hack to avoid sprintf %v on a []byte
   836  					tcol = string(bs)
   837  				}
   838  				var argValue interface{}
   839  				if wcol.Placeholder == "?" {
   840  					argValue = args[wcol.Ordinal-1].Value
   841  				} else {
   842  					// Assign arg value from placeholder name.
   843  					for _, a := range args {
   844  						if a.Name == wcol.Placeholder[1:] {
   845  							argValue = a.Value
   846  							break
   847  						}
   848  					}
   849  				}
   850  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
   851  					continue rows
   852  				}
   853  			}
   854  			mrow := &row{cols: make([]interface{}, len(s.colName))}
   855  			for seli, name := range s.colName {
   856  				mrow.cols[seli] = trow.cols[colIdx[name]]
   857  			}
   858  			mrows = append(mrows, mrow)
   859  		}
   860  
   861  		var colType []string
   862  		for _, column := range s.colName {
   863  			colType = append(colType, t.coltype[t.columnIndex(column)])
   864  		}
   865  
   866  		t.mu.Unlock()
   867  
   868  		setMRows = append(setMRows, mrows)
   869  		setColumns = append(setColumns, s.colName)
   870  		setColType = append(setColType, colType)
   871  
   872  		if s.next == nil {
   873  			break
   874  		}
   875  		s = s.next
   876  	}
   877  
   878  	cursor := &rowsCursor{
   879  		parentMem: s.c,
   880  		posRow:    -1,
   881  		rows:      setMRows,
   882  		cols:      setColumns,
   883  		colType:   setColType,
   884  		errPos:    -1,
   885  	}
   886  	return cursor, nil
   887  }
   888  
   889  func (s *fakeStmt) NumInput() int {
   890  	if s.panic == "NumInput" {
   891  		panic(s.panic)
   892  	}
   893  	return s.placeholders
   894  }
   895  
   896  // hook to simulate broken connections
   897  var hookCommitBadConn func() bool
   898  
   899  func (tx *fakeTx) Commit() error {
   900  	tx.c.currTx = nil
   901  	if hookCommitBadConn != nil && hookCommitBadConn() {
   902  		return driver.ErrBadConn
   903  	}
   904  	tx.c.touchMem()
   905  	return nil
   906  }
   907  
   908  // hook to simulate broken connections
   909  var hookRollbackBadConn func() bool
   910  
   911  func (tx *fakeTx) Rollback() error {
   912  	tx.c.currTx = nil
   913  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
   914  		return driver.ErrBadConn
   915  	}
   916  	tx.c.touchMem()
   917  	return nil
   918  }
   919  
   920  type rowsCursor struct {
   921  	parentMem memToucher
   922  	cols      [][]string
   923  	colType   [][]string
   924  	posSet    int
   925  	posRow    int
   926  	rows      [][]*row
   927  	closed    bool
   928  
   929  	// errPos and err are for making Next return early with error.
   930  	errPos int
   931  	err    error
   932  
   933  	// a clone of slices to give out to clients, indexed by the
   934  	// the original slice's first byte address.  we clone them
   935  	// just so we're able to corrupt them on close.
   936  	bytesClone map[*byte][]byte
   937  
   938  	// Every operation writes to line to enable the race detector
   939  	// check for data races.
   940  	// This is separate from the fakeConn.line to allow for drivers that
   941  	// can start multiple queries on the same transaction at the same time.
   942  	line int64
   943  }
   944  
   945  func (rc *rowsCursor) touchMem() {
   946  	rc.line++
   947  }
   948  
   949  func (rc *rowsCursor) Close() error {
   950  	if !rc.closed {
   951  		for _, bs := range rc.bytesClone {
   952  			bs[0] = 255 // first byte corrupted
   953  		}
   954  	}
   955  	rc.touchMem()
   956  	rc.parentMem.touchMem()
   957  	rc.closed = true
   958  	return nil
   959  }
   960  
   961  func (rc *rowsCursor) Columns() []string {
   962  	return rc.cols[rc.posSet]
   963  }
   964  
   965  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
   966  	return colTypeToReflectType(rc.colType[rc.posSet][index])
   967  }
   968  
   969  var rowsCursorNextHook func(dest []driver.Value) error
   970  
   971  func (rc *rowsCursor) Next(dest []driver.Value) error {
   972  	if rowsCursorNextHook != nil {
   973  		return rowsCursorNextHook(dest)
   974  	}
   975  
   976  	if rc.closed {
   977  		return errors.New("fakedb: cursor is closed")
   978  	}
   979  	rc.touchMem()
   980  	rc.posRow++
   981  	if rc.posRow == rc.errPos {
   982  		return rc.err
   983  	}
   984  	if rc.posRow >= len(rc.rows[rc.posSet]) {
   985  		return io.EOF // per interface spec
   986  	}
   987  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
   988  		// TODO(bradfitz): convert to subset types? naah, I
   989  		// think the subset types should only be input to
   990  		// driver, but the sql package should be able to handle
   991  		// a wider range of types coming out of drivers. all
   992  		// for ease of drivers, and to prevent drivers from
   993  		// messing up conversions or doing them differently.
   994  		dest[i] = v
   995  
   996  		if bs, ok := v.([]byte); ok {
   997  			if rc.bytesClone == nil {
   998  				rc.bytesClone = make(map[*byte][]byte)
   999  			}
  1000  			clone, ok := rc.bytesClone[&bs[0]]
  1001  			if !ok {
  1002  				clone = make([]byte, len(bs))
  1003  				copy(clone, bs)
  1004  				rc.bytesClone[&bs[0]] = clone
  1005  			}
  1006  			dest[i] = clone
  1007  		}
  1008  	}
  1009  	return nil
  1010  }
  1011  
  1012  func (rc *rowsCursor) HasNextResultSet() bool {
  1013  	rc.touchMem()
  1014  	return rc.posSet < len(rc.rows)-1
  1015  }
  1016  
  1017  func (rc *rowsCursor) NextResultSet() error {
  1018  	rc.touchMem()
  1019  	if rc.HasNextResultSet() {
  1020  		rc.posSet++
  1021  		rc.posRow = -1
  1022  		return nil
  1023  	}
  1024  	return io.EOF // Per interface spec.
  1025  }
  1026  
  1027  // fakeDriverString is like driver.String, but indirects pointers like
  1028  // DefaultValueConverter.
  1029  //
  1030  // This could be surprising behavior to retroactively apply to
  1031  // driver.String now that Go1 is out, but this is convenient for
  1032  // our TestPointerParamsAndScans.
  1033  //
  1034  type fakeDriverString struct{}
  1035  
  1036  func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
  1037  	switch c := v.(type) {
  1038  	case string, []byte:
  1039  		return v, nil
  1040  	case *string:
  1041  		if c == nil {
  1042  			return nil, nil
  1043  		}
  1044  		return *c, nil
  1045  	}
  1046  	return fmt.Sprintf("%v", v), nil
  1047  }
  1048  
  1049  type anyTypeConverter struct{}
  1050  
  1051  func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
  1052  	return v, nil
  1053  }
  1054  
  1055  func converterForType(typ string) driver.ValueConverter {
  1056  	switch typ {
  1057  	case "bool":
  1058  		return driver.Bool
  1059  	case "nullbool":
  1060  		return driver.Null{Converter: driver.Bool}
  1061  	case "int32":
  1062  		return driver.Int32
  1063  	case "string":
  1064  		return driver.NotNull{Converter: fakeDriverString{}}
  1065  	case "nullstring":
  1066  		return driver.Null{Converter: fakeDriverString{}}
  1067  	case "int64":
  1068  		// TODO(coopernurse): add type-specific converter
  1069  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1070  	case "nullint64":
  1071  		// TODO(coopernurse): add type-specific converter
  1072  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1073  	case "float64":
  1074  		// TODO(coopernurse): add type-specific converter
  1075  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1076  	case "nullfloat64":
  1077  		// TODO(coopernurse): add type-specific converter
  1078  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1079  	case "datetime":
  1080  		return driver.DefaultParameterConverter
  1081  	case "any":
  1082  		return anyTypeConverter{}
  1083  	}
  1084  	panic("invalid fakedb column type of " + typ)
  1085  }
  1086  
  1087  func colTypeToReflectType(typ string) reflect.Type {
  1088  	switch typ {
  1089  	case "bool":
  1090  		return reflect.TypeOf(false)
  1091  	case "nullbool":
  1092  		return reflect.TypeOf(NullBool{})
  1093  	case "int32":
  1094  		return reflect.TypeOf(int32(0))
  1095  	case "string":
  1096  		return reflect.TypeOf("")
  1097  	case "nullstring":
  1098  		return reflect.TypeOf(NullString{})
  1099  	case "int64":
  1100  		return reflect.TypeOf(int64(0))
  1101  	case "nullint64":
  1102  		return reflect.TypeOf(NullInt64{})
  1103  	case "float64":
  1104  		return reflect.TypeOf(float64(0))
  1105  	case "nullfloat64":
  1106  		return reflect.TypeOf(NullFloat64{})
  1107  	case "datetime":
  1108  		return reflect.TypeOf(time.Time{})
  1109  	case "any":
  1110  		return reflect.TypeOf(new(interface{})).Elem()
  1111  	}
  1112  	panic("invalid fakedb column type of " + typ)
  1113  }