github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/allegrosql/server/server.go (about)

     1  // The MIT License (MIT)
     2  //
     3  // Copyright (c) 2020 wandoulabs
     4  // Copyright (c) 2020 siddontang
     5  //
     6  // Permission is hereby granted, free of charge, to any person obtaining a copy of
     7  // this software and associated documentation files (the "Software"), to deal in
     8  // the Software without restriction, including without limitation the rights to
     9  // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
    10  // the Software, and to permit persons to whom the Software is furnished to do so,
    11  // subject to the following conditions:
    12  //
    13  // The above copyright notice and this permission notice shall be included in all
    14  // copies or substantial portions of the Software.
    15  
    16  // Copyright 2020 WHTCORPS INC, Inc.
    17  //
    18  // Licensed under the Apache License, Version 2.0 (the "License");
    19  // you may not use this file except in compliance with the License.
    20  // You may obtain a copy of the License at
    21  //
    22  //     http://www.apache.org/licenses/LICENSE-2.0
    23  //
    24  // Unless required by applicable law or agreed to in writing, software
    25  // distributed under the License is distributed on an "AS IS" BASIS,
    26  // See the License for the specific language governing permissions and
    27  // limitations under the License.
    28  
    29  package server
    30  
    31  import (
    32  	"context"
    33  	"crypto/tls"
    34  	"flag"
    35  	"fmt"
    36  	"io"
    37  	"math/rand"
    38  	"net"
    39  	"net/http"
    40  	"unsafe"
    41  
    42  	// For pprof
    43  	_ "net/http/pprof"
    44  	"os"
    45  	"os/user"
    46  	"sync"
    47  	"sync/atomic"
    48  	"time"
    49  
    50  	"github.com/blacktear23/go-proxyprotodefCaus"
    51  	"github.com/whtcorpsinc/errors"
    52  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    53  	"github.com/whtcorpsinc/BerolinaSQL/terror"
    54  	"github.com/whtcorpsinc/milevadb/config"
    55  	"github.com/whtcorpsinc/milevadb/petri"
    56  	"github.com/whtcorpsinc/milevadb/errno"
    57  	"github.com/whtcorpsinc/milevadb/metrics"
    58  	"github.com/whtcorpsinc/milevadb/plugin"
    59  	"github.com/whtcorpsinc/milevadb/stochastikctx/variable"
    60  	"github.com/whtcorpsinc/milevadb/soliton"
    61  	"github.com/whtcorpsinc/milevadb/soliton/fastrand"
    62  	"github.com/whtcorpsinc/milevadb/soliton/logutil"
    63  	"github.com/whtcorpsinc/milevadb/soliton/sys/linux"
    64  	"github.com/whtcorpsinc/milevadb/soliton/timeutil"
    65  	"go.uber.org/zap"
    66  	"google.golang.org/grpc"
    67  )
    68  
    69  var (
    70  	baseConnID  uint32
    71  	serverPID   int
    72  	osUser      string
    73  	osVersion   string
    74  	runInGoTest bool
    75  )
    76  
    77  func init() {
    78  	serverPID = os.Getpid()
    79  	currentUser, err := user.Current()
    80  	if err != nil {
    81  		osUser = ""
    82  	} else {
    83  		osUser = currentUser.Name
    84  	}
    85  	osVersion, err = linux.OSVersion()
    86  	if err != nil {
    87  		osVersion = ""
    88  	}
    89  	runInGoTest = flag.Lookup("test.v") != nil || flag.Lookup("check.v") != nil
    90  }
    91  
    92  var (
    93  	errUnknownFieldType        = terror.ClassServer.New(errno.ErrUnknownFieldType, errno.MyALLEGROSQLErrName[errno.ErrUnknownFieldType])
    94  	errInvalidSequence         = terror.ClassServer.New(errno.ErrInvalidSequence, errno.MyALLEGROSQLErrName[errno.ErrInvalidSequence])
    95  	errInvalidType             = terror.ClassServer.New(errno.ErrInvalidType, errno.MyALLEGROSQLErrName[errno.ErrInvalidType])
    96  	errNotAllowedCommand       = terror.ClassServer.New(errno.ErrNotAllowedCommand, errno.MyALLEGROSQLErrName[errno.ErrNotAllowedCommand])
    97  	errAccessDenied            = terror.ClassServer.New(errno.ErrAccessDenied, errno.MyALLEGROSQLErrName[errno.ErrAccessDenied])
    98  	errConCount                = terror.ClassServer.New(errno.ErrConCount, errno.MyALLEGROSQLErrName[errno.ErrConCount])
    99  	errSecureTransportRequired = terror.ClassServer.New(errno.ErrSecureTransportRequired, errno.MyALLEGROSQLErrName[errno.ErrSecureTransportRequired])
   100  )
   101  
   102  // DefaultCapability is the capability of the server when it is created using the default configuration.
   103  // When server is configured with SSL, the server will have extra capabilities compared to DefaultCapability.
   104  const defaultCapability = allegrosql.ClientLongPassword | allegrosql.ClientLongFlag |
   105  	allegrosql.ClientConnectWithDB | allegrosql.ClientProtodefCaus41 |
   106  	allegrosql.ClientTransactions | allegrosql.ClientSecureConnection | allegrosql.ClientFoundRows |
   107  	allegrosql.ClientMultiStatements | allegrosql.ClientMultiResults | allegrosql.ClientLocalFiles |
   108  	allegrosql.ClientConnectAtts | allegrosql.ClientPluginAuth | allegrosql.ClientInteractive
   109  
   110  // Server is the MyALLEGROSQL protodefCaus server
   111  type Server struct {
   112  	cfg               *config.Config
   113  	tlsConfig         unsafe.Pointer // *tls.Config
   114  	driver            IDriver
   115  	listener          net.Listener
   116  	socket            net.Listener
   117  	rwlock            sync.RWMutex
   118  	concurrentLimiter *TokenLimiter
   119  	clients           map[uint32]*clientConn
   120  	capability        uint32
   121  	dom               *petri.Petri
   122  
   123  	statusAddr     string
   124  	statusListener net.Listener
   125  	statusServer   *http.Server
   126  	grpcServer     *grpc.Server
   127  }
   128  
   129  // ConnectionCount gets current connection count.
   130  func (s *Server) ConnectionCount() int {
   131  	s.rwlock.RLock()
   132  	cnt := len(s.clients)
   133  	s.rwlock.RUnlock()
   134  	return cnt
   135  }
   136  
   137  func (s *Server) getToken() *Token {
   138  	start := time.Now()
   139  	tok := s.concurrentLimiter.Get()
   140  	// Note that data smaller than one microsecond is ignored, because that case can be viewed as non-causet.
   141  	metrics.GetTokenDurationHistogram.Observe(float64(time.Since(start).Nanoseconds() / 1e3))
   142  	return tok
   143  }
   144  
   145  func (s *Server) releaseToken(token *Token) {
   146  	s.concurrentLimiter.Put(token)
   147  }
   148  
   149  // SetPetri use to set the server petri.
   150  func (s *Server) SetPetri(dom *petri.Petri) {
   151  	s.dom = dom
   152  }
   153  
   154  // newConn creates a new *clientConn from a net.Conn.
   155  // It allocates a connection ID and random salt data for authentication.
   156  func (s *Server) newConn(conn net.Conn) *clientConn {
   157  	cc := newClientConn(s)
   158  	if s.cfg.Performance.TCPKeepAlive {
   159  		if tcpConn, ok := conn.(*net.TCPConn); ok {
   160  			if err := tcpConn.SetKeepAlive(true); err != nil {
   161  				logutil.BgLogger().Error("failed to set tcp keep alive option", zap.Error(err))
   162  			}
   163  		}
   164  	}
   165  	cc.setConn(conn)
   166  	cc.salt = fastrand.Buf(20)
   167  	return cc
   168  }
   169  
   170  func (s *Server) isUnixSocket() bool {
   171  	return s.cfg.Socket != ""
   172  }
   173  
   174  func (s *Server) forwardUnixSocketToTCP() {
   175  	addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
   176  	for {
   177  		if s.listener == nil {
   178  			return // server shutdown has started
   179  		}
   180  		if uconn, err := s.socket.Accept(); err == nil {
   181  			logutil.BgLogger().Info("server socket forwarding", zap.String("from", s.cfg.Socket), zap.String("to", addr))
   182  			go s.handleForwardedConnection(uconn, addr)
   183  		} else {
   184  			if s.listener != nil {
   185  				logutil.BgLogger().Error("server failed to forward", zap.String("from", s.cfg.Socket), zap.String("to", addr), zap.Error(err))
   186  			}
   187  		}
   188  	}
   189  }
   190  
   191  func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) {
   192  	defer terror.Call(uconn.Close)
   193  	if tconn, err := net.Dial("tcp", addr); err == nil {
   194  		go func() {
   195  			if _, err := io.Copy(uconn, tconn); err != nil {
   196  				logutil.BgLogger().Warn("copy server to socket failed", zap.Error(err))
   197  			}
   198  		}()
   199  		if _, err := io.Copy(tconn, uconn); err != nil {
   200  			logutil.BgLogger().Warn("socket forward copy failed", zap.Error(err))
   201  		}
   202  	} else {
   203  		logutil.BgLogger().Warn("socket forward failed: could not connect", zap.String("addr", addr), zap.Error(err))
   204  	}
   205  }
   206  
   207  // NewServer creates a new Server.
   208  func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
   209  	s := &Server{
   210  		cfg:               cfg,
   211  		driver:            driver,
   212  		concurrentLimiter: NewTokenLimiter(cfg.TokenLimit),
   213  		clients:           make(map[uint32]*clientConn),
   214  	}
   215  
   216  	tlsConfig, err := soliton.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
   217  	if err != nil {
   218  		logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err))
   219  	}
   220  	if tlsConfig != nil {
   221  		setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
   222  		atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
   223  		logutil.BgLogger().Info("allegrosql protodefCaus server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
   224  	} else if cfg.Security.RequireSecureTransport {
   225  		return nil, errSecureTransportRequired.FastGenByArgs()
   226  	}
   227  
   228  	setSystemTimeZoneVariable()
   229  
   230  	s.capability = defaultCapability
   231  	if s.tlsConfig != nil {
   232  		s.capability |= allegrosql.ClientSSL
   233  	}
   234  
   235  	if s.cfg.Host != "" && (s.cfg.Port != 0 || runInGoTest) {
   236  		addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
   237  		if s.listener, err = net.Listen("tcp", addr); err == nil {
   238  			logutil.BgLogger().Info("server is running MyALLEGROSQL protodefCaus", zap.String("addr", addr))
   239  			if cfg.Socket != "" {
   240  				if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil {
   241  					logutil.BgLogger().Info("server redirecting", zap.String("from", s.cfg.Socket), zap.String("to", addr))
   242  					go s.forwardUnixSocketToTCP()
   243  				}
   244  			}
   245  			if runInGoTest && s.cfg.Port == 0 {
   246  				s.cfg.Port = uint(s.listener.Addr().(*net.TCPAddr).Port)
   247  			}
   248  		}
   249  	} else if cfg.Socket != "" {
   250  		if s.listener, err = net.Listen("unix", cfg.Socket); err == nil {
   251  			logutil.BgLogger().Info("server is running MyALLEGROSQL protodefCaus", zap.String("socket", cfg.Socket))
   252  		}
   253  	} else {
   254  		err = errors.New("Server not configured to listen on either -socket or -host and -port")
   255  	}
   256  
   257  	if cfg.ProxyProtodefCaus.Networks != "" {
   258  		pplistener, errProxy := proxyprotodefCaus.NewListener(s.listener, cfg.ProxyProtodefCaus.Networks,
   259  			int(cfg.ProxyProtodefCaus.HeaderTimeout))
   260  		if errProxy != nil {
   261  			logutil.BgLogger().Error("ProxyProtodefCaus networks parameter invalid")
   262  			return nil, errors.Trace(errProxy)
   263  		}
   264  		logutil.BgLogger().Info("server is running MyALLEGROSQL protodefCaus (through PROXY protodefCaus)", zap.String("host", s.cfg.Host))
   265  		s.listener = pplistener
   266  	}
   267  
   268  	if s.cfg.Status.ReportStatus && err == nil {
   269  		err = s.listenStatusHTTPServer()
   270  	}
   271  	if err != nil {
   272  		return nil, errors.Trace(err)
   273  	}
   274  
   275  	// Init rand seed for randomBuf()
   276  	rand.Seed(time.Now().UTC().UnixNano())
   277  	return s, nil
   278  }
   279  
   280  func setSSLVariable(ca, key, cert string) {
   281  	variable.SysVars["have_openssl"].Value = "YES"
   282  	variable.SysVars["have_ssl"].Value = "YES"
   283  	variable.SysVars["ssl_cert"].Value = cert
   284  	variable.SysVars["ssl_key"].Value = key
   285  	variable.SysVars["ssl_ca"].Value = ca
   286  }
   287  
   288  // Run runs the server.
   289  func (s *Server) Run() error {
   290  	metrics.ServerEventCounter.WithLabelValues(metrics.EventStart).Inc()
   291  
   292  	// Start HTTP API to report milevadb info such as TPS.
   293  	if s.cfg.Status.ReportStatus {
   294  		s.startStatusHTTP()
   295  	}
   296  	for {
   297  		conn, err := s.listener.Accept()
   298  		if err != nil {
   299  			if opErr, ok := err.(*net.OpError); ok {
   300  				if opErr.Err.Error() == "use of closed network connection" {
   301  					return nil
   302  				}
   303  			}
   304  
   305  			// If we got PROXY protodefCaus error, we should continue accept.
   306  			if proxyprotodefCaus.IsProxyProtodefCausError(err) {
   307  				logutil.BgLogger().Error("PROXY protodefCaus failed", zap.Error(err))
   308  				continue
   309  			}
   310  
   311  			logutil.BgLogger().Error("accept failed", zap.Error(err))
   312  			return errors.Trace(err)
   313  		}
   314  
   315  		clientConn := s.newConn(conn)
   316  
   317  		err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
   318  			authPlugin := plugin.DeclareAuditManifest(p.Manifest)
   319  			if authPlugin.OnConnectionEvent != nil {
   320  				host, err := clientConn.PeerHost("")
   321  				if err != nil {
   322  					logutil.BgLogger().Error("get peer host failed", zap.Error(err))
   323  					terror.Log(clientConn.Close())
   324  					return errors.Trace(err)
   325  				}
   326  				err = authPlugin.OnConnectionEvent(context.Background(), plugin.PreAuth, &variable.ConnectionInfo{Host: host})
   327  				if err != nil {
   328  					logutil.BgLogger().Info("do connection event failed", zap.Error(err))
   329  					terror.Log(clientConn.Close())
   330  					return errors.Trace(err)
   331  				}
   332  			}
   333  			return nil
   334  		})
   335  		if err != nil {
   336  			continue
   337  		}
   338  
   339  		go s.onConn(clientConn)
   340  	}
   341  }
   342  
   343  // Close closes the server.
   344  func (s *Server) Close() {
   345  	s.rwlock.Lock()
   346  	defer s.rwlock.Unlock()
   347  
   348  	if s.listener != nil {
   349  		err := s.listener.Close()
   350  		terror.Log(errors.Trace(err))
   351  		s.listener = nil
   352  	}
   353  	if s.socket != nil {
   354  		err := s.socket.Close()
   355  		terror.Log(errors.Trace(err))
   356  		s.socket = nil
   357  	}
   358  	if s.statusServer != nil {
   359  		err := s.statusServer.Close()
   360  		terror.Log(errors.Trace(err))
   361  		s.statusServer = nil
   362  	}
   363  	if s.grpcServer != nil {
   364  		s.grpcServer.Stop()
   365  		s.grpcServer = nil
   366  	}
   367  	metrics.ServerEventCounter.WithLabelValues(metrics.EventClose).Inc()
   368  }
   369  
   370  // onConn runs in its own goroutine, handles queries from this connection.
   371  func (s *Server) onConn(conn *clientConn) {
   372  	ctx := logutil.WithConnID(context.Background(), conn.connectionID)
   373  	if err := conn.handshake(ctx); err != nil {
   374  		if plugin.IsEnable(plugin.Audit) && conn.ctx != nil {
   375  			conn.ctx.GetStochastikVars().ConnectionInfo = conn.connectInfo()
   376  			err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
   377  				authPlugin := plugin.DeclareAuditManifest(p.Manifest)
   378  				if authPlugin.OnConnectionEvent != nil {
   379  					pluginCtx := context.WithValue(context.Background(), plugin.RejectReasonCtxValue{}, err.Error())
   380  					return authPlugin.OnConnectionEvent(pluginCtx, plugin.Reject, conn.ctx.GetStochastikVars().ConnectionInfo)
   381  				}
   382  				return nil
   383  			})
   384  			terror.Log(err)
   385  		}
   386  		// Some keep alive services will send request to MilevaDB and disconnect immediately.
   387  		// So we only record metrics.
   388  		metrics.HandShakeErrorCounter.Inc()
   389  		err = conn.Close()
   390  		terror.Log(errors.Trace(err))
   391  		return
   392  	}
   393  
   394  	logutil.Logger(ctx).Debug("new connection", zap.String("remoteAddr", conn.bufReadConn.RemoteAddr().String()))
   395  
   396  	defer func() {
   397  		logutil.Logger(ctx).Debug("connection closed")
   398  	}()
   399  	s.rwlock.Lock()
   400  	s.clients[conn.connectionID] = conn
   401  	connections := len(s.clients)
   402  	s.rwlock.Unlock()
   403  	metrics.ConnGauge.Set(float64(connections))
   404  
   405  	stochastikVars := conn.ctx.GetStochastikVars()
   406  	if plugin.IsEnable(plugin.Audit) {
   407  		stochastikVars.ConnectionInfo = conn.connectInfo()
   408  	}
   409  	err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
   410  		authPlugin := plugin.DeclareAuditManifest(p.Manifest)
   411  		if authPlugin.OnConnectionEvent != nil {
   412  			return authPlugin.OnConnectionEvent(context.Background(), plugin.Connected, stochastikVars.ConnectionInfo)
   413  		}
   414  		return nil
   415  	})
   416  	if err != nil {
   417  		return
   418  	}
   419  
   420  	connectedTime := time.Now()
   421  	conn.Run(ctx)
   422  
   423  	err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
   424  		authPlugin := plugin.DeclareAuditManifest(p.Manifest)
   425  		if authPlugin.OnConnectionEvent != nil {
   426  			stochastikVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond)
   427  			err := authPlugin.OnConnectionEvent(context.Background(), plugin.Disconnect, stochastikVars.ConnectionInfo)
   428  			if err != nil {
   429  				logutil.BgLogger().Warn("do connection event failed", zap.String("plugin", authPlugin.Name), zap.Error(err))
   430  			}
   431  		}
   432  		return nil
   433  	})
   434  	if err != nil {
   435  		return
   436  	}
   437  }
   438  
   439  func (cc *clientConn) connectInfo() *variable.ConnectionInfo {
   440  	connType := "Socket"
   441  	if cc.server.isUnixSocket() {
   442  		connType = "UnixSocket"
   443  	} else if cc.tlsConn != nil {
   444  		connType = "SSL/TLS"
   445  	}
   446  	connInfo := &variable.ConnectionInfo{
   447  		ConnectionID:      cc.connectionID,
   448  		ConnectionType:    connType,
   449  		Host:              cc.peerHost,
   450  		ClientIP:          cc.peerHost,
   451  		ClientPort:        cc.peerPort,
   452  		ServerID:          1,
   453  		ServerPort:        int(cc.server.cfg.Port),
   454  		User:              cc.user,
   455  		ServerOSLoginUser: osUser,
   456  		OSVersion:         osVersion,
   457  		ServerVersion:     allegrosql.MilevaDBReleaseVersion,
   458  		SSLVersion:        "v1.2.0", // for current go version
   459  		PID:               serverPID,
   460  		EDB:                cc.dbname,
   461  	}
   462  	return connInfo
   463  }
   464  
   465  func (s *Server) checkConnectionCount() error {
   466  	// When the value of MaxServerConnections is 0, the number of connections is unlimited.
   467  	if int(s.cfg.MaxServerConnections) == 0 {
   468  		return nil
   469  	}
   470  
   471  	s.rwlock.RLock()
   472  	conns := len(s.clients)
   473  	s.rwlock.RUnlock()
   474  
   475  	if conns >= int(s.cfg.MaxServerConnections) {
   476  		logutil.BgLogger().Error("too many connections",
   477  			zap.Uint32("max connections", s.cfg.MaxServerConnections), zap.Error(errConCount))
   478  		return errConCount
   479  	}
   480  	return nil
   481  }
   482  
   483  // ShowProcessList implements the StochastikManager interface.
   484  func (s *Server) ShowProcessList() map[uint64]*soliton.ProcessInfo {
   485  	s.rwlock.RLock()
   486  	defer s.rwlock.RUnlock()
   487  	rs := make(map[uint64]*soliton.ProcessInfo, len(s.clients))
   488  	for _, client := range s.clients {
   489  		if atomic.LoadInt32(&client.status) == connStatusWaitShutdown {
   490  			continue
   491  		}
   492  		if pi := client.ctx.ShowProcess(); pi != nil {
   493  			rs[pi.ID] = pi
   494  		}
   495  	}
   496  	return rs
   497  }
   498  
   499  // GetProcessInfo implements the StochastikManager interface.
   500  func (s *Server) GetProcessInfo(id uint64) (*soliton.ProcessInfo, bool) {
   501  	s.rwlock.RLock()
   502  	conn, ok := s.clients[uint32(id)]
   503  	s.rwlock.RUnlock()
   504  	if !ok || atomic.LoadInt32(&conn.status) == connStatusWaitShutdown {
   505  		return &soliton.ProcessInfo{}, false
   506  	}
   507  	return conn.ctx.ShowProcess(), ok
   508  }
   509  
   510  // Kill implements the StochastikManager interface.
   511  func (s *Server) Kill(connectionID uint64, query bool) {
   512  	logutil.BgLogger().Info("kill", zap.Uint64("connID", connectionID), zap.Bool("query", query))
   513  	metrics.ServerEventCounter.WithLabelValues(metrics.EventKill).Inc()
   514  
   515  	s.rwlock.RLock()
   516  	defer s.rwlock.RUnlock()
   517  	conn, ok := s.clients[uint32(connectionID)]
   518  	if !ok {
   519  		return
   520  	}
   521  
   522  	if !query {
   523  		// Mark the client connection status as WaitShutdown, when the goroutine detect
   524  		// this, it will end the dispatch loop and exit.
   525  		atomic.StoreInt32(&conn.status, connStatusWaitShutdown)
   526  	}
   527  	killConn(conn)
   528  }
   529  
   530  // UFIDelateTLSConfig implements the StochastikManager interface.
   531  func (s *Server) UFIDelateTLSConfig(cfg *tls.Config) {
   532  	atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(cfg))
   533  }
   534  
   535  func (s *Server) getTLSConfig() *tls.Config {
   536  	return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig))
   537  }
   538  
   539  func killConn(conn *clientConn) {
   540  	sessVars := conn.ctx.GetStochastikVars()
   541  	atomic.StoreUint32(&sessVars.Killed, 1)
   542  }
   543  
   544  // KillAllConnections kills all connections when server is not gracefully shutdown.
   545  func (s *Server) KillAllConnections() {
   546  	logutil.BgLogger().Info("[server] kill all connections.")
   547  
   548  	s.rwlock.RLock()
   549  	defer s.rwlock.RUnlock()
   550  	for _, conn := range s.clients {
   551  		atomic.StoreInt32(&conn.status, connStatusShutdown)
   552  		if err := conn.closeWithoutLock(); err != nil {
   553  			terror.Log(err)
   554  		}
   555  		killConn(conn)
   556  	}
   557  }
   558  
   559  var gracefulCloseConnectionsTimeout = 15 * time.Second
   560  
   561  // TryGracefulDown will try to gracefully close all connection first with timeout. if timeout, will close all connection directly.
   562  func (s *Server) TryGracefulDown() {
   563  	ctx, cancel := context.WithTimeout(context.Background(), gracefulCloseConnectionsTimeout)
   564  	defer cancel()
   565  	done := make(chan struct{})
   566  	go func() {
   567  		s.GracefulDown(ctx, done)
   568  	}()
   569  	select {
   570  	case <-ctx.Done():
   571  		s.KillAllConnections()
   572  	case <-done:
   573  		return
   574  	}
   575  }
   576  
   577  // GracefulDown waits all clients to close.
   578  func (s *Server) GracefulDown(ctx context.Context, done chan struct{}) {
   579  	logutil.Logger(ctx).Info("[server] graceful shutdown.")
   580  	metrics.ServerEventCounter.WithLabelValues(metrics.EventGracefulDown).Inc()
   581  
   582  	count := s.ConnectionCount()
   583  	for i := 0; count > 0; i++ {
   584  		s.kickIdleConnection()
   585  
   586  		count = s.ConnectionCount()
   587  		if count == 0 {
   588  			break
   589  		}
   590  		// Print information for every 30s.
   591  		if i%30 == 0 {
   592  			logutil.Logger(ctx).Info("graceful shutdown...", zap.Int("conn count", count))
   593  		}
   594  		ticker := time.After(time.Second)
   595  		select {
   596  		case <-ctx.Done():
   597  			return
   598  		case <-ticker:
   599  		}
   600  	}
   601  	close(done)
   602  }
   603  
   604  func (s *Server) kickIdleConnection() {
   605  	var conns []*clientConn
   606  	s.rwlock.RLock()
   607  	for _, cc := range s.clients {
   608  		if cc.ShutdownOrNotify() {
   609  			// Shutdowned conn will be closed by us, and notified conn will exist themselves.
   610  			conns = append(conns, cc)
   611  		}
   612  	}
   613  	s.rwlock.RUnlock()
   614  
   615  	for _, cc := range conns {
   616  		err := cc.Close()
   617  		if err != nil {
   618  			logutil.BgLogger().Error("close connection", zap.Error(err))
   619  		}
   620  	}
   621  }
   622  
   623  // setSysTimeZoneOnce is used for parallel run tests. When several servers are running,
   624  // only the first will actually do setSystemTimeZoneVariable, thus we can avoid data race.
   625  var setSysTimeZoneOnce = &sync.Once{}
   626  
   627  func setSystemTimeZoneVariable() {
   628  	setSysTimeZoneOnce.Do(func() {
   629  		tz, err := timeutil.GetSystemTZ()
   630  		if err != nil {
   631  			logutil.BgLogger().Error(
   632  				"Error getting SystemTZ, use default value instead",
   633  				zap.Error(err),
   634  				zap.String("default system_time_zone", variable.SysVars["system_time_zone"].Value))
   635  			return
   636  		}
   637  		variable.SysVars["system_time_zone"].Value = tz
   638  	})
   639  }
   640  
   641  // Server error codes.
   642  const (
   643  	codeUnknownFieldType = 1
   644  	codeInvalidSequence  = 3
   645  	codeInvalidType      = 4
   646  )