github.com/matrixorigin/matrixone@v1.2.0/pkg/frontend/routine_manager.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frontend 16 17 import ( 18 "context" 19 "crypto/tls" 20 "crypto/x509" 21 "fmt" 22 "math" 23 "os" 24 "sync" 25 "sync/atomic" 26 "time" 27 28 "github.com/fagongzi/goetty/v2" 29 "go.uber.org/zap" 30 31 "github.com/matrixorigin/matrixone/pkg/common/moerr" 32 "github.com/matrixorigin/matrixone/pkg/config" 33 "github.com/matrixorigin/matrixone/pkg/defines" 34 "github.com/matrixorigin/matrixone/pkg/logutil" 35 "github.com/matrixorigin/matrixone/pkg/pb/query" 36 "github.com/matrixorigin/matrixone/pkg/queryservice" 37 "github.com/matrixorigin/matrixone/pkg/util/metric" 38 v2 "github.com/matrixorigin/matrixone/pkg/util/metric/v2" 39 "github.com/matrixorigin/matrixone/pkg/util/trace" 40 ) 41 42 type RoutineManager struct { 43 mu sync.RWMutex 44 ctx context.Context 45 clients map[goetty.IOSession]*Routine 46 // routinesByID keeps the routines by connection ID. 47 routinesByConnID map[uint32]*Routine 48 tlsConfig *tls.Config 49 accountRoutine *AccountRoutineManager 50 baseService BaseService 51 sessionManager *queryservice.SessionManager 52 // reportSystemStatusTime is the time when report system status last time. 53 reportSystemStatusTime atomic.Pointer[time.Time] 54 } 55 56 type AccountRoutineManager struct { 57 ctx context.Context 58 killQueueMu sync.RWMutex 59 killIdQueue map[int64]KillRecord 60 accountRoutineMu sync.RWMutex 61 accountId2Routine map[int64]map[*Routine]uint64 62 } 63 64 type KillRecord struct { 65 killTime time.Time 66 version uint64 67 } 68 69 func NewKillRecord(killtime time.Time, version uint64) KillRecord { 70 return KillRecord{ 71 killTime: killtime, 72 version: version, 73 } 74 } 75 76 func (ar *AccountRoutineManager) recordRountine(tenantID int64, rt *Routine, version uint64) { 77 if tenantID == sysAccountID || rt == nil { 78 return 79 } 80 81 ar.accountRoutineMu.Lock() 82 defer ar.accountRoutineMu.Unlock() 83 if _, ok := ar.accountId2Routine[tenantID]; !ok { 84 ar.accountId2Routine[tenantID] = make(map[*Routine]uint64) 85 } 86 ar.accountId2Routine[tenantID][rt] = version 87 } 88 89 func (ar *AccountRoutineManager) deleteRoutine(tenantID int64, rt *Routine) { 90 if tenantID == sysAccountID || rt == nil { 91 return 92 } 93 94 ar.accountRoutineMu.Lock() 95 defer ar.accountRoutineMu.Unlock() 96 _, ok := ar.accountId2Routine[tenantID] 97 if ok { 98 delete(ar.accountId2Routine[tenantID], rt) 99 } 100 if len(ar.accountId2Routine[tenantID]) == 0 { 101 delete(ar.accountId2Routine, tenantID) 102 } 103 } 104 105 func (ar *AccountRoutineManager) EnKillQueue(tenantID int64, version uint64) { 106 if tenantID == sysAccountID { 107 return 108 } 109 110 KillRecord := NewKillRecord(time.Now(), version) 111 ar.killQueueMu.Lock() 112 defer ar.killQueueMu.Unlock() 113 ar.killIdQueue[tenantID] = KillRecord 114 115 } 116 117 func (ar *AccountRoutineManager) AlterRoutineStatue(tenantID int64, status string) { 118 if tenantID == sysAccountID { 119 return 120 } 121 122 ar.accountRoutineMu.Lock() 123 defer ar.accountRoutineMu.Unlock() 124 if rts, ok := ar.accountId2Routine[tenantID]; ok { 125 for rt := range rts { 126 if status == "restricted" { 127 rt.setResricted(true) 128 } else { 129 rt.setResricted(false) 130 } 131 } 132 } 133 } 134 135 func (ar *AccountRoutineManager) deepCopyKillQueue() map[int64]KillRecord { 136 ar.killQueueMu.RLock() 137 defer ar.killQueueMu.RUnlock() 138 139 tempKillQueue := make(map[int64]KillRecord, len(ar.killIdQueue)) 140 for account, record := range ar.killIdQueue { 141 tempKillQueue[account] = record 142 } 143 return tempKillQueue 144 } 145 146 func (ar *AccountRoutineManager) deepCopyRoutineMap() map[int64]map[*Routine]uint64 { 147 ar.accountRoutineMu.RLock() 148 defer ar.accountRoutineMu.RUnlock() 149 150 tempRoutineMap := make(map[int64]map[*Routine]uint64, len(ar.accountId2Routine)) 151 for account, rountine := range ar.accountId2Routine { 152 tempRountines := make(map[*Routine]uint64, len(rountine)) 153 for rt, version := range rountine { 154 tempRountines[rt] = version 155 } 156 tempRoutineMap[account] = tempRountines 157 } 158 return tempRoutineMap 159 } 160 161 func (rm *RoutineManager) getCtx() context.Context { 162 return rm.ctx 163 } 164 165 func (rm *RoutineManager) setRoutine(rs goetty.IOSession, id uint32, r *Routine) { 166 rm.mu.Lock() 167 defer rm.mu.Unlock() 168 rm.clients[rs] = r 169 rm.routinesByConnID[id] = r 170 } 171 172 func (rm *RoutineManager) getRoutine(rs goetty.IOSession) *Routine { 173 rm.mu.RLock() 174 defer rm.mu.RUnlock() 175 return rm.clients[rs] 176 } 177 178 func (rm *RoutineManager) getRoutineByConnID(id uint32) *Routine { 179 rm.mu.RLock() 180 defer rm.mu.RUnlock() 181 r, ok := rm.routinesByConnID[id] 182 if ok { 183 return r 184 } 185 return nil 186 } 187 188 func (rm *RoutineManager) deleteRoutine(rs goetty.IOSession) *Routine { 189 var rt *Routine 190 var ok bool 191 rm.mu.Lock() 192 defer rm.mu.Unlock() 193 if rt, ok = rm.clients[rs]; ok { 194 delete(rm.clients, rs) 195 } 196 if rt != nil { 197 connID := rt.getConnectionID() 198 if _, ok = rm.routinesByConnID[connID]; ok { 199 delete(rm.routinesByConnID, connID) 200 } 201 } 202 return rt 203 } 204 205 func (rm *RoutineManager) getTlsConfig() *tls.Config { 206 return rm.tlsConfig 207 } 208 209 func (rm *RoutineManager) getConnID() (uint32, error) { 210 // Only works in unit test. 211 if getGlobalPu().HAKeeperClient == nil { 212 return nextConnectionID(), nil 213 } 214 ctx, cancel := context.WithTimeout(rm.ctx, time.Second*2) 215 defer cancel() 216 connID, err := getGlobalPu().HAKeeperClient.AllocateIDByKey(ctx, ConnIDAllocKey) 217 if err != nil { 218 return 0, err 219 } 220 // Convert uint64 to uint32 to adapt MySQL protocol. 221 return uint32(connID), nil 222 } 223 224 func (rm *RoutineManager) setBaseService(baseService BaseService) { 225 rm.mu.Lock() 226 defer rm.mu.Unlock() 227 rm.baseService = baseService 228 } 229 230 func (rm *RoutineManager) setSessionMgr(sessionMgr *queryservice.SessionManager) { 231 rm.mu.Lock() 232 defer rm.mu.Unlock() 233 rm.sessionManager = sessionMgr 234 } 235 236 func (rm *RoutineManager) GetAccountRoutineManager() *AccountRoutineManager { 237 return rm.accountRoutine 238 } 239 240 func (rm *RoutineManager) Created(rs goetty.IOSession) { 241 logutil.Debugf("get the connection from %s", rs.RemoteAddress()) 242 createdStart := time.Now() 243 connID, err := rm.getConnID() 244 if err != nil { 245 logutil.Errorf("failed to get connection ID from HAKeeper: %v", err) 246 return 247 } 248 pro := NewMysqlClientProtocol(connID, rs, int(getGlobalPu().SV.MaxBytesInOutbufToFlush), getGlobalPu().SV) 249 routine := NewRoutine(rm.getCtx(), pro, getGlobalPu().SV, rs) 250 v2.CreatedRoutineCounter.Inc() 251 252 cancelCtx := routine.getCancelRoutineCtx() 253 if rm.baseService != nil { 254 cancelCtx = context.WithValue(cancelCtx, defines.NodeIDKey{}, rm.baseService.ID()) 255 } 256 257 // XXX MPOOL pass in a nil mpool. 258 // XXX MPOOL can choose to use a Mid sized mpool, if, we know 259 // this mpool will be deleted. Maybe in the following Closed method. 260 ses := NewSession(cancelCtx, routine.getProtocol(), nil, GSysVariables, true, nil) 261 ses.SetFromRealUser(true) 262 ses.setRoutineManager(rm) 263 ses.setRoutine(routine) 264 ses.clientAddr = pro.Peer() 265 266 ses.timestampMap[TSCreatedStart] = createdStart 267 defer func() { 268 ses.timestampMap[TSCreatedEnd] = time.Now() 269 v2.CreatedDurationHistogram.Observe(ses.timestampMap[TSCreatedEnd].Sub(ses.timestampMap[TSCreatedStart]).Seconds()) 270 }() 271 272 routine.setSession(ses) 273 pro.SetSession(ses) 274 275 logDebugf(pro.GetDebugString(), "have done some preparation for the connection %s", rs.RemoteAddress()) 276 277 // With proxy module enabled, we try to update salt value and label info from proxy. 278 if getGlobalPu().SV.ProxyEnabled { 279 pro.receiveExtraInfo(rs) 280 } 281 282 hsV10pkt := pro.makeHandshakeV10Payload() 283 err = pro.writePackets(hsV10pkt, true) 284 if err != nil { 285 logError(pro.ses, pro.GetDebugString(), 286 "Failed to handshake with server, quitting routine...", 287 zap.Error(err)) 288 routine.killConnection(true) 289 return 290 } 291 292 logDebugf(pro.GetDebugString(), "have sent handshake packet to connection %s", rs.RemoteAddress()) 293 rm.setRoutine(rs, pro.connectionID, routine) 294 } 295 296 /* 297 When the io is closed, the Closed will be called. 298 */ 299 func (rm *RoutineManager) Closed(rs goetty.IOSession) { 300 logutil.Debugf("clean resource of the connection %d:%s", rs.ID(), rs.RemoteAddress()) 301 defer func() { 302 v2.CloseRoutineCounter.Inc() 303 logutil.Debugf("resource of the connection %d:%s has been cleaned", rs.ID(), rs.RemoteAddress()) 304 }() 305 rt := rm.deleteRoutine(rs) 306 307 if rt != nil { 308 ses := rt.getSession() 309 if ses != nil { 310 rt.decreaseCount(func() { 311 account := ses.GetTenantInfo() 312 accountName := sysAccountName 313 if account != nil { 314 accountName = account.GetTenant() 315 } 316 metric.ConnectionCounter(accountName).Dec() 317 rm.accountRoutine.deleteRoutine(int64(account.GetTenantID()), rt) 318 }) 319 rm.sessionManager.RemoveSession(ses) 320 logDebugf(ses.GetDebugString(), "the io session was closed.") 321 } 322 rt.cleanup() 323 } 324 } 325 326 /* 327 kill a connection or query. 328 if killConnection is true, the query will be canceled first, then the network will be closed. 329 if killConnection is false, only the query will be canceled. the connection keeps intact. 330 */ 331 func (rm *RoutineManager) kill(ctx context.Context, killConnection bool, idThatKill, id uint64, statementId string) error { 332 rt := rm.getRoutineByConnID(uint32(id)) 333 334 killMyself := idThatKill == id 335 if rt != nil { 336 if killConnection { 337 logutil.Infof("kill connection %d", id) 338 rt.killConnection(killMyself) 339 rm.accountRoutine.deleteRoutine(int64(rt.ses.GetTenantInfo().GetTenantID()), rt) 340 } else { 341 logutil.Infof("kill query %s on the connection %d", statementId, id) 342 rt.killQuery(killMyself, statementId) 343 } 344 } else { 345 return moerr.NewInternalError(ctx, "Unknown connection id %d", id) 346 } 347 return nil 348 } 349 350 func getConnectionInfo(rs goetty.IOSession) string { 351 conn := rs.RawConn() 352 if conn != nil { 353 return fmt.Sprintf("connection from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) 354 } 355 return fmt.Sprintf("connection from %s", rs.RemoteAddress()) 356 } 357 358 func (rm *RoutineManager) Handler(rs goetty.IOSession, msg interface{}, received uint64) error { 359 logutil.Debugf("get request from %d:%s", rs.ID(), rs.RemoteAddress()) 360 defer func() { 361 logutil.Debugf("request from %d:%s has been processed", rs.ID(), rs.RemoteAddress()) 362 }() 363 var err error 364 var isTlsHeader bool 365 ctx, span := trace.Start(rm.getCtx(), "RoutineManager.Handler", 366 trace.WithKind(trace.SpanKindStatement)) 367 defer span.End() 368 connectionInfo := getConnectionInfo(rs) 369 routine := rm.getRoutine(rs) 370 if routine == nil { 371 err = moerr.NewInternalError(ctx, "routine does not exist") 372 logutil.Errorf("%s error:%v", connectionInfo, err) 373 return err 374 } 375 routine.updateGoroutineId() 376 routine.setInProcessRequest(true) 377 defer routine.setInProcessRequest(false) 378 protocol := routine.getProtocol() 379 protoInfo := protocol.GetDebugString() 380 packet, ok := msg.(*Packet) 381 382 protocol.SetSequenceID(uint8(packet.SequenceID + 1)) 383 var seq = protocol.GetSequenceId() 384 if !ok { 385 err = moerr.NewInternalError(ctx, "message is not Packet") 386 logError(routine.ses, routine.ses.GetDebugString(), 387 "Error occurred", 388 zap.Error(err)) 389 return err 390 } 391 392 ses := routine.getSession() 393 ts := ses.timestampMap 394 395 length := packet.Length 396 payload := packet.Payload 397 for uint32(length) == MaxPayloadSize { 398 msg, err = protocol.GetTcpConnection().Read(goetty.ReadOptions{}) 399 if err != nil { 400 logError(routine.ses, routine.ses.GetDebugString(), 401 "Failed to read message", 402 zap.Error(err)) 403 return err 404 } 405 406 packet, ok = msg.(*Packet) 407 if !ok { 408 err = moerr.NewInternalError(ctx, "message is not Packet") 409 logError(routine.ses, routine.ses.GetDebugString(), 410 "An error occurred", 411 zap.Error(err)) 412 return err 413 } 414 415 protocol.SetSequenceID(uint8(packet.SequenceID + 1)) 416 seq = protocol.GetSequenceId() 417 payload = append(payload, packet.Payload...) 418 length = packet.Length 419 } 420 421 // finish handshake process 422 if !protocol.IsEstablished() { 423 tempCtx, tempCancel := context.WithTimeout(ctx, getGlobalPu().SV.SessionTimeout.Duration) 424 defer tempCancel() 425 ts[TSEstablishStart] = time.Now() 426 logDebugf(protoInfo, "HANDLE HANDSHAKE") 427 428 /* 429 di := MakeDebugInfo(payload,80,8) 430 logutil.Infof("RP[%v] Payload80[%v]",rs.RemoteAddr(),di) 431 */ 432 if protocol.GetCapability()&CLIENT_SSL != 0 && !protocol.IsTlsEstablished() { 433 logDebugf(protoInfo, "setup ssl") 434 isTlsHeader, err = protocol.HandleHandshake(tempCtx, payload) 435 if err != nil { 436 logError(routine.ses, routine.ses.GetDebugString(), 437 "An error occurred", 438 zap.Error(err)) 439 return err 440 } 441 if isTlsHeader { 442 ts[TSUpgradeTLSStart] = time.Now() 443 logDebugf(protoInfo, "upgrade to TLS") 444 // do upgradeTls 445 tlsConn := tls.Server(rs.RawConn(), rm.getTlsConfig()) 446 logDebugf(protoInfo, "get TLS conn ok") 447 tlsCtx, cancelFun := context.WithTimeout(tempCtx, 20*time.Second) 448 if err = tlsConn.HandshakeContext(tlsCtx); err != nil { 449 logError(routine.ses, routine.ses.GetDebugString(), 450 "Error occurred before cancel()", 451 zap.Error(err)) 452 cancelFun() 453 logError(routine.ses, routine.ses.GetDebugString(), 454 "Error occurred after cancel()", 455 zap.Error(err)) 456 return err 457 } 458 cancelFun() 459 logDebug(routine.ses, protoInfo, "TLS handshake ok") 460 rs.UseConn(tlsConn) 461 logDebug(routine.ses, protoInfo, "TLS handshake finished") 462 463 // tls upgradeOk 464 protocol.SetTlsEstablished() 465 ts[TSUpgradeTLSEnd] = time.Now() 466 v2.UpgradeTLSDurationHistogram.Observe(ts[TSUpgradeTLSEnd].Sub(ts[TSUpgradeTLSStart]).Seconds()) 467 } else { 468 // client don't ask server to upgrade TLS 469 if err := protocol.Authenticate(tempCtx); err != nil { 470 return err 471 } 472 protocol.SetTlsEstablished() 473 protocol.SetEstablished() 474 } 475 } else { 476 logDebugf(protoInfo, "handleHandshake") 477 _, err = protocol.HandleHandshake(tempCtx, payload) 478 if err != nil { 479 logError(routine.ses, routine.ses.GetDebugString(), 480 "Error occurred", 481 zap.Error(err)) 482 return err 483 } 484 if err = protocol.Authenticate(tempCtx); err != nil { 485 return err 486 } 487 protocol.SetEstablished() 488 } 489 ts[TSEstablishEnd] = time.Now() 490 v2.EstablishDurationHistogram.Observe(ts[TSEstablishEnd].Sub(ts[TSEstablishStart]).Seconds()) 491 logInfof(ses.GetDebugString(), fmt.Sprintf("mo accept connection, time cost of Created: %s, Establish: %s, UpgradeTLS: %s, Authenticate: %s, SendErrPacket: %s, SendOKPacket: %s, CheckTenant: %s, CheckUser: %s, CheckRole: %s, CheckDbName: %s, InitGlobalSysVar: %s", 492 ts[TSCreatedEnd].Sub(ts[TSCreatedStart]).String(), 493 ts[TSEstablishEnd].Sub(ts[TSEstablishStart]).String(), 494 ts[TSUpgradeTLSEnd].Sub(ts[TSUpgradeTLSStart]).String(), 495 ts[TSAuthenticateEnd].Sub(ts[TSAuthenticateStart]).String(), 496 ts[TSSendErrPacketEnd].Sub(ts[TSSendErrPacketStart]).String(), 497 ts[TSSendOKPacketEnd].Sub(ts[TSSendOKPacketStart]).String(), 498 ts[TSCheckTenantEnd].Sub(ts[TSCheckTenantStart]).String(), 499 ts[TSCheckUserEnd].Sub(ts[TSCheckUserStart]).String(), 500 ts[TSCheckRoleEnd].Sub(ts[TSCheckRoleStart]).String(), 501 ts[TSCheckDbNameEnd].Sub(ts[TSCheckDbNameStart]).String(), 502 ts[TSInitGlobalSysVarEnd].Sub(ts[TSInitGlobalSysVarStart]).String())) 503 504 dbName := protocol.GetDatabaseName() 505 if dbName != "" { 506 ses.SetDatabaseName(dbName) 507 } 508 rm.sessionManager.AddSession(ses) 509 return nil 510 } 511 512 req := protocol.GetRequest(payload) 513 req.seq = seq 514 515 //handle request 516 err = routine.handleRequest(req) 517 if err != nil { 518 if !skipClientQuit(err.Error()) { 519 logError(routine.ses, routine.ses.GetDebugString(), 520 "Error occurred", 521 zap.Error(err)) 522 } 523 return err 524 } 525 526 return nil 527 } 528 529 // clientCount returns the count of the clients 530 func (rm *RoutineManager) clientCount() int { 531 var count int 532 rm.mu.RLock() 533 defer rm.mu.RUnlock() 534 count = len(rm.clients) 535 return count 536 } 537 538 func (rm *RoutineManager) cleanKillQueue() { 539 ar := rm.accountRoutine 540 ar.killQueueMu.Lock() 541 defer ar.killQueueMu.Unlock() 542 for toKillAccount, killRecord := range ar.killIdQueue { 543 if time.Since(killRecord.killTime) > time.Duration(getGlobalPu().SV.CleanKillQueueInterval)*time.Minute { 544 delete(ar.killIdQueue, toKillAccount) 545 } 546 } 547 } 548 549 func (rm *RoutineManager) KillRoutineConnections() { 550 ar := rm.accountRoutine 551 tempKillQueue := ar.deepCopyKillQueue() 552 accountId2RoutineMap := ar.deepCopyRoutineMap() 553 554 for account, killRecord := range tempKillQueue { 555 if rtMap, ok := accountId2RoutineMap[account]; ok { 556 for rt, version := range rtMap { 557 if rt != nil && ((version+1)%math.MaxUint64)-1 <= killRecord.version { 558 //kill connect of this routine 559 rt.killConnection(false) 560 ar.deleteRoutine(account, rt) 561 } 562 } 563 } 564 } 565 566 rm.cleanKillQueue() 567 } 568 569 func (rm *RoutineManager) MigrateConnectionTo(ctx context.Context, req *query.MigrateConnToRequest) error { 570 routine := rm.getRoutineByConnID(req.ConnID) 571 if routine == nil { 572 return moerr.NewInternalError(ctx, "cannot get routine to migrate connection %d", req.ConnID) 573 } 574 return routine.migrateConnectionTo(ctx, req) 575 } 576 577 func (rm *RoutineManager) MigrateConnectionFrom(req *query.MigrateConnFromRequest, resp *query.MigrateConnFromResponse) error { 578 routine := rm.getRoutineByConnID(req.ConnID) 579 if routine == nil { 580 return moerr.NewInternalError(rm.ctx, "cannot get routine to migrate connection %d", req.ConnID) 581 } 582 return routine.migrateConnectionFrom(resp) 583 } 584 585 func NewRoutineManager(ctx context.Context) (*RoutineManager, error) { 586 accountRoutine := &AccountRoutineManager{ 587 killQueueMu: sync.RWMutex{}, 588 accountId2Routine: make(map[int64]map[*Routine]uint64), 589 accountRoutineMu: sync.RWMutex{}, 590 killIdQueue: make(map[int64]KillRecord), 591 ctx: ctx, 592 } 593 rm := &RoutineManager{ 594 ctx: ctx, 595 clients: make(map[goetty.IOSession]*Routine), 596 routinesByConnID: make(map[uint32]*Routine), 597 accountRoutine: accountRoutine, 598 } 599 if getGlobalPu().SV.EnableTls { 600 err := initTlsConfig(rm, getGlobalPu().SV) 601 if err != nil { 602 return nil, err 603 } 604 } 605 606 // add kill connect routine 607 go func() { 608 for { 609 select { 610 case <-rm.ctx.Done(): 611 return 612 default: 613 } 614 rm.KillRoutineConnections() 615 time.Sleep(time.Duration(time.Duration(getGlobalPu().SV.KillRountinesInterval) * time.Second)) 616 } 617 }() 618 619 return rm, nil 620 } 621 622 func initTlsConfig(rm *RoutineManager, SV *config.FrontendParameters) error { 623 if len(SV.TlsCertFile) == 0 || len(SV.TlsKeyFile) == 0 { 624 return moerr.NewInternalError(rm.ctx, "init TLS config error : cert file or key file is empty") 625 } 626 627 cfg, err := ConstructTLSConfig(rm.ctx, SV.TlsCaFile, SV.TlsCertFile, SV.TlsKeyFile) 628 if err != nil { 629 return moerr.NewInternalError(rm.ctx, "init TLS config error: %v", err) 630 } 631 632 rm.tlsConfig = cfg 633 logutil.Info("init TLS config finished") 634 return nil 635 } 636 637 // ConstructTLSConfig creates the TLS config. 638 func ConstructTLSConfig(ctx context.Context, caFile, certFile, keyFile string) (*tls.Config, error) { 639 var err error 640 var tlsCert tls.Certificate 641 642 tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile) 643 if err != nil { 644 return nil, moerr.NewInternalError(ctx, "construct TLS config error: load x509 failed") 645 } 646 647 clientAuthPolicy := tls.NoClientCert 648 var certPool *x509.CertPool 649 if len(caFile) > 0 { 650 var caCert []byte 651 caCert, err = os.ReadFile(caFile) 652 if err != nil { 653 return nil, moerr.NewInternalError(ctx, "construct TLS config error: read TLS ca failed") 654 } 655 certPool = x509.NewCertPool() 656 if certPool.AppendCertsFromPEM(caCert) { 657 clientAuthPolicy = tls.VerifyClientCertIfGiven 658 } 659 } 660 661 return &tls.Config{ 662 Certificates: []tls.Certificate{tlsCert}, 663 ClientCAs: certPool, 664 ClientAuth: clientAuthPolicy, 665 }, nil 666 }