github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/cluster/controller.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package cluster 16 17 import ( 18 "context" 19 "crypto/ed25519" 20 "crypto/rand" 21 "crypto/tls" 22 "crypto/x509" 23 "errors" 24 "fmt" 25 "net/http" 26 "net/url" 27 "os" 28 "strconv" 29 "strings" 30 "sync" 31 "time" 32 33 "github.com/cenkalti/backoff/v4" 34 "github.com/dolthub/go-mysql-server/sql" 35 "github.com/dolthub/go-mysql-server/sql/mysql_db" 36 gmstypes "github.com/dolthub/go-mysql-server/sql/types" 37 "github.com/sirupsen/logrus" 38 "google.golang.org/grpc" 39 "google.golang.org/grpc/codes" 40 "google.golang.org/grpc/credentials" 41 "google.golang.org/grpc/status" 42 43 replicationapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/replicationapi/v1alpha1" 44 "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" 45 "github.com/dolthub/dolt/go/libraries/doltcore/creds" 46 "github.com/dolthub/dolt/go/libraries/doltcore/dbfactory" 47 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 48 "github.com/dolthub/dolt/go/libraries/doltcore/env" 49 "github.com/dolthub/dolt/go/libraries/doltcore/remotesrv" 50 "github.com/dolthub/dolt/go/libraries/doltcore/servercfg" 51 "github.com/dolthub/dolt/go/libraries/doltcore/sqle" 52 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/clusterdb" 53 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 54 "github.com/dolthub/dolt/go/libraries/utils/config" 55 "github.com/dolthub/dolt/go/libraries/utils/filesys" 56 "github.com/dolthub/dolt/go/libraries/utils/jwtauth" 57 "github.com/dolthub/dolt/go/store/types" 58 ) 59 60 type Role string 61 62 const RolePrimary Role = "primary" 63 const RoleStandby Role = "standby" 64 const RoleDetectedBrokenConfig Role = "detected_broken_config" 65 66 const PersistentConfigPrefix = "sqlserver.cluster" 67 68 // State for any ongoing DROP DATABASE replication attempts we have 69 // outstanding. When we create a database, we cancel all on going DROP DATABASE 70 // replication attempts. 71 type databaseDropReplication struct { 72 ctx context.Context 73 cancel func() 74 wg *sync.WaitGroup 75 } 76 77 type Controller struct { 78 cfg servercfg.ClusterConfig 79 persistentCfg config.ReadWriteConfig 80 role Role 81 epoch int 82 systemVars sqlvars 83 mu sync.Mutex 84 commithooks []*commithook 85 sinterceptor serverinterceptor 86 cinterceptor clientinterceptor 87 lgr *logrus.Logger 88 89 standbyCallback IsStandbyCallback 90 iterSessions IterSessions 91 killQuery func(uint32) 92 killConnection func(uint32) error 93 94 jwks *jwtauth.MultiJWKS 95 tlsCfg *tls.Config 96 grpcCreds credentials.PerRPCCredentials 97 pub ed25519.PublicKey 98 priv ed25519.PrivateKey 99 100 replicationClients []*replicationServiceClient 101 102 mysqlDb *mysql_db.MySQLDb 103 mysqlDbPersister *replicatingMySQLDbPersister 104 mysqlDbReplicas []*mysqlDbReplica 105 106 branchControlController *branch_control.Controller 107 branchControlFilesys filesys.Filesys 108 bcReplication *branchControlReplication 109 110 dropDatabase func(*sql.Context, string) error 111 outstandingDropDatabases map[string]*databaseDropReplication 112 remoteSrvDBCache remotesrv.DBCache 113 } 114 115 type sqlvars interface { 116 AddSystemVariables(sysVars []sql.SystemVariable) 117 GetGlobal(name string) (sql.SystemVariable, interface{}, bool) 118 } 119 120 // Our IsStandbyCallback gets called with |true| or |false| when the server 121 // becomes a standby or a primary respectively. Standby replicas should be read 122 // only. 123 type IsStandbyCallback func(bool) 124 125 type procedurestore interface { 126 Register(sql.ExternalStoredProcedureDetails) 127 } 128 129 const ( 130 // Since we fetch the keys from the other replicas we’re going to use a fixed string here. 131 DoltClusterRemoteApiAudience = "dolt-cluster-remote-api.dolthub.com" 132 ) 133 134 func NewController(lgr *logrus.Logger, cfg servercfg.ClusterConfig, pCfg config.ReadWriteConfig) (*Controller, error) { 135 if cfg == nil { 136 return nil, nil 137 } 138 pCfg = config.NewPrefixConfig(pCfg, PersistentConfigPrefix) 139 role, epoch, err := applyBootstrapClusterConfig(lgr, cfg, pCfg) 140 if err != nil { 141 return nil, err 142 } 143 ret := &Controller{ 144 cfg: cfg, 145 persistentCfg: pCfg, 146 role: role, 147 epoch: epoch, 148 commithooks: make([]*commithook, 0), 149 lgr: lgr, 150 } 151 roleSetter := func(role string, epoch int) { 152 ret.setRoleAndEpoch(role, epoch, roleTransitionOptions{ 153 graceful: false, 154 }) 155 } 156 ret.sinterceptor.lgr = lgr.WithFields(logrus.Fields{}) 157 ret.sinterceptor.setRole(role, epoch) 158 ret.sinterceptor.roleSetter = roleSetter 159 ret.cinterceptor.lgr = lgr.WithFields(logrus.Fields{}) 160 ret.cinterceptor.setRole(role, epoch) 161 ret.cinterceptor.roleSetter = roleSetter 162 163 ret.tlsCfg, err = ret.outboundTlsConfig() 164 if err != nil { 165 return nil, err 166 } 167 168 ret.pub, ret.priv, err = ed25519.GenerateKey(rand.Reader) 169 if err != nil { 170 return nil, err 171 } 172 173 keyID := creds.PubKeyToKID(ret.pub) 174 keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID) 175 ret.grpcCreds = &creds.RPCCreds{ 176 PrivKey: ret.priv, 177 Audience: DoltClusterRemoteApiAudience, 178 Issuer: creds.ClientIssuer, 179 KeyID: keyIDStr, 180 RequireTLS: false, 181 } 182 183 ret.jwks = ret.standbyRemotesJWKS() 184 ret.sinterceptor.keyProvider = ret.jwks 185 ret.sinterceptor.jwtExpected = JWTExpectations() 186 187 ret.replicationClients, err = ret.replicationServiceClients(context.Background()) 188 if err != nil { 189 return nil, err 190 } 191 ret.mysqlDbReplicas = make([]*mysqlDbReplica, len(ret.replicationClients)) 192 for i := range ret.mysqlDbReplicas { 193 bo := backoff.NewExponentialBackOff() 194 bo.InitialInterval = time.Second 195 bo.MaxInterval = time.Minute 196 bo.MaxElapsedTime = 0 197 ret.mysqlDbReplicas[i] = &mysqlDbReplica{ 198 lgr: lgr.WithFields(logrus.Fields{}), 199 client: ret.replicationClients[i], 200 backoff: bo, 201 } 202 ret.mysqlDbReplicas[i].cond = sync.NewCond(&ret.mysqlDbReplicas[i].mu) 203 } 204 205 ret.outstandingDropDatabases = make(map[string]*databaseDropReplication) 206 207 return ret, nil 208 } 209 210 func (c *Controller) Run() { 211 var wg sync.WaitGroup 212 wg.Add(1) 213 go func() { 214 defer wg.Done() 215 c.jwks.Run() 216 }() 217 wg.Add(1) 218 go func() { 219 defer wg.Done() 220 c.mysqlDbPersister.Run() 221 }() 222 wg.Add(1) 223 go func() { 224 defer wg.Done() 225 c.bcReplication.Run() 226 }() 227 wg.Wait() 228 for _, client := range c.replicationClients { 229 client.closer() 230 } 231 } 232 233 func (c *Controller) GracefulStop() error { 234 c.jwks.GracefulStop() 235 c.mysqlDbPersister.GracefulStop() 236 c.bcReplication.GracefulStop() 237 return nil 238 } 239 240 func (c *Controller) ManageSystemVariables(variables sqlvars) { 241 if c == nil { 242 return 243 } 244 c.mu.Lock() 245 defer c.mu.Unlock() 246 c.systemVars = variables 247 c.refreshSystemVars() 248 } 249 250 func (c *Controller) ApplyStandbyReplicationConfig(ctx context.Context, bt *sql.BackgroundThreads, mrEnv *env.MultiRepoEnv, dbs ...dsess.SqlDatabase) error { 251 if c == nil { 252 return nil 253 } 254 c.mu.Lock() 255 defer c.mu.Unlock() 256 for _, db := range dbs { 257 denv := mrEnv.GetEnv(db.Name()) 258 if denv == nil { 259 continue 260 } 261 c.lgr.Tracef("cluster/controller: applying commit hooks for %s with role %s", db.Name(), string(c.role)) 262 hooks, err := c.applyCommitHooks(ctx, db.Name(), bt, denv) 263 if err != nil { 264 return err 265 } 266 c.commithooks = append(c.commithooks, hooks...) 267 } 268 return nil 269 } 270 271 type IterSessions func(func(sql.Session) (bool, error)) error 272 273 func (c *Controller) SetIsStandbyCallback(callback IsStandbyCallback) { 274 if c == nil { 275 return 276 } 277 c.mu.Lock() 278 defer c.mu.Unlock() 279 c.standbyCallback = callback 280 c.setProviderIsStandby(c.role != RolePrimary) 281 } 282 283 func (c *Controller) ManageQueryConnections(iterSessions IterSessions, killQuery func(uint32), killConnection func(uint32) error) { 284 if c == nil { 285 return 286 } 287 c.mu.Lock() 288 defer c.mu.Unlock() 289 c.iterSessions = iterSessions 290 c.killQuery = killQuery 291 c.killConnection = killConnection 292 } 293 294 func (c *Controller) applyCommitHooks(ctx context.Context, name string, bt *sql.BackgroundThreads, denv *env.DoltEnv) ([]*commithook, error) { 295 ttfdir, err := denv.TempTableFilesDir() 296 if err != nil { 297 return nil, err 298 } 299 remotes, err := denv.GetRemotes() 300 if err != nil { 301 return nil, err 302 } 303 dialprovider := c.gRPCDialProvider(denv) 304 var hooks []*commithook 305 for _, r := range c.cfg.StandbyRemotes() { 306 remoteUrl := strings.Replace(r.RemoteURLTemplate(), dsess.URLTemplateDatabasePlaceholder, name, -1) 307 remote, ok := remotes.Get(r.Name()) 308 if !ok { 309 remote = env.NewRemote(r.Name(), remoteUrl, nil) 310 err := denv.AddRemote(remote) 311 if err != nil { 312 return nil, fmt.Errorf("sqle: cluster: standby replication: could not create remote %s for database %s: %w", r.Name(), name, err) 313 } 314 } 315 commitHook := newCommitHook(c.lgr, r.Name(), remote.Url, name, c.role, func(ctx context.Context) (*doltdb.DoltDB, error) { 316 return remote.GetRemoteDB(ctx, types.Format_Default, dialprovider) 317 }, denv.DoltDB, ttfdir) 318 denv.DoltDB.PrependCommitHook(ctx, commitHook) 319 if err := commitHook.Run(bt); err != nil { 320 return nil, err 321 } 322 hooks = append(hooks, commitHook) 323 } 324 return hooks, nil 325 } 326 327 func (c *Controller) gRPCDialProvider(denv *env.DoltEnv) dbfactory.GRPCDialProvider { 328 return grpcDialProvider{env.NewGRPCDialProviderFromDoltEnv(denv), &c.cinterceptor, c.tlsCfg, c.grpcCreds} 329 } 330 331 func (c *Controller) RegisterStoredProcedures(store procedurestore) { 332 if c == nil { 333 return 334 } 335 store.Register(newAssumeRoleProcedure(c)) 336 store.Register(newTransitionToStandbyProcedure(c)) 337 } 338 339 // Incoming drop database replication requests need a way to drop a database in 340 // the sqle.DatabaseProvider. This is our callback for that functionality. 341 func (c *Controller) SetDropDatabase(dropDatabase func(*sql.Context, string) error) { 342 if c == nil { 343 return 344 } 345 c.mu.Lock() 346 defer c.mu.Unlock() 347 c.dropDatabase = dropDatabase 348 } 349 350 // DropDatabaseHook gets called when the database provider drops a 351 // database. This is how we learn that we need to replicate a drop database. 352 func (c *Controller) DropDatabaseHook() func(*sql.Context, string) { 353 return c.dropDatabaseHook 354 } 355 356 func (c *Controller) dropDatabaseHook(_ *sql.Context, dbname string) { 357 c.mu.Lock() 358 defer c.mu.Unlock() 359 360 // We always cleanup the commithooks associated with that database. 361 362 j := 0 363 for i := 0; i < len(c.commithooks); i++ { 364 if c.commithooks[i].dbname == dbname { 365 c.commithooks[i].databaseWasDropped() 366 continue 367 } 368 if j != i { 369 c.commithooks[j] = c.commithooks[i] 370 } 371 j += 1 372 } 373 c.commithooks = c.commithooks[:j] 374 375 if c.role != RolePrimary { 376 return 377 } 378 379 // If we are the primary, we will replicate the drop to our standby replicas. 380 381 ctx, cancel := context.WithCancel(context.Background()) 382 wg := &sync.WaitGroup{} 383 wg.Add(len(c.replicationClients)) 384 state := &databaseDropReplication{ 385 ctx: ctx, 386 cancel: cancel, 387 wg: wg, 388 } 389 c.outstandingDropDatabases[dbname] = state 390 391 for _, client := range c.replicationClients { 392 client := client 393 go c.replicateDropDatabase(state, client, dbname) 394 } 395 } 396 397 func (c *Controller) cancelDropDatabaseReplication(dbname string) { 398 c.mu.Lock() 399 defer c.mu.Unlock() 400 if s := c.outstandingDropDatabases[dbname]; s != nil { 401 s.cancel() 402 s.wg.Wait() 403 } 404 } 405 406 func (c *Controller) replicateDropDatabase(s *databaseDropReplication, client *replicationServiceClient, dbname string) { 407 defer s.wg.Done() 408 bo := backoff.NewExponentialBackOff() 409 bo.InitialInterval = time.Millisecond 410 bo.MaxInterval = time.Minute 411 bo.MaxElapsedTime = 0 412 for { 413 if s.ctx.Err() != nil { 414 return 415 } 416 ctx, cancel := context.WithTimeout(s.ctx, 15*time.Second) 417 _, err := client.client.DropDatabase(ctx, &replicationapi.DropDatabaseRequest{ 418 Name: dbname, 419 }) 420 cancel() 421 if err == nil { 422 c.lgr.Tracef("successfully replicated drop of [%s] to %s", dbname, client.remote) 423 return 424 } 425 if status.Code(err) == codes.FailedPrecondition { 426 c.lgr.Warnf("drop of [%s] to %s will note be replicated; FailedPrecondition", dbname, client.remote) 427 return 428 } 429 c.lgr.Warnf("failed to replicate drop of [%s] to %s: %v", dbname, client.remote, err) 430 if s.ctx.Err() != nil { 431 return 432 } 433 d := bo.NextBackOff() 434 c.lgr.Tracef("sleeping %v before next drop attempt for database [%s] at %s", d, dbname, client.remote) 435 select { 436 case <-time.After(d): 437 case <-s.ctx.Done(): 438 return 439 } 440 } 441 } 442 443 func (c *Controller) ClusterDatabase() sql.Database { 444 if c == nil { 445 return nil 446 } 447 return clusterdb.NewClusterDatabase(c) 448 } 449 450 func (c *Controller) RemoteSrvListenAddr() string { 451 if c == nil { 452 return "" 453 } 454 return fmt.Sprintf("%s:%d", c.cfg.RemotesAPIConfig().Address(), c.cfg.RemotesAPIConfig().Port()) 455 } 456 457 func (c *Controller) ServerOptions() []grpc.ServerOption { 458 return c.sinterceptor.Options() 459 } 460 461 func (c *Controller) refreshSystemVars() { 462 role, epoch := string(c.role), c.epoch 463 vars := []sql.SystemVariable{ 464 &sql.MysqlSystemVariable{ 465 Name: dsess.DoltClusterRoleVariable, 466 Dynamic: false, 467 Scope: sql.GetMysqlScope(sql.SystemVariableScope_Persist), 468 Type: gmstypes.NewSystemStringType(dsess.DoltClusterRoleVariable), 469 Default: role, 470 }, 471 &sql.MysqlSystemVariable{ 472 Name: dsess.DoltClusterRoleEpochVariable, 473 Dynamic: false, 474 Scope: sql.GetMysqlScope(sql.SystemVariableScope_Persist), 475 Type: gmstypes.NewSystemIntType(dsess.DoltClusterRoleEpochVariable, 0, 9223372036854775807, false), 476 Default: epoch, 477 }, 478 } 479 c.systemVars.AddSystemVariables(vars) 480 } 481 482 func (c *Controller) persistVariables() error { 483 toset := make(map[string]string) 484 toset[dsess.DoltClusterRoleVariable] = string(c.role) 485 toset[dsess.DoltClusterRoleEpochVariable] = strconv.Itoa(c.epoch) 486 return c.persistentCfg.SetStrings(toset) 487 } 488 489 func applyBootstrapClusterConfig(lgr *logrus.Logger, cfg servercfg.ClusterConfig, pCfg config.ReadWriteConfig) (Role, int, error) { 490 toset := make(map[string]string) 491 persistentRole := pCfg.GetStringOrDefault(dsess.DoltClusterRoleVariable, "") 492 var roleFromPersistentConfig bool 493 persistentEpoch := pCfg.GetStringOrDefault(dsess.DoltClusterRoleEpochVariable, "") 494 if persistentRole == "" { 495 if cfg.BootstrapRole() != "" { 496 lgr.Tracef("cluster/controller: persisted cluster role was empty, apply bootstrap_role %s", cfg.BootstrapRole()) 497 persistentRole = cfg.BootstrapRole() 498 } else { 499 lgr.Trace("cluster/controller: persisted cluster role was empty, bootstrap_role was empty: defaulted to primary") 500 persistentRole = "primary" 501 } 502 toset[dsess.DoltClusterRoleVariable] = persistentRole 503 } else { 504 roleFromPersistentConfig = true 505 lgr.Tracef("cluster/controller: persisted cluster role is %s", persistentRole) 506 } 507 if persistentEpoch == "" { 508 persistentEpoch = strconv.Itoa(cfg.BootstrapEpoch()) 509 lgr.Tracef("cluster/controller: persisted cluster role epoch is empty, took boostrap_epoch: %s", persistentEpoch) 510 toset[dsess.DoltClusterRoleEpochVariable] = persistentEpoch 511 } else { 512 lgr.Tracef("cluster/controller: persisted cluster role epoch is %s", persistentEpoch) 513 } 514 if persistentRole != string(RolePrimary) && persistentRole != string(RoleStandby) { 515 isallowed := persistentRole == string(RoleDetectedBrokenConfig) && roleFromPersistentConfig 516 if !isallowed { 517 return "", 0, fmt.Errorf("persisted role %s.%s = %s must be \"primary\" or \"secondary\"", PersistentConfigPrefix, dsess.DoltClusterRoleVariable, persistentRole) 518 } 519 } 520 epochi, err := strconv.Atoi(persistentEpoch) 521 if err != nil { 522 return "", 0, fmt.Errorf("persisted role epoch %s.%s = %s must be an integer", PersistentConfigPrefix, dsess.DoltClusterRoleEpochVariable, persistentEpoch) 523 } 524 if len(toset) > 0 { 525 err := pCfg.SetStrings(toset) 526 if err != nil { 527 return "", 0, err 528 } 529 } 530 return Role(persistentRole), epochi, nil 531 } 532 533 type roleTransitionOptions struct { 534 // If true, all standby replicas must be caught up in order to 535 // transition from primary to standby. 536 graceful bool 537 538 // If non-zero and |graceful| is true, will allow a transition from 539 // primary to standby to succeed only if this many standby replicas 540 // are known to be caught up at the finalization of the replication 541 // hooks. 542 minCaughtUpStandbys int 543 544 // If non-nil, this connection will be saved if and when the connection 545 // process needs to terminate existing connections. 546 saveConnID *int 547 } 548 549 type roleTransitionResult struct { 550 // true if the role changed as a result of this call. 551 changedRole bool 552 553 // filled in with graceful transition results if this was a graceful 554 // transition and it was successful. 555 gracefulTransitionResults []graceTransitionResult 556 } 557 558 func (c *Controller) setRoleAndEpoch(role string, epoch int, opts roleTransitionOptions) (roleTransitionResult, error) { 559 graceful := opts.graceful 560 saveConnID := -1 561 if opts.saveConnID != nil { 562 saveConnID = *opts.saveConnID 563 } 564 565 c.mu.Lock() 566 defer c.mu.Unlock() 567 if epoch == c.epoch && role == string(c.role) { 568 return roleTransitionResult{false, nil}, nil 569 } 570 571 if role != string(RolePrimary) && role != string(RoleStandby) && role != string(RoleDetectedBrokenConfig) { 572 return roleTransitionResult{false, nil}, fmt.Errorf("error assuming role '%s'; valid roles are 'primary' and 'standby'", role) 573 } 574 575 if epoch < c.epoch { 576 return roleTransitionResult{false, nil}, fmt.Errorf("error assuming role '%s' at epoch %d; already at epoch %d", role, epoch, c.epoch) 577 } 578 if epoch == c.epoch { 579 // This is allowed for non-graceful transitions to 'standby', which only occur from interceptors and 580 // other signals that the cluster is misconfigured. 581 isallowed := !graceful && (role == string(RoleStandby) || role == string(RoleDetectedBrokenConfig)) 582 if !isallowed { 583 return roleTransitionResult{false, nil}, fmt.Errorf("error assuming role '%s' at epoch %d; already at epoch %d with different role, '%s'", role, epoch, c.epoch, c.role) 584 } 585 } 586 587 changedrole := role != string(c.role) 588 var gracefulResults []graceTransitionResult 589 590 if changedrole { 591 var err error 592 if role == string(RoleStandby) { 593 if graceful { 594 beforeRole, beforeEpoch := c.role, c.epoch 595 gracefulResults, err = c.gracefulTransitionToStandby(saveConnID, opts.minCaughtUpStandbys) 596 if err == nil && (beforeRole != c.role || beforeEpoch != c.epoch) { 597 // The role or epoch moved out from under us while we were unlocked and transitioning to standby. 598 err = fmt.Errorf("error assuming role '%s' at epoch %d: the role configuration changed while we were replicating to our standbys. Please try again", role, epoch) 599 } 600 if err != nil { 601 c.setProviderIsStandby(c.role != RolePrimary) 602 c.killRunningQueries(saveConnID) 603 return roleTransitionResult{false, nil}, err 604 } 605 } else { 606 c.immediateTransitionToStandby() 607 } 608 } else if role == string(RoleDetectedBrokenConfig) { 609 c.immediateTransitionToStandby() 610 } else { 611 c.transitionToPrimary(saveConnID) 612 } 613 } 614 615 c.role = Role(role) 616 c.epoch = epoch 617 618 c.refreshSystemVars() 619 c.cinterceptor.setRole(c.role, c.epoch) 620 c.sinterceptor.setRole(c.role, c.epoch) 621 if changedrole { 622 for _, h := range c.commithooks { 623 h.setRole(c.role) 624 } 625 c.mysqlDbPersister.setRole(c.role) 626 c.bcReplication.setRole(c.role) 627 } 628 _ = c.persistVariables() 629 return roleTransitionResult{ 630 changedRole: changedrole, 631 gracefulTransitionResults: gracefulResults, 632 }, nil 633 } 634 635 func (c *Controller) roleAndEpoch() (Role, int) { 636 c.mu.Lock() 637 defer c.mu.Unlock() 638 return c.role, c.epoch 639 } 640 641 func (c *Controller) registerCommitHook(hook *commithook) { 642 c.mu.Lock() 643 defer c.mu.Unlock() 644 c.commithooks = append(c.commithooks, hook) 645 } 646 647 func (c *Controller) GetClusterStatus() []clusterdb.ReplicaStatus { 648 if c == nil { 649 return []clusterdb.ReplicaStatus{} 650 } 651 c.mu.Lock() 652 epoch, role := c.epoch, c.role 653 commithooks := make([]*commithook, len(c.commithooks)) 654 copy(commithooks, c.commithooks) 655 c.mu.Unlock() 656 ret := make([]clusterdb.ReplicaStatus, len(commithooks)) 657 for i, c := range commithooks { 658 lag, lastUpdate, currentErrorStr := c.status() 659 ret[i] = clusterdb.ReplicaStatus{ 660 Database: c.dbname, 661 Remote: c.remotename, 662 Role: string(role), 663 Epoch: epoch, 664 ReplicationLag: lag, 665 LastUpdate: lastUpdate, 666 CurrentError: currentErrorStr, 667 } 668 } 669 return ret 670 } 671 672 func (c *Controller) recordSuccessfulRemoteSrvCommit(name string) { 673 c.lgr.Tracef("standby replica received push and updated database %s", name) 674 c.mu.Lock() 675 commithooks := make([]*commithook, len(c.commithooks)) 676 copy(commithooks, c.commithooks) 677 c.mu.Unlock() 678 for _, c := range commithooks { 679 if c.dbname == name { 680 c.recordSuccessfulRemoteSrvCommit() 681 } 682 } 683 } 684 685 func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql.Context, error), args remotesrv.ServerArgs) (remotesrv.ServerArgs, error) { 686 c.mu.Lock() 687 defer c.mu.Unlock() 688 listenaddr := c.RemoteSrvListenAddr() 689 args.HttpListenAddr = listenaddr 690 args.GrpcListenAddr = listenaddr 691 args.Options = c.ServerOptions() 692 var err error 693 args.FS, args.DBCache, err = sqle.RemoteSrvFSAndDBCache(ctxFactory, sqle.CreateUnknownDatabases) 694 if err != nil { 695 return remotesrv.ServerArgs{}, err 696 } 697 args.DBCache = remotesrvStoreCache{args.DBCache, c} 698 c.remoteSrvDBCache = args.DBCache 699 700 keyID := creds.PubKeyToKID(c.pub) 701 keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID) 702 args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub) 703 704 return args, nil 705 } 706 707 func (c *Controller) HookMySQLDbPersister(persister MySQLDbPersister, mysqlDb *mysql_db.MySQLDb) MySQLDbPersister { 708 if c != nil { 709 c.mysqlDb = mysqlDb 710 c.mysqlDbPersister = &replicatingMySQLDbPersister{ 711 base: persister, 712 replicas: c.mysqlDbReplicas, 713 } 714 c.mysqlDbPersister.setRole(c.role) 715 persister = c.mysqlDbPersister 716 } 717 return persister 718 } 719 720 func (c *Controller) HookBranchControlPersistence(controller *branch_control.Controller, fs filesys.Filesys) { 721 if c != nil { 722 c.branchControlController = controller 723 c.branchControlFilesys = fs 724 725 replicas := make([]*branchControlReplica, len(c.replicationClients)) 726 for i := range replicas { 727 bo := backoff.NewExponentialBackOff() 728 bo.InitialInterval = time.Second 729 bo.MaxInterval = time.Minute 730 bo.MaxElapsedTime = 0 731 replicas[i] = &branchControlReplica{ 732 backoff: bo, 733 client: c.replicationClients[i], 734 lgr: c.lgr.WithFields(logrus.Fields{}), 735 } 736 replicas[i].cond = sync.NewCond(&replicas[i].mu) 737 } 738 c.bcReplication = &branchControlReplication{ 739 replicas: replicas, 740 bcController: controller, 741 } 742 c.bcReplication.setRole(c.role) 743 744 controller.SavedCallback = func(ctx context.Context) { 745 contents := controller.Serialized.Load() 746 if contents != nil { 747 var rsc doltdb.ReplicationStatusController 748 c.bcReplication.UpdateBranchControlContents(ctx, *contents, &rsc) 749 if sqlCtx, ok := ctx.(*sql.Context); ok { 750 dsess.WaitForReplicationController(sqlCtx, rsc) 751 } 752 } 753 } 754 } 755 } 756 757 func (c *Controller) RegisterGrpcServices(ctxFactory func(context.Context) (*sql.Context, error), srv *grpc.Server) { 758 replicationapi.RegisterReplicationServiceServer(srv, &replicationServiceServer{ 759 ctxFactory: ctxFactory, 760 mysqlDb: c.mysqlDb, 761 branchControl: c.branchControlController, 762 branchControlFilesys: c.branchControlFilesys, 763 dropDatabase: c.dropDatabase, 764 lgr: c.lgr.WithFields(logrus.Fields{}), 765 }) 766 } 767 768 // TODO: make the deadline here configurable or something. 769 const waitForHooksToReplicateTimeout = 10 * time.Second 770 771 type graceTransitionResult struct { 772 caughtUp bool 773 database string 774 remote string 775 remoteUrl string 776 } 777 778 // The order of operations is: 779 // * Set all databases in database_provider to read-only. 780 // * Kill all running queries in GMS. 781 // * Replicate all databases to their standby remotes. 782 // - If success, return success. 783 // - If failure, set all databases in database_provider back to their original state. Return failure. 784 // 785 // saveConnID is potentially a connID of the caller to 786 // dolt_assume_cluster_role(), which should not be killed with the other 787 // connections. That connection will be transitioned to a terminal error state 788 // after returning the results of dolt_assume_cluster_role(). 789 // 790 // called with c.mu held 791 func (c *Controller) gracefulTransitionToStandby(saveConnID, minCaughtUpStandbys int) ([]graceTransitionResult, error) { 792 c.setProviderIsStandby(true) 793 c.killRunningQueries(saveConnID) 794 795 var hookStates, mysqlStates, bcStates []graceTransitionResult 796 var hookErr, mysqlErr, bcErr error 797 798 // We concurrently wait for hooks, mysql and dolt_branch_control replication to true up. 799 // If we encounter any errors while doing this, we fail the graceful transition. 800 801 var wg sync.WaitGroup 802 wg.Add(3) 803 go func() { 804 defer wg.Done() 805 // waitForHooksToReplicate will release the lock while it 806 // blocks, but will return with the lock held. 807 hookStates, hookErr = c.waitForHooksToReplicate(waitForHooksToReplicateTimeout) 808 }() 809 go func() { 810 defer wg.Done() 811 mysqlStates, mysqlErr = c.mysqlDbPersister.waitForReplication(waitForHooksToReplicateTimeout) 812 }() 813 go func() { 814 defer wg.Done() 815 bcStates, bcErr = c.bcReplication.waitForReplication(waitForHooksToReplicateTimeout) 816 }() 817 wg.Wait() 818 819 if hookErr != nil { 820 return nil, hookErr 821 } 822 if mysqlErr != nil { 823 return nil, mysqlErr 824 } 825 if bcErr != nil { 826 return nil, bcErr 827 } 828 829 if len(hookStates) != len(c.commithooks) { 830 c.lgr.Warnf("cluster/controller: failed to transition to standby; the set of replicated databases changed during the transition.") 831 return nil, errors.New("cluster/controller: failed to transition to standby; the set of replicated databases changed during the transition.") 832 } 833 834 res := make([]graceTransitionResult, 0, len(hookStates)+len(mysqlStates)+len(bcStates)) 835 res = append(res, hookStates...) 836 res = append(res, mysqlStates...) 837 res = append(res, bcStates...) 838 839 if minCaughtUpStandbys == 0 { 840 for _, state := range res { 841 if !state.caughtUp { 842 c.lgr.Warnf("cluster/controller: failed to replicate all databases to all standbys; not transitioning to standby.") 843 return nil, fmt.Errorf("cluster/controller: failed to transition from primary to standby gracefully; could not replicate databases to standby in a timely manner.") 844 } 845 } 846 c.lgr.Tracef("cluster/controller: successfully replicated all databases to all standbys; transitioning to standby.") 847 } else { 848 databases := make(map[string]struct{}) 849 replicas := make(map[string]int) 850 for _, r := range res { 851 databases[r.database] = struct{}{} 852 url, err := url.Parse(r.remoteUrl) 853 if err != nil { 854 return nil, fmt.Errorf("cluster/controller: could not parse remote_url (%s) for remote %s on database %s: %w", r.remoteUrl, r.remote, r.database, err) 855 } 856 if _, ok := replicas[url.Host]; !ok { 857 replicas[url.Host] = 0 858 } 859 if r.caughtUp { 860 replicas[url.Host] = replicas[url.Host] + 1 861 } 862 } 863 numCaughtUp := 0 864 for _, v := range replicas { 865 if v == len(databases) { 866 numCaughtUp += 1 867 } 868 } 869 if numCaughtUp < minCaughtUpStandbys { 870 return nil, fmt.Errorf("cluster/controller: failed to transition from primary to standby gracefully; could not ensure %d replicas were caught up on all %d databases. Only caught up %d standbys fully.", minCaughtUpStandbys, len(databases), numCaughtUp) 871 } 872 c.lgr.Tracef("cluster/controller: successfully replicated all databases to %d out of %d standbys; transitioning to standby.", numCaughtUp, len(replicas)) 873 } 874 875 return res, nil 876 } 877 878 func allCaughtUp(res []graceTransitionResult) bool { 879 for _, r := range res { 880 if !r.caughtUp { 881 return false 882 } 883 } 884 return true 885 } 886 887 // The order of operations is: 888 // * Set all databases in database_provider to read-only. 889 // * Kill all running queries in GMS. 890 // * Return success. NOTE: we do not attempt to replicate to the standby. 891 // 892 // called with c.mu held 893 func (c *Controller) immediateTransitionToStandby() error { 894 c.setProviderIsStandby(true) 895 c.killRunningQueries(-1) 896 return nil 897 } 898 899 // The order of operations is: 900 // * Set all databases in database_provider back to their original mode: read-write or read only. 901 // * Kill all running queries in GMS. 902 // * Return success. 903 // 904 // saveConnID is potentially the connID of the caller to 905 // dolt_assume_cluster_role(). 906 // 907 // called with c.mu held 908 func (c *Controller) transitionToPrimary(saveConnID int) error { 909 c.setProviderIsStandby(false) 910 c.killRunningQueries(saveConnID) 911 return nil 912 } 913 914 // Kills all running queries in the managed GMS engine. 915 // called with c.mu held 916 func (c *Controller) killRunningQueries(saveConnID int) { 917 if c.iterSessions != nil { 918 c.iterSessions(func(session sql.Session) (stop bool, err error) { 919 if int(session.ID()) != saveConnID { 920 c.killQuery(session.ID()) 921 c.killConnection(session.ID()) 922 } 923 return 924 }) 925 } 926 } 927 928 // called with c.mu held 929 func (c *Controller) setProviderIsStandby(standby bool) { 930 if c.standbyCallback != nil { 931 c.standbyCallback(standby) 932 } 933 } 934 935 // Called during a graceful transition from primary to standby. Waits until all 936 // commithooks report nextHead == lastPushedHead. 937 // 938 // Returns `[]bool` with an entry for each `commithook` which existed at the 939 // start of the call. The entry will be `true` if that `commithook` was caught 940 // up as part of this wait, and `false` otherwise. 941 // 942 // called with c.mu held 943 func (c *Controller) waitForHooksToReplicate(timeout time.Duration) ([]graceTransitionResult, error) { 944 commithooks := make([]*commithook, len(c.commithooks)) 945 copy(commithooks, c.commithooks) 946 res := make([]graceTransitionResult, len(commithooks)) 947 for i := range res { 948 res[i].database = commithooks[i].dbname 949 res[i].remote = commithooks[i].remotename 950 res[i].remoteUrl = commithooks[i].remoteurl 951 } 952 var wg sync.WaitGroup 953 wg.Add(len(commithooks)) 954 for li, lch := range commithooks { 955 i := li 956 ch := lch 957 ok := ch.setWaitNotify(func() { 958 // called with ch.mu locked. 959 if !res[i].caughtUp && ch.isCaughtUp() { 960 res[i].caughtUp = true 961 wg.Done() 962 } 963 }) 964 if !ok { 965 for j := li - 1; j >= 0; j-- { 966 commithooks[j].setWaitNotify(nil) 967 } 968 c.lgr.Warnf("cluster/controller: failed to wait for graceful transition to standby; there were concurrent attempts to transition..") 969 return nil, errors.New("cluster/controller: failed to transition from primary to standby gracefully; did not gain exclusive access to commithooks.") 970 } 971 } 972 c.mu.Unlock() 973 done := make(chan struct{}) 974 go func() { 975 wg.Wait() 976 close(done) 977 }() 978 select { 979 case <-done: 980 case <-time.After(timeout): 981 } 982 c.mu.Lock() 983 for _, ch := range commithooks { 984 ch.setWaitNotify(nil) 985 } 986 987 // Make certain we don't leak the wg.Wait goroutine in the failure case. 988 // At this point, none of the callbacks will ever be called again and 989 // ch.setWaitNotify grabs a lock and so establishes the happens before. 990 for _, b := range res { 991 if !b.caughtUp { 992 wg.Done() 993 } 994 } 995 <-done 996 997 return res, nil 998 } 999 1000 // Within a cluster, if remotesapi is configured with a tls_ca, we take the 1001 // following semantics: 1002 // * The configured tls_ca file holds a set of PEM encoded x509 certificates, 1003 // all of which are trusted roots for the outbound connections the 1004 // remotestorage client establishes. 1005 // * The certificate chain presented by the server must validate to a root 1006 // which was present in tls_ca. In particular, every certificate in the chain 1007 // must be within its validity window, the signatures must be valid, key usage 1008 // and isCa must be correctly set for the roots and the intermediates, and the 1009 // leaf must have extended key usage server auth. 1010 // * On the other hand, no verification is done against the SAN or the Subject 1011 // of the certificate. 1012 // 1013 // We use these TLS semantics for both connections to the gRPC endpoint which 1014 // is the actual remotesapi, and for connections to any HTTPS endpoints to 1015 // which the gRPC service returns URLs. For now, this works perfectly for our 1016 // use case, but it's tightly coupled to `cluster:` deployment topologies and 1017 // the likes. 1018 // 1019 // If tls_ca is not set then default TLS handling is performed. In particular, 1020 // if the remotesapi endpoints is HTTPS, then the system roots are used and 1021 // ServerName is verified against the presented URL SANs of the certificates. 1022 // 1023 // This tls Config is used for fetching JWKS, for outbound GRPC connections and 1024 // for outbound https connections on the URLs that the GRPC services return. 1025 func (c *Controller) outboundTlsConfig() (*tls.Config, error) { 1026 tlsCA := c.cfg.RemotesAPIConfig().TLSCA() 1027 if tlsCA == "" { 1028 return nil, nil 1029 } 1030 urlmatches := c.cfg.RemotesAPIConfig().ServerNameURLMatches() 1031 dnsmatches := c.cfg.RemotesAPIConfig().ServerNameDNSMatches() 1032 pem, err := os.ReadFile(tlsCA) 1033 if err != nil { 1034 return nil, err 1035 } 1036 roots := x509.NewCertPool() 1037 if ok := roots.AppendCertsFromPEM(pem); !ok { 1038 return nil, errors.New("error loading ca roots from " + tlsCA) 1039 } 1040 verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 1041 certs := make([]*x509.Certificate, len(rawCerts)) 1042 var err error 1043 for i, asn1Data := range rawCerts { 1044 certs[i], err = x509.ParseCertificate(asn1Data) 1045 if err != nil { 1046 return err 1047 } 1048 } 1049 keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} 1050 opts := x509.VerifyOptions{ 1051 Roots: roots, 1052 CurrentTime: time.Now(), 1053 Intermediates: x509.NewCertPool(), 1054 KeyUsages: keyUsages, 1055 } 1056 for _, cert := range certs[1:] { 1057 opts.Intermediates.AddCert(cert) 1058 } 1059 _, err = certs[0].Verify(opts) 1060 if err != nil { 1061 return err 1062 } 1063 if len(urlmatches) > 0 { 1064 found := false 1065 for _, n := range urlmatches { 1066 for _, cn := range certs[0].URIs { 1067 if n == cn.String() { 1068 found = true 1069 } 1070 break 1071 } 1072 if found { 1073 break 1074 } 1075 } 1076 if !found { 1077 return errors.New("expected certificate to match something in server_name_urls, but it did not") 1078 } 1079 } 1080 if len(dnsmatches) > 0 { 1081 found := false 1082 for _, n := range dnsmatches { 1083 for _, cn := range certs[0].DNSNames { 1084 if n == cn { 1085 found = true 1086 } 1087 break 1088 } 1089 if found { 1090 break 1091 } 1092 } 1093 if !found { 1094 return errors.New("expected certificate to match something in server_name_dns, but it did not") 1095 } 1096 } 1097 return nil 1098 } 1099 return &tls.Config{ 1100 // We have to InsecureSkipVerify because ServerName is always 1101 // set by the grpc dial provider and golang tls.Config does not 1102 // have good support for performing certificate validation 1103 // without server name validation. 1104 InsecureSkipVerify: true, 1105 1106 VerifyPeerCertificate: verifyFunc, 1107 1108 NextProtos: []string{"h2"}, 1109 }, nil 1110 } 1111 1112 func (c *Controller) standbyRemotesJWKS() *jwtauth.MultiJWKS { 1113 client := &http.Client{ 1114 Transport: &http.Transport{ 1115 TLSClientConfig: c.tlsCfg, 1116 ForceAttemptHTTP2: true, 1117 }, 1118 } 1119 urls := make([]string, len(c.cfg.StandbyRemotes())) 1120 for i, r := range c.cfg.StandbyRemotes() { 1121 urls[i] = strings.Replace(r.RemoteURLTemplate(), dsess.URLTemplateDatabasePlaceholder, ".well-known/jwks.json", -1) 1122 } 1123 return jwtauth.NewMultiJWKS(c.lgr.WithFields(logrus.Fields{"component": "jwks-key-provider"}), urls, client) 1124 } 1125 1126 type replicationServiceClient struct { 1127 remote string 1128 url string 1129 tls bool 1130 client replicationapi.ReplicationServiceClient 1131 closer func() error 1132 } 1133 1134 func (c *Controller) replicationServiceDialOptions() []grpc.DialOption { 1135 var ret []grpc.DialOption 1136 if c.tlsCfg == nil { 1137 ret = append(ret, grpc.WithInsecure()) 1138 } else { 1139 ret = append(ret, grpc.WithTransportCredentials(credentials.NewTLS(c.tlsCfg))) 1140 } 1141 1142 ret = append(ret, grpc.WithStreamInterceptor(c.cinterceptor.Stream())) 1143 ret = append(ret, grpc.WithUnaryInterceptor(c.cinterceptor.Unary())) 1144 1145 ret = append(ret, grpc.WithPerRPCCredentials(c.grpcCreds)) 1146 1147 return ret 1148 } 1149 1150 func (c *Controller) replicationServiceClients(ctx context.Context) ([]*replicationServiceClient, error) { 1151 var ret []*replicationServiceClient 1152 for _, r := range c.cfg.StandbyRemotes() { 1153 urlStr := strings.Replace(r.RemoteURLTemplate(), dsess.URLTemplateDatabasePlaceholder, "", -1) 1154 url, err := url.Parse(urlStr) 1155 if err != nil { 1156 return nil, fmt.Errorf("could not parse remote url template [%s] for remote %s: %w", r.RemoteURLTemplate(), r.Name(), err) 1157 } 1158 grpcTarget := "dns:" + url.Hostname() + ":" + url.Port() 1159 cc, err := grpc.DialContext(ctx, grpcTarget, c.replicationServiceDialOptions()...) 1160 if err != nil { 1161 return nil, fmt.Errorf("could not dial grpc endpoint [%s] for remote %s: %w", grpcTarget, r.Name(), err) 1162 } 1163 client := replicationapi.NewReplicationServiceClient(cc) 1164 ret = append(ret, &replicationServiceClient{ 1165 remote: r.Name(), 1166 url: grpcTarget, 1167 tls: c.tlsCfg != nil, 1168 client: client, 1169 closer: cc.Close, 1170 }) 1171 } 1172 return ret, nil 1173 } 1174 1175 // Generally r.url is a gRPC dial endpoint and will be something like "dns:53.78.2.1:3832", or something like that. 1176 // 1177 // We want to match these endpoints up with Dolt remotes URLs, which will typically be something like http://53.78.2.1:3832. 1178 func (r *replicationServiceClient) httpUrl() string { 1179 prefix := "https://" 1180 if !r.tls { 1181 prefix = "http://" 1182 } 1183 return prefix + strings.TrimPrefix(r.url, "dns:") 1184 }