github.com/ltltlt/go-source-code@v0.0.0-20190830023027-95be009773aa/database/sql/fakedb_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"reflect"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  var _ = log.Printf
    24  
    25  // fakeDriver is a fake database that implements Go's driver.Driver
    26  // interface, just for testing.
    27  //
    28  // It speaks a query language that's semantically similar to but
    29  // syntactically different and simpler than SQL.  The syntax is as
    30  // follows:
    31  //
    32  //   WIPE
    33  //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    34  //     where types are: "string", [u]int{8,16,32,64}, "bool"
    35  //   INSERT|<tablename>|col=val,col2=val2,col3=?
    36  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    37  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    38  //
    39  // Any of these can be preceded by PANIC|<method>|, to cause the
    40  // named method on fakeStmt to panic.
    41  //
    42  // Any of these can be proceeded by WAIT|<duration>|, to cause the
    43  // named method on fakeStmt to sleep for the specified duration.
    44  //
    45  // Multiple of these can be combined when separated with a semicolon.
    46  //
    47  // When opening a fakeDriver's database, it starts empty with no
    48  // tables. All tables and data are stored in memory only.
    49  type fakeDriver struct {
    50  	mu         sync.Mutex // guards 3 following fields
    51  	openCount  int        // conn opens
    52  	closeCount int        // conn closes
    53  	waitCh     chan struct{}
    54  	waitingCh  chan struct{}
    55  	dbs        map[string]*fakeDB
    56  }
    57  
    58  type fakeConnector struct {
    59  	name string
    60  
    61  	waiter func(context.Context)
    62  }
    63  
    64  func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
    65  	conn, err := fdriver.Open(c.name)
    66  	conn.(*fakeConn).waiter = c.waiter
    67  	return conn, err
    68  }
    69  
    70  func (c *fakeConnector) Driver() driver.Driver {
    71  	return fdriver
    72  }
    73  
    74  type fakeDriverCtx struct {
    75  	fakeDriver
    76  }
    77  
    78  var _ driver.DriverContext = &fakeDriverCtx{}
    79  
    80  func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
    81  	return &fakeConnector{name: name}, nil
    82  }
    83  
    84  type fakeDB struct {
    85  	name string
    86  
    87  	mu       sync.Mutex
    88  	tables   map[string]*table
    89  	badConn  bool
    90  	allowAny bool
    91  }
    92  
    93  type table struct {
    94  	mu      sync.Mutex
    95  	colname []string
    96  	coltype []string
    97  	rows    []*row
    98  }
    99  
   100  func (t *table) columnIndex(name string) int {
   101  	for n, nname := range t.colname {
   102  		if name == nname {
   103  			return n
   104  		}
   105  	}
   106  	return -1
   107  }
   108  
   109  type row struct {
   110  	cols []interface{} // must be same size as its table colname + coltype
   111  }
   112  
   113  type memToucher interface {
   114  	// touchMem reads & writes some memory, to help find data races.
   115  	touchMem()
   116  }
   117  
   118  type fakeConn struct {
   119  	db *fakeDB // where to return ourselves to
   120  
   121  	currTx *fakeTx
   122  
   123  	// Every operation writes to line to enable the race detector
   124  	// check for data races.
   125  	line int64
   126  
   127  	// Stats for tests:
   128  	mu          sync.Mutex
   129  	stmtsMade   int
   130  	stmtsClosed int
   131  	numPrepare  int
   132  
   133  	// bad connection tests; see isBad()
   134  	bad       bool
   135  	stickyBad bool
   136  
   137  	skipDirtySession bool // tests that use Conn should set this to true.
   138  
   139  	// dirtySession tests ResetSession, true if a query has executed
   140  	// until ResetSession is called.
   141  	dirtySession bool
   142  
   143  	// The waiter is called before each query. May be used in place of the "WAIT"
   144  	// directive.
   145  	waiter func(context.Context)
   146  }
   147  
   148  func (c *fakeConn) touchMem() {
   149  	c.line++
   150  }
   151  
   152  func (c *fakeConn) incrStat(v *int) {
   153  	c.mu.Lock()
   154  	*v++
   155  	c.mu.Unlock()
   156  }
   157  
   158  type fakeTx struct {
   159  	c *fakeConn
   160  }
   161  
   162  type boundCol struct {
   163  	Column      string
   164  	Placeholder string
   165  	Ordinal     int
   166  }
   167  
   168  type fakeStmt struct {
   169  	memToucher
   170  	c *fakeConn
   171  	q string // just for debugging
   172  
   173  	cmd   string
   174  	table string
   175  	panic string
   176  	wait  time.Duration
   177  
   178  	next *fakeStmt // used for returning multiple results.
   179  
   180  	closed bool
   181  
   182  	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
   183  	colType      []string      // used by CREATE
   184  	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
   185  	placeholders int           // used by INSERT/SELECT: number of ? params
   186  
   187  	whereCol []boundCol // used by SELECT (all placeholders)
   188  
   189  	placeholderConverter []driver.ValueConverter // used by INSERT
   190  }
   191  
   192  var fdriver driver.Driver = &fakeDriver{}
   193  
   194  func init() {
   195  	Register("test", fdriver)
   196  }
   197  
   198  func contains(list []string, y string) bool {
   199  	for _, x := range list {
   200  		if x == y {
   201  			return true
   202  		}
   203  	}
   204  	return false
   205  }
   206  
   207  type Dummy struct {
   208  	driver.Driver
   209  }
   210  
   211  func TestDrivers(t *testing.T) {
   212  	unregisterAllDrivers()
   213  	Register("test", fdriver)
   214  	Register("invalid", Dummy{})
   215  	all := Drivers()
   216  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   217  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   218  	}
   219  }
   220  
   221  // hook to simulate connection failures
   222  var hookOpenErr struct {
   223  	sync.Mutex
   224  	fn func() error
   225  }
   226  
   227  func setHookOpenErr(fn func() error) {
   228  	hookOpenErr.Lock()
   229  	defer hookOpenErr.Unlock()
   230  	hookOpenErr.fn = fn
   231  }
   232  
   233  // Supports dsn forms:
   234  //    <dbname>
   235  //    <dbname>;<opts>  (only currently supported option is `badConn`,
   236  //                      which causes driver.ErrBadConn to be returned on
   237  //                      every other conn.Begin())
   238  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   239  	hookOpenErr.Lock()
   240  	fn := hookOpenErr.fn
   241  	hookOpenErr.Unlock()
   242  	if fn != nil {
   243  		if err := fn(); err != nil {
   244  			return nil, err
   245  		}
   246  	}
   247  	parts := strings.Split(dsn, ";")
   248  	if len(parts) < 1 {
   249  		return nil, errors.New("fakedb: no database name")
   250  	}
   251  	name := parts[0]
   252  
   253  	db := d.getDB(name)
   254  
   255  	d.mu.Lock()
   256  	d.openCount++
   257  	d.mu.Unlock()
   258  	conn := &fakeConn{db: db}
   259  
   260  	if len(parts) >= 2 && parts[1] == "badConn" {
   261  		conn.bad = true
   262  	}
   263  	if d.waitCh != nil {
   264  		d.waitingCh <- struct{}{}
   265  		<-d.waitCh
   266  		d.waitCh = nil
   267  		d.waitingCh = nil
   268  	}
   269  	return conn, nil
   270  }
   271  
   272  func (d *fakeDriver) getDB(name string) *fakeDB {
   273  	d.mu.Lock()
   274  	defer d.mu.Unlock()
   275  	if d.dbs == nil {
   276  		d.dbs = make(map[string]*fakeDB)
   277  	}
   278  	db, ok := d.dbs[name]
   279  	if !ok {
   280  		db = &fakeDB{name: name}
   281  		d.dbs[name] = db
   282  	}
   283  	return db
   284  }
   285  
   286  func (db *fakeDB) wipe() {
   287  	db.mu.Lock()
   288  	defer db.mu.Unlock()
   289  	db.tables = nil
   290  }
   291  
   292  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   293  	db.mu.Lock()
   294  	defer db.mu.Unlock()
   295  	if db.tables == nil {
   296  		db.tables = make(map[string]*table)
   297  	}
   298  	if _, exist := db.tables[name]; exist {
   299  		return fmt.Errorf("table %q already exists", name)
   300  	}
   301  	if len(columnNames) != len(columnTypes) {
   302  		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
   303  			name, len(columnNames), len(columnTypes))
   304  	}
   305  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   306  	return nil
   307  }
   308  
   309  // must be called with db.mu lock held
   310  func (db *fakeDB) table(table string) (*table, bool) {
   311  	if db.tables == nil {
   312  		return nil, false
   313  	}
   314  	t, ok := db.tables[table]
   315  	return t, ok
   316  }
   317  
   318  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   319  	db.mu.Lock()
   320  	defer db.mu.Unlock()
   321  	t, ok := db.table(table)
   322  	if !ok {
   323  		return
   324  	}
   325  	for n, cname := range t.colname {
   326  		if cname == column {
   327  			return t.coltype[n], true
   328  		}
   329  	}
   330  	return "", false
   331  }
   332  
   333  func (c *fakeConn) isBad() bool {
   334  	if c.stickyBad {
   335  		return true
   336  	} else if c.bad {
   337  		if c.db == nil {
   338  			return false
   339  		}
   340  		// alternate between bad conn and not bad conn
   341  		c.db.badConn = !c.db.badConn
   342  		return c.db.badConn
   343  	} else {
   344  		return false
   345  	}
   346  }
   347  
   348  func (c *fakeConn) isDirtyAndMark() bool {
   349  	if c.skipDirtySession {
   350  		return false
   351  	}
   352  	if c.currTx != nil {
   353  		c.dirtySession = true
   354  		return false
   355  	}
   356  	if c.dirtySession {
   357  		return true
   358  	}
   359  	c.dirtySession = true
   360  	return false
   361  }
   362  
   363  func (c *fakeConn) Begin() (driver.Tx, error) {
   364  	if c.isBad() {
   365  		return nil, driver.ErrBadConn
   366  	}
   367  	if c.currTx != nil {
   368  		return nil, errors.New("already in a transaction")
   369  	}
   370  	c.touchMem()
   371  	c.currTx = &fakeTx{c: c}
   372  	return c.currTx, nil
   373  }
   374  
   375  var hookPostCloseConn struct {
   376  	sync.Mutex
   377  	fn func(*fakeConn, error)
   378  }
   379  
   380  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   381  	hookPostCloseConn.Lock()
   382  	defer hookPostCloseConn.Unlock()
   383  	hookPostCloseConn.fn = fn
   384  }
   385  
   386  var testStrictClose *testing.T
   387  
   388  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   389  // fails to close. If nil, the check is disabled.
   390  func setStrictFakeConnClose(t *testing.T) {
   391  	testStrictClose = t
   392  }
   393  
   394  func (c *fakeConn) ResetSession(ctx context.Context) error {
   395  	c.dirtySession = false
   396  	if c.isBad() {
   397  		return driver.ErrBadConn
   398  	}
   399  	return nil
   400  }
   401  
   402  func (c *fakeConn) Close() (err error) {
   403  	drv := fdriver.(*fakeDriver)
   404  	defer func() {
   405  		if err != nil && testStrictClose != nil {
   406  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   407  		}
   408  		hookPostCloseConn.Lock()
   409  		fn := hookPostCloseConn.fn
   410  		hookPostCloseConn.Unlock()
   411  		if fn != nil {
   412  			fn(c, err)
   413  		}
   414  		if err == nil {
   415  			drv.mu.Lock()
   416  			drv.closeCount++
   417  			drv.mu.Unlock()
   418  		}
   419  	}()
   420  	c.touchMem()
   421  	if c.currTx != nil {
   422  		return errors.New("can't close fakeConn; in a Transaction")
   423  	}
   424  	if c.db == nil {
   425  		return errors.New("can't close fakeConn; already closed")
   426  	}
   427  	if c.stmtsMade > c.stmtsClosed {
   428  		return errors.New("can't close; dangling statement(s)")
   429  	}
   430  	c.db = nil
   431  	return nil
   432  }
   433  
   434  func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
   435  	for _, arg := range args {
   436  		switch arg.Value.(type) {
   437  		case int64, float64, bool, nil, []byte, string, time.Time:
   438  		default:
   439  			if !allowAny {
   440  				return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   441  			}
   442  		}
   443  	}
   444  	return nil
   445  }
   446  
   447  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   448  	// Ensure that ExecContext is called if available.
   449  	panic("ExecContext was not called.")
   450  }
   451  
   452  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   453  	// This is an optional interface, but it's implemented here
   454  	// just to check that all the args are of the proper types.
   455  	// ErrSkip is returned so the caller acts as if we didn't
   456  	// implement this at all.
   457  	err := checkSubsetTypes(c.db.allowAny, args)
   458  	if err != nil {
   459  		return nil, err
   460  	}
   461  	return nil, driver.ErrSkip
   462  }
   463  
   464  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   465  	// Ensure that ExecContext is called if available.
   466  	panic("QueryContext was not called.")
   467  }
   468  
   469  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   470  	// This is an optional interface, but it's implemented here
   471  	// just to check that all the args are of the proper types.
   472  	// ErrSkip is returned so the caller acts as if we didn't
   473  	// implement this at all.
   474  	err := checkSubsetTypes(c.db.allowAny, args)
   475  	if err != nil {
   476  		return nil, err
   477  	}
   478  	return nil, driver.ErrSkip
   479  }
   480  
   481  func errf(msg string, args ...interface{}) error {
   482  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   483  }
   484  
   485  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   486  // (note that where columns must always contain ? marks,
   487  //  just a limitation for fakedb)
   488  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   489  	if len(parts) != 3 {
   490  		stmt.Close()
   491  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   492  	}
   493  	stmt.table = parts[0]
   494  
   495  	stmt.colName = strings.Split(parts[1], ",")
   496  	for n, colspec := range strings.Split(parts[2], ",") {
   497  		if colspec == "" {
   498  			continue
   499  		}
   500  		nameVal := strings.Split(colspec, "=")
   501  		if len(nameVal) != 2 {
   502  			stmt.Close()
   503  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   504  		}
   505  		column, value := nameVal[0], nameVal[1]
   506  		_, ok := c.db.columnType(stmt.table, column)
   507  		if !ok {
   508  			stmt.Close()
   509  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   510  		}
   511  		if !strings.HasPrefix(value, "?") {
   512  			stmt.Close()
   513  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   514  				stmt.table, column)
   515  		}
   516  		stmt.placeholders++
   517  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   518  	}
   519  	return stmt, nil
   520  }
   521  
   522  // parts are table|col=type,col2=type2
   523  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   524  	if len(parts) != 2 {
   525  		stmt.Close()
   526  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   527  	}
   528  	stmt.table = parts[0]
   529  	for n, colspec := range strings.Split(parts[1], ",") {
   530  		nameType := strings.Split(colspec, "=")
   531  		if len(nameType) != 2 {
   532  			stmt.Close()
   533  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   534  		}
   535  		stmt.colName = append(stmt.colName, nameType[0])
   536  		stmt.colType = append(stmt.colType, nameType[1])
   537  	}
   538  	return stmt, nil
   539  }
   540  
   541  // parts are table|col=?,col2=val
   542  func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   543  	if len(parts) != 2 {
   544  		stmt.Close()
   545  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   546  	}
   547  	stmt.table = parts[0]
   548  	for n, colspec := range strings.Split(parts[1], ",") {
   549  		nameVal := strings.Split(colspec, "=")
   550  		if len(nameVal) != 2 {
   551  			stmt.Close()
   552  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   553  		}
   554  		column, value := nameVal[0], nameVal[1]
   555  		ctype, ok := c.db.columnType(stmt.table, column)
   556  		if !ok {
   557  			stmt.Close()
   558  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   559  		}
   560  		stmt.colName = append(stmt.colName, column)
   561  
   562  		if !strings.HasPrefix(value, "?") {
   563  			var subsetVal interface{}
   564  			// Convert to driver subset type
   565  			switch ctype {
   566  			case "string":
   567  				subsetVal = []byte(value)
   568  			case "blob":
   569  				subsetVal = []byte(value)
   570  			case "int32":
   571  				i, err := strconv.Atoi(value)
   572  				if err != nil {
   573  					stmt.Close()
   574  					return nil, errf("invalid conversion to int32 from %q", value)
   575  				}
   576  				subsetVal = int64(i) // int64 is a subset type, but not int32
   577  			default:
   578  				stmt.Close()
   579  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   580  			}
   581  			stmt.colValue = append(stmt.colValue, subsetVal)
   582  		} else {
   583  			stmt.placeholders++
   584  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   585  			stmt.colValue = append(stmt.colValue, value)
   586  		}
   587  	}
   588  	return stmt, nil
   589  }
   590  
   591  // hook to simulate broken connections
   592  var hookPrepareBadConn func() bool
   593  
   594  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   595  	panic("use PrepareContext")
   596  }
   597  
   598  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   599  	c.numPrepare++
   600  	if c.db == nil {
   601  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   602  	}
   603  
   604  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   605  		return nil, driver.ErrBadConn
   606  	}
   607  
   608  	c.touchMem()
   609  	var firstStmt, prev *fakeStmt
   610  	for _, query := range strings.Split(query, ";") {
   611  		parts := strings.Split(query, "|")
   612  		if len(parts) < 1 {
   613  			return nil, errf("empty query")
   614  		}
   615  		stmt := &fakeStmt{q: query, c: c, memToucher: c}
   616  		if firstStmt == nil {
   617  			firstStmt = stmt
   618  		}
   619  		if len(parts) >= 3 {
   620  			switch parts[0] {
   621  			case "PANIC":
   622  				stmt.panic = parts[1]
   623  				parts = parts[2:]
   624  			case "WAIT":
   625  				wait, err := time.ParseDuration(parts[1])
   626  				if err != nil {
   627  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
   628  				}
   629  				parts = parts[2:]
   630  				stmt.wait = wait
   631  			}
   632  		}
   633  		cmd := parts[0]
   634  		stmt.cmd = cmd
   635  		parts = parts[1:]
   636  
   637  		if c.waiter != nil {
   638  			c.waiter(ctx)
   639  		}
   640  
   641  		if stmt.wait > 0 {
   642  			wait := time.NewTimer(stmt.wait)
   643  			select {
   644  			case <-wait.C:
   645  			case <-ctx.Done():
   646  				wait.Stop()
   647  				return nil, ctx.Err()
   648  			}
   649  		}
   650  
   651  		c.incrStat(&c.stmtsMade)
   652  		var err error
   653  		switch cmd {
   654  		case "WIPE":
   655  			// Nothing
   656  		case "SELECT":
   657  			stmt, err = c.prepareSelect(stmt, parts)
   658  		case "CREATE":
   659  			stmt, err = c.prepareCreate(stmt, parts)
   660  		case "INSERT":
   661  			stmt, err = c.prepareInsert(stmt, parts)
   662  		case "NOSERT":
   663  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   664  			// Used for some of the concurrent tests.
   665  			stmt, err = c.prepareInsert(stmt, parts)
   666  		default:
   667  			stmt.Close()
   668  			return nil, errf("unsupported command type %q", cmd)
   669  		}
   670  		if err != nil {
   671  			return nil, err
   672  		}
   673  		if prev != nil {
   674  			prev.next = stmt
   675  		}
   676  		prev = stmt
   677  	}
   678  	return firstStmt, nil
   679  }
   680  
   681  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   682  	if s.panic == "ColumnConverter" {
   683  		panic(s.panic)
   684  	}
   685  	if len(s.placeholderConverter) == 0 {
   686  		return driver.DefaultParameterConverter
   687  	}
   688  	return s.placeholderConverter[idx]
   689  }
   690  
   691  func (s *fakeStmt) Close() error {
   692  	if s.panic == "Close" {
   693  		panic(s.panic)
   694  	}
   695  	if s.c == nil {
   696  		panic("nil conn in fakeStmt.Close")
   697  	}
   698  	if s.c.db == nil {
   699  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   700  	}
   701  	s.touchMem()
   702  	if !s.closed {
   703  		s.c.incrStat(&s.c.stmtsClosed)
   704  		s.closed = true
   705  	}
   706  	if s.next != nil {
   707  		s.next.Close()
   708  	}
   709  	return nil
   710  }
   711  
   712  var errClosed = errors.New("fakedb: statement has been closed")
   713  
   714  // hook to simulate broken connections
   715  var hookExecBadConn func() bool
   716  
   717  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   718  	panic("Using ExecContext")
   719  }
   720  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   721  	if s.panic == "Exec" {
   722  		panic(s.panic)
   723  	}
   724  	if s.closed {
   725  		return nil, errClosed
   726  	}
   727  
   728  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   729  		return nil, driver.ErrBadConn
   730  	}
   731  	if s.c.isDirtyAndMark() {
   732  		return nil, errors.New("session is dirty")
   733  	}
   734  
   735  	err := checkSubsetTypes(s.c.db.allowAny, args)
   736  	if err != nil {
   737  		return nil, err
   738  	}
   739  	s.touchMem()
   740  
   741  	if s.wait > 0 {
   742  		time.Sleep(s.wait)
   743  	}
   744  
   745  	select {
   746  	default:
   747  	case <-ctx.Done():
   748  		return nil, ctx.Err()
   749  	}
   750  
   751  	db := s.c.db
   752  	switch s.cmd {
   753  	case "WIPE":
   754  		db.wipe()
   755  		return driver.ResultNoRows, nil
   756  	case "CREATE":
   757  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   758  			return nil, err
   759  		}
   760  		return driver.ResultNoRows, nil
   761  	case "INSERT":
   762  		return s.execInsert(args, true)
   763  	case "NOSERT":
   764  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   765  		// Used for some of the concurrent tests.
   766  		return s.execInsert(args, false)
   767  	}
   768  	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
   769  	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
   770  }
   771  
   772  // When doInsert is true, add the row to the table.
   773  // When doInsert is false do prep-work and error checking, but don't
   774  // actually add the row to the table.
   775  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   776  	db := s.c.db
   777  	if len(args) != s.placeholders {
   778  		panic("error in pkg db; should only get here if size is correct")
   779  	}
   780  	db.mu.Lock()
   781  	t, ok := db.table(s.table)
   782  	db.mu.Unlock()
   783  	if !ok {
   784  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   785  	}
   786  
   787  	t.mu.Lock()
   788  	defer t.mu.Unlock()
   789  
   790  	var cols []interface{}
   791  	if doInsert {
   792  		cols = make([]interface{}, len(t.colname))
   793  	}
   794  	argPos := 0
   795  	for n, colname := range s.colName {
   796  		colidx := t.columnIndex(colname)
   797  		if colidx == -1 {
   798  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   799  		}
   800  		var val interface{}
   801  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   802  			if strvalue == "?" {
   803  				val = args[argPos].Value
   804  			} else {
   805  				// Assign value from argument placeholder name.
   806  				for _, a := range args {
   807  					if a.Name == strvalue[1:] {
   808  						val = a.Value
   809  						break
   810  					}
   811  				}
   812  			}
   813  			argPos++
   814  		} else {
   815  			val = s.colValue[n]
   816  		}
   817  		if doInsert {
   818  			cols[colidx] = val
   819  		}
   820  	}
   821  
   822  	if doInsert {
   823  		t.rows = append(t.rows, &row{cols: cols})
   824  	}
   825  	return driver.RowsAffected(1), nil
   826  }
   827  
   828  // hook to simulate broken connections
   829  var hookQueryBadConn func() bool
   830  
   831  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   832  	panic("Use QueryContext")
   833  }
   834  
   835  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   836  	if s.panic == "Query" {
   837  		panic(s.panic)
   838  	}
   839  	if s.closed {
   840  		return nil, errClosed
   841  	}
   842  
   843  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   844  		return nil, driver.ErrBadConn
   845  	}
   846  	if s.c.isDirtyAndMark() {
   847  		return nil, errors.New("session is dirty")
   848  	}
   849  
   850  	err := checkSubsetTypes(s.c.db.allowAny, args)
   851  	if err != nil {
   852  		return nil, err
   853  	}
   854  
   855  	s.touchMem()
   856  	db := s.c.db
   857  	if len(args) != s.placeholders {
   858  		panic("error in pkg db; should only get here if size is correct")
   859  	}
   860  
   861  	setMRows := make([][]*row, 0, 1)
   862  	setColumns := make([][]string, 0, 1)
   863  	setColType := make([][]string, 0, 1)
   864  
   865  	for {
   866  		db.mu.Lock()
   867  		t, ok := db.table(s.table)
   868  		db.mu.Unlock()
   869  		if !ok {
   870  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   871  		}
   872  
   873  		if s.table == "magicquery" {
   874  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   875  				if args[0].Value == "sleep" {
   876  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   877  				}
   878  			}
   879  		}
   880  
   881  		t.mu.Lock()
   882  
   883  		colIdx := make(map[string]int) // select column name -> column index in table
   884  		for _, name := range s.colName {
   885  			idx := t.columnIndex(name)
   886  			if idx == -1 {
   887  				t.mu.Unlock()
   888  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   889  			}
   890  			colIdx[name] = idx
   891  		}
   892  
   893  		mrows := []*row{}
   894  	rows:
   895  		for _, trow := range t.rows {
   896  			// Process the where clause, skipping non-match rows. This is lazy
   897  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   898  			// for test code.
   899  			for _, wcol := range s.whereCol {
   900  				idx := t.columnIndex(wcol.Column)
   901  				if idx == -1 {
   902  					t.mu.Unlock()
   903  					return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
   904  				}
   905  				tcol := trow.cols[idx]
   906  				if bs, ok := tcol.([]byte); ok {
   907  					// lazy hack to avoid sprintf %v on a []byte
   908  					tcol = string(bs)
   909  				}
   910  				var argValue interface{}
   911  				if wcol.Placeholder == "?" {
   912  					argValue = args[wcol.Ordinal-1].Value
   913  				} else {
   914  					// Assign arg value from placeholder name.
   915  					for _, a := range args {
   916  						if a.Name == wcol.Placeholder[1:] {
   917  							argValue = a.Value
   918  							break
   919  						}
   920  					}
   921  				}
   922  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
   923  					continue rows
   924  				}
   925  			}
   926  			mrow := &row{cols: make([]interface{}, len(s.colName))}
   927  			for seli, name := range s.colName {
   928  				mrow.cols[seli] = trow.cols[colIdx[name]]
   929  			}
   930  			mrows = append(mrows, mrow)
   931  		}
   932  
   933  		var colType []string
   934  		for _, column := range s.colName {
   935  			colType = append(colType, t.coltype[t.columnIndex(column)])
   936  		}
   937  
   938  		t.mu.Unlock()
   939  
   940  		setMRows = append(setMRows, mrows)
   941  		setColumns = append(setColumns, s.colName)
   942  		setColType = append(setColType, colType)
   943  
   944  		if s.next == nil {
   945  			break
   946  		}
   947  		s = s.next
   948  	}
   949  
   950  	cursor := &rowsCursor{
   951  		parentMem: s.c,
   952  		posRow:    -1,
   953  		rows:      setMRows,
   954  		cols:      setColumns,
   955  		colType:   setColType,
   956  		errPos:    -1,
   957  	}
   958  	return cursor, nil
   959  }
   960  
   961  func (s *fakeStmt) NumInput() int {
   962  	if s.panic == "NumInput" {
   963  		panic(s.panic)
   964  	}
   965  	return s.placeholders
   966  }
   967  
   968  // hook to simulate broken connections
   969  var hookCommitBadConn func() bool
   970  
   971  func (tx *fakeTx) Commit() error {
   972  	tx.c.currTx = nil
   973  	if hookCommitBadConn != nil && hookCommitBadConn() {
   974  		return driver.ErrBadConn
   975  	}
   976  	tx.c.touchMem()
   977  	return nil
   978  }
   979  
   980  // hook to simulate broken connections
   981  var hookRollbackBadConn func() bool
   982  
   983  func (tx *fakeTx) Rollback() error {
   984  	tx.c.currTx = nil
   985  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
   986  		return driver.ErrBadConn
   987  	}
   988  	tx.c.touchMem()
   989  	return nil
   990  }
   991  
   992  type rowsCursor struct {
   993  	parentMem memToucher
   994  	cols      [][]string
   995  	colType   [][]string
   996  	posSet    int
   997  	posRow    int
   998  	rows      [][]*row
   999  	closed    bool
  1000  
  1001  	// errPos and err are for making Next return early with error.
  1002  	errPos int
  1003  	err    error
  1004  
  1005  	// a clone of slices to give out to clients, indexed by the
  1006  	// the original slice's first byte address.  we clone them
  1007  	// just so we're able to corrupt them on close.
  1008  	bytesClone map[*byte][]byte
  1009  
  1010  	// Every operation writes to line to enable the race detector
  1011  	// check for data races.
  1012  	// This is separate from the fakeConn.line to allow for drivers that
  1013  	// can start multiple queries on the same transaction at the same time.
  1014  	line int64
  1015  }
  1016  
  1017  func (rc *rowsCursor) touchMem() {
  1018  	rc.parentMem.touchMem()
  1019  	rc.line++
  1020  }
  1021  
  1022  func (rc *rowsCursor) Close() error {
  1023  	rc.touchMem()
  1024  	rc.parentMem.touchMem()
  1025  	rc.closed = true
  1026  	return nil
  1027  }
  1028  
  1029  func (rc *rowsCursor) Columns() []string {
  1030  	return rc.cols[rc.posSet]
  1031  }
  1032  
  1033  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
  1034  	return colTypeToReflectType(rc.colType[rc.posSet][index])
  1035  }
  1036  
  1037  var rowsCursorNextHook func(dest []driver.Value) error
  1038  
  1039  func (rc *rowsCursor) Next(dest []driver.Value) error {
  1040  	if rowsCursorNextHook != nil {
  1041  		return rowsCursorNextHook(dest)
  1042  	}
  1043  
  1044  	if rc.closed {
  1045  		return errors.New("fakedb: cursor is closed")
  1046  	}
  1047  	rc.touchMem()
  1048  	rc.posRow++
  1049  	if rc.posRow == rc.errPos {
  1050  		return rc.err
  1051  	}
  1052  	if rc.posRow >= len(rc.rows[rc.posSet]) {
  1053  		return io.EOF // per interface spec
  1054  	}
  1055  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
  1056  		// TODO(bradfitz): convert to subset types? naah, I
  1057  		// think the subset types should only be input to
  1058  		// driver, but the sql package should be able to handle
  1059  		// a wider range of types coming out of drivers. all
  1060  		// for ease of drivers, and to prevent drivers from
  1061  		// messing up conversions or doing them differently.
  1062  		dest[i] = v
  1063  
  1064  		if bs, ok := v.([]byte); ok {
  1065  			if rc.bytesClone == nil {
  1066  				rc.bytesClone = make(map[*byte][]byte)
  1067  			}
  1068  			clone, ok := rc.bytesClone[&bs[0]]
  1069  			if !ok {
  1070  				clone = make([]byte, len(bs))
  1071  				copy(clone, bs)
  1072  				rc.bytesClone[&bs[0]] = clone
  1073  			}
  1074  			dest[i] = clone
  1075  		}
  1076  	}
  1077  	return nil
  1078  }
  1079  
  1080  func (rc *rowsCursor) HasNextResultSet() bool {
  1081  	rc.touchMem()
  1082  	return rc.posSet < len(rc.rows)-1
  1083  }
  1084  
  1085  func (rc *rowsCursor) NextResultSet() error {
  1086  	rc.touchMem()
  1087  	if rc.HasNextResultSet() {
  1088  		rc.posSet++
  1089  		rc.posRow = -1
  1090  		return nil
  1091  	}
  1092  	return io.EOF // Per interface spec.
  1093  }
  1094  
  1095  // fakeDriverString is like driver.String, but indirects pointers like
  1096  // DefaultValueConverter.
  1097  //
  1098  // This could be surprising behavior to retroactively apply to
  1099  // driver.String now that Go1 is out, but this is convenient for
  1100  // our TestPointerParamsAndScans.
  1101  //
  1102  type fakeDriverString struct{}
  1103  
  1104  func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
  1105  	switch c := v.(type) {
  1106  	case string, []byte:
  1107  		return v, nil
  1108  	case *string:
  1109  		if c == nil {
  1110  			return nil, nil
  1111  		}
  1112  		return *c, nil
  1113  	}
  1114  	return fmt.Sprintf("%v", v), nil
  1115  }
  1116  
  1117  type anyTypeConverter struct{}
  1118  
  1119  func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
  1120  	return v, nil
  1121  }
  1122  
  1123  func converterForType(typ string) driver.ValueConverter {
  1124  	switch typ {
  1125  	case "bool":
  1126  		return driver.Bool
  1127  	case "nullbool":
  1128  		return driver.Null{Converter: driver.Bool}
  1129  	case "int32":
  1130  		return driver.Int32
  1131  	case "string":
  1132  		return driver.NotNull{Converter: fakeDriverString{}}
  1133  	case "nullstring":
  1134  		return driver.Null{Converter: fakeDriverString{}}
  1135  	case "int64":
  1136  		// TODO(coopernurse): add type-specific converter
  1137  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1138  	case "nullint64":
  1139  		// TODO(coopernurse): add type-specific converter
  1140  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1141  	case "float64":
  1142  		// TODO(coopernurse): add type-specific converter
  1143  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1144  	case "nullfloat64":
  1145  		// TODO(coopernurse): add type-specific converter
  1146  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1147  	case "datetime":
  1148  		return driver.DefaultParameterConverter
  1149  	case "any":
  1150  		return anyTypeConverter{}
  1151  	}
  1152  	panic("invalid fakedb column type of " + typ)
  1153  }
  1154  
  1155  func colTypeToReflectType(typ string) reflect.Type {
  1156  	switch typ {
  1157  	case "bool":
  1158  		return reflect.TypeOf(false)
  1159  	case "nullbool":
  1160  		return reflect.TypeOf(NullBool{})
  1161  	case "int32":
  1162  		return reflect.TypeOf(int32(0))
  1163  	case "string":
  1164  		return reflect.TypeOf("")
  1165  	case "nullstring":
  1166  		return reflect.TypeOf(NullString{})
  1167  	case "int64":
  1168  		return reflect.TypeOf(int64(0))
  1169  	case "nullint64":
  1170  		return reflect.TypeOf(NullInt64{})
  1171  	case "float64":
  1172  		return reflect.TypeOf(float64(0))
  1173  	case "nullfloat64":
  1174  		return reflect.TypeOf(NullFloat64{})
  1175  	case "datetime":
  1176  		return reflect.TypeOf(time.Time{})
  1177  	case "any":
  1178  		return reflect.TypeOf(new(interface{})).Elem()
  1179  	}
  1180  	panic("invalid fakedb column type of " + typ)
  1181  }