vitess.io/vitess@v0.16.2/go/vt/vtgate/plugin_mysql_server.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vtgate
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"net"
    23  	"os"
    24  	"os/signal"
    25  	"regexp"
    26  	"strings"
    27  	"sync"
    28  	"sync/atomic"
    29  	"syscall"
    30  	"time"
    31  
    32  	"github.com/spf13/pflag"
    33  
    34  	"vitess.io/vitess/go/vt/sqlparser"
    35  	"vitess.io/vitess/go/vt/vterrors"
    36  
    37  	"vitess.io/vitess/go/mysql"
    38  	"vitess.io/vitess/go/sqltypes"
    39  	"vitess.io/vitess/go/trace"
    40  	"vitess.io/vitess/go/vt/callerid"
    41  	"vitess.io/vitess/go/vt/callinfo"
    42  	"vitess.io/vitess/go/vt/log"
    43  	"vitess.io/vitess/go/vt/servenv"
    44  	"vitess.io/vitess/go/vt/vttls"
    45  
    46  	"github.com/google/uuid"
    47  
    48  	querypb "vitess.io/vitess/go/vt/proto/query"
    49  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    50  )
    51  
    52  var (
    53  	mysqlServerPort                   = -1
    54  	mysqlServerBindAddress            string
    55  	mysqlServerSocketPath             string
    56  	mysqlTCPVersion                   = "tcp"
    57  	mysqlAuthServerImpl               = "static"
    58  	mysqlAllowClearTextWithoutTLS     bool
    59  	mysqlProxyProtocol                bool
    60  	mysqlServerRequireSecureTransport bool
    61  	mysqlSslCert                      string
    62  	mysqlSslKey                       string
    63  	mysqlSslCa                        string
    64  	mysqlSslCrl                       string
    65  	mysqlSslServerCA                  string
    66  	mysqlTLSMinVersion                string
    67  
    68  	mysqlConnReadTimeout          time.Duration
    69  	mysqlConnWriteTimeout         time.Duration
    70  	mysqlQueryTimeout             time.Duration
    71  	mysqlSlowConnectWarnThreshold time.Duration
    72  	mysqlConnBufferPooling        bool
    73  
    74  	mysqlDefaultWorkloadName = "OLTP"
    75  	mysqlDefaultWorkload     int32
    76  
    77  	busyConnections int32
    78  )
    79  
    80  func registerPluginFlags(fs *pflag.FlagSet) {
    81  	fs.IntVar(&mysqlServerPort, "mysql_server_port", mysqlServerPort, "If set, also listen for MySQL binary protocol connections on this port.")
    82  	fs.StringVar(&mysqlServerBindAddress, "mysql_server_bind_address", mysqlServerBindAddress, "Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.")
    83  	fs.StringVar(&mysqlServerSocketPath, "mysql_server_socket_path", mysqlServerSocketPath, "This option specifies the Unix socket file to use when listening for local connections. By default it will be empty and it won't listen to a unix socket")
    84  	fs.StringVar(&mysqlTCPVersion, "mysql_tcp_version", mysqlTCPVersion, "Select tcp, tcp4, or tcp6 to control the socket type.")
    85  	fs.StringVar(&mysqlAuthServerImpl, "mysql_auth_server_impl", mysqlAuthServerImpl, "Which auth server implementation to use. Options: none, ldap, clientcert, static, vault.")
    86  	fs.BoolVar(&mysqlAllowClearTextWithoutTLS, "mysql_allow_clear_text_without_tls", mysqlAllowClearTextWithoutTLS, "If set, the server will allow the use of a clear text password over non-SSL connections.")
    87  	fs.BoolVar(&mysqlProxyProtocol, "proxy_protocol", mysqlProxyProtocol, "Enable HAProxy PROXY protocol on MySQL listener socket")
    88  	fs.BoolVar(&mysqlServerRequireSecureTransport, "mysql_server_require_secure_transport", mysqlServerRequireSecureTransport, "Reject insecure connections but only if mysql_server_ssl_cert and mysql_server_ssl_key are provided")
    89  	fs.StringVar(&mysqlSslCert, "mysql_server_ssl_cert", mysqlSslCert, "Path to the ssl cert for mysql server plugin SSL")
    90  	fs.StringVar(&mysqlSslKey, "mysql_server_ssl_key", mysqlSslKey, "Path to ssl key for mysql server plugin SSL")
    91  	fs.StringVar(&mysqlSslCa, "mysql_server_ssl_ca", mysqlSslCa, "Path to ssl CA for mysql server plugin SSL. If specified, server will require and validate client certs.")
    92  	fs.StringVar(&mysqlSslCrl, "mysql_server_ssl_crl", mysqlSslCrl, "Path to ssl CRL for mysql server plugin SSL")
    93  	fs.StringVar(&mysqlTLSMinVersion, "mysql_server_tls_min_version", mysqlTLSMinVersion, "Configures the minimal TLS version negotiated when SSL is enabled. Defaults to TLSv1.2. Options: TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3.")
    94  	fs.StringVar(&mysqlSslServerCA, "mysql_server_ssl_server_ca", mysqlSslServerCA, "path to server CA in PEM format, which will be combine with server cert, return full certificate chain to clients")
    95  	fs.DurationVar(&mysqlSlowConnectWarnThreshold, "mysql_slow_connect_warn_threshold", mysqlSlowConnectWarnThreshold, "Warn if it takes more than the given threshold for a mysql connection to establish")
    96  	fs.DurationVar(&mysqlConnReadTimeout, "mysql_server_read_timeout", mysqlConnReadTimeout, "connection read timeout")
    97  	fs.DurationVar(&mysqlConnWriteTimeout, "mysql_server_write_timeout", mysqlConnWriteTimeout, "connection write timeout")
    98  	fs.DurationVar(&mysqlQueryTimeout, "mysql_server_query_timeout", mysqlQueryTimeout, "mysql query timeout")
    99  	fs.BoolVar(&mysqlConnBufferPooling, "mysql-server-pool-conn-read-buffers", mysqlConnBufferPooling, "If set, the server will pool incoming connection read buffers")
   100  	fs.StringVar(&mysqlDefaultWorkloadName, "mysql_default_workload", mysqlDefaultWorkloadName, "Default session workload (OLTP, OLAP, DBA)")
   101  }
   102  
   103  // vtgateHandler implements the Listener interface.
   104  // It stores the Session in the ClientData of a Connection.
   105  type vtgateHandler struct {
   106  	mysql.UnimplementedHandler
   107  	mu sync.Mutex
   108  
   109  	vtg         *VTGate
   110  	connections map[*mysql.Conn]bool
   111  }
   112  
   113  func newVtgateHandler(vtg *VTGate) *vtgateHandler {
   114  	return &vtgateHandler{
   115  		vtg:         vtg,
   116  		connections: make(map[*mysql.Conn]bool),
   117  	}
   118  }
   119  
   120  func (vh *vtgateHandler) NewConnection(c *mysql.Conn) {
   121  	vh.mu.Lock()
   122  	defer vh.mu.Unlock()
   123  	vh.connections[c] = true
   124  }
   125  
   126  func (vh *vtgateHandler) numConnections() int {
   127  	vh.mu.Lock()
   128  	defer vh.mu.Unlock()
   129  	return len(vh.connections)
   130  }
   131  
   132  func (vh *vtgateHandler) ComResetConnection(c *mysql.Conn) {
   133  	ctx := context.Background()
   134  	session := vh.session(c)
   135  	if session.InTransaction {
   136  		defer atomic.AddInt32(&busyConnections, -1)
   137  	}
   138  	err := vh.vtg.CloseSession(ctx, session)
   139  	if err != nil {
   140  		log.Errorf("Error happened in transaction rollback: %v", err)
   141  	}
   142  }
   143  
   144  func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) {
   145  	// Rollback if there is an ongoing transaction. Ignore error.
   146  	defer func() {
   147  		vh.mu.Lock()
   148  		defer vh.mu.Unlock()
   149  		delete(vh.connections, c)
   150  	}()
   151  
   152  	var ctx context.Context
   153  	var cancel context.CancelFunc
   154  	if mysqlQueryTimeout != 0 {
   155  		ctx, cancel = context.WithTimeout(context.Background(), mysqlQueryTimeout)
   156  		defer cancel()
   157  	} else {
   158  		ctx = context.Background()
   159  	}
   160  	session := vh.session(c)
   161  	if session.InTransaction {
   162  		defer atomic.AddInt32(&busyConnections, -1)
   163  	}
   164  	_ = vh.vtg.CloseSession(ctx, session)
   165  }
   166  
   167  // Regexp to extract parent span id over the sql query
   168  var r = regexp.MustCompile(`/\*VT_SPAN_CONTEXT=(.*)\*/`)
   169  
   170  // this function is here to make this logic easy to test by decoupling the logic from the `trace.NewSpan` and `trace.NewFromString` functions
   171  func startSpanTestable(ctx context.Context, query, label string,
   172  	newSpan func(context.Context, string) (trace.Span, context.Context),
   173  	newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error)) (trace.Span, context.Context, error) {
   174  	_, comments := sqlparser.SplitMarginComments(query)
   175  	match := r.FindStringSubmatch(comments.Leading)
   176  	span, ctx := getSpan(ctx, match, newSpan, label, newSpanFromString)
   177  
   178  	trace.AnnotateSQL(span, sqlparser.Preview(query))
   179  
   180  	return span, ctx, nil
   181  }
   182  
   183  func getSpan(ctx context.Context, match []string, newSpan func(context.Context, string) (trace.Span, context.Context), label string, newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error)) (trace.Span, context.Context) {
   184  	var span trace.Span
   185  	if len(match) != 0 {
   186  		var err error
   187  		span, ctx, err = newSpanFromString(ctx, match[1], label)
   188  		if err == nil {
   189  			return span, ctx
   190  		}
   191  		log.Warningf("Unable to parse VT_SPAN_CONTEXT: %s", err.Error())
   192  	}
   193  	span, ctx = newSpan(ctx, label)
   194  	return span, ctx
   195  }
   196  
   197  func startSpan(ctx context.Context, query, label string) (trace.Span, context.Context, error) {
   198  	return startSpanTestable(ctx, query, label, trace.NewSpan, trace.NewFromString)
   199  }
   200  
   201  func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
   202  	ctx := context.Background()
   203  	var cancel context.CancelFunc
   204  	if mysqlQueryTimeout != 0 {
   205  		ctx, cancel = context.WithTimeout(ctx, mysqlQueryTimeout)
   206  		defer cancel()
   207  	}
   208  
   209  	span, ctx, err := startSpan(ctx, query, "vtgateHandler.ComQuery")
   210  	if err != nil {
   211  		return vterrors.Wrap(err, "failed to extract span")
   212  	}
   213  	defer span.Finish()
   214  
   215  	ctx = callinfo.MysqlCallInfo(ctx, c)
   216  
   217  	// Fill in the ImmediateCallerID with the UserData returned by
   218  	// the AuthServer plugin for that user. If nothing was
   219  	// returned, use the User. This lets the plugin map a MySQL
   220  	// user used for authentication to a Vitess User used for
   221  	// Table ACLs and Vitess authentication in general.
   222  	im := c.UserData.Get()
   223  	ef := callerid.NewEffectiveCallerID(
   224  		c.User,                  /* principal: who */
   225  		c.RemoteAddr().String(), /* component: running client process */
   226  		"VTGate MySQL Connector" /* subcomponent: part of the client */)
   227  	ctx = callerid.NewContext(ctx, ef, im)
   228  
   229  	session := vh.session(c)
   230  	if !session.InTransaction {
   231  		atomic.AddInt32(&busyConnections, 1)
   232  	}
   233  	defer func() {
   234  		if !session.InTransaction {
   235  			atomic.AddInt32(&busyConnections, -1)
   236  		}
   237  	}()
   238  
   239  	if session.Options.Workload == querypb.ExecuteOptions_OLAP {
   240  		err := vh.vtg.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback)
   241  		return mysql.NewSQLErrorFromError(err)
   242  	}
   243  	session, result, err := vh.vtg.Execute(ctx, session, query, make(map[string]*querypb.BindVariable))
   244  
   245  	if err := mysql.NewSQLErrorFromError(err); err != nil {
   246  		return err
   247  	}
   248  	fillInTxStatusFlags(c, session)
   249  	return callback(result)
   250  }
   251  
   252  func fillInTxStatusFlags(c *mysql.Conn, session *vtgatepb.Session) {
   253  	if session.InTransaction {
   254  		c.StatusFlags |= mysql.ServerStatusInTrans
   255  	} else {
   256  		c.StatusFlags &= mysql.NoServerStatusInTrans
   257  	}
   258  	if session.Autocommit {
   259  		c.StatusFlags |= mysql.ServerStatusAutocommit
   260  	} else {
   261  		c.StatusFlags &= mysql.NoServerStatusAutocommit
   262  	}
   263  }
   264  
   265  // ComPrepare is the handler for command prepare.
   266  func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
   267  	var ctx context.Context
   268  	var cancel context.CancelFunc
   269  	if mysqlQueryTimeout != 0 {
   270  		ctx, cancel = context.WithTimeout(context.Background(), mysqlQueryTimeout)
   271  		defer cancel()
   272  	} else {
   273  		ctx = context.Background()
   274  	}
   275  
   276  	ctx = callinfo.MysqlCallInfo(ctx, c)
   277  
   278  	// Fill in the ImmediateCallerID with the UserData returned by
   279  	// the AuthServer plugin for that user. If nothing was
   280  	// returned, use the User. This lets the plugin map a MySQL
   281  	// user used for authentication to a Vitess User used for
   282  	// Table ACLs and Vitess authentication in general.
   283  	im := c.UserData.Get()
   284  	ef := callerid.NewEffectiveCallerID(
   285  		c.User,                  /* principal: who */
   286  		c.RemoteAddr().String(), /* component: running client process */
   287  		"VTGate MySQL Connector" /* subcomponent: part of the client */)
   288  	ctx = callerid.NewContext(ctx, ef, im)
   289  
   290  	session := vh.session(c)
   291  	if !session.InTransaction {
   292  		atomic.AddInt32(&busyConnections, 1)
   293  	}
   294  	defer func() {
   295  		if !session.InTransaction {
   296  			atomic.AddInt32(&busyConnections, -1)
   297  		}
   298  	}()
   299  
   300  	session, fld, err := vh.vtg.Prepare(ctx, session, query, bindVars)
   301  	err = mysql.NewSQLErrorFromError(err)
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  	return fld, nil
   306  }
   307  
   308  func (vh *vtgateHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
   309  	var ctx context.Context
   310  	var cancel context.CancelFunc
   311  	if mysqlQueryTimeout != 0 {
   312  		ctx, cancel = context.WithTimeout(context.Background(), mysqlQueryTimeout)
   313  		defer cancel()
   314  	} else {
   315  		ctx = context.Background()
   316  	}
   317  
   318  	ctx = callinfo.MysqlCallInfo(ctx, c)
   319  
   320  	// Fill in the ImmediateCallerID with the UserData returned by
   321  	// the AuthServer plugin for that user. If nothing was
   322  	// returned, use the User. This lets the plugin map a MySQL
   323  	// user used for authentication to a Vitess User used for
   324  	// Table ACLs and Vitess authentication in general.
   325  	im := c.UserData.Get()
   326  	ef := callerid.NewEffectiveCallerID(
   327  		c.User,                  /* principal: who */
   328  		c.RemoteAddr().String(), /* component: running client process */
   329  		"VTGate MySQL Connector" /* subcomponent: part of the client */)
   330  	ctx = callerid.NewContext(ctx, ef, im)
   331  
   332  	session := vh.session(c)
   333  	if !session.InTransaction {
   334  		atomic.AddInt32(&busyConnections, 1)
   335  	}
   336  	defer func() {
   337  		if !session.InTransaction {
   338  			atomic.AddInt32(&busyConnections, -1)
   339  		}
   340  	}()
   341  
   342  	if session.Options.Workload == querypb.ExecuteOptions_OLAP {
   343  		err := vh.vtg.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback)
   344  		return mysql.NewSQLErrorFromError(err)
   345  	}
   346  	_, qr, err := vh.vtg.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars)
   347  	if err != nil {
   348  		err = mysql.NewSQLErrorFromError(err)
   349  		return err
   350  	}
   351  	fillInTxStatusFlags(c, session)
   352  
   353  	return callback(qr)
   354  }
   355  
   356  func (vh *vtgateHandler) WarningCount(c *mysql.Conn) uint16 {
   357  	return uint16(len(vh.session(c).GetWarnings()))
   358  }
   359  
   360  // ComRegisterReplica is part of the mysql.Handler interface.
   361  func (vh *vtgateHandler) ComRegisterReplica(c *mysql.Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error {
   362  	return vterrors.VT12001("ComRegisterReplica for the VTGate handler")
   363  }
   364  
   365  // ComBinlogDump is part of the mysql.Handler interface.
   366  func (vh *vtgateHandler) ComBinlogDump(c *mysql.Conn, logFile string, binlogPos uint32) error {
   367  	return vterrors.VT12001("ComBinlogDump for the VTGate handler")
   368  }
   369  
   370  // ComBinlogDumpGTID is part of the mysql.Handler interface.
   371  func (vh *vtgateHandler) ComBinlogDumpGTID(c *mysql.Conn, logFile string, logPos uint64, gtidSet mysql.GTIDSet) error {
   372  	return vterrors.VT12001("ComBinlogDumpGTID for the VTGate handler")
   373  }
   374  
   375  func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session {
   376  	session, _ := c.ClientData.(*vtgatepb.Session)
   377  	if session == nil {
   378  		u, _ := uuid.NewUUID()
   379  		session = &vtgatepb.Session{
   380  			Options: &querypb.ExecuteOptions{
   381  				IncludedFields: querypb.ExecuteOptions_ALL,
   382  				Workload:       querypb.ExecuteOptions_Workload(mysqlDefaultWorkload),
   383  
   384  				// The collation field of ExecuteOption is set right before an execution.
   385  			},
   386  			Autocommit:           true,
   387  			DDLStrategy:          defaultDDLStrategy,
   388  			SessionUUID:          u.String(),
   389  			EnableSystemSettings: sysVarSetEnabled,
   390  		}
   391  		if c.Capabilities&mysql.CapabilityClientFoundRows != 0 {
   392  			session.Options.ClientFoundRows = true
   393  		}
   394  		c.ClientData = session
   395  	}
   396  	return session
   397  }
   398  
   399  var mysqlListener *mysql.Listener
   400  var mysqlUnixListener *mysql.Listener
   401  var sigChan chan os.Signal
   402  var vtgateHandle *vtgateHandler
   403  
   404  // initTLSConfig inits tls config for the given mysql listener
   405  func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool, mysqlMinTLSVersion uint16) error {
   406  	serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlMinTLSVersion)
   407  	if err != nil {
   408  		log.Exitf("grpcutils.TLSServerConfig failed: %v", err)
   409  		return err
   410  	}
   411  	mysqlListener.TLSConfig.Store(serverConfig)
   412  	mysqlListener.RequireSecureTransport = mysqlServerRequireSecureTransport
   413  	sigChan = make(chan os.Signal, 1)
   414  	signal.Notify(sigChan, syscall.SIGHUP)
   415  	go func() {
   416  		for range sigChan {
   417  			serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlMinTLSVersion)
   418  			if err != nil {
   419  				log.Errorf("grpcutils.TLSServerConfig failed: %v", err)
   420  			} else {
   421  				log.Info("grpcutils.TLSServerConfig updated")
   422  				mysqlListener.TLSConfig.Store(serverConfig)
   423  			}
   424  		}
   425  	}()
   426  	return nil
   427  }
   428  
   429  // initiMySQLProtocol starts the mysql protocol.
   430  // It should be called only once in a process.
   431  func initMySQLProtocol() {
   432  	// Flag is not set, just return.
   433  	if mysqlServerPort < 0 && mysqlServerSocketPath == "" {
   434  		return
   435  	}
   436  
   437  	// If no VTGate was created, just return.
   438  	if rpcVTGate == nil {
   439  		return
   440  	}
   441  
   442  	// Initialize registered AuthServer implementations (or other plugins)
   443  	for _, initFn := range pluginInitializers {
   444  		initFn()
   445  	}
   446  	authServer := mysql.GetAuthServer(mysqlAuthServerImpl)
   447  
   448  	// Check mysql_default_workload
   449  	var ok bool
   450  	if mysqlDefaultWorkload, ok = querypb.ExecuteOptions_Workload_value[strings.ToUpper(mysqlDefaultWorkloadName)]; !ok {
   451  		log.Exitf("-mysql_default_workload must be one of [OLTP, OLAP, DBA, UNSPECIFIED]")
   452  	}
   453  
   454  	switch mysqlTCPVersion {
   455  	case "tcp", "tcp4", "tcp6":
   456  		// Valid flag value.
   457  	default:
   458  		log.Exitf("-mysql_tcp_version must be one of [tcp, tcp4, tcp6]")
   459  	}
   460  
   461  	// Create a Listener.
   462  	var err error
   463  	vtgateHandle = newVtgateHandler(rpcVTGate)
   464  	if mysqlServerPort >= 0 {
   465  		mysqlListener, err = mysql.NewListener(
   466  			mysqlTCPVersion,
   467  			net.JoinHostPort(mysqlServerBindAddress, fmt.Sprintf("%v", mysqlServerPort)),
   468  			authServer,
   469  			vtgateHandle,
   470  			mysqlConnReadTimeout,
   471  			mysqlConnWriteTimeout,
   472  			mysqlProxyProtocol,
   473  			mysqlConnBufferPooling,
   474  		)
   475  		if err != nil {
   476  			log.Exitf("mysql.NewListener failed: %v", err)
   477  		}
   478  		mysqlListener.ServerVersion = servenv.MySQLServerVersion()
   479  		if mysqlSslCert != "" && mysqlSslKey != "" {
   480  			tlsVersion, err := vttls.TLSVersionToNumber(mysqlTLSMinVersion)
   481  			if err != nil {
   482  				log.Exitf("mysql.NewListener failed: %v", err)
   483  			}
   484  
   485  			_ = initTLSConfig(mysqlListener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlServerRequireSecureTransport, tlsVersion)
   486  		}
   487  		mysqlListener.AllowClearTextWithoutTLS.Set(mysqlAllowClearTextWithoutTLS)
   488  		// Check for the connection threshold
   489  		if mysqlSlowConnectWarnThreshold != 0 {
   490  			log.Infof("setting mysql slow connection threshold to %v", mysqlSlowConnectWarnThreshold)
   491  			mysqlListener.SlowConnectWarnThreshold.Set(mysqlSlowConnectWarnThreshold)
   492  		}
   493  		// Start listening for tcp
   494  		go mysqlListener.Accept()
   495  	}
   496  
   497  	if mysqlServerSocketPath != "" {
   498  		// Let's create this unix socket with permissions to all users. In this way,
   499  		// clients can connect to vtgate mysql server without being vtgate user
   500  		oldMask := syscall.Umask(000)
   501  		mysqlUnixListener, err = newMysqlUnixSocket(mysqlServerSocketPath, authServer, vtgateHandle)
   502  		_ = syscall.Umask(oldMask)
   503  		if err != nil {
   504  			log.Exitf("mysql.NewListener failed: %v", err)
   505  			return
   506  		}
   507  		// Listen for unix socket
   508  		go mysqlUnixListener.Accept()
   509  	}
   510  }
   511  
   512  // newMysqlUnixSocket creates a new unix socket mysql listener. If a socket file already exists, attempts
   513  // to clean it up.
   514  func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mysql.Handler) (*mysql.Listener, error) {
   515  	listener, err := mysql.NewListener(
   516  		"unix",
   517  		address,
   518  		authServer,
   519  		handler,
   520  		mysqlConnReadTimeout,
   521  		mysqlConnWriteTimeout,
   522  		false,
   523  		mysqlConnBufferPooling,
   524  	)
   525  
   526  	switch err := err.(type) {
   527  	case nil:
   528  		return listener, nil
   529  	case *net.OpError:
   530  		log.Warningf("Found existent socket when trying to create new unix mysql listener: %s, attempting to clean up", address)
   531  		// err.Op should never be different from listen, just being extra careful
   532  		// in case in the future other errors are returned here
   533  		if err.Op != "listen" {
   534  			return nil, err
   535  		}
   536  		_, dialErr := net.Dial("unix", address)
   537  		if dialErr == nil {
   538  			log.Errorf("Existent socket '%s' is still accepting connections, aborting", address)
   539  			return nil, err
   540  		}
   541  		removeFileErr := os.Remove(address)
   542  		if removeFileErr != nil {
   543  			log.Errorf("Couldn't remove existent socket file: %s", address)
   544  			return nil, err
   545  		}
   546  		listener, listenerErr := mysql.NewListener(
   547  			"unix",
   548  			address,
   549  			authServer,
   550  			handler,
   551  			mysqlConnReadTimeout,
   552  			mysqlConnWriteTimeout,
   553  			false,
   554  			mysqlConnBufferPooling,
   555  		)
   556  		return listener, listenerErr
   557  	default:
   558  		return nil, err
   559  	}
   560  }
   561  
   562  func shutdownMysqlProtocolAndDrain() {
   563  	if mysqlListener != nil {
   564  		mysqlListener.Close()
   565  		mysqlListener = nil
   566  	}
   567  	if mysqlUnixListener != nil {
   568  		mysqlUnixListener.Close()
   569  		mysqlUnixListener = nil
   570  	}
   571  	if sigChan != nil {
   572  		signal.Stop(sigChan)
   573  	}
   574  
   575  	if atomic.LoadInt32(&busyConnections) > 0 {
   576  		log.Infof("Waiting for all client connections to be idle (%d active)...", atomic.LoadInt32(&busyConnections))
   577  		start := time.Now()
   578  		reported := start
   579  		for atomic.LoadInt32(&busyConnections) != 0 {
   580  			if time.Since(reported) > 2*time.Second {
   581  				log.Infof("Still waiting for client connections to be idle (%d active)...", atomic.LoadInt32(&busyConnections))
   582  				reported = time.Now()
   583  			}
   584  
   585  			time.Sleep(1 * time.Millisecond)
   586  		}
   587  	}
   588  }
   589  
   590  func rollbackAtShutdown() {
   591  	defer log.Flush()
   592  	if vtgateHandle == nil {
   593  		// we still haven't been able to initialise the vtgateHandler, so we don't need to rollback anything
   594  		return
   595  	}
   596  
   597  	// Close all open connections. If they're waiting for reads, this will cause
   598  	// them to error out, which will automatically rollback open transactions.
   599  	func() {
   600  		if vtgateHandle != nil {
   601  			vtgateHandle.mu.Lock()
   602  			defer vtgateHandle.mu.Unlock()
   603  			for c := range vtgateHandle.connections {
   604  				if c != nil {
   605  					log.Infof("Rolling back transactions associated with connection ID: %v", c.ConnectionID)
   606  					c.Close()
   607  				}
   608  			}
   609  		}
   610  	}()
   611  
   612  	// If vtgate is instead busy executing a query, the number of open conns
   613  	// will be non-zero. Give another second for those queries to finish.
   614  	for i := 0; i < 100; i++ {
   615  		if vtgateHandle.numConnections() == 0 {
   616  			log.Infof("All connections have been rolled back.")
   617  			return
   618  		}
   619  		time.Sleep(10 * time.Millisecond)
   620  	}
   621  	log.Errorf("All connections did not go idle. Shutting down anyway.")
   622  }
   623  
   624  func mysqlSocketPath() string {
   625  	if mysqlServerSocketPath == "" {
   626  		return ""
   627  	}
   628  	return mysqlServerSocketPath
   629  }
   630  
   631  func init() {
   632  	servenv.OnParseFor("vtgate", registerPluginFlags)
   633  	servenv.OnParseFor("vtcombo", registerPluginFlags)
   634  
   635  	servenv.OnRun(initMySQLProtocol)
   636  	servenv.OnTermSync(shutdownMysqlProtocolAndDrain)
   637  	servenv.OnClose(rollbackAtShutdown)
   638  }
   639  
   640  var pluginInitializers []func()
   641  
   642  // RegisterPluginInitializer lets plugins register themselves to be init'ed at servenv.OnRun-time
   643  func RegisterPluginInitializer(initializer func()) {
   644  	pluginInitializers = append(pluginInitializers, initializer)
   645  }