github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/server/driver_tidb.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package server 15 16 import ( 17 "github.com/insionng/yougam/libraries/juju/errors" 18 "github.com/insionng/yougam/libraries/pingcap/tidb" 19 "github.com/insionng/yougam/libraries/pingcap/tidb/ast" 20 "github.com/insionng/yougam/libraries/pingcap/tidb/kv" 21 "github.com/insionng/yougam/libraries/pingcap/tidb/mysql" 22 "github.com/insionng/yougam/libraries/pingcap/tidb/util/types" 23 ) 24 25 // TiDBDriver implements IDriver. 26 type TiDBDriver struct { 27 store kv.Storage 28 } 29 30 // NewTiDBDriver creates a new TiDBDriver. 31 func NewTiDBDriver(store kv.Storage) *TiDBDriver { 32 driver := &TiDBDriver{ 33 store: store, 34 } 35 return driver 36 } 37 38 // TiDBContext implements IContext. 39 type TiDBContext struct { 40 session tidb.Session 41 currentDB string 42 warningCount uint16 43 stmts map[int]*TiDBStatement 44 } 45 46 // TiDBStatement implements IStatement. 47 type TiDBStatement struct { 48 id uint32 49 numParams int 50 boundParams [][]byte 51 ctx *TiDBContext 52 } 53 54 // ID implements IStatement ID method. 55 func (ts *TiDBStatement) ID() int { 56 return int(ts.id) 57 } 58 59 // Execute implements IStatement Execute method. 60 func (ts *TiDBStatement) Execute(args ...interface{}) (rs ResultSet, err error) { 61 tidbRecordset, err := ts.ctx.session.ExecutePreparedStmt(ts.id, args...) 62 if err != nil { 63 return nil, err 64 } 65 if tidbRecordset == nil { 66 return 67 } 68 rs = &tidbResultSet{ 69 recordSet: tidbRecordset, 70 } 71 return 72 } 73 74 // AppendParam implements IStatement AppendParam method. 75 func (ts *TiDBStatement) AppendParam(paramID int, data []byte) error { 76 if paramID >= len(ts.boundParams) { 77 return mysql.NewErr(mysql.ErrWrongArguments, "stmt_send_longdata") 78 } 79 ts.boundParams[paramID] = append(ts.boundParams[paramID], data...) 80 return nil 81 } 82 83 // NumParams implements IStatement NumParams method. 84 func (ts *TiDBStatement) NumParams() int { 85 return ts.numParams 86 } 87 88 // BoundParams implements IStatement BoundParams method. 89 func (ts *TiDBStatement) BoundParams() [][]byte { 90 return ts.boundParams 91 } 92 93 // Reset implements IStatement Reset method. 94 func (ts *TiDBStatement) Reset() { 95 for i := range ts.boundParams { 96 ts.boundParams[i] = nil 97 } 98 } 99 100 // Close implements IStatement Close method. 101 func (ts *TiDBStatement) Close() error { 102 //TODO close at tidb level 103 err := ts.ctx.session.DropPreparedStmt(ts.id) 104 if err != nil { 105 return errors.Trace(err) 106 } 107 delete(ts.ctx.stmts, int(ts.id)) 108 return nil 109 } 110 111 // OpenCtx implements IDriver. 112 func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string) (IContext, error) { 113 session, _ := tidb.CreateSession(qd.store) 114 session.SetClientCapability(capability) 115 session.SetConnectionID(connID) 116 if dbname != "" { 117 _, err := session.Execute("use " + dbname) 118 if err != nil { 119 return nil, err 120 } 121 } 122 tc := &TiDBContext{ 123 session: session, 124 currentDB: dbname, 125 stmts: make(map[int]*TiDBStatement), 126 } 127 return tc, nil 128 } 129 130 // Status implements IContext Status method. 131 func (tc *TiDBContext) Status() uint16 { 132 return tc.session.Status() 133 } 134 135 // LastInsertID implements IContext LastInsertID method. 136 func (tc *TiDBContext) LastInsertID() uint64 { 137 return tc.session.LastInsertID() 138 } 139 140 // AffectedRows implements IContext AffectedRows method. 141 func (tc *TiDBContext) AffectedRows() uint64 { 142 return tc.session.AffectedRows() 143 } 144 145 // CurrentDB implements IContext CurrentDB method. 146 func (tc *TiDBContext) CurrentDB() string { 147 return tc.currentDB 148 } 149 150 // WarningCount implements IContext WarningCount method. 151 func (tc *TiDBContext) WarningCount() uint16 { 152 return tc.warningCount 153 } 154 155 // Execute implements IContext Execute method. 156 func (tc *TiDBContext) Execute(sql string) (rs ResultSet, err error) { 157 rsList, err := tc.session.Execute(sql) 158 if err != nil { 159 return 160 } 161 if len(rsList) == 0 { // result ok 162 return 163 } 164 rs = &tidbResultSet{ 165 recordSet: rsList[0], 166 } 167 return 168 } 169 170 // Close implements IContext Close method. 171 func (tc *TiDBContext) Close() (err error) { 172 return tc.session.Close() 173 } 174 175 // Auth implements IContext Auth method. 176 func (tc *TiDBContext) Auth(user string, auth []byte, salt []byte) bool { 177 return tc.session.Auth(user, auth, salt) 178 } 179 180 // FieldList implements IContext FieldList method. 181 func (tc *TiDBContext) FieldList(table string) (colums []*ColumnInfo, err error) { 182 rs, err := tc.Execute("SELECT * FROM " + table + " LIMIT 0") 183 if err != nil { 184 return nil, errors.Trace(err) 185 } 186 colums, err = rs.Columns() 187 if err != nil { 188 return nil, errors.Trace(err) 189 } 190 return 191 } 192 193 // GetStatement implements IContext GetStatement method. 194 func (tc *TiDBContext) GetStatement(stmtID int) IStatement { 195 tcStmt := tc.stmts[stmtID] 196 if tcStmt != nil { 197 return tcStmt 198 } 199 return nil 200 } 201 202 // Prepare implements IContext Prepare method. 203 func (tc *TiDBContext) Prepare(sql string) (statement IStatement, columns, params []*ColumnInfo, err error) { 204 stmtID, paramCount, fields, err := tc.session.PrepareStmt(sql) 205 if err != nil { 206 return 207 } 208 stmt := &TiDBStatement{ 209 id: stmtID, 210 numParams: paramCount, 211 boundParams: make([][]byte, paramCount), 212 ctx: tc, 213 } 214 statement = stmt 215 columns = make([]*ColumnInfo, len(fields)) 216 for i := range fields { 217 columns[i] = convertColumnInfo(fields[i]) 218 } 219 params = make([]*ColumnInfo, paramCount) 220 for i := range params { 221 params[i] = &ColumnInfo{ 222 Type: mysql.TypeBlob, 223 } 224 } 225 tc.stmts[int(stmtID)] = stmt 226 return 227 } 228 229 type tidbResultSet struct { 230 recordSet ast.RecordSet 231 } 232 233 func (trs *tidbResultSet) Next() ([]types.Datum, error) { 234 row, err := trs.recordSet.Next() 235 if err != nil { 236 return nil, errors.Trace(err) 237 } 238 if row != nil { 239 return row.Data, nil 240 } 241 return nil, nil 242 } 243 244 func (trs *tidbResultSet) Close() error { 245 return trs.recordSet.Close() 246 } 247 248 func (trs *tidbResultSet) Columns() ([]*ColumnInfo, error) { 249 fields, err := trs.recordSet.Fields() 250 if err != nil { 251 return nil, errors.Trace(err) 252 } 253 var columns []*ColumnInfo 254 for _, v := range fields { 255 columns = append(columns, convertColumnInfo(v)) 256 } 257 return columns, nil 258 } 259 260 func convertColumnInfo(fld *ast.ResultField) (ci *ColumnInfo) { 261 ci = new(ColumnInfo) 262 ci.Name = fld.ColumnAsName.O 263 ci.OrgName = fld.Column.Name.O 264 ci.Table = fld.TableAsName.O 265 if fld.Table != nil { 266 ci.OrgTable = fld.Table.Name.O 267 } 268 ci.Schema = fld.DBName.O 269 ci.Flag = uint16(fld.Column.Flag) 270 ci.Charset = uint16(mysql.CharsetIDs[fld.Column.Charset]) 271 if fld.Column.Flen == types.UnspecifiedLength { 272 ci.ColumnLength = 0 273 } else { 274 ci.ColumnLength = uint32(fld.Column.Flen) 275 } 276 if fld.Column.Decimal == types.UnspecifiedLength { 277 ci.Decimal = 0 278 } else { 279 ci.Decimal = uint8(fld.Column.Decimal) 280 } 281 ci.Type = uint8(fld.Column.Tp) 282 283 // Keep things compatible for old clients. 284 // Refer to mysql-server/sql/protocol.cc send_result_set_metadata() 285 if ci.Type == mysql.TypeVarchar { 286 ci.Type = mysql.TypeVarString 287 } 288 return 289 }