github.com/runner-mei/ql@v1.1.0/driver.go (about)

     1  // Copyright 2014 The ql Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // database/sql/driver
     6  
     7  package ql
     8  
     9  import (
    10  	"bytes"
    11  	"database/sql"
    12  	"database/sql/driver"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"math/big"
    17  	"net/url"
    18  	"os"
    19  	"path/filepath"
    20  	"strconv"
    21  	"strings"
    22  	"sync"
    23  	"time"
    24  )
    25  
    26  var (
    27  	_ driver.Conn    = (*driverConn)(nil)
    28  	_ driver.Driver  = (*sqlDriver)(nil)
    29  	_ driver.Execer  = (*driverConn)(nil)
    30  	_ driver.Queryer = (*driverConn)(nil)
    31  	_ driver.Result  = (*driverResult)(nil)
    32  	_ driver.Rows    = (*driverRows)(nil)
    33  	_ driver.Stmt    = (*driverStmt)(nil)
    34  	_ driver.Tx      = (*driverConn)(nil)
    35  
    36  	txBegin    = MustCompile("BEGIN TRANSACTION;")
    37  	txCommit   = MustCompile("COMMIT;")
    38  	txRollback = MustCompile("ROLLBACK;")
    39  
    40  	errNoResult = errors.New("query statement does not produce a result set (no top level SELECT)")
    41  )
    42  
    43  type errList []error
    44  
    45  func (e *errList) append(err error) {
    46  	if err != nil {
    47  		*e = append(*e, err)
    48  	}
    49  }
    50  
    51  func (e errList) error() error {
    52  	if len(e) == 0 {
    53  		return nil
    54  	}
    55  
    56  	return e
    57  }
    58  
    59  func (e errList) Error() string {
    60  	a := make([]string, len(e))
    61  	for i, v := range e {
    62  		a[i] = v.Error()
    63  	}
    64  	return strings.Join(a, "\n")
    65  }
    66  
    67  func params(args []driver.Value) []interface{} {
    68  	r := make([]interface{}, len(args))
    69  	for i, v := range args {
    70  		r[i] = interface{}(v)
    71  	}
    72  	return r
    73  }
    74  
    75  var (
    76  	fileDriver     = &sqlDriver{dbs: map[string]*driverDB{}}
    77  	fileDriverOnce sync.Once
    78  	memDriver      = &sqlDriver{isMem: true, dbs: map[string]*driverDB{}}
    79  	memDriverOnce  sync.Once
    80  )
    81  
    82  // RegisterDriver registers a QL database/sql/driver[0] named "ql". The name
    83  // parameter of
    84  //
    85  //	sql.Open("ql", name)
    86  //
    87  // is interpreted as a path name to a named DB file which will be created if
    88  // not present. The underlying QL database data are persisted on db.Close().
    89  // RegisterDriver can be safely called multiple times, it'll register the
    90  // driver only once.
    91  //
    92  // The name argument can be optionally prefixed by "file://". In that case the
    93  // prefix is stripped before interpreting it as a file name.
    94  //
    95  // The name argument can be optionally prefixed by "memory://". In that case
    96  // the prefix is stripped before interpreting it as a name of a memory-only,
    97  // volatile DB.
    98  //
    99  //  [0]: http://golang.org/pkg/database/sql/driver/
   100  func RegisterDriver() {
   101  	fileDriverOnce.Do(func() { sql.Register("ql", fileDriver) })
   102  }
   103  
   104  // RegisterMemDriver registers a QL memory database/sql/driver[0] named
   105  // "ql-mem".  The name parameter of
   106  //
   107  //	sql.Open("ql-mem", name)
   108  //
   109  // is interpreted as an unique memory DB name which will be created if not
   110  // present. The underlying QL memory database data are not persisted on
   111  // db.Close(). RegisterMemDriver can be safely called multiple times, it'll
   112  // register the driver only once.
   113  //
   114  //  [0]: http://golang.org/pkg/database/sql/driver/
   115  func RegisterMemDriver() {
   116  	memDriverOnce.Do(func() { sql.Register("ql-mem", memDriver) })
   117  }
   118  
   119  type driverDB struct {
   120  	db       *DB
   121  	name     string
   122  	refcount int
   123  }
   124  
   125  func newDriverDB(db *DB, name string) *driverDB {
   126  	return &driverDB{db: db, name: name, refcount: 1}
   127  }
   128  
   129  // sqlDriver implements the interface required by database/sql/driver.
   130  type sqlDriver struct {
   131  	dbs   map[string]*driverDB
   132  	isMem bool
   133  	mu    sync.Mutex
   134  }
   135  
   136  func (d *sqlDriver) lock() func() {
   137  	d.mu.Lock()
   138  	return d.mu.Unlock
   139  }
   140  
   141  // Open returns a new connection to the database.  The name is a string in a
   142  // driver-specific format.
   143  //
   144  // Open may return a cached connection (one previously closed), but doing so is
   145  // unnecessary; the sql package maintains a pool of idle connections for
   146  // efficient re-use.
   147  //
   148  // The returned connection is only used by one goroutine at a time.
   149  //
   150  // The name supported URL parameters:
   151  //
   152  //	headroom	Size of the WAL headroom. See https://github.com/cznic/ql/issues/140.
   153  func (d *sqlDriver) Open(name string) (driver.Conn, error) {
   154  	switch {
   155  	case d == fileDriver:
   156  		if !strings.Contains(name, "://") && !strings.HasPrefix(name, "file") {
   157  			name = "file://" + name
   158  		}
   159  	case d == memDriver:
   160  		if !strings.Contains(name, "://") && !strings.HasPrefix(name, "memory") {
   161  			name = "memory://" + name
   162  		}
   163  	default:
   164  		return nil, fmt.Errorf("open: unexpected/unsupported instance of driver.Driver: %p", d)
   165  	}
   166  
   167  	name = filepath.ToSlash(name) // Ensure / separated URLs on Windows
   168  	uri, err := url.Parse(name)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  
   173  	switch uri.Scheme {
   174  	case "file":
   175  		// ok
   176  	case "memory":
   177  		d = memDriver
   178  	default:
   179  		return nil, fmt.Errorf("open: unexpected/unsupported scheme: %s", uri.Scheme)
   180  	}
   181  
   182  	name = filepath.Clean(filepath.Join(uri.Host, uri.Path))
   183  	if d == fileDriver && (name == "" || name == "." || name == string(os.PathSeparator)) {
   184  		return nil, fmt.Errorf("invalid DB name %q", name)
   185  	}
   186  
   187  	var headroom int64
   188  	if a := uri.Query()["headroom"]; len(a) != 0 {
   189  		if headroom, err = strconv.ParseInt(a[0], 10, 64); err != nil {
   190  			return nil, err
   191  		}
   192  	}
   193  
   194  	defer d.lock()()
   195  	db := d.dbs[name]
   196  	if db == nil {
   197  		var err error
   198  		var db0 *DB
   199  		switch d.isMem {
   200  		case true:
   201  			db0, err = OpenMem()
   202  		default:
   203  			db0, err = OpenFile(name, &Options{CanCreate: true, Headroom: headroom})
   204  		}
   205  		if err != nil {
   206  			return nil, err
   207  		}
   208  
   209  		db = newDriverDB(db0, name)
   210  		d.dbs[name] = db
   211  		return newDriverConn(d, db), nil
   212  	}
   213  
   214  	db.refcount++
   215  	return newDriverConn(d, db), nil
   216  }
   217  
   218  // driverConn is a connection to a database. It is not used concurrently by
   219  // multiple goroutines.
   220  //
   221  // Conn is assumed to be stateful.
   222  type driverConn struct {
   223  	ctx    *TCtx
   224  	db     *driverDB
   225  	driver *sqlDriver
   226  	stop   map[*driverStmt]struct{}
   227  	tnl    int
   228  }
   229  
   230  func newDriverConn(d *sqlDriver, ddb *driverDB) driver.Conn {
   231  	r := &driverConn{
   232  		db:     ddb,
   233  		driver: d,
   234  		stop:   map[*driverStmt]struct{}{},
   235  	}
   236  	return r
   237  }
   238  
   239  // Prepare returns a prepared statement, bound to this connection.
   240  func (c *driverConn) Prepare(query string) (driver.Stmt, error) {
   241  	list, err := Compile(query)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  
   246  	s := &driverStmt{conn: c, stmt: list}
   247  	c.stop[s] = struct{}{}
   248  	return s, nil
   249  }
   250  
   251  // Close invalidates and potentially stops any current prepared statements and
   252  // transactions, marking this connection as no longer in use.
   253  //
   254  // Because the sql package maintains a free pool of connections and only calls
   255  // Close when there's a surplus of idle connections, it shouldn't be necessary
   256  // for drivers to do their own connection caching.
   257  func (c *driverConn) Close() error {
   258  	var err errList
   259  	for s := range c.stop {
   260  		err.append(s.Close())
   261  	}
   262  	defer c.driver.lock()()
   263  	dbs, name := c.driver.dbs, c.db.name
   264  	v := dbs[name]
   265  	v.refcount--
   266  	if v.refcount == 0 {
   267  		err.append(c.db.db.Close())
   268  		delete(dbs, name)
   269  	}
   270  	return err.error()
   271  }
   272  
   273  // Begin starts and returns a new transaction.
   274  func (c *driverConn) Begin() (driver.Tx, error) {
   275  	if c.ctx == nil {
   276  		c.ctx = NewRWCtx()
   277  	}
   278  
   279  	if _, _, err := c.db.db.Execute(c.ctx, txBegin); err != nil {
   280  		return nil, err
   281  	}
   282  
   283  	c.tnl++
   284  	return c, nil
   285  }
   286  
   287  func (c *driverConn) Commit() error {
   288  	if c.tnl == 0 || c.ctx == nil {
   289  		return errCommitNotInTransaction
   290  	}
   291  
   292  	if _, _, err := c.db.db.Execute(c.ctx, txCommit); err != nil {
   293  		return err
   294  	}
   295  
   296  	c.tnl--
   297  	if c.tnl == 0 {
   298  		c.ctx = nil
   299  	}
   300  	return nil
   301  }
   302  
   303  func (c *driverConn) Rollback() error {
   304  	if c.tnl == 0 || c.ctx == nil {
   305  		return errRollbackNotInTransaction
   306  	}
   307  
   308  	if _, _, err := c.db.db.Execute(c.ctx, txRollback); err != nil {
   309  		return err
   310  	}
   311  
   312  	c.tnl--
   313  	if c.tnl == 0 {
   314  		c.ctx = nil
   315  	}
   316  	return nil
   317  }
   318  
   319  // Execer is an optional interface that may be implemented by a Conn.
   320  //
   321  // If a Conn does not implement Execer, the sql package's DB.Exec will first
   322  // prepare a query, execute the statement, and then close the statement.
   323  //
   324  // Exec may return driver.ErrSkip.
   325  func (c *driverConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   326  	list, err := Compile(query)
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  
   331  	return driverExec(c.db, c.ctx, list, args)
   332  }
   333  
   334  func driverExec(db *driverDB, ctx *TCtx, list List, args []driver.Value) (driver.Result, error) {
   335  	if _, _, err := db.db.Execute(ctx, list, params(args)...); err != nil {
   336  		return nil, err
   337  	}
   338  
   339  	if len(list.l) == 1 {
   340  		switch list.l[0].(type) {
   341  		case *createTableStmt, *dropTableStmt, *alterTableAddStmt,
   342  			*alterTableDropColumnStmt, *truncateTableStmt:
   343  			return driver.ResultNoRows, nil
   344  		}
   345  	}
   346  
   347  	r := &driverResult{}
   348  	if ctx != nil {
   349  		r.lastInsertID, r.rowsAffected = ctx.LastInsertID, ctx.RowsAffected
   350  	}
   351  	return r, nil
   352  }
   353  
   354  // Queryer is an optional interface that may be implemented by a Conn.
   355  //
   356  // If a Conn does not implement Queryer, the sql package's DB.Query will first
   357  // prepare a query, execute the statement, and then close the statement.
   358  //
   359  // Query may return driver.ErrSkip.
   360  func (c *driverConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   361  	list, err := Compile(query)
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  
   366  	return driverQuery(c.db, c.ctx, list, args)
   367  }
   368  
   369  func driverQuery(db *driverDB, ctx *TCtx, list List, args []driver.Value) (driver.Rows, error) {
   370  	rss, _, err := db.db.Execute(ctx, list, params(args)...)
   371  	if err != nil {
   372  		return nil, err
   373  	}
   374  
   375  	switch n := len(rss); n {
   376  	case 0:
   377  		return nil, errNoResult
   378  	case 1:
   379  		return newdriverRows(rss[len(rss)-1]), nil
   380  	default:
   381  		return nil, fmt.Errorf("query produced %d result sets, expected only one", n)
   382  	}
   383  }
   384  
   385  // driverResult is the result of a query execution.
   386  type driverResult struct {
   387  	lastInsertID int64
   388  	rowsAffected int64
   389  }
   390  
   391  // LastInsertId returns the database's auto-generated ID after, for example, an
   392  // INSERT into a table with primary key.
   393  func (r *driverResult) LastInsertId() (int64, error) { // -golint
   394  	return r.lastInsertID, nil
   395  }
   396  
   397  // RowsAffected returns the number of rows affected by the query.
   398  func (r *driverResult) RowsAffected() (int64, error) {
   399  	return r.rowsAffected, nil
   400  }
   401  
   402  // driverRows is an iterator over an executed query's results.
   403  type driverRows struct {
   404  	rs   Recordset
   405  	done chan int
   406  	rows chan interface{}
   407  }
   408  
   409  func newdriverRows(rs Recordset) *driverRows {
   410  	r := &driverRows{
   411  		rs:   rs,
   412  		done: make(chan int),
   413  		rows: make(chan interface{}, 500),
   414  	}
   415  	go func() {
   416  		err := io.EOF
   417  		if e := r.rs.Do(false, func(data []interface{}) (bool, error) {
   418  			select {
   419  			case r.rows <- data:
   420  				return true, nil
   421  			case <-r.done:
   422  				return false, nil
   423  			}
   424  		}); e != nil {
   425  			err = e
   426  		}
   427  
   428  		select {
   429  		case r.rows <- err:
   430  		case <-r.done:
   431  		}
   432  	}()
   433  	return r
   434  }
   435  
   436  // Columns returns the names of the columns. The number of columns of the
   437  // result is inferred from the length of the slice.  If a particular column
   438  // name isn't known, an empty string should be returned for that entry.
   439  func (r *driverRows) Columns() []string {
   440  	f, _ := r.rs.Fields()
   441  	return f
   442  }
   443  
   444  // Close closes the rows iterator.
   445  func (r *driverRows) Close() error {
   446  	close(r.done)
   447  	return nil
   448  }
   449  
   450  // Next is called to populate the next row of data into the provided slice. The
   451  // provided slice will be the same size as the Columns() are wide.
   452  //
   453  // The dest slice may be populated only with a driver Value type, but excluding
   454  // string.  All string values must be converted to []byte.
   455  //
   456  // Next should return io.EOF when there are no more rows.
   457  func (r *driverRows) Next(dest []driver.Value) error {
   458  	select {
   459  	case rx := <-r.rows:
   460  		switch x := rx.(type) {
   461  		case error:
   462  			return x
   463  		case []interface{}:
   464  			if g, e := len(x), len(dest); g != e {
   465  				return fmt.Errorf("field count mismatch: got %d, need %d", g, e)
   466  			}
   467  
   468  			for i, xi := range x {
   469  				switch v := xi.(type) {
   470  				case nil, int64, float64, bool, []byte, time.Time:
   471  					dest[i] = v
   472  				case complex64, complex128, *big.Int, *big.Rat:
   473  					var buf bytes.Buffer
   474  					fmt.Fprintf(&buf, "%v", v)
   475  					dest[i] = buf.Bytes()
   476  				case int8:
   477  					dest[i] = int64(v)
   478  				case int16:
   479  					dest[i] = int64(v)
   480  				case int32:
   481  					dest[i] = int64(v)
   482  				case int:
   483  					dest[i] = int64(v)
   484  				case uint8:
   485  					dest[i] = int64(v)
   486  				case uint16:
   487  					dest[i] = int64(v)
   488  				case uint32:
   489  					dest[i] = int64(v)
   490  				case uint64:
   491  					dest[i] = int64(v)
   492  				case uint:
   493  					dest[i] = int64(v)
   494  				case time.Duration:
   495  					dest[i] = int64(v)
   496  				case string:
   497  					dest[i] = []byte(v)
   498  				default:
   499  					return fmt.Errorf("internal error 004")
   500  				}
   501  			}
   502  			return nil
   503  		default:
   504  			return fmt.Errorf("internal error 005")
   505  		}
   506  	case <-r.done:
   507  		return io.EOF
   508  	}
   509  }
   510  
   511  // driverStmt is a prepared statement. It is bound to a driverConn and not used
   512  // by multiple goroutines concurrently.
   513  type driverStmt struct {
   514  	conn *driverConn
   515  	stmt List
   516  }
   517  
   518  // Close closes the statement.
   519  //
   520  // As of Go 1.1, a Stmt will not be closed if it's in use by any queries.
   521  func (s *driverStmt) Close() error {
   522  	delete(s.conn.stop, s)
   523  	return nil
   524  }
   525  
   526  // NumInput returns the number of placeholder parameters.
   527  //
   528  // If NumInput returns >= 0, the sql package will sanity check argument counts
   529  // from callers and return errors to the caller before the statement's Exec or
   530  // Query methods are called.
   531  //
   532  // NumInput may also return -1, if the driver doesn't know its number of
   533  // placeholders. In that case, the sql package will not sanity check Exec or
   534  // Query argument counts.
   535  func (s *driverStmt) NumInput() int {
   536  	if x := s.stmt; len(x.l) == 1 {
   537  		return x.params
   538  	}
   539  
   540  	return -1
   541  }
   542  
   543  // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
   544  func (s *driverStmt) Exec(args []driver.Value) (driver.Result, error) {
   545  	c := s.conn
   546  	return driverExec(c.db, c.ctx, s.stmt, args)
   547  }
   548  
   549  // Exec executes a query that may return rows, such as a SELECT.
   550  func (s *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
   551  	c := s.conn
   552  	return driverQuery(c.db, c.ctx, s.stmt, args)
   553  }