github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/driver.go (about)

     1  // Copyright 2013 The ql Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSES/QL-LICENSE file.
     4  
     5  // Copyright 2015 PingCAP, Inc.
     6  //
     7  // Licensed under the Apache License, Version 2.0 (the "License");
     8  // you may not use this file except in compliance with the License.
     9  // You may obtain a copy of the License at
    10  //
    11  //     http://www.apache.org/licenses/LICENSE-2.0
    12  //
    13  // Unless required by applicable law or agreed to in writing, software
    14  // distributed under the License is distributed on an "AS IS" BASIS,
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  // database/sql/driver
    19  
    20  package tidb
    21  
    22  import (
    23  	"database/sql"
    24  	"database/sql/driver"
    25  	"io"
    26  	"net/url"
    27  	"path/filepath"
    28  	"strings"
    29  	"sync"
    30  
    31  	"github.com/insionng/yougam/libraries/juju/errors"
    32  	"github.com/insionng/yougam/libraries/pingcap/tidb/ast"
    33  	"github.com/insionng/yougam/libraries/pingcap/tidb/model"
    34  	"github.com/insionng/yougam/libraries/pingcap/tidb/sessionctx"
    35  	"github.com/insionng/yougam/libraries/pingcap/tidb/terror"
    36  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    37  )
    38  
    39  const (
    40  	// DriverName is name of TiDB driver.
    41  	DriverName = "tidb"
    42  )
    43  
    44  var (
    45  	_ driver.Conn    = (*driverConn)(nil)
    46  	_ driver.Execer  = (*driverConn)(nil)
    47  	_ driver.Queryer = (*driverConn)(nil)
    48  	_ driver.Tx      = (*driverConn)(nil)
    49  
    50  	_ driver.Result = (*driverResult)(nil)
    51  	_ driver.Rows   = (*driverRows)(nil)
    52  	_ driver.Stmt   = (*driverStmt)(nil)
    53  	_ driver.Driver = (*sqlDriver)(nil)
    54  
    55  	txBeginSQL    = "BEGIN;"
    56  	txCommitSQL   = "COMMIT;"
    57  	txRollbackSQL = "ROLLBACK;"
    58  
    59  	errNoResult = errors.New("query statement does not produce a result set (no top level SELECT)")
    60  )
    61  
    62  type errList []error
    63  
    64  type driverParams struct {
    65  	storePath string
    66  	dbName    string
    67  	// when set to true `mysql.Time` isn't encoded as string but passed as `time.Time`
    68  	// this option is named for compatibility the same as in the mysql driver
    69  	// while we actually do not have additional parsing to do
    70  	parseTime bool
    71  }
    72  
    73  func (e *errList) append(err error) {
    74  	if err != nil {
    75  		*e = append(*e, err)
    76  	}
    77  }
    78  
    79  func (e errList) error() error {
    80  	if len(e) == 0 {
    81  		return nil
    82  	}
    83  
    84  	return e
    85  }
    86  
    87  func (e errList) Error() string {
    88  	a := make([]string, len(e))
    89  	for i, v := range e {
    90  		a[i] = v.Error()
    91  	}
    92  	return strings.Join(a, "\n")
    93  }
    94  
    95  func params(args []driver.Value) []interface{} {
    96  	r := make([]interface{}, len(args))
    97  	for i, v := range args {
    98  		r[i] = interface{}(v)
    99  	}
   100  	return r
   101  }
   102  
   103  var (
   104  	tidbDriver = &sqlDriver{}
   105  	driverOnce sync.Once
   106  )
   107  
   108  // RegisterDriver registers TiDB driver.
   109  // The name argument can be optionally prefixed by "engine://". In that case the
   110  // prefix is recognized as a storage engine name.
   111  //
   112  // The name argument can be optionally prefixed by "memory://". In that case
   113  // the prefix is stripped before interpreting it as a name of a memory-only,
   114  // volatile DB.
   115  //
   116  //  [0]: http://yougam/libraries/pkg/database/sql/driver/
   117  func RegisterDriver() {
   118  	driverOnce.Do(func() { sql.Register(DriverName, tidbDriver) })
   119  }
   120  
   121  // sqlDriver implements the interface required by database/sql/driver.
   122  type sqlDriver struct {
   123  	mu sync.Mutex
   124  }
   125  
   126  func (d *sqlDriver) lock() {
   127  	d.mu.Lock()
   128  }
   129  
   130  func (d *sqlDriver) unlock() {
   131  	d.mu.Unlock()
   132  }
   133  
   134  // parseDriverDSN cuts off DB name from dsn. It returns error if the dsn is not
   135  // valid.
   136  func parseDriverDSN(dsn string) (params *driverParams, err error) {
   137  	u, err := url.Parse(dsn)
   138  	if err != nil {
   139  		return nil, errors.Trace(err)
   140  	}
   141  	path := filepath.Join(u.Host, u.Path)
   142  	dbName := filepath.Clean(filepath.Base(path))
   143  	if dbName == "" || dbName == "." || dbName == string(filepath.Separator) {
   144  		return nil, errors.Errorf("invalid DB name %q", dbName)
   145  	}
   146  	// cut off dbName
   147  	path = filepath.Clean(filepath.Dir(path))
   148  	if path == "" || path == "." || path == string(filepath.Separator) {
   149  		return nil, errors.Errorf("invalid dsn %q", dsn)
   150  	}
   151  	u.Path, u.Host = path, ""
   152  	params = &driverParams{
   153  		storePath: u.String(),
   154  		dbName:    dbName,
   155  	}
   156  	// parse additional driver params
   157  	query := u.Query()
   158  	if parseTime := query.Get("parseTime"); parseTime == "true" {
   159  		params.parseTime = true
   160  	}
   161  
   162  	return params, nil
   163  }
   164  
   165  // Open returns a new connection to the database.
   166  //
   167  // The dsn must be a URL format 'engine://path/dbname?params'.
   168  // Engine is the storage name registered with RegisterStore.
   169  // Path is the storage specific format.
   170  // Params is key-value pairs split by '&', optional params are storage specific.
   171  // Examples:
   172  //    goleveldb://relative/path/test
   173  //    boltdb:///absolute/path/test
   174  //    hbase://zk1,zk2,zk3/hbasetbl/test?tso=zk
   175  //
   176  // Open may return a cached connection (one previously closed), but doing so is
   177  // unnecessary; the sql package maintains a pool of idle connections for
   178  // efficient re-use.
   179  //
   180  // The behavior of the mysql driver regarding time parsing can also be imitated
   181  // by passing ?parseTime
   182  //
   183  // The returned connection is only used by one goroutine at a time.
   184  func (d *sqlDriver) Open(dsn string) (driver.Conn, error) {
   185  	params, err := parseDriverDSN(dsn)
   186  	if err != nil {
   187  		return nil, errors.Trace(err)
   188  	}
   189  	store, err := NewStore(params.storePath)
   190  	if err != nil {
   191  		return nil, errors.Trace(err)
   192  	}
   193  
   194  	sess, err := CreateSession(store)
   195  	if err != nil {
   196  		return nil, errors.Trace(err)
   197  	}
   198  	s := sess.(*session)
   199  
   200  	d.lock()
   201  	defer d.unlock()
   202  
   203  	DBName := model.NewCIStr(params.dbName)
   204  	domain := sessionctx.GetDomain(s)
   205  	cs := &ast.CharsetOpt{
   206  		Chs: "utf8",
   207  		Col: "utf8_bin",
   208  	}
   209  	if !domain.InfoSchema().SchemaExists(DBName) {
   210  		err = domain.DDL().CreateSchema(s, DBName, cs)
   211  		if err != nil {
   212  			return nil, errors.Trace(err)
   213  		}
   214  	}
   215  	driver := &sqlDriver{}
   216  	return newDriverConn(s, driver, DBName.O, params)
   217  }
   218  
   219  // driverConn is a connection to a database. It is not used concurrently by
   220  // multiple goroutines.
   221  //
   222  // Conn is assumed to be stateful.
   223  type driverConn struct {
   224  	s      Session
   225  	driver *sqlDriver
   226  	stmts  map[string]driver.Stmt
   227  	params *driverParams
   228  }
   229  
   230  func newDriverConn(sess *session, d *sqlDriver, schema string, params *driverParams) (driver.Conn, error) {
   231  	r := &driverConn{
   232  		driver: d,
   233  		stmts:  map[string]driver.Stmt{},
   234  		s:      sess,
   235  		params: params,
   236  	}
   237  
   238  	_, err := r.s.Execute("use " + schema)
   239  	if err != nil {
   240  		return nil, errors.Trace(err)
   241  	}
   242  	return r, nil
   243  }
   244  
   245  // Prepare returns a prepared statement, bound to this connection.
   246  func (c *driverConn) Prepare(query string) (driver.Stmt, error) {
   247  	stmtID, paramCount, fields, err := c.s.PrepareStmt(query)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	s := &driverStmt{
   252  		conn:       c,
   253  		query:      query,
   254  		stmtID:     stmtID,
   255  		paramCount: paramCount,
   256  		isQuery:    fields != nil,
   257  	}
   258  	c.stmts[query] = s
   259  	return s, nil
   260  }
   261  
   262  // Close invalidates and potentially stops any current prepared statements and
   263  // transactions, marking this connection as no longer in use.
   264  //
   265  // Because the sql package maintains a free pool of connections and only calls
   266  // Close when there's a surplus of idle connections, it shouldn't be necessary
   267  // for drivers to do their own connection caching.
   268  func (c *driverConn) Close() error {
   269  	var err errList
   270  	for _, s := range c.stmts {
   271  		stmt := s.(*driverStmt)
   272  		err.append(stmt.conn.s.DropPreparedStmt(stmt.stmtID))
   273  	}
   274  
   275  	c.driver.lock()
   276  	defer c.driver.unlock()
   277  
   278  	return err.error()
   279  }
   280  
   281  // Begin starts and returns a new transaction.
   282  func (c *driverConn) Begin() (driver.Tx, error) {
   283  	if c.s == nil {
   284  		return nil, errors.Errorf("Need init first")
   285  	}
   286  
   287  	if _, err := c.s.Execute(txBeginSQL); err != nil {
   288  		return nil, errors.Trace(err)
   289  	}
   290  
   291  	return c, nil
   292  }
   293  
   294  func (c *driverConn) Commit() error {
   295  	if c.s == nil {
   296  		return terror.CommitNotInTransaction
   297  	}
   298  	_, err := c.s.Execute(txCommitSQL)
   299  
   300  	if err != nil {
   301  		return errors.Trace(err)
   302  	}
   303  
   304  	err = c.s.FinishTxn(false)
   305  	return errors.Trace(err)
   306  }
   307  
   308  func (c *driverConn) Rollback() error {
   309  	if c.s == nil {
   310  		return terror.RollbackNotInTransaction
   311  	}
   312  
   313  	if _, err := c.s.Execute(txRollbackSQL); err != nil {
   314  		return errors.Trace(err)
   315  	}
   316  
   317  	return nil
   318  }
   319  
   320  // Execer is an optional interface that may be implemented by a Conn.
   321  //
   322  // If a Conn does not implement Execer, the sql package's DB.Exec will first
   323  // prepare a query, execute the statement, and then close the statement.
   324  //
   325  // Exec may return driver.ErrSkip.
   326  func (c *driverConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   327  	return c.driverExec(query, args)
   328  
   329  }
   330  
   331  func (c *driverConn) getStmt(query string) (stmt driver.Stmt, err error) {
   332  	stmt, ok := c.stmts[query]
   333  	if !ok {
   334  		stmt, err = c.Prepare(query)
   335  		if err != nil {
   336  			return nil, errors.Trace(err)
   337  		}
   338  	}
   339  	return
   340  }
   341  
   342  func (c *driverConn) driverExec(query string, args []driver.Value) (driver.Result, error) {
   343  	if len(args) == 0 {
   344  		if _, err := c.s.Execute(query); err != nil {
   345  			return nil, errors.Trace(err)
   346  		}
   347  		r := &driverResult{}
   348  		r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows())
   349  		return r, nil
   350  	}
   351  	stmt, err := c.getStmt(query)
   352  	if err != nil {
   353  		return nil, errors.Trace(err)
   354  	}
   355  	return stmt.Exec(args)
   356  }
   357  
   358  // Queryer is an optional interface that may be implemented by a Conn.
   359  //
   360  // If a Conn does not implement Queryer, the sql package's DB.Query will first
   361  // prepare a query, execute the statement, and then close the statement.
   362  //
   363  // Query may return driver.ErrSkip.
   364  func (c *driverConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   365  	return c.driverQuery(query, args)
   366  }
   367  
   368  func (c *driverConn) driverQuery(query string, args []driver.Value) (driver.Rows, error) {
   369  	if len(args) == 0 {
   370  		rss, err := c.s.Execute(query)
   371  		if err != nil {
   372  			return nil, errors.Trace(err)
   373  		}
   374  		if len(rss) == 0 {
   375  			return nil, errors.Trace(errNoResult)
   376  		}
   377  		return &driverRows{params: c.params, rs: rss[0]}, nil
   378  	}
   379  	stmt, err := c.getStmt(query)
   380  	if err != nil {
   381  		return nil, errors.Trace(err)
   382  	}
   383  	return stmt.Query(args)
   384  }
   385  
   386  // driverResult is the result of a query execution.
   387  type driverResult struct {
   388  	lastInsertID int64
   389  	rowsAffected int64
   390  }
   391  
   392  // LastInsertID returns the database's auto-generated ID after, for example, an
   393  // INSERT into a table with primary key.
   394  func (r *driverResult) LastInsertId() (int64, error) { // -golint
   395  	return r.lastInsertID, nil
   396  }
   397  
   398  // RowsAffected returns the number of rows affected by the query.
   399  func (r *driverResult) RowsAffected() (int64, error) {
   400  	return r.rowsAffected, nil
   401  }
   402  
   403  // driverRows is an iterator over an executed query's results.
   404  type driverRows struct {
   405  	rs     ast.RecordSet
   406  	params *driverParams
   407  }
   408  
   409  // Columns returns the names of the columns. The number of columns of the
   410  // result is inferred from the length of the slice.  If a particular column
   411  // name isn't known, an empty string should be returned for that entry.
   412  func (r *driverRows) Columns() []string {
   413  	if r.rs == nil {
   414  		return []string{}
   415  	}
   416  	fs, _ := r.rs.Fields()
   417  	names := make([]string, len(fs))
   418  	for i, f := range fs {
   419  		names[i] = f.ColumnAsName.O
   420  	}
   421  	return names
   422  }
   423  
   424  // Close closes the rows iterator.
   425  func (r *driverRows) Close() error {
   426  	if r.rs != nil {
   427  		return r.rs.Close()
   428  	}
   429  	return nil
   430  }
   431  
   432  // Next is called to populate the next row of data into the provided slice. The
   433  // provided slice will be the same size as the Columns() are wide.
   434  //
   435  // The dest slice may be populated only with a driver Value type, but excluding
   436  // string.  All string values must be converted to []byte.
   437  //
   438  // Next should return io.EOF when there are no more rows.
   439  func (r *driverRows) Next(dest []driver.Value) error {
   440  	if r.rs == nil {
   441  		return io.EOF
   442  	}
   443  	row, err := r.rs.Next()
   444  	if err != nil {
   445  		return errors.Trace(err)
   446  	}
   447  	if row == nil {
   448  		return io.EOF
   449  	}
   450  	if len(row.Data) != len(dest) {
   451  		return errors.Errorf("field count mismatch: got %d, need %d", len(row.Data), len(dest))
   452  	}
   453  	for i, xi := range row.Data {
   454  		switch xi.Kind() {
   455  		case types.KindNull:
   456  			dest[i] = nil
   457  		case types.KindInt64:
   458  			dest[i] = xi.GetInt64()
   459  		case types.KindUint64:
   460  			dest[i] = xi.GetUint64()
   461  		case types.KindFloat32:
   462  			dest[i] = xi.GetFloat32()
   463  		case types.KindFloat64:
   464  			dest[i] = xi.GetFloat64()
   465  		case types.KindString:
   466  			dest[i] = xi.GetString()
   467  		case types.KindBytes:
   468  			dest[i] = xi.GetBytes()
   469  		case types.KindMysqlBit:
   470  			dest[i] = xi.GetMysqlBit().ToString()
   471  		case types.KindMysqlDecimal:
   472  			dest[i] = xi.GetMysqlDecimal().String()
   473  		case types.KindMysqlDuration:
   474  			dest[i] = xi.GetMysqlDuration().String()
   475  		case types.KindMysqlEnum:
   476  			dest[i] = xi.GetMysqlEnum().String()
   477  		case types.KindMysqlHex:
   478  			dest[i] = xi.GetMysqlHex().ToString()
   479  		case types.KindMysqlSet:
   480  			dest[i] = xi.GetMysqlSet().String()
   481  		case types.KindMysqlTime:
   482  			t := xi.GetMysqlTime()
   483  			if !r.params.parseTime {
   484  				dest[i] = t.String()
   485  			} else {
   486  				dest[i] = t.Time
   487  			}
   488  		default:
   489  			return errors.Errorf("unable to handle type %T", xi.GetValue())
   490  		}
   491  	}
   492  	return nil
   493  }
   494  
   495  // driverStmt is a prepared statement. It is bound to a driverConn and not used
   496  // by multiple goroutines concurrently.
   497  type driverStmt struct {
   498  	conn       *driverConn
   499  	query      string
   500  	stmtID     uint32
   501  	paramCount int
   502  	isQuery    bool
   503  }
   504  
   505  // Close closes the statement.
   506  //
   507  // As of Go 1.1, a Stmt will not be closed if it's in use by any queries.
   508  func (s *driverStmt) Close() error {
   509  	s.conn.s.DropPreparedStmt(s.stmtID)
   510  	delete(s.conn.stmts, s.query)
   511  	return nil
   512  }
   513  
   514  // NumInput returns the number of placeholder parameters.
   515  //
   516  // If NumInput returns >= 0, the sql package will sanity check argument counts
   517  // from callers and return errors to the caller before the statement's Exec or
   518  // Query methods are called.
   519  //
   520  // NumInput may also return -1, if the driver doesn't know its number of
   521  // placeholders. In that case, the sql package will not sanity check Exec or
   522  // Query argument counts.
   523  func (s *driverStmt) NumInput() int {
   524  	return s.paramCount
   525  }
   526  
   527  // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
   528  func (s *driverStmt) Exec(args []driver.Value) (driver.Result, error) {
   529  	c := s.conn
   530  	_, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...)
   531  	if err != nil {
   532  		return nil, errors.Trace(err)
   533  	}
   534  	r := &driverResult{}
   535  	if s != nil {
   536  		r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows())
   537  	}
   538  	return r, nil
   539  }
   540  
   541  // Exec executes a query that may return rows, such as a SELECT.
   542  func (s *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
   543  	c := s.conn
   544  	rs, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...)
   545  	if err != nil {
   546  		return nil, errors.Trace(err)
   547  	}
   548  	if rs == nil {
   549  		if s.isQuery {
   550  			return nil, errors.Trace(errNoResult)
   551  		}
   552  		// The statement is not a query.
   553  		return &driverRows{}, nil
   554  	}
   555  	return &driverRows{params: s.conn.params, rs: rs}, nil
   556  }
   557  
   558  func init() {
   559  	RegisterDriver()
   560  }