github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/server/driver_tidb.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package server
    15  
    16  import (
    17  	"github.com/insionng/yougam/libraries/juju/errors"
    18  	"github.com/insionng/yougam/libraries/pingcap/tidb"
    19  	"github.com/insionng/yougam/libraries/pingcap/tidb/ast"
    20  	"github.com/insionng/yougam/libraries/pingcap/tidb/kv"
    21  	"github.com/insionng/yougam/libraries/pingcap/tidb/mysql"
    22  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    23  )
    24  
    25  // TiDBDriver implements IDriver.
    26  type TiDBDriver struct {
    27  	store kv.Storage
    28  }
    29  
    30  // NewTiDBDriver creates a new TiDBDriver.
    31  func NewTiDBDriver(store kv.Storage) *TiDBDriver {
    32  	driver := &TiDBDriver{
    33  		store: store,
    34  	}
    35  	return driver
    36  }
    37  
    38  // TiDBContext implements IContext.
    39  type TiDBContext struct {
    40  	session      tidb.Session
    41  	currentDB    string
    42  	warningCount uint16
    43  	stmts        map[int]*TiDBStatement
    44  }
    45  
    46  // TiDBStatement implements IStatement.
    47  type TiDBStatement struct {
    48  	id          uint32
    49  	numParams   int
    50  	boundParams [][]byte
    51  	ctx         *TiDBContext
    52  }
    53  
    54  // ID implements IStatement ID method.
    55  func (ts *TiDBStatement) ID() int {
    56  	return int(ts.id)
    57  }
    58  
    59  // Execute implements IStatement Execute method.
    60  func (ts *TiDBStatement) Execute(args ...interface{}) (rs ResultSet, err error) {
    61  	tidbRecordset, err := ts.ctx.session.ExecutePreparedStmt(ts.id, args...)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	if tidbRecordset == nil {
    66  		return
    67  	}
    68  	rs = &tidbResultSet{
    69  		recordSet: tidbRecordset,
    70  	}
    71  	return
    72  }
    73  
    74  // AppendParam implements IStatement AppendParam method.
    75  func (ts *TiDBStatement) AppendParam(paramID int, data []byte) error {
    76  	if paramID >= len(ts.boundParams) {
    77  		return mysql.NewErr(mysql.ErrWrongArguments, "stmt_send_longdata")
    78  	}
    79  	ts.boundParams[paramID] = append(ts.boundParams[paramID], data...)
    80  	return nil
    81  }
    82  
    83  // NumParams implements IStatement NumParams method.
    84  func (ts *TiDBStatement) NumParams() int {
    85  	return ts.numParams
    86  }
    87  
    88  // BoundParams implements IStatement BoundParams method.
    89  func (ts *TiDBStatement) BoundParams() [][]byte {
    90  	return ts.boundParams
    91  }
    92  
    93  // Reset implements IStatement Reset method.
    94  func (ts *TiDBStatement) Reset() {
    95  	for i := range ts.boundParams {
    96  		ts.boundParams[i] = nil
    97  	}
    98  }
    99  
   100  // Close implements IStatement Close method.
   101  func (ts *TiDBStatement) Close() error {
   102  	//TODO close at tidb level
   103  	err := ts.ctx.session.DropPreparedStmt(ts.id)
   104  	if err != nil {
   105  		return errors.Trace(err)
   106  	}
   107  	delete(ts.ctx.stmts, int(ts.id))
   108  	return nil
   109  }
   110  
   111  // OpenCtx implements IDriver.
   112  func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string) (IContext, error) {
   113  	session, _ := tidb.CreateSession(qd.store)
   114  	session.SetClientCapability(capability)
   115  	session.SetConnectionID(connID)
   116  	if dbname != "" {
   117  		_, err := session.Execute("use " + dbname)
   118  		if err != nil {
   119  			return nil, err
   120  		}
   121  	}
   122  	tc := &TiDBContext{
   123  		session:   session,
   124  		currentDB: dbname,
   125  		stmts:     make(map[int]*TiDBStatement),
   126  	}
   127  	return tc, nil
   128  }
   129  
   130  // Status implements IContext Status method.
   131  func (tc *TiDBContext) Status() uint16 {
   132  	return tc.session.Status()
   133  }
   134  
   135  // LastInsertID implements IContext LastInsertID method.
   136  func (tc *TiDBContext) LastInsertID() uint64 {
   137  	return tc.session.LastInsertID()
   138  }
   139  
   140  // AffectedRows implements IContext AffectedRows method.
   141  func (tc *TiDBContext) AffectedRows() uint64 {
   142  	return tc.session.AffectedRows()
   143  }
   144  
   145  // CurrentDB implements IContext CurrentDB method.
   146  func (tc *TiDBContext) CurrentDB() string {
   147  	return tc.currentDB
   148  }
   149  
   150  // WarningCount implements IContext WarningCount method.
   151  func (tc *TiDBContext) WarningCount() uint16 {
   152  	return tc.warningCount
   153  }
   154  
   155  // Execute implements IContext Execute method.
   156  func (tc *TiDBContext) Execute(sql string) (rs ResultSet, err error) {
   157  	rsList, err := tc.session.Execute(sql)
   158  	if err != nil {
   159  		return
   160  	}
   161  	if len(rsList) == 0 { // result ok
   162  		return
   163  	}
   164  	rs = &tidbResultSet{
   165  		recordSet: rsList[0],
   166  	}
   167  	return
   168  }
   169  
   170  // Close implements IContext Close method.
   171  func (tc *TiDBContext) Close() (err error) {
   172  	return tc.session.Close()
   173  }
   174  
   175  // Auth implements IContext Auth method.
   176  func (tc *TiDBContext) Auth(user string, auth []byte, salt []byte) bool {
   177  	return tc.session.Auth(user, auth, salt)
   178  }
   179  
   180  // FieldList implements IContext FieldList method.
   181  func (tc *TiDBContext) FieldList(table string) (colums []*ColumnInfo, err error) {
   182  	rs, err := tc.Execute("SELECT * FROM " + table + " LIMIT 0")
   183  	if err != nil {
   184  		return nil, errors.Trace(err)
   185  	}
   186  	colums, err = rs.Columns()
   187  	if err != nil {
   188  		return nil, errors.Trace(err)
   189  	}
   190  	return
   191  }
   192  
   193  // GetStatement implements IContext GetStatement method.
   194  func (tc *TiDBContext) GetStatement(stmtID int) IStatement {
   195  	tcStmt := tc.stmts[stmtID]
   196  	if tcStmt != nil {
   197  		return tcStmt
   198  	}
   199  	return nil
   200  }
   201  
   202  // Prepare implements IContext Prepare method.
   203  func (tc *TiDBContext) Prepare(sql string) (statement IStatement, columns, params []*ColumnInfo, err error) {
   204  	stmtID, paramCount, fields, err := tc.session.PrepareStmt(sql)
   205  	if err != nil {
   206  		return
   207  	}
   208  	stmt := &TiDBStatement{
   209  		id:          stmtID,
   210  		numParams:   paramCount,
   211  		boundParams: make([][]byte, paramCount),
   212  		ctx:         tc,
   213  	}
   214  	statement = stmt
   215  	columns = make([]*ColumnInfo, len(fields))
   216  	for i := range fields {
   217  		columns[i] = convertColumnInfo(fields[i])
   218  	}
   219  	params = make([]*ColumnInfo, paramCount)
   220  	for i := range params {
   221  		params[i] = &ColumnInfo{
   222  			Type: mysql.TypeBlob,
   223  		}
   224  	}
   225  	tc.stmts[int(stmtID)] = stmt
   226  	return
   227  }
   228  
   229  type tidbResultSet struct {
   230  	recordSet ast.RecordSet
   231  }
   232  
   233  func (trs *tidbResultSet) Next() ([]types.Datum, error) {
   234  	row, err := trs.recordSet.Next()
   235  	if err != nil {
   236  		return nil, errors.Trace(err)
   237  	}
   238  	if row != nil {
   239  		return row.Data, nil
   240  	}
   241  	return nil, nil
   242  }
   243  
   244  func (trs *tidbResultSet) Close() error {
   245  	return trs.recordSet.Close()
   246  }
   247  
   248  func (trs *tidbResultSet) Columns() ([]*ColumnInfo, error) {
   249  	fields, err := trs.recordSet.Fields()
   250  	if err != nil {
   251  		return nil, errors.Trace(err)
   252  	}
   253  	var columns []*ColumnInfo
   254  	for _, v := range fields {
   255  		columns = append(columns, convertColumnInfo(v))
   256  	}
   257  	return columns, nil
   258  }
   259  
   260  func convertColumnInfo(fld *ast.ResultField) (ci *ColumnInfo) {
   261  	ci = new(ColumnInfo)
   262  	ci.Name = fld.ColumnAsName.O
   263  	ci.OrgName = fld.Column.Name.O
   264  	ci.Table = fld.TableAsName.O
   265  	if fld.Table != nil {
   266  		ci.OrgTable = fld.Table.Name.O
   267  	}
   268  	ci.Schema = fld.DBName.O
   269  	ci.Flag = uint16(fld.Column.Flag)
   270  	ci.Charset = uint16(mysql.CharsetIDs[fld.Column.Charset])
   271  	if fld.Column.Flen == types.UnspecifiedLength {
   272  		ci.ColumnLength = 0
   273  	} else {
   274  		ci.ColumnLength = uint32(fld.Column.Flen)
   275  	}
   276  	if fld.Column.Decimal == types.UnspecifiedLength {
   277  		ci.Decimal = 0
   278  	} else {
   279  		ci.Decimal = uint8(fld.Column.Decimal)
   280  	}
   281  	ci.Type = uint8(fld.Column.Tp)
   282  
   283  	// Keep things compatible for old clients.
   284  	// Refer to mysql-server/sql/protocol.cc send_result_set_metadata()
   285  	if ci.Type == mysql.TypeVarchar {
   286  		ci.Type = mysql.TypeVarString
   287  	}
   288  	return
   289  }