vitess.io/vitess@v0.16.2/go/vt/vtgate/plugin_mysql_server.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package vtgate 18 19 import ( 20 "context" 21 "fmt" 22 "net" 23 "os" 24 "os/signal" 25 "regexp" 26 "strings" 27 "sync" 28 "sync/atomic" 29 "syscall" 30 "time" 31 32 "github.com/spf13/pflag" 33 34 "vitess.io/vitess/go/vt/sqlparser" 35 "vitess.io/vitess/go/vt/vterrors" 36 37 "vitess.io/vitess/go/mysql" 38 "vitess.io/vitess/go/sqltypes" 39 "vitess.io/vitess/go/trace" 40 "vitess.io/vitess/go/vt/callerid" 41 "vitess.io/vitess/go/vt/callinfo" 42 "vitess.io/vitess/go/vt/log" 43 "vitess.io/vitess/go/vt/servenv" 44 "vitess.io/vitess/go/vt/vttls" 45 46 "github.com/google/uuid" 47 48 querypb "vitess.io/vitess/go/vt/proto/query" 49 vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" 50 ) 51 52 var ( 53 mysqlServerPort = -1 54 mysqlServerBindAddress string 55 mysqlServerSocketPath string 56 mysqlTCPVersion = "tcp" 57 mysqlAuthServerImpl = "static" 58 mysqlAllowClearTextWithoutTLS bool 59 mysqlProxyProtocol bool 60 mysqlServerRequireSecureTransport bool 61 mysqlSslCert string 62 mysqlSslKey string 63 mysqlSslCa string 64 mysqlSslCrl string 65 mysqlSslServerCA string 66 mysqlTLSMinVersion string 67 68 mysqlConnReadTimeout time.Duration 69 mysqlConnWriteTimeout time.Duration 70 mysqlQueryTimeout time.Duration 71 mysqlSlowConnectWarnThreshold time.Duration 72 mysqlConnBufferPooling bool 73 74 mysqlDefaultWorkloadName = "OLTP" 75 mysqlDefaultWorkload int32 76 77 busyConnections int32 78 ) 79 80 func registerPluginFlags(fs *pflag.FlagSet) { 81 fs.IntVar(&mysqlServerPort, "mysql_server_port", mysqlServerPort, "If set, also listen for MySQL binary protocol connections on this port.") 82 fs.StringVar(&mysqlServerBindAddress, "mysql_server_bind_address", mysqlServerBindAddress, "Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.") 83 fs.StringVar(&mysqlServerSocketPath, "mysql_server_socket_path", mysqlServerSocketPath, "This option specifies the Unix socket file to use when listening for local connections. By default it will be empty and it won't listen to a unix socket") 84 fs.StringVar(&mysqlTCPVersion, "mysql_tcp_version", mysqlTCPVersion, "Select tcp, tcp4, or tcp6 to control the socket type.") 85 fs.StringVar(&mysqlAuthServerImpl, "mysql_auth_server_impl", mysqlAuthServerImpl, "Which auth server implementation to use. Options: none, ldap, clientcert, static, vault.") 86 fs.BoolVar(&mysqlAllowClearTextWithoutTLS, "mysql_allow_clear_text_without_tls", mysqlAllowClearTextWithoutTLS, "If set, the server will allow the use of a clear text password over non-SSL connections.") 87 fs.BoolVar(&mysqlProxyProtocol, "proxy_protocol", mysqlProxyProtocol, "Enable HAProxy PROXY protocol on MySQL listener socket") 88 fs.BoolVar(&mysqlServerRequireSecureTransport, "mysql_server_require_secure_transport", mysqlServerRequireSecureTransport, "Reject insecure connections but only if mysql_server_ssl_cert and mysql_server_ssl_key are provided") 89 fs.StringVar(&mysqlSslCert, "mysql_server_ssl_cert", mysqlSslCert, "Path to the ssl cert for mysql server plugin SSL") 90 fs.StringVar(&mysqlSslKey, "mysql_server_ssl_key", mysqlSslKey, "Path to ssl key for mysql server plugin SSL") 91 fs.StringVar(&mysqlSslCa, "mysql_server_ssl_ca", mysqlSslCa, "Path to ssl CA for mysql server plugin SSL. If specified, server will require and validate client certs.") 92 fs.StringVar(&mysqlSslCrl, "mysql_server_ssl_crl", mysqlSslCrl, "Path to ssl CRL for mysql server plugin SSL") 93 fs.StringVar(&mysqlTLSMinVersion, "mysql_server_tls_min_version", mysqlTLSMinVersion, "Configures the minimal TLS version negotiated when SSL is enabled. Defaults to TLSv1.2. Options: TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3.") 94 fs.StringVar(&mysqlSslServerCA, "mysql_server_ssl_server_ca", mysqlSslServerCA, "path to server CA in PEM format, which will be combine with server cert, return full certificate chain to clients") 95 fs.DurationVar(&mysqlSlowConnectWarnThreshold, "mysql_slow_connect_warn_threshold", mysqlSlowConnectWarnThreshold, "Warn if it takes more than the given threshold for a mysql connection to establish") 96 fs.DurationVar(&mysqlConnReadTimeout, "mysql_server_read_timeout", mysqlConnReadTimeout, "connection read timeout") 97 fs.DurationVar(&mysqlConnWriteTimeout, "mysql_server_write_timeout", mysqlConnWriteTimeout, "connection write timeout") 98 fs.DurationVar(&mysqlQueryTimeout, "mysql_server_query_timeout", mysqlQueryTimeout, "mysql query timeout") 99 fs.BoolVar(&mysqlConnBufferPooling, "mysql-server-pool-conn-read-buffers", mysqlConnBufferPooling, "If set, the server will pool incoming connection read buffers") 100 fs.StringVar(&mysqlDefaultWorkloadName, "mysql_default_workload", mysqlDefaultWorkloadName, "Default session workload (OLTP, OLAP, DBA)") 101 } 102 103 // vtgateHandler implements the Listener interface. 104 // It stores the Session in the ClientData of a Connection. 105 type vtgateHandler struct { 106 mysql.UnimplementedHandler 107 mu sync.Mutex 108 109 vtg *VTGate 110 connections map[*mysql.Conn]bool 111 } 112 113 func newVtgateHandler(vtg *VTGate) *vtgateHandler { 114 return &vtgateHandler{ 115 vtg: vtg, 116 connections: make(map[*mysql.Conn]bool), 117 } 118 } 119 120 func (vh *vtgateHandler) NewConnection(c *mysql.Conn) { 121 vh.mu.Lock() 122 defer vh.mu.Unlock() 123 vh.connections[c] = true 124 } 125 126 func (vh *vtgateHandler) numConnections() int { 127 vh.mu.Lock() 128 defer vh.mu.Unlock() 129 return len(vh.connections) 130 } 131 132 func (vh *vtgateHandler) ComResetConnection(c *mysql.Conn) { 133 ctx := context.Background() 134 session := vh.session(c) 135 if session.InTransaction { 136 defer atomic.AddInt32(&busyConnections, -1) 137 } 138 err := vh.vtg.CloseSession(ctx, session) 139 if err != nil { 140 log.Errorf("Error happened in transaction rollback: %v", err) 141 } 142 } 143 144 func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) { 145 // Rollback if there is an ongoing transaction. Ignore error. 146 defer func() { 147 vh.mu.Lock() 148 defer vh.mu.Unlock() 149 delete(vh.connections, c) 150 }() 151 152 var ctx context.Context 153 var cancel context.CancelFunc 154 if mysqlQueryTimeout != 0 { 155 ctx, cancel = context.WithTimeout(context.Background(), mysqlQueryTimeout) 156 defer cancel() 157 } else { 158 ctx = context.Background() 159 } 160 session := vh.session(c) 161 if session.InTransaction { 162 defer atomic.AddInt32(&busyConnections, -1) 163 } 164 _ = vh.vtg.CloseSession(ctx, session) 165 } 166 167 // Regexp to extract parent span id over the sql query 168 var r = regexp.MustCompile(`/\*VT_SPAN_CONTEXT=(.*)\*/`) 169 170 // this function is here to make this logic easy to test by decoupling the logic from the `trace.NewSpan` and `trace.NewFromString` functions 171 func startSpanTestable(ctx context.Context, query, label string, 172 newSpan func(context.Context, string) (trace.Span, context.Context), 173 newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error)) (trace.Span, context.Context, error) { 174 _, comments := sqlparser.SplitMarginComments(query) 175 match := r.FindStringSubmatch(comments.Leading) 176 span, ctx := getSpan(ctx, match, newSpan, label, newSpanFromString) 177 178 trace.AnnotateSQL(span, sqlparser.Preview(query)) 179 180 return span, ctx, nil 181 } 182 183 func getSpan(ctx context.Context, match []string, newSpan func(context.Context, string) (trace.Span, context.Context), label string, newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error)) (trace.Span, context.Context) { 184 var span trace.Span 185 if len(match) != 0 { 186 var err error 187 span, ctx, err = newSpanFromString(ctx, match[1], label) 188 if err == nil { 189 return span, ctx 190 } 191 log.Warningf("Unable to parse VT_SPAN_CONTEXT: %s", err.Error()) 192 } 193 span, ctx = newSpan(ctx, label) 194 return span, ctx 195 } 196 197 func startSpan(ctx context.Context, query, label string) (trace.Span, context.Context, error) { 198 return startSpanTestable(ctx, query, label, trace.NewSpan, trace.NewFromString) 199 } 200 201 func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { 202 ctx := context.Background() 203 var cancel context.CancelFunc 204 if mysqlQueryTimeout != 0 { 205 ctx, cancel = context.WithTimeout(ctx, mysqlQueryTimeout) 206 defer cancel() 207 } 208 209 span, ctx, err := startSpan(ctx, query, "vtgateHandler.ComQuery") 210 if err != nil { 211 return vterrors.Wrap(err, "failed to extract span") 212 } 213 defer span.Finish() 214 215 ctx = callinfo.MysqlCallInfo(ctx, c) 216 217 // Fill in the ImmediateCallerID with the UserData returned by 218 // the AuthServer plugin for that user. If nothing was 219 // returned, use the User. This lets the plugin map a MySQL 220 // user used for authentication to a Vitess User used for 221 // Table ACLs and Vitess authentication in general. 222 im := c.UserData.Get() 223 ef := callerid.NewEffectiveCallerID( 224 c.User, /* principal: who */ 225 c.RemoteAddr().String(), /* component: running client process */ 226 "VTGate MySQL Connector" /* subcomponent: part of the client */) 227 ctx = callerid.NewContext(ctx, ef, im) 228 229 session := vh.session(c) 230 if !session.InTransaction { 231 atomic.AddInt32(&busyConnections, 1) 232 } 233 defer func() { 234 if !session.InTransaction { 235 atomic.AddInt32(&busyConnections, -1) 236 } 237 }() 238 239 if session.Options.Workload == querypb.ExecuteOptions_OLAP { 240 err := vh.vtg.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback) 241 return mysql.NewSQLErrorFromError(err) 242 } 243 session, result, err := vh.vtg.Execute(ctx, session, query, make(map[string]*querypb.BindVariable)) 244 245 if err := mysql.NewSQLErrorFromError(err); err != nil { 246 return err 247 } 248 fillInTxStatusFlags(c, session) 249 return callback(result) 250 } 251 252 func fillInTxStatusFlags(c *mysql.Conn, session *vtgatepb.Session) { 253 if session.InTransaction { 254 c.StatusFlags |= mysql.ServerStatusInTrans 255 } else { 256 c.StatusFlags &= mysql.NoServerStatusInTrans 257 } 258 if session.Autocommit { 259 c.StatusFlags |= mysql.ServerStatusAutocommit 260 } else { 261 c.StatusFlags &= mysql.NoServerStatusAutocommit 262 } 263 } 264 265 // ComPrepare is the handler for command prepare. 266 func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { 267 var ctx context.Context 268 var cancel context.CancelFunc 269 if mysqlQueryTimeout != 0 { 270 ctx, cancel = context.WithTimeout(context.Background(), mysqlQueryTimeout) 271 defer cancel() 272 } else { 273 ctx = context.Background() 274 } 275 276 ctx = callinfo.MysqlCallInfo(ctx, c) 277 278 // Fill in the ImmediateCallerID with the UserData returned by 279 // the AuthServer plugin for that user. If nothing was 280 // returned, use the User. This lets the plugin map a MySQL 281 // user used for authentication to a Vitess User used for 282 // Table ACLs and Vitess authentication in general. 283 im := c.UserData.Get() 284 ef := callerid.NewEffectiveCallerID( 285 c.User, /* principal: who */ 286 c.RemoteAddr().String(), /* component: running client process */ 287 "VTGate MySQL Connector" /* subcomponent: part of the client */) 288 ctx = callerid.NewContext(ctx, ef, im) 289 290 session := vh.session(c) 291 if !session.InTransaction { 292 atomic.AddInt32(&busyConnections, 1) 293 } 294 defer func() { 295 if !session.InTransaction { 296 atomic.AddInt32(&busyConnections, -1) 297 } 298 }() 299 300 session, fld, err := vh.vtg.Prepare(ctx, session, query, bindVars) 301 err = mysql.NewSQLErrorFromError(err) 302 if err != nil { 303 return nil, err 304 } 305 return fld, nil 306 } 307 308 func (vh *vtgateHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { 309 var ctx context.Context 310 var cancel context.CancelFunc 311 if mysqlQueryTimeout != 0 { 312 ctx, cancel = context.WithTimeout(context.Background(), mysqlQueryTimeout) 313 defer cancel() 314 } else { 315 ctx = context.Background() 316 } 317 318 ctx = callinfo.MysqlCallInfo(ctx, c) 319 320 // Fill in the ImmediateCallerID with the UserData returned by 321 // the AuthServer plugin for that user. If nothing was 322 // returned, use the User. This lets the plugin map a MySQL 323 // user used for authentication to a Vitess User used for 324 // Table ACLs and Vitess authentication in general. 325 im := c.UserData.Get() 326 ef := callerid.NewEffectiveCallerID( 327 c.User, /* principal: who */ 328 c.RemoteAddr().String(), /* component: running client process */ 329 "VTGate MySQL Connector" /* subcomponent: part of the client */) 330 ctx = callerid.NewContext(ctx, ef, im) 331 332 session := vh.session(c) 333 if !session.InTransaction { 334 atomic.AddInt32(&busyConnections, 1) 335 } 336 defer func() { 337 if !session.InTransaction { 338 atomic.AddInt32(&busyConnections, -1) 339 } 340 }() 341 342 if session.Options.Workload == querypb.ExecuteOptions_OLAP { 343 err := vh.vtg.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback) 344 return mysql.NewSQLErrorFromError(err) 345 } 346 _, qr, err := vh.vtg.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars) 347 if err != nil { 348 err = mysql.NewSQLErrorFromError(err) 349 return err 350 } 351 fillInTxStatusFlags(c, session) 352 353 return callback(qr) 354 } 355 356 func (vh *vtgateHandler) WarningCount(c *mysql.Conn) uint16 { 357 return uint16(len(vh.session(c).GetWarnings())) 358 } 359 360 // ComRegisterReplica is part of the mysql.Handler interface. 361 func (vh *vtgateHandler) ComRegisterReplica(c *mysql.Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error { 362 return vterrors.VT12001("ComRegisterReplica for the VTGate handler") 363 } 364 365 // ComBinlogDump is part of the mysql.Handler interface. 366 func (vh *vtgateHandler) ComBinlogDump(c *mysql.Conn, logFile string, binlogPos uint32) error { 367 return vterrors.VT12001("ComBinlogDump for the VTGate handler") 368 } 369 370 // ComBinlogDumpGTID is part of the mysql.Handler interface. 371 func (vh *vtgateHandler) ComBinlogDumpGTID(c *mysql.Conn, logFile string, logPos uint64, gtidSet mysql.GTIDSet) error { 372 return vterrors.VT12001("ComBinlogDumpGTID for the VTGate handler") 373 } 374 375 func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { 376 session, _ := c.ClientData.(*vtgatepb.Session) 377 if session == nil { 378 u, _ := uuid.NewUUID() 379 session = &vtgatepb.Session{ 380 Options: &querypb.ExecuteOptions{ 381 IncludedFields: querypb.ExecuteOptions_ALL, 382 Workload: querypb.ExecuteOptions_Workload(mysqlDefaultWorkload), 383 384 // The collation field of ExecuteOption is set right before an execution. 385 }, 386 Autocommit: true, 387 DDLStrategy: defaultDDLStrategy, 388 SessionUUID: u.String(), 389 EnableSystemSettings: sysVarSetEnabled, 390 } 391 if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { 392 session.Options.ClientFoundRows = true 393 } 394 c.ClientData = session 395 } 396 return session 397 } 398 399 var mysqlListener *mysql.Listener 400 var mysqlUnixListener *mysql.Listener 401 var sigChan chan os.Signal 402 var vtgateHandle *vtgateHandler 403 404 // initTLSConfig inits tls config for the given mysql listener 405 func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool, mysqlMinTLSVersion uint16) error { 406 serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlMinTLSVersion) 407 if err != nil { 408 log.Exitf("grpcutils.TLSServerConfig failed: %v", err) 409 return err 410 } 411 mysqlListener.TLSConfig.Store(serverConfig) 412 mysqlListener.RequireSecureTransport = mysqlServerRequireSecureTransport 413 sigChan = make(chan os.Signal, 1) 414 signal.Notify(sigChan, syscall.SIGHUP) 415 go func() { 416 for range sigChan { 417 serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlMinTLSVersion) 418 if err != nil { 419 log.Errorf("grpcutils.TLSServerConfig failed: %v", err) 420 } else { 421 log.Info("grpcutils.TLSServerConfig updated") 422 mysqlListener.TLSConfig.Store(serverConfig) 423 } 424 } 425 }() 426 return nil 427 } 428 429 // initiMySQLProtocol starts the mysql protocol. 430 // It should be called only once in a process. 431 func initMySQLProtocol() { 432 // Flag is not set, just return. 433 if mysqlServerPort < 0 && mysqlServerSocketPath == "" { 434 return 435 } 436 437 // If no VTGate was created, just return. 438 if rpcVTGate == nil { 439 return 440 } 441 442 // Initialize registered AuthServer implementations (or other plugins) 443 for _, initFn := range pluginInitializers { 444 initFn() 445 } 446 authServer := mysql.GetAuthServer(mysqlAuthServerImpl) 447 448 // Check mysql_default_workload 449 var ok bool 450 if mysqlDefaultWorkload, ok = querypb.ExecuteOptions_Workload_value[strings.ToUpper(mysqlDefaultWorkloadName)]; !ok { 451 log.Exitf("-mysql_default_workload must be one of [OLTP, OLAP, DBA, UNSPECIFIED]") 452 } 453 454 switch mysqlTCPVersion { 455 case "tcp", "tcp4", "tcp6": 456 // Valid flag value. 457 default: 458 log.Exitf("-mysql_tcp_version must be one of [tcp, tcp4, tcp6]") 459 } 460 461 // Create a Listener. 462 var err error 463 vtgateHandle = newVtgateHandler(rpcVTGate) 464 if mysqlServerPort >= 0 { 465 mysqlListener, err = mysql.NewListener( 466 mysqlTCPVersion, 467 net.JoinHostPort(mysqlServerBindAddress, fmt.Sprintf("%v", mysqlServerPort)), 468 authServer, 469 vtgateHandle, 470 mysqlConnReadTimeout, 471 mysqlConnWriteTimeout, 472 mysqlProxyProtocol, 473 mysqlConnBufferPooling, 474 ) 475 if err != nil { 476 log.Exitf("mysql.NewListener failed: %v", err) 477 } 478 mysqlListener.ServerVersion = servenv.MySQLServerVersion() 479 if mysqlSslCert != "" && mysqlSslKey != "" { 480 tlsVersion, err := vttls.TLSVersionToNumber(mysqlTLSMinVersion) 481 if err != nil { 482 log.Exitf("mysql.NewListener failed: %v", err) 483 } 484 485 _ = initTLSConfig(mysqlListener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlServerRequireSecureTransport, tlsVersion) 486 } 487 mysqlListener.AllowClearTextWithoutTLS.Set(mysqlAllowClearTextWithoutTLS) 488 // Check for the connection threshold 489 if mysqlSlowConnectWarnThreshold != 0 { 490 log.Infof("setting mysql slow connection threshold to %v", mysqlSlowConnectWarnThreshold) 491 mysqlListener.SlowConnectWarnThreshold.Set(mysqlSlowConnectWarnThreshold) 492 } 493 // Start listening for tcp 494 go mysqlListener.Accept() 495 } 496 497 if mysqlServerSocketPath != "" { 498 // Let's create this unix socket with permissions to all users. In this way, 499 // clients can connect to vtgate mysql server without being vtgate user 500 oldMask := syscall.Umask(000) 501 mysqlUnixListener, err = newMysqlUnixSocket(mysqlServerSocketPath, authServer, vtgateHandle) 502 _ = syscall.Umask(oldMask) 503 if err != nil { 504 log.Exitf("mysql.NewListener failed: %v", err) 505 return 506 } 507 // Listen for unix socket 508 go mysqlUnixListener.Accept() 509 } 510 } 511 512 // newMysqlUnixSocket creates a new unix socket mysql listener. If a socket file already exists, attempts 513 // to clean it up. 514 func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mysql.Handler) (*mysql.Listener, error) { 515 listener, err := mysql.NewListener( 516 "unix", 517 address, 518 authServer, 519 handler, 520 mysqlConnReadTimeout, 521 mysqlConnWriteTimeout, 522 false, 523 mysqlConnBufferPooling, 524 ) 525 526 switch err := err.(type) { 527 case nil: 528 return listener, nil 529 case *net.OpError: 530 log.Warningf("Found existent socket when trying to create new unix mysql listener: %s, attempting to clean up", address) 531 // err.Op should never be different from listen, just being extra careful 532 // in case in the future other errors are returned here 533 if err.Op != "listen" { 534 return nil, err 535 } 536 _, dialErr := net.Dial("unix", address) 537 if dialErr == nil { 538 log.Errorf("Existent socket '%s' is still accepting connections, aborting", address) 539 return nil, err 540 } 541 removeFileErr := os.Remove(address) 542 if removeFileErr != nil { 543 log.Errorf("Couldn't remove existent socket file: %s", address) 544 return nil, err 545 } 546 listener, listenerErr := mysql.NewListener( 547 "unix", 548 address, 549 authServer, 550 handler, 551 mysqlConnReadTimeout, 552 mysqlConnWriteTimeout, 553 false, 554 mysqlConnBufferPooling, 555 ) 556 return listener, listenerErr 557 default: 558 return nil, err 559 } 560 } 561 562 func shutdownMysqlProtocolAndDrain() { 563 if mysqlListener != nil { 564 mysqlListener.Close() 565 mysqlListener = nil 566 } 567 if mysqlUnixListener != nil { 568 mysqlUnixListener.Close() 569 mysqlUnixListener = nil 570 } 571 if sigChan != nil { 572 signal.Stop(sigChan) 573 } 574 575 if atomic.LoadInt32(&busyConnections) > 0 { 576 log.Infof("Waiting for all client connections to be idle (%d active)...", atomic.LoadInt32(&busyConnections)) 577 start := time.Now() 578 reported := start 579 for atomic.LoadInt32(&busyConnections) != 0 { 580 if time.Since(reported) > 2*time.Second { 581 log.Infof("Still waiting for client connections to be idle (%d active)...", atomic.LoadInt32(&busyConnections)) 582 reported = time.Now() 583 } 584 585 time.Sleep(1 * time.Millisecond) 586 } 587 } 588 } 589 590 func rollbackAtShutdown() { 591 defer log.Flush() 592 if vtgateHandle == nil { 593 // we still haven't been able to initialise the vtgateHandler, so we don't need to rollback anything 594 return 595 } 596 597 // Close all open connections. If they're waiting for reads, this will cause 598 // them to error out, which will automatically rollback open transactions. 599 func() { 600 if vtgateHandle != nil { 601 vtgateHandle.mu.Lock() 602 defer vtgateHandle.mu.Unlock() 603 for c := range vtgateHandle.connections { 604 if c != nil { 605 log.Infof("Rolling back transactions associated with connection ID: %v", c.ConnectionID) 606 c.Close() 607 } 608 } 609 } 610 }() 611 612 // If vtgate is instead busy executing a query, the number of open conns 613 // will be non-zero. Give another second for those queries to finish. 614 for i := 0; i < 100; i++ { 615 if vtgateHandle.numConnections() == 0 { 616 log.Infof("All connections have been rolled back.") 617 return 618 } 619 time.Sleep(10 * time.Millisecond) 620 } 621 log.Errorf("All connections did not go idle. Shutting down anyway.") 622 } 623 624 func mysqlSocketPath() string { 625 if mysqlServerSocketPath == "" { 626 return "" 627 } 628 return mysqlServerSocketPath 629 } 630 631 func init() { 632 servenv.OnParseFor("vtgate", registerPluginFlags) 633 servenv.OnParseFor("vtcombo", registerPluginFlags) 634 635 servenv.OnRun(initMySQLProtocol) 636 servenv.OnTermSync(shutdownMysqlProtocolAndDrain) 637 servenv.OnClose(rollbackAtShutdown) 638 } 639 640 var pluginInitializers []func() 641 642 // RegisterPluginInitializer lets plugins register themselves to be init'ed at servenv.OnRun-time 643 func RegisterPluginInitializer(initializer func()) { 644 pluginInitializers = append(pluginInitializers, initializer) 645 }