vitess.io/vitess@v0.16.2/go/vt/vitessdriver/driver.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vitessdriver
    18  
    19  import (
    20  	"context"
    21  	"database/sql"
    22  	"database/sql/driver"
    23  	"encoding/base64"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/protobuf/proto"
    30  
    31  	"vitess.io/vitess/go/sqltypes"
    32  	querypb "vitess.io/vitess/go/vt/proto/query"
    33  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    34  	"vitess.io/vitess/go/vt/vtgate/grpcvtgateconn"
    35  	"vitess.io/vitess/go/vt/vtgate/vtgateconn"
    36  )
    37  
    38  var (
    39  	errNoIntermixing        = errors.New("named and positional arguments intermixing disallowed")
    40  	errIsolationUnsupported = errors.New("isolation levels are not supported")
    41  )
    42  
    43  // Type-check interfaces.
    44  var (
    45  	_ driver.QueryerContext   = &conn{}
    46  	_ driver.ExecerContext    = &conn{}
    47  	_ driver.StmtQueryContext = &stmt{}
    48  	_ driver.StmtExecContext  = &stmt{}
    49  )
    50  
    51  func init() {
    52  	sql.Register("vitess", drv{})
    53  }
    54  
    55  // Open is a Vitess helper function for sql.Open().
    56  //
    57  // It opens a database connection to vtgate running at "address".
    58  func Open(address, target string) (*sql.DB, error) {
    59  	c := Configuration{
    60  		Address: address,
    61  		Target:  target,
    62  	}
    63  	return OpenWithConfiguration(c)
    64  }
    65  
    66  // OpenForStreaming is the same as Open() but uses streaming RPCs to retrieve
    67  // the results.
    68  //
    69  // The streaming mode is recommended for large results.
    70  func OpenForStreaming(address, target string) (*sql.DB, error) {
    71  	c := Configuration{
    72  		Address:   address,
    73  		Target:    target,
    74  		Streaming: true,
    75  	}
    76  	return OpenWithConfiguration(c)
    77  }
    78  
    79  // OpenWithConfiguration is the generic Vitess helper function for sql.Open().
    80  //
    81  // It allows to pass in a Configuration struct to control all possible
    82  // settings of the Vitess Go SQL driver.
    83  func OpenWithConfiguration(c Configuration) (*sql.DB, error) {
    84  	c.setDefaults()
    85  
    86  	json, err := c.toJSON()
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	if len(c.GRPCDialOptions) != 0 {
    92  		vtgateconn.RegisterDialer(c.Protocol, grpcvtgateconn.DialWithOpts(context.TODO(), c.GRPCDialOptions...))
    93  	}
    94  
    95  	return sql.Open(c.DriverName, json)
    96  }
    97  
    98  type drv struct {
    99  }
   100  
   101  // Open implements the database/sql/driver.Driver interface.
   102  //
   103  // For "name", the Vitess driver requires that a JSON object is passed in.
   104  //
   105  // Instead of using this call and passing in a hand-crafted JSON string, it's
   106  // recommended to use the public Vitess helper functions like
   107  // Open(), OpenShard() or OpenWithConfiguration() instead. These will generate
   108  // the required JSON string behind the scenes for you.
   109  //
   110  // Example for a JSON string:
   111  //
   112  //	{"protocol": "grpc", "address": "localhost:1111", "target": "@primary"}
   113  //
   114  // For a description of the available fields, see the Configuration struct.
   115  func (d drv) Open(name string) (driver.Conn, error) {
   116  	c := &conn{}
   117  	err := json.Unmarshal([]byte(name), c)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	c.setDefaults()
   123  
   124  	if c.convert, err = newConverter(&c.Configuration); err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	if err = c.dial(); err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	return c, nil
   133  }
   134  
   135  // Configuration holds all Vitess driver settings.
   136  //
   137  // Fields with documented default values do not have to be set explicitly.
   138  type Configuration struct {
   139  	// Protocol is the name of the vtgate RPC client implementation.
   140  	// Note: In open-source "grpc" is the recommended implementation.
   141  	//
   142  	// Default: "grpc"
   143  	Protocol string
   144  
   145  	// Address must point to a vtgate instance.
   146  	//
   147  	// Format: hostname:port
   148  	Address string
   149  
   150  	// Target specifies the default target.
   151  	Target string
   152  
   153  	// Streaming is true when streaming RPCs are used.
   154  	// Recommended for large results.
   155  	// Default: false
   156  	Streaming bool
   157  
   158  	// DefaultLocation is the timezone string that will be used
   159  	// when converting DATETIME and DATE into time.Time.
   160  	// This setting has no effect if ConvertDatetime is not set.
   161  	// Default: UTC
   162  	DefaultLocation string
   163  
   164  	// GRPCDialOptions registers a new vtgateconn dialer with these dial options using the
   165  	// protocol as the key. This may overwrite the default grpcvtgateconn dial option
   166  	// if a custom one hasn't been specified in the config.
   167  	//
   168  	// Default: none
   169  	GRPCDialOptions []grpc.DialOption `json:"-"`
   170  
   171  	// Driver is the name registered with the database/sql package. This override
   172  	// is here in case you have wrapped the driver for stats or other interceptors.
   173  	//
   174  	// Default: "vitess"
   175  	DriverName string `json:"-"`
   176  
   177  	// SessionToken is a protobuf encoded vtgatepb.Session represented as base64, which
   178  	// can be used to distribute a transaction over the wire.
   179  	SessionToken string
   180  }
   181  
   182  // toJSON converts Configuration to the JSON string which is required by the
   183  // Vitess driver. Default values for empty fields will be set.
   184  func (c Configuration) toJSON() (string, error) {
   185  	jsonBytes, err := json.Marshal(c)
   186  	if err != nil {
   187  		return "", err
   188  	}
   189  	return string(jsonBytes), nil
   190  }
   191  
   192  // setDefaults sets the default values for empty fields.
   193  func (c *Configuration) setDefaults() {
   194  	// if no protocol is provided default to grpc so the driver is in control
   195  	// of the connection protocol and not the flag vtgateconn.VtgateProtocol
   196  	if c.Protocol == "" {
   197  		c.Protocol = "grpc"
   198  	}
   199  
   200  	if c.DriverName == "" {
   201  		c.DriverName = "vitess"
   202  	}
   203  }
   204  
   205  type conn struct {
   206  	Configuration
   207  	convert *converter
   208  	conn    *vtgateconn.VTGateConn
   209  	session *vtgateconn.VTGateSession
   210  }
   211  
   212  func (c *conn) dial() error {
   213  	var err error
   214  	c.conn, err = vtgateconn.DialProtocol(context.Background(), c.Protocol, c.Address)
   215  	if err != nil {
   216  		return err
   217  	}
   218  	if c.Configuration.SessionToken != "" {
   219  		sessionFromToken, err := sessionTokenToSession(c.Configuration.SessionToken)
   220  		if err != nil {
   221  			return err
   222  		}
   223  		c.session = c.conn.SessionFromPb(sessionFromToken)
   224  	} else {
   225  		c.session = c.conn.Session(c.Target, nil)
   226  	}
   227  	return nil
   228  }
   229  
   230  func (c *conn) Ping(ctx context.Context) error {
   231  	if c.Streaming {
   232  		return errors.New("Ping not allowed for streaming connections")
   233  	}
   234  
   235  	_, err := c.ExecContext(ctx, "select 1", nil)
   236  	return err
   237  }
   238  
   239  func (c *conn) Prepare(query string) (driver.Stmt, error) {
   240  	return &stmt{c: c, query: query}, nil
   241  }
   242  
   243  func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   244  	return c.Prepare(query)
   245  }
   246  
   247  func (c *conn) Close() error {
   248  	c.conn.Close()
   249  	return nil
   250  }
   251  
   252  // DistributedTxFromSessionToken allows users to send serialized sessions over the wire and
   253  // reconnect to an existing transaction. Setting the sessionToken and address on the
   254  // supplied configuration is the minimum required
   255  // WARNING: the original Tx must already have already done work on all shards to be affected,
   256  // otherwise the ShardSessions will not be sent through in the session token, and thus will
   257  // never be committed in the source. The returned validation function checks to make sure that
   258  // the new transaction work has not added any new ShardSessions.
   259  func DistributedTxFromSessionToken(ctx context.Context, c Configuration) (*sql.Tx, func() error, error) {
   260  	if c.SessionToken == "" {
   261  		return nil, nil, errors.New("c.SessionToken is required")
   262  	}
   263  
   264  	session, err := sessionTokenToSession(c.SessionToken)
   265  	if err != nil {
   266  		return nil, nil, err
   267  	}
   268  
   269  	// if there isn't 1 or more shards already referenced, no work in this Tx can be committed
   270  	originalShardSessionCount := len(session.ShardSessions)
   271  	if originalShardSessionCount == 0 {
   272  		return nil, nil, errors.New("there must be at least 1 ShardSession")
   273  	}
   274  
   275  	db, err := OpenWithConfiguration(c)
   276  	if err != nil {
   277  		return nil, nil, err
   278  	}
   279  
   280  	// this should return the only connection associated with the db
   281  	tx, err := db.BeginTx(ctx, nil)
   282  	if err != nil {
   283  		return nil, nil, err
   284  	}
   285  
   286  	// this is designed to be run after all new work has been done in the tx, similar to
   287  	// where you would traditionally run a tx.Commit, to help prevent you from silently
   288  	// losing transactional data.
   289  	validationFunc := func() error {
   290  		var sessionToken string
   291  		sessionToken, err = SessionTokenFromTx(ctx, tx)
   292  		if err != nil {
   293  			return err
   294  		}
   295  
   296  		session, err = sessionTokenToSession(sessionToken)
   297  		if err != nil {
   298  			return err
   299  		}
   300  
   301  		if len(session.ShardSessions) > originalShardSessionCount {
   302  			return fmt.Errorf("mismatched ShardSession count: originally %d, now %d",
   303  				originalShardSessionCount, len(session.ShardSessions),
   304  			)
   305  		}
   306  
   307  		return nil
   308  	}
   309  
   310  	return tx, validationFunc, nil
   311  }
   312  
   313  // SessionTokenFromTx serializes the sessionFromToken on the tx, which can be reconstituted
   314  // into a *sql.Tx using DistributedTxFromSessionToken
   315  func SessionTokenFromTx(ctx context.Context, tx *sql.Tx) (string, error) {
   316  	var sessionToken string
   317  
   318  	err := tx.QueryRowContext(ctx, "vt_session_token").Scan(&sessionToken)
   319  	if err != nil {
   320  		return "", err
   321  	}
   322  
   323  	session, err := sessionTokenToSession(sessionToken)
   324  	if err != nil {
   325  		return "", err
   326  	}
   327  
   328  	// if there isn't 1 or more shards already referenced, no work in this Tx can be committed
   329  	originalShardSessionCount := len(session.ShardSessions)
   330  	if originalShardSessionCount == 0 {
   331  		return "", errors.New("there must be at least 1 ShardSession")
   332  	}
   333  
   334  	return sessionToken, nil
   335  }
   336  
   337  func newSessionTokenRow(session *vtgatepb.Session, c *converter) (driver.Rows, error) {
   338  	sessionToken, err := sessionToSessionToken(session)
   339  	if err != nil {
   340  		return nil, err
   341  	}
   342  
   343  	qr := sqltypes.Result{
   344  		Fields: []*querypb.Field{{
   345  			Name: "vt_session_token",
   346  			Type: sqltypes.VarBinary,
   347  		}},
   348  		Rows: [][]sqltypes.Value{{
   349  			sqltypes.NewVarBinary(sessionToken),
   350  		}},
   351  	}
   352  
   353  	return newRows(&qr, c), nil
   354  }
   355  
   356  func sessionToSessionToken(session *vtgatepb.Session) (string, error) {
   357  	b, err := proto.Marshal(session)
   358  	if err != nil {
   359  		return "", err
   360  	}
   361  
   362  	return base64.StdEncoding.EncodeToString(b), nil
   363  }
   364  
   365  func sessionTokenToSession(sessionToken string) (*vtgatepb.Session, error) {
   366  	b, err := base64.StdEncoding.DecodeString(sessionToken)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  
   371  	session := &vtgatepb.Session{}
   372  	err = proto.Unmarshal(b, session)
   373  	if err != nil {
   374  		return nil, err
   375  	}
   376  
   377  	return session, nil
   378  }
   379  
   380  func (c *conn) Begin() (driver.Tx, error) {
   381  	// if we're loading from an existing session, we need to avoid starting a new transaction
   382  	if c.Configuration.SessionToken != "" {
   383  		return c, nil
   384  	}
   385  
   386  	if _, err := c.Exec("begin", nil); err != nil {
   387  		return nil, err
   388  	}
   389  	return c, nil
   390  }
   391  
   392  func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
   393  	// We don't use the context. The function signature accepts the context
   394  	// to signal to the driver that it's allowed to call Rollback on Cancel.
   395  	if opts.Isolation != driver.IsolationLevel(0) || opts.ReadOnly {
   396  		return nil, errIsolationUnsupported
   397  	}
   398  	return c.Begin()
   399  }
   400  
   401  func (c *conn) Commit() error {
   402  	// if we're loading from an existing session, disallow committing/rolling back the transaction
   403  	// this isn't a technical limitation, but is enforced to prevent misuse, so that only
   404  	// the original creator of the transaction can commit/rollback
   405  	if c.Configuration.SessionToken != "" {
   406  		return errors.New("calling Commit from a distributed tx is not allowed")
   407  	}
   408  
   409  	_, err := c.Exec("commit", nil)
   410  	return err
   411  }
   412  
   413  func (c *conn) Rollback() error {
   414  	// if we're loading from an existing session, disallow committing/rolling back the transaction
   415  	// this isn't a technical limitation, but is enforced to prevent misuse, so that only
   416  	// the original creator of the transaction can commit/rollback
   417  	if c.Configuration.SessionToken != "" {
   418  		return errors.New("calling Rollback from a distributed tx is not allowed")
   419  	}
   420  
   421  	_, err := c.Exec("rollback", nil)
   422  	return err
   423  }
   424  
   425  func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
   426  	ctx := context.TODO()
   427  
   428  	if c.Streaming {
   429  		return nil, errors.New("Exec not allowed for streaming connections")
   430  	}
   431  	bindVars, err := c.convert.buildBindVars(args)
   432  	if err != nil {
   433  		return nil, err
   434  	}
   435  
   436  	qr, err := c.session.Execute(ctx, query, bindVars)
   437  	if err != nil {
   438  		return nil, err
   439  	}
   440  	return result{int64(qr.InsertID), int64(qr.RowsAffected)}, nil
   441  }
   442  
   443  func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   444  	if c.Streaming {
   445  		return nil, errors.New("Exec not allowed for streaming connections")
   446  	}
   447  
   448  	bv, err := c.convert.bindVarsFromNamedValues(args)
   449  	if err != nil {
   450  		return nil, err
   451  	}
   452  	qr, err := c.session.Execute(ctx, query, bv)
   453  	if err != nil {
   454  		return nil, err
   455  	}
   456  	return result{int64(qr.InsertID), int64(qr.RowsAffected)}, nil
   457  }
   458  
   459  func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
   460  	ctx := context.TODO()
   461  	bindVars, err := c.convert.buildBindVars(args)
   462  	if err != nil {
   463  		return nil, err
   464  	}
   465  
   466  	if c.Streaming {
   467  		stream, err := c.session.StreamExecute(ctx, query, bindVars)
   468  		if err != nil {
   469  			return nil, err
   470  		}
   471  		return newStreamingRows(stream, c.convert), nil
   472  	}
   473  
   474  	qr, err := c.session.Execute(ctx, query, bindVars)
   475  	if err != nil {
   476  		return nil, err
   477  	}
   478  	return newRows(qr, c.convert), nil
   479  }
   480  
   481  func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   482  	// special case for serializing the current sessionFromToken state
   483  	if query == "vt_session_token" {
   484  		return newSessionTokenRow(c.session.SessionPb(), c.convert)
   485  	}
   486  
   487  	bv, err := c.convert.bindVarsFromNamedValues(args)
   488  	if err != nil {
   489  		return nil, err
   490  	}
   491  
   492  	if c.Streaming {
   493  		stream, err := c.session.StreamExecute(ctx, query, bv)
   494  		if err != nil {
   495  			return nil, err
   496  		}
   497  		return newStreamingRows(stream, c.convert), nil
   498  	}
   499  
   500  	qr, err := c.session.Execute(ctx, query, bv)
   501  	if err != nil {
   502  		return nil, err
   503  	}
   504  	return newRows(qr, c.convert), nil
   505  }
   506  
   507  type stmt struct {
   508  	c     *conn
   509  	query string
   510  }
   511  
   512  func (s *stmt) Close() error {
   513  	return nil
   514  }
   515  
   516  func (s *stmt) NumInput() int {
   517  	// -1 = Golang sql won't sanity check argument counts before Exec or Query.
   518  	return -1
   519  }
   520  
   521  func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
   522  	return s.c.Exec(s.query, args)
   523  }
   524  
   525  func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   526  	return s.c.ExecContext(ctx, s.query, args)
   527  }
   528  
   529  func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
   530  	return s.c.Query(s.query, args)
   531  }
   532  
   533  func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   534  	return s.c.QueryContext(ctx, s.query, args)
   535  }
   536  
   537  type result struct {
   538  	insertid, rowsaffected int64
   539  }
   540  
   541  func (r result) LastInsertId() (int64, error) {
   542  	return r.insertid, nil
   543  }
   544  
   545  func (r result) RowsAffected() (int64, error) {
   546  	return r.rowsaffected, nil
   547  }