github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/driver.go (about) 1 // Copyright 2013 The ql Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSES/QL-LICENSE file. 4 5 // Copyright 2015 PingCAP, Inc. 6 // 7 // Licensed under the Apache License, Version 2.0 (the "License"); 8 // you may not use this file except in compliance with the License. 9 // You may obtain a copy of the License at 10 // 11 // http://www.apache.org/licenses/LICENSE-2.0 12 // 13 // Unless required by applicable law or agreed to in writing, software 14 // distributed under the License is distributed on an "AS IS" BASIS, 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 // database/sql/driver 19 20 package tidb 21 22 import ( 23 "database/sql" 24 "database/sql/driver" 25 "io" 26 "net/url" 27 "path/filepath" 28 "strings" 29 "sync" 30 31 "github.com/insionng/yougam/libraries/juju/errors" 32 "github.com/insionng/yougam/libraries/pingcap/tidb/ast" 33 "github.com/insionng/yougam/libraries/pingcap/tidb/model" 34 "github.com/insionng/yougam/libraries/pingcap/tidb/sessionctx" 35 "github.com/insionng/yougam/libraries/pingcap/tidb/terror" 36 "github.com/insionng/yougam/libraries/pingcap/tidb/util/types" 37 ) 38 39 const ( 40 // DriverName is name of TiDB driver. 41 DriverName = "tidb" 42 ) 43 44 var ( 45 _ driver.Conn = (*driverConn)(nil) 46 _ driver.Execer = (*driverConn)(nil) 47 _ driver.Queryer = (*driverConn)(nil) 48 _ driver.Tx = (*driverConn)(nil) 49 50 _ driver.Result = (*driverResult)(nil) 51 _ driver.Rows = (*driverRows)(nil) 52 _ driver.Stmt = (*driverStmt)(nil) 53 _ driver.Driver = (*sqlDriver)(nil) 54 55 txBeginSQL = "BEGIN;" 56 txCommitSQL = "COMMIT;" 57 txRollbackSQL = "ROLLBACK;" 58 59 errNoResult = errors.New("query statement does not produce a result set (no top level SELECT)") 60 ) 61 62 type errList []error 63 64 type driverParams struct { 65 storePath string 66 dbName string 67 // when set to true `mysql.Time` isn't encoded as string but passed as `time.Time` 68 // this option is named for compatibility the same as in the mysql driver 69 // while we actually do not have additional parsing to do 70 parseTime bool 71 } 72 73 func (e *errList) append(err error) { 74 if err != nil { 75 *e = append(*e, err) 76 } 77 } 78 79 func (e errList) error() error { 80 if len(e) == 0 { 81 return nil 82 } 83 84 return e 85 } 86 87 func (e errList) Error() string { 88 a := make([]string, len(e)) 89 for i, v := range e { 90 a[i] = v.Error() 91 } 92 return strings.Join(a, "\n") 93 } 94 95 func params(args []driver.Value) []interface{} { 96 r := make([]interface{}, len(args)) 97 for i, v := range args { 98 r[i] = interface{}(v) 99 } 100 return r 101 } 102 103 var ( 104 tidbDriver = &sqlDriver{} 105 driverOnce sync.Once 106 ) 107 108 // RegisterDriver registers TiDB driver. 109 // The name argument can be optionally prefixed by "engine://". In that case the 110 // prefix is recognized as a storage engine name. 111 // 112 // The name argument can be optionally prefixed by "memory://". In that case 113 // the prefix is stripped before interpreting it as a name of a memory-only, 114 // volatile DB. 115 // 116 // [0]: http://yougam/libraries/pkg/database/sql/driver/ 117 func RegisterDriver() { 118 driverOnce.Do(func() { sql.Register(DriverName, tidbDriver) }) 119 } 120 121 // sqlDriver implements the interface required by database/sql/driver. 122 type sqlDriver struct { 123 mu sync.Mutex 124 } 125 126 func (d *sqlDriver) lock() { 127 d.mu.Lock() 128 } 129 130 func (d *sqlDriver) unlock() { 131 d.mu.Unlock() 132 } 133 134 // parseDriverDSN cuts off DB name from dsn. It returns error if the dsn is not 135 // valid. 136 func parseDriverDSN(dsn string) (params *driverParams, err error) { 137 u, err := url.Parse(dsn) 138 if err != nil { 139 return nil, errors.Trace(err) 140 } 141 path := filepath.Join(u.Host, u.Path) 142 dbName := filepath.Clean(filepath.Base(path)) 143 if dbName == "" || dbName == "." || dbName == string(filepath.Separator) { 144 return nil, errors.Errorf("invalid DB name %q", dbName) 145 } 146 // cut off dbName 147 path = filepath.Clean(filepath.Dir(path)) 148 if path == "" || path == "." || path == string(filepath.Separator) { 149 return nil, errors.Errorf("invalid dsn %q", dsn) 150 } 151 u.Path, u.Host = path, "" 152 params = &driverParams{ 153 storePath: u.String(), 154 dbName: dbName, 155 } 156 // parse additional driver params 157 query := u.Query() 158 if parseTime := query.Get("parseTime"); parseTime == "true" { 159 params.parseTime = true 160 } 161 162 return params, nil 163 } 164 165 // Open returns a new connection to the database. 166 // 167 // The dsn must be a URL format 'engine://path/dbname?params'. 168 // Engine is the storage name registered with RegisterStore. 169 // Path is the storage specific format. 170 // Params is key-value pairs split by '&', optional params are storage specific. 171 // Examples: 172 // goleveldb://relative/path/test 173 // boltdb:///absolute/path/test 174 // hbase://zk1,zk2,zk3/hbasetbl/test?tso=zk 175 // 176 // Open may return a cached connection (one previously closed), but doing so is 177 // unnecessary; the sql package maintains a pool of idle connections for 178 // efficient re-use. 179 // 180 // The behavior of the mysql driver regarding time parsing can also be imitated 181 // by passing ?parseTime 182 // 183 // The returned connection is only used by one goroutine at a time. 184 func (d *sqlDriver) Open(dsn string) (driver.Conn, error) { 185 params, err := parseDriverDSN(dsn) 186 if err != nil { 187 return nil, errors.Trace(err) 188 } 189 store, err := NewStore(params.storePath) 190 if err != nil { 191 return nil, errors.Trace(err) 192 } 193 194 sess, err := CreateSession(store) 195 if err != nil { 196 return nil, errors.Trace(err) 197 } 198 s := sess.(*session) 199 200 d.lock() 201 defer d.unlock() 202 203 DBName := model.NewCIStr(params.dbName) 204 domain := sessionctx.GetDomain(s) 205 cs := &ast.CharsetOpt{ 206 Chs: "utf8", 207 Col: "utf8_bin", 208 } 209 if !domain.InfoSchema().SchemaExists(DBName) { 210 err = domain.DDL().CreateSchema(s, DBName, cs) 211 if err != nil { 212 return nil, errors.Trace(err) 213 } 214 } 215 driver := &sqlDriver{} 216 return newDriverConn(s, driver, DBName.O, params) 217 } 218 219 // driverConn is a connection to a database. It is not used concurrently by 220 // multiple goroutines. 221 // 222 // Conn is assumed to be stateful. 223 type driverConn struct { 224 s Session 225 driver *sqlDriver 226 stmts map[string]driver.Stmt 227 params *driverParams 228 } 229 230 func newDriverConn(sess *session, d *sqlDriver, schema string, params *driverParams) (driver.Conn, error) { 231 r := &driverConn{ 232 driver: d, 233 stmts: map[string]driver.Stmt{}, 234 s: sess, 235 params: params, 236 } 237 238 _, err := r.s.Execute("use " + schema) 239 if err != nil { 240 return nil, errors.Trace(err) 241 } 242 return r, nil 243 } 244 245 // Prepare returns a prepared statement, bound to this connection. 246 func (c *driverConn) Prepare(query string) (driver.Stmt, error) { 247 stmtID, paramCount, fields, err := c.s.PrepareStmt(query) 248 if err != nil { 249 return nil, err 250 } 251 s := &driverStmt{ 252 conn: c, 253 query: query, 254 stmtID: stmtID, 255 paramCount: paramCount, 256 isQuery: fields != nil, 257 } 258 c.stmts[query] = s 259 return s, nil 260 } 261 262 // Close invalidates and potentially stops any current prepared statements and 263 // transactions, marking this connection as no longer in use. 264 // 265 // Because the sql package maintains a free pool of connections and only calls 266 // Close when there's a surplus of idle connections, it shouldn't be necessary 267 // for drivers to do their own connection caching. 268 func (c *driverConn) Close() error { 269 var err errList 270 for _, s := range c.stmts { 271 stmt := s.(*driverStmt) 272 err.append(stmt.conn.s.DropPreparedStmt(stmt.stmtID)) 273 } 274 275 c.driver.lock() 276 defer c.driver.unlock() 277 278 return err.error() 279 } 280 281 // Begin starts and returns a new transaction. 282 func (c *driverConn) Begin() (driver.Tx, error) { 283 if c.s == nil { 284 return nil, errors.Errorf("Need init first") 285 } 286 287 if _, err := c.s.Execute(txBeginSQL); err != nil { 288 return nil, errors.Trace(err) 289 } 290 291 return c, nil 292 } 293 294 func (c *driverConn) Commit() error { 295 if c.s == nil { 296 return terror.CommitNotInTransaction 297 } 298 _, err := c.s.Execute(txCommitSQL) 299 300 if err != nil { 301 return errors.Trace(err) 302 } 303 304 err = c.s.FinishTxn(false) 305 return errors.Trace(err) 306 } 307 308 func (c *driverConn) Rollback() error { 309 if c.s == nil { 310 return terror.RollbackNotInTransaction 311 } 312 313 if _, err := c.s.Execute(txRollbackSQL); err != nil { 314 return errors.Trace(err) 315 } 316 317 return nil 318 } 319 320 // Execer is an optional interface that may be implemented by a Conn. 321 // 322 // If a Conn does not implement Execer, the sql package's DB.Exec will first 323 // prepare a query, execute the statement, and then close the statement. 324 // 325 // Exec may return driver.ErrSkip. 326 func (c *driverConn) Exec(query string, args []driver.Value) (driver.Result, error) { 327 return c.driverExec(query, args) 328 329 } 330 331 func (c *driverConn) getStmt(query string) (stmt driver.Stmt, err error) { 332 stmt, ok := c.stmts[query] 333 if !ok { 334 stmt, err = c.Prepare(query) 335 if err != nil { 336 return nil, errors.Trace(err) 337 } 338 } 339 return 340 } 341 342 func (c *driverConn) driverExec(query string, args []driver.Value) (driver.Result, error) { 343 if len(args) == 0 { 344 if _, err := c.s.Execute(query); err != nil { 345 return nil, errors.Trace(err) 346 } 347 r := &driverResult{} 348 r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows()) 349 return r, nil 350 } 351 stmt, err := c.getStmt(query) 352 if err != nil { 353 return nil, errors.Trace(err) 354 } 355 return stmt.Exec(args) 356 } 357 358 // Queryer is an optional interface that may be implemented by a Conn. 359 // 360 // If a Conn does not implement Queryer, the sql package's DB.Query will first 361 // prepare a query, execute the statement, and then close the statement. 362 // 363 // Query may return driver.ErrSkip. 364 func (c *driverConn) Query(query string, args []driver.Value) (driver.Rows, error) { 365 return c.driverQuery(query, args) 366 } 367 368 func (c *driverConn) driverQuery(query string, args []driver.Value) (driver.Rows, error) { 369 if len(args) == 0 { 370 rss, err := c.s.Execute(query) 371 if err != nil { 372 return nil, errors.Trace(err) 373 } 374 if len(rss) == 0 { 375 return nil, errors.Trace(errNoResult) 376 } 377 return &driverRows{params: c.params, rs: rss[0]}, nil 378 } 379 stmt, err := c.getStmt(query) 380 if err != nil { 381 return nil, errors.Trace(err) 382 } 383 return stmt.Query(args) 384 } 385 386 // driverResult is the result of a query execution. 387 type driverResult struct { 388 lastInsertID int64 389 rowsAffected int64 390 } 391 392 // LastInsertID returns the database's auto-generated ID after, for example, an 393 // INSERT into a table with primary key. 394 func (r *driverResult) LastInsertId() (int64, error) { // -golint 395 return r.lastInsertID, nil 396 } 397 398 // RowsAffected returns the number of rows affected by the query. 399 func (r *driverResult) RowsAffected() (int64, error) { 400 return r.rowsAffected, nil 401 } 402 403 // driverRows is an iterator over an executed query's results. 404 type driverRows struct { 405 rs ast.RecordSet 406 params *driverParams 407 } 408 409 // Columns returns the names of the columns. The number of columns of the 410 // result is inferred from the length of the slice. If a particular column 411 // name isn't known, an empty string should be returned for that entry. 412 func (r *driverRows) Columns() []string { 413 if r.rs == nil { 414 return []string{} 415 } 416 fs, _ := r.rs.Fields() 417 names := make([]string, len(fs)) 418 for i, f := range fs { 419 names[i] = f.ColumnAsName.O 420 } 421 return names 422 } 423 424 // Close closes the rows iterator. 425 func (r *driverRows) Close() error { 426 if r.rs != nil { 427 return r.rs.Close() 428 } 429 return nil 430 } 431 432 // Next is called to populate the next row of data into the provided slice. The 433 // provided slice will be the same size as the Columns() are wide. 434 // 435 // The dest slice may be populated only with a driver Value type, but excluding 436 // string. All string values must be converted to []byte. 437 // 438 // Next should return io.EOF when there are no more rows. 439 func (r *driverRows) Next(dest []driver.Value) error { 440 if r.rs == nil { 441 return io.EOF 442 } 443 row, err := r.rs.Next() 444 if err != nil { 445 return errors.Trace(err) 446 } 447 if row == nil { 448 return io.EOF 449 } 450 if len(row.Data) != len(dest) { 451 return errors.Errorf("field count mismatch: got %d, need %d", len(row.Data), len(dest)) 452 } 453 for i, xi := range row.Data { 454 switch xi.Kind() { 455 case types.KindNull: 456 dest[i] = nil 457 case types.KindInt64: 458 dest[i] = xi.GetInt64() 459 case types.KindUint64: 460 dest[i] = xi.GetUint64() 461 case types.KindFloat32: 462 dest[i] = xi.GetFloat32() 463 case types.KindFloat64: 464 dest[i] = xi.GetFloat64() 465 case types.KindString: 466 dest[i] = xi.GetString() 467 case types.KindBytes: 468 dest[i] = xi.GetBytes() 469 case types.KindMysqlBit: 470 dest[i] = xi.GetMysqlBit().ToString() 471 case types.KindMysqlDecimal: 472 dest[i] = xi.GetMysqlDecimal().String() 473 case types.KindMysqlDuration: 474 dest[i] = xi.GetMysqlDuration().String() 475 case types.KindMysqlEnum: 476 dest[i] = xi.GetMysqlEnum().String() 477 case types.KindMysqlHex: 478 dest[i] = xi.GetMysqlHex().ToString() 479 case types.KindMysqlSet: 480 dest[i] = xi.GetMysqlSet().String() 481 case types.KindMysqlTime: 482 t := xi.GetMysqlTime() 483 if !r.params.parseTime { 484 dest[i] = t.String() 485 } else { 486 dest[i] = t.Time 487 } 488 default: 489 return errors.Errorf("unable to handle type %T", xi.GetValue()) 490 } 491 } 492 return nil 493 } 494 495 // driverStmt is a prepared statement. It is bound to a driverConn and not used 496 // by multiple goroutines concurrently. 497 type driverStmt struct { 498 conn *driverConn 499 query string 500 stmtID uint32 501 paramCount int 502 isQuery bool 503 } 504 505 // Close closes the statement. 506 // 507 // As of Go 1.1, a Stmt will not be closed if it's in use by any queries. 508 func (s *driverStmt) Close() error { 509 s.conn.s.DropPreparedStmt(s.stmtID) 510 delete(s.conn.stmts, s.query) 511 return nil 512 } 513 514 // NumInput returns the number of placeholder parameters. 515 // 516 // If NumInput returns >= 0, the sql package will sanity check argument counts 517 // from callers and return errors to the caller before the statement's Exec or 518 // Query methods are called. 519 // 520 // NumInput may also return -1, if the driver doesn't know its number of 521 // placeholders. In that case, the sql package will not sanity check Exec or 522 // Query argument counts. 523 func (s *driverStmt) NumInput() int { 524 return s.paramCount 525 } 526 527 // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE. 528 func (s *driverStmt) Exec(args []driver.Value) (driver.Result, error) { 529 c := s.conn 530 _, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...) 531 if err != nil { 532 return nil, errors.Trace(err) 533 } 534 r := &driverResult{} 535 if s != nil { 536 r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows()) 537 } 538 return r, nil 539 } 540 541 // Exec executes a query that may return rows, such as a SELECT. 542 func (s *driverStmt) Query(args []driver.Value) (driver.Rows, error) { 543 c := s.conn 544 rs, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...) 545 if err != nil { 546 return nil, errors.Trace(err) 547 } 548 if rs == nil { 549 if s.isQuery { 550 return nil, errors.Trace(errNoResult) 551 } 552 // The statement is not a query. 553 return &driverRows{}, nil 554 } 555 return &driverRows{params: s.conn.params, rs: rs}, nil 556 } 557 558 func init() { 559 RegisterDriver() 560 }