github.com/gocuntian/go@v0.0.0-20160610041250-fee02d270bf8/src/database/sql/fakedb_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"database/sql/driver"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"sort"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  var _ = log.Printf
    22  
    23  // fakeDriver is a fake database that implements Go's driver.Driver
    24  // interface, just for testing.
    25  //
    26  // It speaks a query language that's semantically similar to but
    27  // syntactically different and simpler than SQL.  The syntax is as
    28  // follows:
    29  //
    30  //   WIPE
    31  //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    32  //     where types are: "string", [u]int{8,16,32,64}, "bool"
    33  //   INSERT|<tablename>|col=val,col2=val2,col3=?
    34  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    35  //
    36  // Any of these can be preceded by PANIC|<method>|, to cause the
    37  // named method on fakeStmt to panic.
    38  //
    39  // When opening a fakeDriver's database, it starts empty with no
    40  // tables. All tables and data are stored in memory only.
    41  type fakeDriver struct {
    42  	mu         sync.Mutex // guards 3 following fields
    43  	openCount  int        // conn opens
    44  	closeCount int        // conn closes
    45  	waitCh     chan struct{}
    46  	waitingCh  chan struct{}
    47  	dbs        map[string]*fakeDB
    48  }
    49  
    50  type fakeDB struct {
    51  	name string
    52  
    53  	mu      sync.Mutex
    54  	tables  map[string]*table
    55  	badConn bool
    56  }
    57  
    58  type table struct {
    59  	mu      sync.Mutex
    60  	colname []string
    61  	coltype []string
    62  	rows    []*row
    63  }
    64  
    65  func (t *table) columnIndex(name string) int {
    66  	for n, nname := range t.colname {
    67  		if name == nname {
    68  			return n
    69  		}
    70  	}
    71  	return -1
    72  }
    73  
    74  type row struct {
    75  	cols []interface{} // must be same size as its table colname + coltype
    76  }
    77  
    78  type fakeConn struct {
    79  	db *fakeDB // where to return ourselves to
    80  
    81  	currTx *fakeTx
    82  
    83  	// Stats for tests:
    84  	mu          sync.Mutex
    85  	stmtsMade   int
    86  	stmtsClosed int
    87  	numPrepare  int
    88  
    89  	// bad connection tests; see isBad()
    90  	bad       bool
    91  	stickyBad bool
    92  }
    93  
    94  func (c *fakeConn) incrStat(v *int) {
    95  	c.mu.Lock()
    96  	*v++
    97  	c.mu.Unlock()
    98  }
    99  
   100  type fakeTx struct {
   101  	c *fakeConn
   102  }
   103  
   104  type fakeStmt struct {
   105  	c *fakeConn
   106  	q string // just for debugging
   107  
   108  	cmd   string
   109  	table string
   110  	panic string
   111  
   112  	closed bool
   113  
   114  	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
   115  	colType      []string      // used by CREATE
   116  	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
   117  	placeholders int           // used by INSERT/SELECT: number of ? params
   118  
   119  	whereCol []string // used by SELECT (all placeholders)
   120  
   121  	placeholderConverter []driver.ValueConverter // used by INSERT
   122  }
   123  
   124  var fdriver driver.Driver = &fakeDriver{}
   125  
   126  func init() {
   127  	Register("test", fdriver)
   128  }
   129  
   130  func contains(list []string, y string) bool {
   131  	for _, x := range list {
   132  		if x == y {
   133  			return true
   134  		}
   135  	}
   136  	return false
   137  }
   138  
   139  type Dummy struct {
   140  	driver.Driver
   141  }
   142  
   143  func TestDrivers(t *testing.T) {
   144  	unregisterAllDrivers()
   145  	Register("test", fdriver)
   146  	Register("invalid", Dummy{})
   147  	all := Drivers()
   148  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   149  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   150  	}
   151  }
   152  
   153  // hook to simulate connection failures
   154  var hookOpenErr struct {
   155  	sync.Mutex
   156  	fn func() error
   157  }
   158  
   159  func setHookOpenErr(fn func() error) {
   160  	hookOpenErr.Lock()
   161  	defer hookOpenErr.Unlock()
   162  	hookOpenErr.fn = fn
   163  }
   164  
   165  // Supports dsn forms:
   166  //    <dbname>
   167  //    <dbname>;<opts>  (only currently supported option is `badConn`,
   168  //                      which causes driver.ErrBadConn to be returned on
   169  //                      every other conn.Begin())
   170  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   171  	hookOpenErr.Lock()
   172  	fn := hookOpenErr.fn
   173  	hookOpenErr.Unlock()
   174  	if fn != nil {
   175  		if err := fn(); err != nil {
   176  			return nil, err
   177  		}
   178  	}
   179  	parts := strings.Split(dsn, ";")
   180  	if len(parts) < 1 {
   181  		return nil, errors.New("fakedb: no database name")
   182  	}
   183  	name := parts[0]
   184  
   185  	db := d.getDB(name)
   186  
   187  	d.mu.Lock()
   188  	d.openCount++
   189  	d.mu.Unlock()
   190  	conn := &fakeConn{db: db}
   191  
   192  	if len(parts) >= 2 && parts[1] == "badConn" {
   193  		conn.bad = true
   194  	}
   195  	if d.waitCh != nil {
   196  		d.waitingCh <- struct{}{}
   197  		<-d.waitCh
   198  		d.waitCh = nil
   199  		d.waitingCh = nil
   200  	}
   201  	return conn, nil
   202  }
   203  
   204  func (d *fakeDriver) getDB(name string) *fakeDB {
   205  	d.mu.Lock()
   206  	defer d.mu.Unlock()
   207  	if d.dbs == nil {
   208  		d.dbs = make(map[string]*fakeDB)
   209  	}
   210  	db, ok := d.dbs[name]
   211  	if !ok {
   212  		db = &fakeDB{name: name}
   213  		d.dbs[name] = db
   214  	}
   215  	return db
   216  }
   217  
   218  func (db *fakeDB) wipe() {
   219  	db.mu.Lock()
   220  	defer db.mu.Unlock()
   221  	db.tables = nil
   222  }
   223  
   224  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   225  	db.mu.Lock()
   226  	defer db.mu.Unlock()
   227  	if db.tables == nil {
   228  		db.tables = make(map[string]*table)
   229  	}
   230  	if _, exist := db.tables[name]; exist {
   231  		return fmt.Errorf("table %q already exists", name)
   232  	}
   233  	if len(columnNames) != len(columnTypes) {
   234  		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
   235  			name, len(columnNames), len(columnTypes))
   236  	}
   237  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   238  	return nil
   239  }
   240  
   241  // must be called with db.mu lock held
   242  func (db *fakeDB) table(table string) (*table, bool) {
   243  	if db.tables == nil {
   244  		return nil, false
   245  	}
   246  	t, ok := db.tables[table]
   247  	return t, ok
   248  }
   249  
   250  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   251  	db.mu.Lock()
   252  	defer db.mu.Unlock()
   253  	t, ok := db.table(table)
   254  	if !ok {
   255  		return
   256  	}
   257  	for n, cname := range t.colname {
   258  		if cname == column {
   259  			return t.coltype[n], true
   260  		}
   261  	}
   262  	return "", false
   263  }
   264  
   265  func (c *fakeConn) isBad() bool {
   266  	if c.stickyBad {
   267  		return true
   268  	} else if c.bad {
   269  		// alternate between bad conn and not bad conn
   270  		c.db.badConn = !c.db.badConn
   271  		return c.db.badConn
   272  	} else {
   273  		return false
   274  	}
   275  }
   276  
   277  func (c *fakeConn) Begin() (driver.Tx, error) {
   278  	if c.isBad() {
   279  		return nil, driver.ErrBadConn
   280  	}
   281  	if c.currTx != nil {
   282  		return nil, errors.New("already in a transaction")
   283  	}
   284  	c.currTx = &fakeTx{c: c}
   285  	return c.currTx, nil
   286  }
   287  
   288  var hookPostCloseConn struct {
   289  	sync.Mutex
   290  	fn func(*fakeConn, error)
   291  }
   292  
   293  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   294  	hookPostCloseConn.Lock()
   295  	defer hookPostCloseConn.Unlock()
   296  	hookPostCloseConn.fn = fn
   297  }
   298  
   299  var testStrictClose *testing.T
   300  
   301  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   302  // fails to close. If nil, the check is disabled.
   303  func setStrictFakeConnClose(t *testing.T) {
   304  	testStrictClose = t
   305  }
   306  
   307  func (c *fakeConn) Close() (err error) {
   308  	drv := fdriver.(*fakeDriver)
   309  	defer func() {
   310  		if err != nil && testStrictClose != nil {
   311  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   312  		}
   313  		hookPostCloseConn.Lock()
   314  		fn := hookPostCloseConn.fn
   315  		hookPostCloseConn.Unlock()
   316  		if fn != nil {
   317  			fn(c, err)
   318  		}
   319  		if err == nil {
   320  			drv.mu.Lock()
   321  			drv.closeCount++
   322  			drv.mu.Unlock()
   323  		}
   324  	}()
   325  	if c.currTx != nil {
   326  		return errors.New("can't close fakeConn; in a Transaction")
   327  	}
   328  	if c.db == nil {
   329  		return errors.New("can't close fakeConn; already closed")
   330  	}
   331  	if c.stmtsMade > c.stmtsClosed {
   332  		return errors.New("can't close; dangling statement(s)")
   333  	}
   334  	c.db = nil
   335  	return nil
   336  }
   337  
   338  func checkSubsetTypes(args []driver.Value) error {
   339  	for n, arg := range args {
   340  		switch arg.(type) {
   341  		case int64, float64, bool, nil, []byte, string, time.Time:
   342  		default:
   343  			return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
   344  		}
   345  	}
   346  	return nil
   347  }
   348  
   349  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   350  	// This is an optional interface, but it's implemented here
   351  	// just to check that all the args are of the proper types.
   352  	// ErrSkip is returned so the caller acts as if we didn't
   353  	// implement this at all.
   354  	err := checkSubsetTypes(args)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  	return nil, driver.ErrSkip
   359  }
   360  
   361  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   362  	// This is an optional interface, but it's implemented here
   363  	// just to check that all the args are of the proper types.
   364  	// ErrSkip is returned so the caller acts as if we didn't
   365  	// implement this at all.
   366  	err := checkSubsetTypes(args)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  	return nil, driver.ErrSkip
   371  }
   372  
   373  func errf(msg string, args ...interface{}) error {
   374  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   375  }
   376  
   377  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   378  // (note that where columns must always contain ? marks,
   379  //  just a limitation for fakedb)
   380  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
   381  	if len(parts) != 3 {
   382  		stmt.Close()
   383  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   384  	}
   385  	stmt.table = parts[0]
   386  	stmt.colName = strings.Split(parts[1], ",")
   387  	for n, colspec := range strings.Split(parts[2], ",") {
   388  		if colspec == "" {
   389  			continue
   390  		}
   391  		nameVal := strings.Split(colspec, "=")
   392  		if len(nameVal) != 2 {
   393  			stmt.Close()
   394  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   395  		}
   396  		column, value := nameVal[0], nameVal[1]
   397  		_, ok := c.db.columnType(stmt.table, column)
   398  		if !ok {
   399  			stmt.Close()
   400  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   401  		}
   402  		if value != "?" {
   403  			stmt.Close()
   404  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   405  				stmt.table, column)
   406  		}
   407  		stmt.whereCol = append(stmt.whereCol, column)
   408  		stmt.placeholders++
   409  	}
   410  	return stmt, nil
   411  }
   412  
   413  // parts are table|col=type,col2=type2
   414  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
   415  	if len(parts) != 2 {
   416  		stmt.Close()
   417  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   418  	}
   419  	stmt.table = parts[0]
   420  	for n, colspec := range strings.Split(parts[1], ",") {
   421  		nameType := strings.Split(colspec, "=")
   422  		if len(nameType) != 2 {
   423  			stmt.Close()
   424  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   425  		}
   426  		stmt.colName = append(stmt.colName, nameType[0])
   427  		stmt.colType = append(stmt.colType, nameType[1])
   428  	}
   429  	return stmt, nil
   430  }
   431  
   432  // parts are table|col=?,col2=val
   433  func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
   434  	if len(parts) != 2 {
   435  		stmt.Close()
   436  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   437  	}
   438  	stmt.table = parts[0]
   439  	for n, colspec := range strings.Split(parts[1], ",") {
   440  		nameVal := strings.Split(colspec, "=")
   441  		if len(nameVal) != 2 {
   442  			stmt.Close()
   443  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   444  		}
   445  		column, value := nameVal[0], nameVal[1]
   446  		ctype, ok := c.db.columnType(stmt.table, column)
   447  		if !ok {
   448  			stmt.Close()
   449  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   450  		}
   451  		stmt.colName = append(stmt.colName, column)
   452  
   453  		if value != "?" {
   454  			var subsetVal interface{}
   455  			// Convert to driver subset type
   456  			switch ctype {
   457  			case "string":
   458  				subsetVal = []byte(value)
   459  			case "blob":
   460  				subsetVal = []byte(value)
   461  			case "int32":
   462  				i, err := strconv.Atoi(value)
   463  				if err != nil {
   464  					stmt.Close()
   465  					return nil, errf("invalid conversion to int32 from %q", value)
   466  				}
   467  				subsetVal = int64(i) // int64 is a subset type, but not int32
   468  			default:
   469  				stmt.Close()
   470  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   471  			}
   472  			stmt.colValue = append(stmt.colValue, subsetVal)
   473  		} else {
   474  			stmt.placeholders++
   475  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   476  			stmt.colValue = append(stmt.colValue, "?")
   477  		}
   478  	}
   479  	return stmt, nil
   480  }
   481  
   482  // hook to simulate broken connections
   483  var hookPrepareBadConn func() bool
   484  
   485  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   486  	c.numPrepare++
   487  	if c.db == nil {
   488  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   489  	}
   490  
   491  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   492  		return nil, driver.ErrBadConn
   493  	}
   494  
   495  	parts := strings.Split(query, "|")
   496  	if len(parts) < 1 {
   497  		return nil, errf("empty query")
   498  	}
   499  	stmt := &fakeStmt{q: query, c: c}
   500  	if len(parts) >= 3 && parts[0] == "PANIC" {
   501  		stmt.panic = parts[1]
   502  		parts = parts[2:]
   503  	}
   504  	cmd := parts[0]
   505  	stmt.cmd = cmd
   506  	parts = parts[1:]
   507  
   508  	c.incrStat(&c.stmtsMade)
   509  	switch cmd {
   510  	case "WIPE":
   511  		// Nothing
   512  	case "SELECT":
   513  		return c.prepareSelect(stmt, parts)
   514  	case "CREATE":
   515  		return c.prepareCreate(stmt, parts)
   516  	case "INSERT":
   517  		return c.prepareInsert(stmt, parts)
   518  	case "NOSERT":
   519  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   520  		// Used for some of the concurrent tests.
   521  		return c.prepareInsert(stmt, parts)
   522  	default:
   523  		stmt.Close()
   524  		return nil, errf("unsupported command type %q", cmd)
   525  	}
   526  	return stmt, nil
   527  }
   528  
   529  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   530  	if s.panic == "ColumnConverter" {
   531  		panic(s.panic)
   532  	}
   533  	if len(s.placeholderConverter) == 0 {
   534  		return driver.DefaultParameterConverter
   535  	}
   536  	return s.placeholderConverter[idx]
   537  }
   538  
   539  func (s *fakeStmt) Close() error {
   540  	if s.panic == "Close" {
   541  		panic(s.panic)
   542  	}
   543  	if s.c == nil {
   544  		panic("nil conn in fakeStmt.Close")
   545  	}
   546  	if s.c.db == nil {
   547  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   548  	}
   549  	if !s.closed {
   550  		s.c.incrStat(&s.c.stmtsClosed)
   551  		s.closed = true
   552  	}
   553  	return nil
   554  }
   555  
   556  var errClosed = errors.New("fakedb: statement has been closed")
   557  
   558  // hook to simulate broken connections
   559  var hookExecBadConn func() bool
   560  
   561  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   562  	if s.panic == "Exec" {
   563  		panic(s.panic)
   564  	}
   565  	if s.closed {
   566  		return nil, errClosed
   567  	}
   568  
   569  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   570  		return nil, driver.ErrBadConn
   571  	}
   572  
   573  	err := checkSubsetTypes(args)
   574  	if err != nil {
   575  		return nil, err
   576  	}
   577  
   578  	db := s.c.db
   579  	switch s.cmd {
   580  	case "WIPE":
   581  		db.wipe()
   582  		return driver.ResultNoRows, nil
   583  	case "CREATE":
   584  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   585  			return nil, err
   586  		}
   587  		return driver.ResultNoRows, nil
   588  	case "INSERT":
   589  		return s.execInsert(args, true)
   590  	case "NOSERT":
   591  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   592  		// Used for some of the concurrent tests.
   593  		return s.execInsert(args, false)
   594  	}
   595  	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
   596  	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
   597  }
   598  
   599  // When doInsert is true, add the row to the table.
   600  // When doInsert is false do prep-work and error checking, but don't
   601  // actually add the row to the table.
   602  func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) {
   603  	db := s.c.db
   604  	if len(args) != s.placeholders {
   605  		panic("error in pkg db; should only get here if size is correct")
   606  	}
   607  	db.mu.Lock()
   608  	t, ok := db.table(s.table)
   609  	db.mu.Unlock()
   610  	if !ok {
   611  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   612  	}
   613  
   614  	t.mu.Lock()
   615  	defer t.mu.Unlock()
   616  
   617  	var cols []interface{}
   618  	if doInsert {
   619  		cols = make([]interface{}, len(t.colname))
   620  	}
   621  	argPos := 0
   622  	for n, colname := range s.colName {
   623  		colidx := t.columnIndex(colname)
   624  		if colidx == -1 {
   625  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   626  		}
   627  		var val interface{}
   628  		if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
   629  			val = args[argPos]
   630  			argPos++
   631  		} else {
   632  			val = s.colValue[n]
   633  		}
   634  		if doInsert {
   635  			cols[colidx] = val
   636  		}
   637  	}
   638  
   639  	if doInsert {
   640  		t.rows = append(t.rows, &row{cols: cols})
   641  	}
   642  	return driver.RowsAffected(1), nil
   643  }
   644  
   645  // hook to simulate broken connections
   646  var hookQueryBadConn func() bool
   647  
   648  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   649  	if s.panic == "Query" {
   650  		panic(s.panic)
   651  	}
   652  	if s.closed {
   653  		return nil, errClosed
   654  	}
   655  
   656  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   657  		return nil, driver.ErrBadConn
   658  	}
   659  
   660  	err := checkSubsetTypes(args)
   661  	if err != nil {
   662  		return nil, err
   663  	}
   664  
   665  	db := s.c.db
   666  	if len(args) != s.placeholders {
   667  		panic("error in pkg db; should only get here if size is correct")
   668  	}
   669  
   670  	db.mu.Lock()
   671  	t, ok := db.table(s.table)
   672  	db.mu.Unlock()
   673  	if !ok {
   674  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   675  	}
   676  
   677  	if s.table == "magicquery" {
   678  		if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
   679  			if args[0] == "sleep" {
   680  				time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
   681  			}
   682  		}
   683  	}
   684  
   685  	t.mu.Lock()
   686  	defer t.mu.Unlock()
   687  
   688  	colIdx := make(map[string]int) // select column name -> column index in table
   689  	for _, name := range s.colName {
   690  		idx := t.columnIndex(name)
   691  		if idx == -1 {
   692  			return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   693  		}
   694  		colIdx[name] = idx
   695  	}
   696  
   697  	mrows := []*row{}
   698  rows:
   699  	for _, trow := range t.rows {
   700  		// Process the where clause, skipping non-match rows. This is lazy
   701  		// and just uses fmt.Sprintf("%v") to test equality. Good enough
   702  		// for test code.
   703  		for widx, wcol := range s.whereCol {
   704  			idx := t.columnIndex(wcol)
   705  			if idx == -1 {
   706  				return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
   707  			}
   708  			tcol := trow.cols[idx]
   709  			if bs, ok := tcol.([]byte); ok {
   710  				// lazy hack to avoid sprintf %v on a []byte
   711  				tcol = string(bs)
   712  			}
   713  			if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
   714  				continue rows
   715  			}
   716  		}
   717  		mrow := &row{cols: make([]interface{}, len(s.colName))}
   718  		for seli, name := range s.colName {
   719  			mrow.cols[seli] = trow.cols[colIdx[name]]
   720  		}
   721  		mrows = append(mrows, mrow)
   722  	}
   723  
   724  	cursor := &rowsCursor{
   725  		pos:    -1,
   726  		rows:   mrows,
   727  		cols:   s.colName,
   728  		errPos: -1,
   729  	}
   730  	return cursor, nil
   731  }
   732  
   733  func (s *fakeStmt) NumInput() int {
   734  	if s.panic == "NumInput" {
   735  		panic(s.panic)
   736  	}
   737  	return s.placeholders
   738  }
   739  
   740  // hook to simulate broken connections
   741  var hookCommitBadConn func() bool
   742  
   743  func (tx *fakeTx) Commit() error {
   744  	tx.c.currTx = nil
   745  	if hookCommitBadConn != nil && hookCommitBadConn() {
   746  		return driver.ErrBadConn
   747  	}
   748  	return nil
   749  }
   750  
   751  // hook to simulate broken connections
   752  var hookRollbackBadConn func() bool
   753  
   754  func (tx *fakeTx) Rollback() error {
   755  	tx.c.currTx = nil
   756  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
   757  		return driver.ErrBadConn
   758  	}
   759  	return nil
   760  }
   761  
   762  type rowsCursor struct {
   763  	cols   []string
   764  	pos    int
   765  	rows   []*row
   766  	closed bool
   767  
   768  	// errPos and err are for making Next return early with error.
   769  	errPos int
   770  	err    error
   771  
   772  	// a clone of slices to give out to clients, indexed by the
   773  	// the original slice's first byte address.  we clone them
   774  	// just so we're able to corrupt them on close.
   775  	bytesClone map[*byte][]byte
   776  }
   777  
   778  func (rc *rowsCursor) Close() error {
   779  	if !rc.closed {
   780  		for _, bs := range rc.bytesClone {
   781  			bs[0] = 255 // first byte corrupted
   782  		}
   783  	}
   784  	rc.closed = true
   785  	return nil
   786  }
   787  
   788  func (rc *rowsCursor) Columns() []string {
   789  	return rc.cols
   790  }
   791  
   792  var rowsCursorNextHook func(dest []driver.Value) error
   793  
   794  func (rc *rowsCursor) Next(dest []driver.Value) error {
   795  	if rowsCursorNextHook != nil {
   796  		return rowsCursorNextHook(dest)
   797  	}
   798  
   799  	if rc.closed {
   800  		return errors.New("fakedb: cursor is closed")
   801  	}
   802  	rc.pos++
   803  	if rc.pos == rc.errPos {
   804  		return rc.err
   805  	}
   806  	if rc.pos >= len(rc.rows) {
   807  		return io.EOF // per interface spec
   808  	}
   809  	for i, v := range rc.rows[rc.pos].cols {
   810  		// TODO(bradfitz): convert to subset types? naah, I
   811  		// think the subset types should only be input to
   812  		// driver, but the sql package should be able to handle
   813  		// a wider range of types coming out of drivers. all
   814  		// for ease of drivers, and to prevent drivers from
   815  		// messing up conversions or doing them differently.
   816  		dest[i] = v
   817  
   818  		if bs, ok := v.([]byte); ok {
   819  			if rc.bytesClone == nil {
   820  				rc.bytesClone = make(map[*byte][]byte)
   821  			}
   822  			clone, ok := rc.bytesClone[&bs[0]]
   823  			if !ok {
   824  				clone = make([]byte, len(bs))
   825  				copy(clone, bs)
   826  				rc.bytesClone[&bs[0]] = clone
   827  			}
   828  			dest[i] = clone
   829  		}
   830  	}
   831  	return nil
   832  }
   833  
   834  // fakeDriverString is like driver.String, but indirects pointers like
   835  // DefaultValueConverter.
   836  //
   837  // This could be surprising behavior to retroactively apply to
   838  // driver.String now that Go1 is out, but this is convenient for
   839  // our TestPointerParamsAndScans.
   840  //
   841  type fakeDriverString struct{}
   842  
   843  func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
   844  	switch c := v.(type) {
   845  	case string, []byte:
   846  		return v, nil
   847  	case *string:
   848  		if c == nil {
   849  			return nil, nil
   850  		}
   851  		return *c, nil
   852  	}
   853  	return fmt.Sprintf("%v", v), nil
   854  }
   855  
   856  func converterForType(typ string) driver.ValueConverter {
   857  	switch typ {
   858  	case "bool":
   859  		return driver.Bool
   860  	case "nullbool":
   861  		return driver.Null{Converter: driver.Bool}
   862  	case "int32":
   863  		return driver.Int32
   864  	case "string":
   865  		return driver.NotNull{Converter: fakeDriverString{}}
   866  	case "nullstring":
   867  		return driver.Null{Converter: fakeDriverString{}}
   868  	case "int64":
   869  		// TODO(coopernurse): add type-specific converter
   870  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
   871  	case "nullint64":
   872  		// TODO(coopernurse): add type-specific converter
   873  		return driver.Null{Converter: driver.DefaultParameterConverter}
   874  	case "float64":
   875  		// TODO(coopernurse): add type-specific converter
   876  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
   877  	case "nullfloat64":
   878  		// TODO(coopernurse): add type-specific converter
   879  		return driver.Null{Converter: driver.DefaultParameterConverter}
   880  	case "datetime":
   881  		return driver.DefaultParameterConverter
   882  	}
   883  	panic("invalid fakedb column type of " + typ)
   884  }