github.com/matrixorigin/matrixone@v0.7.0/pkg/frontend/routine_manager.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package frontend
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"crypto/x509"
    21  	"fmt"
    22  	"os"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/matrixorigin/matrixone/pkg/util/metric"
    27  
    28  	"github.com/fagongzi/goetty/v2"
    29  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    30  	"github.com/matrixorigin/matrixone/pkg/config"
    31  	"github.com/matrixorigin/matrixone/pkg/defines"
    32  	"github.com/matrixorigin/matrixone/pkg/logutil"
    33  	"github.com/matrixorigin/matrixone/pkg/util/trace"
    34  )
    35  
    36  type RoutineManager struct {
    37  	mu             sync.Mutex
    38  	ctx            context.Context
    39  	clients        map[goetty.IOSession]*Routine
    40  	pu             *config.ParameterUnit
    41  	skipCheckUser  bool
    42  	tlsConfig      *tls.Config
    43  	autoIncrCaches defines.AutoIncrCaches
    44  }
    45  
    46  func (rm *RoutineManager) GetAutoIncrCache() defines.AutoIncrCaches {
    47  	rm.mu.Lock()
    48  	defer rm.mu.Unlock()
    49  	return rm.autoIncrCaches
    50  }
    51  
    52  func (rm *RoutineManager) SetSkipCheckUser(b bool) {
    53  	rm.mu.Lock()
    54  	defer rm.mu.Unlock()
    55  	rm.skipCheckUser = b
    56  }
    57  
    58  func (rm *RoutineManager) GetSkipCheckUser() bool {
    59  	rm.mu.Lock()
    60  	defer rm.mu.Unlock()
    61  	return rm.skipCheckUser
    62  }
    63  
    64  func (rm *RoutineManager) getParameterUnit() *config.ParameterUnit {
    65  	rm.mu.Lock()
    66  	defer rm.mu.Unlock()
    67  	return rm.pu
    68  }
    69  
    70  func (rm *RoutineManager) getCtx() context.Context {
    71  	rm.mu.Lock()
    72  	defer rm.mu.Unlock()
    73  	return rm.ctx
    74  }
    75  
    76  func (rm *RoutineManager) setRoutine(rs goetty.IOSession, r *Routine) {
    77  	rm.mu.Lock()
    78  	defer rm.mu.Unlock()
    79  	rm.clients[rs] = r
    80  }
    81  
    82  func (rm *RoutineManager) getRoutine(rs goetty.IOSession) *Routine {
    83  	rm.mu.Lock()
    84  	defer rm.mu.Unlock()
    85  	return rm.clients[rs]
    86  }
    87  
    88  func (rm *RoutineManager) getTlsConfig() *tls.Config {
    89  	rm.mu.Lock()
    90  	defer rm.mu.Unlock()
    91  	return rm.tlsConfig
    92  }
    93  
    94  func (rm *RoutineManager) Created(rs goetty.IOSession) {
    95  	logutil.Debugf("get the connection from %s", rs.RemoteAddress())
    96  	pu := rm.getParameterUnit()
    97  	pro := NewMysqlClientProtocol(nextConnectionID(), rs, int(pu.SV.MaxBytesInOutbufToFlush), pu.SV)
    98  	pro.SetSkipCheckUser(rm.GetSkipCheckUser())
    99  	exe := NewMysqlCmdExecutor()
   100  	exe.SetRoutineManager(rm)
   101  	exe.ChooseDoQueryFunc(pu.SV.EnableDoComQueryInProgress)
   102  
   103  	routine := NewRoutine(rm.getCtx(), pro, exe, pu.SV, rs)
   104  
   105  	// XXX MPOOL pass in a nil mpool.
   106  	// XXX MPOOL can choose to use a Mid sized mpool, if, we know
   107  	// this mpool will be deleted.  Maybe in the following Closed method.
   108  	ses := NewSession(routine.getProtocol(), nil, pu, GSysVariables, true)
   109  	ses.SetRequestContext(routine.getCancelRoutineCtx())
   110  	ses.SetFromRealUser(true)
   111  	ses.setSkipCheckPrivilege(rm.GetSkipCheckUser())
   112  
   113  	// Add  autoIncrCaches in session structure.
   114  	ses.SetAutoIncrCaches(rm.autoIncrCaches)
   115  
   116  	routine.setSession(ses)
   117  	pro.SetSession(ses)
   118  
   119  	logDebugf(pro.GetConciseProfile(), "have done some preparation for the connection %s", rs.RemoteAddress())
   120  
   121  	hsV10pkt := pro.makeHandshakeV10Payload()
   122  	err := pro.writePackets(hsV10pkt)
   123  	if err != nil {
   124  		logErrorf(pro.GetConciseProfile(), "failed to handshake with server, quiting routine... %s", err)
   125  		routine.killConnection(true)
   126  		return
   127  	}
   128  
   129  	logDebugf(pro.GetConciseProfile(), "have sent handshake packet to connection %s", rs.RemoteAddress())
   130  	rm.setRoutine(rs, routine)
   131  }
   132  
   133  /*
   134  When the io is closed, the Closed will be called.
   135  */
   136  func (rm *RoutineManager) Closed(rs goetty.IOSession) {
   137  	logutil.Debugf("clean resource of the connection %d:%s", rs.ID(), rs.RemoteAddress())
   138  	defer func() {
   139  		logutil.Debugf("resource of the connection %d:%s has been cleaned", rs.ID(), rs.RemoteAddress())
   140  	}()
   141  	var rt *Routine
   142  	var ok bool
   143  
   144  	rm.mu.Lock()
   145  	rt, ok = rm.clients[rs]
   146  	if ok {
   147  		delete(rm.clients, rs)
   148  	}
   149  	rm.mu.Unlock()
   150  
   151  	if rt != nil {
   152  		ses := rt.getSession()
   153  		if ses != nil {
   154  			rt.decreaseCount(func() {
   155  				account := ses.GetTenantInfo()
   156  				accountName := sysAccountName
   157  				if account != nil {
   158  					accountName = account.GetTenant()
   159  				}
   160  				metric.ConnectionCounter(accountName).Dec()
   161  			})
   162  			logDebugf(ses.GetConciseProfile(), "the io session was closed.")
   163  		}
   164  		rt.cleanup()
   165  	}
   166  }
   167  
   168  /*
   169  kill a connection or query.
   170  if killConnection is true, the query will be canceled first, then the network will be closed.
   171  if killConnection is false, only the query will be canceled. the connection keeps intact.
   172  */
   173  func (rm *RoutineManager) kill(ctx context.Context, killConnection bool, idThatKill, id uint64, statementId string) error {
   174  	var rt *Routine = nil
   175  	rm.mu.Lock()
   176  	for _, value := range rm.clients {
   177  		if uint64(value.getConnectionID()) == id {
   178  			rt = value
   179  			break
   180  		}
   181  	}
   182  	rm.mu.Unlock()
   183  
   184  	killMyself := idThatKill == id
   185  	if rt != nil {
   186  		if killConnection {
   187  			logutil.Infof("kill connection %d", id)
   188  			rt.killConnection(killMyself)
   189  		} else {
   190  			logutil.Infof("kill query %s on the connection %d", statementId, id)
   191  			rt.killQuery(killMyself, statementId)
   192  		}
   193  	} else {
   194  		return moerr.NewInternalError(ctx, "Unknown connection id %d", id)
   195  	}
   196  	return nil
   197  }
   198  
   199  func getConnectionInfo(rs goetty.IOSession) string {
   200  	conn := rs.RawConn()
   201  	if conn != nil {
   202  		return fmt.Sprintf("connection from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
   203  	}
   204  	return fmt.Sprintf("connection from %s", rs.RemoteAddress())
   205  }
   206  
   207  func (rm *RoutineManager) Handler(rs goetty.IOSession, msg interface{}, received uint64) error {
   208  	logutil.Debugf("get request from %d:%s", rs.ID(), rs.RemoteAddress())
   209  	defer func() {
   210  		logutil.Debugf("request from %d:%s has been processed", rs.ID(), rs.RemoteAddress())
   211  	}()
   212  	var err error
   213  	var isTlsHeader bool
   214  	ctx, span := trace.Start(rm.getCtx(), "RoutineManager.Handler")
   215  	defer span.End()
   216  	connectionInfo := getConnectionInfo(rs)
   217  	routine := rm.getRoutine(rs)
   218  	if routine == nil {
   219  		err = moerr.NewInternalError(ctx, "routine does not exist")
   220  		logutil.Errorf("%s error:%v", connectionInfo, err)
   221  		return err
   222  	}
   223  	routine.setInProcessRequest(true)
   224  	defer routine.setInProcessRequest(false)
   225  	protocol := routine.getProtocol()
   226  	protoProfile := protocol.GetConciseProfile()
   227  	packet, ok := msg.(*Packet)
   228  
   229  	protocol.SetSequenceID(uint8(packet.SequenceID + 1))
   230  	var seq = protocol.GetSequenceId()
   231  	if !ok {
   232  		err = moerr.NewInternalError(ctx, "message is not Packet")
   233  		logErrorf(protoProfile, "error:%v", err)
   234  		return err
   235  	}
   236  
   237  	length := packet.Length
   238  	payload := packet.Payload
   239  	for uint32(length) == MaxPayloadSize {
   240  		msg, err = protocol.GetTcpConnection().Read(goetty.ReadOptions{})
   241  		if err != nil {
   242  			logErrorf(protoProfile, "read message failed. error:%s", err)
   243  			return err
   244  		}
   245  
   246  		packet, ok = msg.(*Packet)
   247  		if !ok {
   248  			err = moerr.NewInternalError(ctx, "message is not Packet")
   249  			logErrorf(protoProfile, "error:%v", err)
   250  			return err
   251  		}
   252  
   253  		protocol.SetSequenceID(uint8(packet.SequenceID + 1))
   254  		seq = protocol.GetSequenceId()
   255  		payload = append(payload, packet.Payload...)
   256  		length = packet.Length
   257  	}
   258  
   259  	// finish handshake process
   260  	if !protocol.IsEstablished() {
   261  		logDebugf(protoProfile, "HANDLE HANDSHAKE")
   262  
   263  		/*
   264  			di := MakeDebugInfo(payload,80,8)
   265  			logutil.Infof("RP[%v] Payload80[%v]",rs.RemoteAddr(),di)
   266  		*/
   267  		ses := routine.getSession()
   268  		if protocol.GetCapability()&CLIENT_SSL != 0 && !protocol.IsTlsEstablished() {
   269  			logDebugf(protoProfile, "setup ssl")
   270  			isTlsHeader, err = protocol.HandleHandshake(ctx, payload)
   271  			if err != nil {
   272  				logErrorf(protoProfile, "error:%v", err)
   273  				return err
   274  			}
   275  			if isTlsHeader {
   276  				logDebugf(protoProfile, "upgrade to TLS")
   277  				// do upgradeTls
   278  				tlsConn := tls.Server(rs.RawConn(), rm.getTlsConfig())
   279  				logDebugf(protoProfile, "get TLS conn ok")
   280  				newCtx, cancelFun := context.WithTimeout(ctx, 20*time.Second)
   281  				if err = tlsConn.HandshakeContext(newCtx); err != nil {
   282  					logErrorf(protoProfile, "before cancel() error:%v", err)
   283  					cancelFun()
   284  					logErrorf(protoProfile, "after cancel() error:%v", err)
   285  					return err
   286  				}
   287  				cancelFun()
   288  				logDebugf(protoProfile, "TLS handshake ok")
   289  				rs.UseConn(tlsConn)
   290  				logDebugf(protoProfile, "TLS handshake finished")
   291  
   292  				// tls upgradeOk
   293  				protocol.SetTlsEstablished()
   294  			} else {
   295  				// client don't ask server to upgrade TLS
   296  				protocol.SetTlsEstablished()
   297  				protocol.SetEstablished()
   298  			}
   299  		} else {
   300  			logDebugf(protoProfile, "handleHandshake")
   301  			_, err = protocol.HandleHandshake(ctx, payload)
   302  			if err != nil {
   303  				logErrorf(protoProfile, "error:%v", err)
   304  				return err
   305  			}
   306  			protocol.SetEstablished()
   307  		}
   308  
   309  		dbName := protocol.GetDatabaseName()
   310  		if ses != nil && dbName != "" {
   311  			ses.SetDatabaseName(dbName)
   312  		}
   313  		return nil
   314  	}
   315  
   316  	req := routine.getProtocol().GetRequest(payload)
   317  	req.seq = seq
   318  
   319  	//handle request
   320  	err = routine.handleRequest(req)
   321  	if err != nil {
   322  		logErrorf(protoProfile, "error:%v", err)
   323  		return err
   324  	}
   325  
   326  	return nil
   327  }
   328  
   329  // clientCount returns the count of the clients
   330  func (rm *RoutineManager) clientCount() int {
   331  	var count int
   332  	rm.mu.Lock()
   333  	count = len(rm.clients)
   334  	rm.mu.Unlock()
   335  	return count
   336  }
   337  
   338  func NewRoutineManager(ctx context.Context, pu *config.ParameterUnit) (*RoutineManager, error) {
   339  	rm := &RoutineManager{
   340  		ctx:     ctx,
   341  		clients: make(map[goetty.IOSession]*Routine),
   342  		pu:      pu,
   343  	}
   344  
   345  	// Initialize auto incre cache.
   346  	rm.autoIncrCaches.AutoIncrCaches = make(map[string]defines.AutoIncrCache)
   347  	rm.autoIncrCaches.Mu = &rm.mu
   348  
   349  	if pu.SV.EnableTls {
   350  		err := initTlsConfig(rm, pu.SV)
   351  		if err != nil {
   352  			return nil, err
   353  		}
   354  	}
   355  	return rm, nil
   356  }
   357  
   358  func initTlsConfig(rm *RoutineManager, SV *config.FrontendParameters) error {
   359  	if len(SV.TlsCertFile) == 0 || len(SV.TlsKeyFile) == 0 {
   360  		return moerr.NewInternalError(rm.ctx, "init TLS config error : cert file or key file is empty")
   361  	}
   362  
   363  	var tlsCert tls.Certificate
   364  	var err error
   365  	tlsCert, err = tls.LoadX509KeyPair(SV.TlsCertFile, SV.TlsKeyFile)
   366  	if err != nil {
   367  		return moerr.NewInternalError(rm.ctx, "init TLS config error :load x509 failed")
   368  	}
   369  
   370  	clientAuthPolicy := tls.NoClientCert
   371  	var certPool *x509.CertPool
   372  	if len(SV.TlsCaFile) > 0 {
   373  		var caCert []byte
   374  		caCert, err = os.ReadFile(SV.TlsCaFile)
   375  		if err != nil {
   376  			return moerr.NewInternalError(rm.ctx, "init TLS config error :read TlsCaFile failed")
   377  		}
   378  		certPool = x509.NewCertPool()
   379  		if certPool.AppendCertsFromPEM(caCert) {
   380  			clientAuthPolicy = tls.VerifyClientCertIfGiven
   381  		}
   382  	}
   383  
   384  	// This excludes ciphers listed in tls.InsecureCipherSuites() and can be used to filter out more
   385  	// var cipherSuites []uint16
   386  	// var cipherNames []string
   387  	// for _, sc := range tls.CipherSuites() {
   388  	// cipherSuites = append(cipherSuites, sc.ID)
   389  	// switch sc.ID {
   390  	// case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
   391  	// 	tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305:
   392  	// logutil.Info("Disabling weak cipherSuite", zap.String("cipherSuite", sc.Name))
   393  	// default:
   394  	// cipherNames = append(cipherNames, sc.Name)
   395  	// cipherSuites = append(cipherSuites, sc.ID)
   396  	// }
   397  	// }
   398  	// logutil.Info("Enabled ciphersuites", zap.Strings("cipherNames", cipherNames))
   399  
   400  	rm.tlsConfig = &tls.Config{
   401  		Certificates: []tls.Certificate{tlsCert},
   402  		ClientCAs:    certPool,
   403  		ClientAuth:   clientAuthPolicy,
   404  		// MinVersion:   tls.VersionTLS13,
   405  		// CipherSuites: cipherSuites,
   406  	}
   407  	logutil.Info("init TLS config finished")
   408  	return nil
   409  }