github.com/matrixorigin/matrixone@v0.7.0/pkg/frontend/routine_manager.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frontend 16 17 import ( 18 "context" 19 "crypto/tls" 20 "crypto/x509" 21 "fmt" 22 "os" 23 "sync" 24 "time" 25 26 "github.com/matrixorigin/matrixone/pkg/util/metric" 27 28 "github.com/fagongzi/goetty/v2" 29 "github.com/matrixorigin/matrixone/pkg/common/moerr" 30 "github.com/matrixorigin/matrixone/pkg/config" 31 "github.com/matrixorigin/matrixone/pkg/defines" 32 "github.com/matrixorigin/matrixone/pkg/logutil" 33 "github.com/matrixorigin/matrixone/pkg/util/trace" 34 ) 35 36 type RoutineManager struct { 37 mu sync.Mutex 38 ctx context.Context 39 clients map[goetty.IOSession]*Routine 40 pu *config.ParameterUnit 41 skipCheckUser bool 42 tlsConfig *tls.Config 43 autoIncrCaches defines.AutoIncrCaches 44 } 45 46 func (rm *RoutineManager) GetAutoIncrCache() defines.AutoIncrCaches { 47 rm.mu.Lock() 48 defer rm.mu.Unlock() 49 return rm.autoIncrCaches 50 } 51 52 func (rm *RoutineManager) SetSkipCheckUser(b bool) { 53 rm.mu.Lock() 54 defer rm.mu.Unlock() 55 rm.skipCheckUser = b 56 } 57 58 func (rm *RoutineManager) GetSkipCheckUser() bool { 59 rm.mu.Lock() 60 defer rm.mu.Unlock() 61 return rm.skipCheckUser 62 } 63 64 func (rm *RoutineManager) getParameterUnit() *config.ParameterUnit { 65 rm.mu.Lock() 66 defer rm.mu.Unlock() 67 return rm.pu 68 } 69 70 func (rm *RoutineManager) getCtx() context.Context { 71 rm.mu.Lock() 72 defer rm.mu.Unlock() 73 return rm.ctx 74 } 75 76 func (rm *RoutineManager) setRoutine(rs goetty.IOSession, r *Routine) { 77 rm.mu.Lock() 78 defer rm.mu.Unlock() 79 rm.clients[rs] = r 80 } 81 82 func (rm *RoutineManager) getRoutine(rs goetty.IOSession) *Routine { 83 rm.mu.Lock() 84 defer rm.mu.Unlock() 85 return rm.clients[rs] 86 } 87 88 func (rm *RoutineManager) getTlsConfig() *tls.Config { 89 rm.mu.Lock() 90 defer rm.mu.Unlock() 91 return rm.tlsConfig 92 } 93 94 func (rm *RoutineManager) Created(rs goetty.IOSession) { 95 logutil.Debugf("get the connection from %s", rs.RemoteAddress()) 96 pu := rm.getParameterUnit() 97 pro := NewMysqlClientProtocol(nextConnectionID(), rs, int(pu.SV.MaxBytesInOutbufToFlush), pu.SV) 98 pro.SetSkipCheckUser(rm.GetSkipCheckUser()) 99 exe := NewMysqlCmdExecutor() 100 exe.SetRoutineManager(rm) 101 exe.ChooseDoQueryFunc(pu.SV.EnableDoComQueryInProgress) 102 103 routine := NewRoutine(rm.getCtx(), pro, exe, pu.SV, rs) 104 105 // XXX MPOOL pass in a nil mpool. 106 // XXX MPOOL can choose to use a Mid sized mpool, if, we know 107 // this mpool will be deleted. Maybe in the following Closed method. 108 ses := NewSession(routine.getProtocol(), nil, pu, GSysVariables, true) 109 ses.SetRequestContext(routine.getCancelRoutineCtx()) 110 ses.SetFromRealUser(true) 111 ses.setSkipCheckPrivilege(rm.GetSkipCheckUser()) 112 113 // Add autoIncrCaches in session structure. 114 ses.SetAutoIncrCaches(rm.autoIncrCaches) 115 116 routine.setSession(ses) 117 pro.SetSession(ses) 118 119 logDebugf(pro.GetConciseProfile(), "have done some preparation for the connection %s", rs.RemoteAddress()) 120 121 hsV10pkt := pro.makeHandshakeV10Payload() 122 err := pro.writePackets(hsV10pkt) 123 if err != nil { 124 logErrorf(pro.GetConciseProfile(), "failed to handshake with server, quiting routine... %s", err) 125 routine.killConnection(true) 126 return 127 } 128 129 logDebugf(pro.GetConciseProfile(), "have sent handshake packet to connection %s", rs.RemoteAddress()) 130 rm.setRoutine(rs, routine) 131 } 132 133 /* 134 When the io is closed, the Closed will be called. 135 */ 136 func (rm *RoutineManager) Closed(rs goetty.IOSession) { 137 logutil.Debugf("clean resource of the connection %d:%s", rs.ID(), rs.RemoteAddress()) 138 defer func() { 139 logutil.Debugf("resource of the connection %d:%s has been cleaned", rs.ID(), rs.RemoteAddress()) 140 }() 141 var rt *Routine 142 var ok bool 143 144 rm.mu.Lock() 145 rt, ok = rm.clients[rs] 146 if ok { 147 delete(rm.clients, rs) 148 } 149 rm.mu.Unlock() 150 151 if rt != nil { 152 ses := rt.getSession() 153 if ses != nil { 154 rt.decreaseCount(func() { 155 account := ses.GetTenantInfo() 156 accountName := sysAccountName 157 if account != nil { 158 accountName = account.GetTenant() 159 } 160 metric.ConnectionCounter(accountName).Dec() 161 }) 162 logDebugf(ses.GetConciseProfile(), "the io session was closed.") 163 } 164 rt.cleanup() 165 } 166 } 167 168 /* 169 kill a connection or query. 170 if killConnection is true, the query will be canceled first, then the network will be closed. 171 if killConnection is false, only the query will be canceled. the connection keeps intact. 172 */ 173 func (rm *RoutineManager) kill(ctx context.Context, killConnection bool, idThatKill, id uint64, statementId string) error { 174 var rt *Routine = nil 175 rm.mu.Lock() 176 for _, value := range rm.clients { 177 if uint64(value.getConnectionID()) == id { 178 rt = value 179 break 180 } 181 } 182 rm.mu.Unlock() 183 184 killMyself := idThatKill == id 185 if rt != nil { 186 if killConnection { 187 logutil.Infof("kill connection %d", id) 188 rt.killConnection(killMyself) 189 } else { 190 logutil.Infof("kill query %s on the connection %d", statementId, id) 191 rt.killQuery(killMyself, statementId) 192 } 193 } else { 194 return moerr.NewInternalError(ctx, "Unknown connection id %d", id) 195 } 196 return nil 197 } 198 199 func getConnectionInfo(rs goetty.IOSession) string { 200 conn := rs.RawConn() 201 if conn != nil { 202 return fmt.Sprintf("connection from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) 203 } 204 return fmt.Sprintf("connection from %s", rs.RemoteAddress()) 205 } 206 207 func (rm *RoutineManager) Handler(rs goetty.IOSession, msg interface{}, received uint64) error { 208 logutil.Debugf("get request from %d:%s", rs.ID(), rs.RemoteAddress()) 209 defer func() { 210 logutil.Debugf("request from %d:%s has been processed", rs.ID(), rs.RemoteAddress()) 211 }() 212 var err error 213 var isTlsHeader bool 214 ctx, span := trace.Start(rm.getCtx(), "RoutineManager.Handler") 215 defer span.End() 216 connectionInfo := getConnectionInfo(rs) 217 routine := rm.getRoutine(rs) 218 if routine == nil { 219 err = moerr.NewInternalError(ctx, "routine does not exist") 220 logutil.Errorf("%s error:%v", connectionInfo, err) 221 return err 222 } 223 routine.setInProcessRequest(true) 224 defer routine.setInProcessRequest(false) 225 protocol := routine.getProtocol() 226 protoProfile := protocol.GetConciseProfile() 227 packet, ok := msg.(*Packet) 228 229 protocol.SetSequenceID(uint8(packet.SequenceID + 1)) 230 var seq = protocol.GetSequenceId() 231 if !ok { 232 err = moerr.NewInternalError(ctx, "message is not Packet") 233 logErrorf(protoProfile, "error:%v", err) 234 return err 235 } 236 237 length := packet.Length 238 payload := packet.Payload 239 for uint32(length) == MaxPayloadSize { 240 msg, err = protocol.GetTcpConnection().Read(goetty.ReadOptions{}) 241 if err != nil { 242 logErrorf(protoProfile, "read message failed. error:%s", err) 243 return err 244 } 245 246 packet, ok = msg.(*Packet) 247 if !ok { 248 err = moerr.NewInternalError(ctx, "message is not Packet") 249 logErrorf(protoProfile, "error:%v", err) 250 return err 251 } 252 253 protocol.SetSequenceID(uint8(packet.SequenceID + 1)) 254 seq = protocol.GetSequenceId() 255 payload = append(payload, packet.Payload...) 256 length = packet.Length 257 } 258 259 // finish handshake process 260 if !protocol.IsEstablished() { 261 logDebugf(protoProfile, "HANDLE HANDSHAKE") 262 263 /* 264 di := MakeDebugInfo(payload,80,8) 265 logutil.Infof("RP[%v] Payload80[%v]",rs.RemoteAddr(),di) 266 */ 267 ses := routine.getSession() 268 if protocol.GetCapability()&CLIENT_SSL != 0 && !protocol.IsTlsEstablished() { 269 logDebugf(protoProfile, "setup ssl") 270 isTlsHeader, err = protocol.HandleHandshake(ctx, payload) 271 if err != nil { 272 logErrorf(protoProfile, "error:%v", err) 273 return err 274 } 275 if isTlsHeader { 276 logDebugf(protoProfile, "upgrade to TLS") 277 // do upgradeTls 278 tlsConn := tls.Server(rs.RawConn(), rm.getTlsConfig()) 279 logDebugf(protoProfile, "get TLS conn ok") 280 newCtx, cancelFun := context.WithTimeout(ctx, 20*time.Second) 281 if err = tlsConn.HandshakeContext(newCtx); err != nil { 282 logErrorf(protoProfile, "before cancel() error:%v", err) 283 cancelFun() 284 logErrorf(protoProfile, "after cancel() error:%v", err) 285 return err 286 } 287 cancelFun() 288 logDebugf(protoProfile, "TLS handshake ok") 289 rs.UseConn(tlsConn) 290 logDebugf(protoProfile, "TLS handshake finished") 291 292 // tls upgradeOk 293 protocol.SetTlsEstablished() 294 } else { 295 // client don't ask server to upgrade TLS 296 protocol.SetTlsEstablished() 297 protocol.SetEstablished() 298 } 299 } else { 300 logDebugf(protoProfile, "handleHandshake") 301 _, err = protocol.HandleHandshake(ctx, payload) 302 if err != nil { 303 logErrorf(protoProfile, "error:%v", err) 304 return err 305 } 306 protocol.SetEstablished() 307 } 308 309 dbName := protocol.GetDatabaseName() 310 if ses != nil && dbName != "" { 311 ses.SetDatabaseName(dbName) 312 } 313 return nil 314 } 315 316 req := routine.getProtocol().GetRequest(payload) 317 req.seq = seq 318 319 //handle request 320 err = routine.handleRequest(req) 321 if err != nil { 322 logErrorf(protoProfile, "error:%v", err) 323 return err 324 } 325 326 return nil 327 } 328 329 // clientCount returns the count of the clients 330 func (rm *RoutineManager) clientCount() int { 331 var count int 332 rm.mu.Lock() 333 count = len(rm.clients) 334 rm.mu.Unlock() 335 return count 336 } 337 338 func NewRoutineManager(ctx context.Context, pu *config.ParameterUnit) (*RoutineManager, error) { 339 rm := &RoutineManager{ 340 ctx: ctx, 341 clients: make(map[goetty.IOSession]*Routine), 342 pu: pu, 343 } 344 345 // Initialize auto incre cache. 346 rm.autoIncrCaches.AutoIncrCaches = make(map[string]defines.AutoIncrCache) 347 rm.autoIncrCaches.Mu = &rm.mu 348 349 if pu.SV.EnableTls { 350 err := initTlsConfig(rm, pu.SV) 351 if err != nil { 352 return nil, err 353 } 354 } 355 return rm, nil 356 } 357 358 func initTlsConfig(rm *RoutineManager, SV *config.FrontendParameters) error { 359 if len(SV.TlsCertFile) == 0 || len(SV.TlsKeyFile) == 0 { 360 return moerr.NewInternalError(rm.ctx, "init TLS config error : cert file or key file is empty") 361 } 362 363 var tlsCert tls.Certificate 364 var err error 365 tlsCert, err = tls.LoadX509KeyPair(SV.TlsCertFile, SV.TlsKeyFile) 366 if err != nil { 367 return moerr.NewInternalError(rm.ctx, "init TLS config error :load x509 failed") 368 } 369 370 clientAuthPolicy := tls.NoClientCert 371 var certPool *x509.CertPool 372 if len(SV.TlsCaFile) > 0 { 373 var caCert []byte 374 caCert, err = os.ReadFile(SV.TlsCaFile) 375 if err != nil { 376 return moerr.NewInternalError(rm.ctx, "init TLS config error :read TlsCaFile failed") 377 } 378 certPool = x509.NewCertPool() 379 if certPool.AppendCertsFromPEM(caCert) { 380 clientAuthPolicy = tls.VerifyClientCertIfGiven 381 } 382 } 383 384 // This excludes ciphers listed in tls.InsecureCipherSuites() and can be used to filter out more 385 // var cipherSuites []uint16 386 // var cipherNames []string 387 // for _, sc := range tls.CipherSuites() { 388 // cipherSuites = append(cipherSuites, sc.ID) 389 // switch sc.ID { 390 // case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, 391 // tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: 392 // logutil.Info("Disabling weak cipherSuite", zap.String("cipherSuite", sc.Name)) 393 // default: 394 // cipherNames = append(cipherNames, sc.Name) 395 // cipherSuites = append(cipherSuites, sc.ID) 396 // } 397 // } 398 // logutil.Info("Enabled ciphersuites", zap.Strings("cipherNames", cipherNames)) 399 400 rm.tlsConfig = &tls.Config{ 401 Certificates: []tls.Certificate{tlsCert}, 402 ClientCAs: certPool, 403 ClientAuth: clientAuthPolicy, 404 // MinVersion: tls.VersionTLS13, 405 // CipherSuites: cipherSuites, 406 } 407 logutil.Info("init TLS config finished") 408 return nil 409 }