vitess.io/vitess@v0.16.2/go/vt/vitessdriver/driver.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package vitessdriver 18 19 import ( 20 "context" 21 "database/sql" 22 "database/sql/driver" 23 "encoding/base64" 24 "encoding/json" 25 "errors" 26 "fmt" 27 28 "google.golang.org/grpc" 29 "google.golang.org/protobuf/proto" 30 31 "vitess.io/vitess/go/sqltypes" 32 querypb "vitess.io/vitess/go/vt/proto/query" 33 vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" 34 "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" 35 "vitess.io/vitess/go/vt/vtgate/vtgateconn" 36 ) 37 38 var ( 39 errNoIntermixing = errors.New("named and positional arguments intermixing disallowed") 40 errIsolationUnsupported = errors.New("isolation levels are not supported") 41 ) 42 43 // Type-check interfaces. 44 var ( 45 _ driver.QueryerContext = &conn{} 46 _ driver.ExecerContext = &conn{} 47 _ driver.StmtQueryContext = &stmt{} 48 _ driver.StmtExecContext = &stmt{} 49 ) 50 51 func init() { 52 sql.Register("vitess", drv{}) 53 } 54 55 // Open is a Vitess helper function for sql.Open(). 56 // 57 // It opens a database connection to vtgate running at "address". 58 func Open(address, target string) (*sql.DB, error) { 59 c := Configuration{ 60 Address: address, 61 Target: target, 62 } 63 return OpenWithConfiguration(c) 64 } 65 66 // OpenForStreaming is the same as Open() but uses streaming RPCs to retrieve 67 // the results. 68 // 69 // The streaming mode is recommended for large results. 70 func OpenForStreaming(address, target string) (*sql.DB, error) { 71 c := Configuration{ 72 Address: address, 73 Target: target, 74 Streaming: true, 75 } 76 return OpenWithConfiguration(c) 77 } 78 79 // OpenWithConfiguration is the generic Vitess helper function for sql.Open(). 80 // 81 // It allows to pass in a Configuration struct to control all possible 82 // settings of the Vitess Go SQL driver. 83 func OpenWithConfiguration(c Configuration) (*sql.DB, error) { 84 c.setDefaults() 85 86 json, err := c.toJSON() 87 if err != nil { 88 return nil, err 89 } 90 91 if len(c.GRPCDialOptions) != 0 { 92 vtgateconn.RegisterDialer(c.Protocol, grpcvtgateconn.DialWithOpts(context.TODO(), c.GRPCDialOptions...)) 93 } 94 95 return sql.Open(c.DriverName, json) 96 } 97 98 type drv struct { 99 } 100 101 // Open implements the database/sql/driver.Driver interface. 102 // 103 // For "name", the Vitess driver requires that a JSON object is passed in. 104 // 105 // Instead of using this call and passing in a hand-crafted JSON string, it's 106 // recommended to use the public Vitess helper functions like 107 // Open(), OpenShard() or OpenWithConfiguration() instead. These will generate 108 // the required JSON string behind the scenes for you. 109 // 110 // Example for a JSON string: 111 // 112 // {"protocol": "grpc", "address": "localhost:1111", "target": "@primary"} 113 // 114 // For a description of the available fields, see the Configuration struct. 115 func (d drv) Open(name string) (driver.Conn, error) { 116 c := &conn{} 117 err := json.Unmarshal([]byte(name), c) 118 if err != nil { 119 return nil, err 120 } 121 122 c.setDefaults() 123 124 if c.convert, err = newConverter(&c.Configuration); err != nil { 125 return nil, err 126 } 127 128 if err = c.dial(); err != nil { 129 return nil, err 130 } 131 132 return c, nil 133 } 134 135 // Configuration holds all Vitess driver settings. 136 // 137 // Fields with documented default values do not have to be set explicitly. 138 type Configuration struct { 139 // Protocol is the name of the vtgate RPC client implementation. 140 // Note: In open-source "grpc" is the recommended implementation. 141 // 142 // Default: "grpc" 143 Protocol string 144 145 // Address must point to a vtgate instance. 146 // 147 // Format: hostname:port 148 Address string 149 150 // Target specifies the default target. 151 Target string 152 153 // Streaming is true when streaming RPCs are used. 154 // Recommended for large results. 155 // Default: false 156 Streaming bool 157 158 // DefaultLocation is the timezone string that will be used 159 // when converting DATETIME and DATE into time.Time. 160 // This setting has no effect if ConvertDatetime is not set. 161 // Default: UTC 162 DefaultLocation string 163 164 // GRPCDialOptions registers a new vtgateconn dialer with these dial options using the 165 // protocol as the key. This may overwrite the default grpcvtgateconn dial option 166 // if a custom one hasn't been specified in the config. 167 // 168 // Default: none 169 GRPCDialOptions []grpc.DialOption `json:"-"` 170 171 // Driver is the name registered with the database/sql package. This override 172 // is here in case you have wrapped the driver for stats or other interceptors. 173 // 174 // Default: "vitess" 175 DriverName string `json:"-"` 176 177 // SessionToken is a protobuf encoded vtgatepb.Session represented as base64, which 178 // can be used to distribute a transaction over the wire. 179 SessionToken string 180 } 181 182 // toJSON converts Configuration to the JSON string which is required by the 183 // Vitess driver. Default values for empty fields will be set. 184 func (c Configuration) toJSON() (string, error) { 185 jsonBytes, err := json.Marshal(c) 186 if err != nil { 187 return "", err 188 } 189 return string(jsonBytes), nil 190 } 191 192 // setDefaults sets the default values for empty fields. 193 func (c *Configuration) setDefaults() { 194 // if no protocol is provided default to grpc so the driver is in control 195 // of the connection protocol and not the flag vtgateconn.VtgateProtocol 196 if c.Protocol == "" { 197 c.Protocol = "grpc" 198 } 199 200 if c.DriverName == "" { 201 c.DriverName = "vitess" 202 } 203 } 204 205 type conn struct { 206 Configuration 207 convert *converter 208 conn *vtgateconn.VTGateConn 209 session *vtgateconn.VTGateSession 210 } 211 212 func (c *conn) dial() error { 213 var err error 214 c.conn, err = vtgateconn.DialProtocol(context.Background(), c.Protocol, c.Address) 215 if err != nil { 216 return err 217 } 218 if c.Configuration.SessionToken != "" { 219 sessionFromToken, err := sessionTokenToSession(c.Configuration.SessionToken) 220 if err != nil { 221 return err 222 } 223 c.session = c.conn.SessionFromPb(sessionFromToken) 224 } else { 225 c.session = c.conn.Session(c.Target, nil) 226 } 227 return nil 228 } 229 230 func (c *conn) Ping(ctx context.Context) error { 231 if c.Streaming { 232 return errors.New("Ping not allowed for streaming connections") 233 } 234 235 _, err := c.ExecContext(ctx, "select 1", nil) 236 return err 237 } 238 239 func (c *conn) Prepare(query string) (driver.Stmt, error) { 240 return &stmt{c: c, query: query}, nil 241 } 242 243 func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 244 return c.Prepare(query) 245 } 246 247 func (c *conn) Close() error { 248 c.conn.Close() 249 return nil 250 } 251 252 // DistributedTxFromSessionToken allows users to send serialized sessions over the wire and 253 // reconnect to an existing transaction. Setting the sessionToken and address on the 254 // supplied configuration is the minimum required 255 // WARNING: the original Tx must already have already done work on all shards to be affected, 256 // otherwise the ShardSessions will not be sent through in the session token, and thus will 257 // never be committed in the source. The returned validation function checks to make sure that 258 // the new transaction work has not added any new ShardSessions. 259 func DistributedTxFromSessionToken(ctx context.Context, c Configuration) (*sql.Tx, func() error, error) { 260 if c.SessionToken == "" { 261 return nil, nil, errors.New("c.SessionToken is required") 262 } 263 264 session, err := sessionTokenToSession(c.SessionToken) 265 if err != nil { 266 return nil, nil, err 267 } 268 269 // if there isn't 1 or more shards already referenced, no work in this Tx can be committed 270 originalShardSessionCount := len(session.ShardSessions) 271 if originalShardSessionCount == 0 { 272 return nil, nil, errors.New("there must be at least 1 ShardSession") 273 } 274 275 db, err := OpenWithConfiguration(c) 276 if err != nil { 277 return nil, nil, err 278 } 279 280 // this should return the only connection associated with the db 281 tx, err := db.BeginTx(ctx, nil) 282 if err != nil { 283 return nil, nil, err 284 } 285 286 // this is designed to be run after all new work has been done in the tx, similar to 287 // where you would traditionally run a tx.Commit, to help prevent you from silently 288 // losing transactional data. 289 validationFunc := func() error { 290 var sessionToken string 291 sessionToken, err = SessionTokenFromTx(ctx, tx) 292 if err != nil { 293 return err 294 } 295 296 session, err = sessionTokenToSession(sessionToken) 297 if err != nil { 298 return err 299 } 300 301 if len(session.ShardSessions) > originalShardSessionCount { 302 return fmt.Errorf("mismatched ShardSession count: originally %d, now %d", 303 originalShardSessionCount, len(session.ShardSessions), 304 ) 305 } 306 307 return nil 308 } 309 310 return tx, validationFunc, nil 311 } 312 313 // SessionTokenFromTx serializes the sessionFromToken on the tx, which can be reconstituted 314 // into a *sql.Tx using DistributedTxFromSessionToken 315 func SessionTokenFromTx(ctx context.Context, tx *sql.Tx) (string, error) { 316 var sessionToken string 317 318 err := tx.QueryRowContext(ctx, "vt_session_token").Scan(&sessionToken) 319 if err != nil { 320 return "", err 321 } 322 323 session, err := sessionTokenToSession(sessionToken) 324 if err != nil { 325 return "", err 326 } 327 328 // if there isn't 1 or more shards already referenced, no work in this Tx can be committed 329 originalShardSessionCount := len(session.ShardSessions) 330 if originalShardSessionCount == 0 { 331 return "", errors.New("there must be at least 1 ShardSession") 332 } 333 334 return sessionToken, nil 335 } 336 337 func newSessionTokenRow(session *vtgatepb.Session, c *converter) (driver.Rows, error) { 338 sessionToken, err := sessionToSessionToken(session) 339 if err != nil { 340 return nil, err 341 } 342 343 qr := sqltypes.Result{ 344 Fields: []*querypb.Field{{ 345 Name: "vt_session_token", 346 Type: sqltypes.VarBinary, 347 }}, 348 Rows: [][]sqltypes.Value{{ 349 sqltypes.NewVarBinary(sessionToken), 350 }}, 351 } 352 353 return newRows(&qr, c), nil 354 } 355 356 func sessionToSessionToken(session *vtgatepb.Session) (string, error) { 357 b, err := proto.Marshal(session) 358 if err != nil { 359 return "", err 360 } 361 362 return base64.StdEncoding.EncodeToString(b), nil 363 } 364 365 func sessionTokenToSession(sessionToken string) (*vtgatepb.Session, error) { 366 b, err := base64.StdEncoding.DecodeString(sessionToken) 367 if err != nil { 368 return nil, err 369 } 370 371 session := &vtgatepb.Session{} 372 err = proto.Unmarshal(b, session) 373 if err != nil { 374 return nil, err 375 } 376 377 return session, nil 378 } 379 380 func (c *conn) Begin() (driver.Tx, error) { 381 // if we're loading from an existing session, we need to avoid starting a new transaction 382 if c.Configuration.SessionToken != "" { 383 return c, nil 384 } 385 386 if _, err := c.Exec("begin", nil); err != nil { 387 return nil, err 388 } 389 return c, nil 390 } 391 392 func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) { 393 // We don't use the context. The function signature accepts the context 394 // to signal to the driver that it's allowed to call Rollback on Cancel. 395 if opts.Isolation != driver.IsolationLevel(0) || opts.ReadOnly { 396 return nil, errIsolationUnsupported 397 } 398 return c.Begin() 399 } 400 401 func (c *conn) Commit() error { 402 // if we're loading from an existing session, disallow committing/rolling back the transaction 403 // this isn't a technical limitation, but is enforced to prevent misuse, so that only 404 // the original creator of the transaction can commit/rollback 405 if c.Configuration.SessionToken != "" { 406 return errors.New("calling Commit from a distributed tx is not allowed") 407 } 408 409 _, err := c.Exec("commit", nil) 410 return err 411 } 412 413 func (c *conn) Rollback() error { 414 // if we're loading from an existing session, disallow committing/rolling back the transaction 415 // this isn't a technical limitation, but is enforced to prevent misuse, so that only 416 // the original creator of the transaction can commit/rollback 417 if c.Configuration.SessionToken != "" { 418 return errors.New("calling Rollback from a distributed tx is not allowed") 419 } 420 421 _, err := c.Exec("rollback", nil) 422 return err 423 } 424 425 func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { 426 ctx := context.TODO() 427 428 if c.Streaming { 429 return nil, errors.New("Exec not allowed for streaming connections") 430 } 431 bindVars, err := c.convert.buildBindVars(args) 432 if err != nil { 433 return nil, err 434 } 435 436 qr, err := c.session.Execute(ctx, query, bindVars) 437 if err != nil { 438 return nil, err 439 } 440 return result{int64(qr.InsertID), int64(qr.RowsAffected)}, nil 441 } 442 443 func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 444 if c.Streaming { 445 return nil, errors.New("Exec not allowed for streaming connections") 446 } 447 448 bv, err := c.convert.bindVarsFromNamedValues(args) 449 if err != nil { 450 return nil, err 451 } 452 qr, err := c.session.Execute(ctx, query, bv) 453 if err != nil { 454 return nil, err 455 } 456 return result{int64(qr.InsertID), int64(qr.RowsAffected)}, nil 457 } 458 459 func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 460 ctx := context.TODO() 461 bindVars, err := c.convert.buildBindVars(args) 462 if err != nil { 463 return nil, err 464 } 465 466 if c.Streaming { 467 stream, err := c.session.StreamExecute(ctx, query, bindVars) 468 if err != nil { 469 return nil, err 470 } 471 return newStreamingRows(stream, c.convert), nil 472 } 473 474 qr, err := c.session.Execute(ctx, query, bindVars) 475 if err != nil { 476 return nil, err 477 } 478 return newRows(qr, c.convert), nil 479 } 480 481 func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 482 // special case for serializing the current sessionFromToken state 483 if query == "vt_session_token" { 484 return newSessionTokenRow(c.session.SessionPb(), c.convert) 485 } 486 487 bv, err := c.convert.bindVarsFromNamedValues(args) 488 if err != nil { 489 return nil, err 490 } 491 492 if c.Streaming { 493 stream, err := c.session.StreamExecute(ctx, query, bv) 494 if err != nil { 495 return nil, err 496 } 497 return newStreamingRows(stream, c.convert), nil 498 } 499 500 qr, err := c.session.Execute(ctx, query, bv) 501 if err != nil { 502 return nil, err 503 } 504 return newRows(qr, c.convert), nil 505 } 506 507 type stmt struct { 508 c *conn 509 query string 510 } 511 512 func (s *stmt) Close() error { 513 return nil 514 } 515 516 func (s *stmt) NumInput() int { 517 // -1 = Golang sql won't sanity check argument counts before Exec or Query. 518 return -1 519 } 520 521 func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { 522 return s.c.Exec(s.query, args) 523 } 524 525 func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 526 return s.c.ExecContext(ctx, s.query, args) 527 } 528 529 func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { 530 return s.c.Query(s.query, args) 531 } 532 533 func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 534 return s.c.QueryContext(ctx, s.query, args) 535 } 536 537 type result struct { 538 insertid, rowsaffected int64 539 } 540 541 func (r result) LastInsertId() (int64, error) { 542 return r.insertid, nil 543 } 544 545 func (r result) RowsAffected() (int64, error) { 546 return r.rowsaffected, nil 547 }