github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dsess/session.go (about) 1 // Copyright 2020 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dsess 16 17 import ( 18 "context" 19 "errors" 20 "fmt" 21 "strconv" 22 "strings" 23 "sync" 24 "time" 25 26 "github.com/dolthub/go-mysql-server/sql" 27 sqltypes "github.com/dolthub/go-mysql-server/sql/types" 28 "github.com/shopspring/decimal" 29 30 "github.com/dolthub/dolt/go/cmd/dolt/cli" 31 "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" 32 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 33 "github.com/dolthub/dolt/go/libraries/doltcore/env" 34 "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" 35 "github.com/dolthub/dolt/go/libraries/doltcore/ref" 36 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate" 37 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" 38 "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" 39 "github.com/dolthub/dolt/go/libraries/utils/config" 40 "github.com/dolthub/dolt/go/libraries/utils/filesys" 41 "github.com/dolthub/dolt/go/store/hash" 42 "github.com/dolthub/dolt/go/store/types" 43 ) 44 45 const ( 46 DbRevisionDelimiter = "/" 47 ) 48 49 var ErrSessionNotPersistable = errors.New("session is not persistable") 50 51 // DoltSession is the sql.Session implementation used by dolt. It is accessible through a *sql.Context instance 52 type DoltSession struct { 53 sql.Session 54 username string 55 email string 56 dbStates map[string]*DatabaseSessionState 57 dbCache *DatabaseCache 58 provider DoltDatabaseProvider 59 tempTables map[string][]sql.Table 60 globalsConf config.ReadWriteConfig 61 branchController *branch_control.Controller 62 statsProv sql.StatsProvider 63 mu *sync.Mutex 64 fs filesys.Filesys 65 66 // If non-nil, this will be returned from ValidateSession. 67 // Used by sqle/cluster to put a session into a terminal err state. 68 validateErr error 69 } 70 71 var _ sql.Session = (*DoltSession)(nil) 72 var _ sql.PersistableSession = (*DoltSession)(nil) 73 var _ sql.TransactionSession = (*DoltSession)(nil) 74 var _ branch_control.Context = (*DoltSession)(nil) 75 76 // DefaultSession creates a DoltSession with default values 77 func DefaultSession(pro DoltDatabaseProvider) *DoltSession { 78 return &DoltSession{ 79 Session: sql.NewBaseSession(), 80 username: "", 81 email: "", 82 dbStates: make(map[string]*DatabaseSessionState), 83 dbCache: newDatabaseCache(), 84 provider: pro, 85 tempTables: make(map[string][]sql.Table), 86 globalsConf: config.NewMapConfig(make(map[string]string)), 87 branchController: branch_control.CreateDefaultController(context.TODO()), // Default sessions are fine with the default controller 88 mu: &sync.Mutex{}, 89 fs: pro.FileSystem(), 90 } 91 } 92 93 // NewDoltSession creates a DoltSession object from a standard sql.Session and 0 or more Database objects. 94 func NewDoltSession( 95 sqlSess *sql.BaseSession, 96 pro DoltDatabaseProvider, 97 conf config.ReadWriteConfig, 98 branchController *branch_control.Controller, 99 statsProvider sql.StatsProvider, 100 ) (*DoltSession, error) { 101 username := conf.GetStringOrDefault(config.UserNameKey, "") 102 email := conf.GetStringOrDefault(config.UserEmailKey, "") 103 globals := config.NewPrefixConfig(conf, env.SqlServerGlobalsPrefix) 104 105 sess := &DoltSession{ 106 Session: sqlSess, 107 username: username, 108 email: email, 109 dbStates: make(map[string]*DatabaseSessionState), 110 dbCache: newDatabaseCache(), 111 provider: pro, 112 tempTables: make(map[string][]sql.Table), 113 globalsConf: globals, 114 branchController: branchController, 115 statsProv: statsProvider, 116 mu: &sync.Mutex{}, 117 fs: pro.FileSystem(), 118 } 119 120 return sess, nil 121 } 122 123 // Provider returns the RevisionDatabaseProvider for this session. 124 func (d *DoltSession) Provider() DoltDatabaseProvider { 125 return d.provider 126 } 127 128 // StatsProvider returns the sql.StatsProvider for this session. 129 func (d *DoltSession) StatsProvider() sql.StatsProvider { 130 return d.statsProv 131 } 132 133 // DSessFromSess retrieves a dolt session from a standard sql.Session 134 func DSessFromSess(sess sql.Session) *DoltSession { 135 return sess.(*DoltSession) 136 } 137 138 // lookupDbState is the private version of LookupDbState, returning a struct that has more information available than 139 // the interface returned by the public method. 140 func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*branchState, bool, error) { 141 dbName = strings.ToLower(dbName) 142 143 var baseName, rev string 144 baseName, rev = SplitRevisionDbName(dbName) 145 146 d.mu.Lock() 147 dbState, dbStateFound := d.dbStates[baseName] 148 d.mu.Unlock() 149 150 if dbStateFound { 151 // If we got an unqualified name, use the current working set head 152 if rev == "" { 153 rev = dbState.checkedOutRevSpec 154 } 155 156 branchState, ok := dbState.heads[strings.ToLower(rev)] 157 158 if ok { 159 if dbState.Err != nil { 160 return nil, false, dbState.Err 161 } 162 163 return branchState, ok, nil 164 } 165 } 166 167 // No state for this db / branch combination yet, look it up from the provider. We use the unqualified DB name (no 168 // branch) if the current DB has not yet been loaded into this session. It will resolve to that DB's default branch 169 // in that case. 170 revisionQualifiedName := dbName 171 if rev != "" { 172 revisionQualifiedName = RevisionDbName(baseName, rev) 173 } 174 175 database, ok, err := d.provider.SessionDatabase(ctx, revisionQualifiedName) 176 if err != nil { 177 return nil, false, err 178 } 179 if !ok { 180 return nil, false, nil 181 } 182 183 // Add the initial state to the session for future reuse 184 if err := d.addDB(ctx, database); err != nil { 185 return nil, false, err 186 } 187 188 d.mu.Lock() 189 dbState, dbStateFound = d.dbStates[baseName] 190 d.mu.Unlock() 191 if !dbStateFound { 192 // should be impossible 193 return nil, false, sql.ErrDatabaseNotFound.New(dbName) 194 } 195 196 return dbState.heads[strings.ToLower(database.Revision())], true, nil 197 } 198 199 // RevisionDbName returns the name of the revision db for the base name and revision string given 200 func RevisionDbName(baseName string, rev string) string { 201 return baseName + DbRevisionDelimiter + rev 202 } 203 204 func SplitRevisionDbName(dbName string) (string, string) { 205 var baseName, rev string 206 parts := strings.SplitN(dbName, DbRevisionDelimiter, 2) 207 baseName = parts[0] 208 if len(parts) > 1 { 209 rev = parts[1] 210 } 211 return baseName, rev 212 } 213 214 // LookupDbState returns the session state for the database named. Unqualified database names, e.g. `mydb` get resolved 215 // to the currently checked out HEAD, which could be a branch, a commit, a tag, etc. Revision-qualified database names, 216 // e.g. `mydb/branch1` get resolved to the session state for the revision named. 217 // A note on unqualified database names: unqualified names will resolve to a) the head last checked out with 218 // `dolt_checkout`, or b) the database's default branch, if this session hasn't called `dolt_checkout` yet. 219 // Also returns a bool indicating whether the database was found, and an error if one occurred. 220 func (d *DoltSession) LookupDbState(ctx *sql.Context, dbName string) (SessionState, bool, error) { 221 s, ok, err := d.lookupDbState(ctx, dbName) 222 if err != nil { 223 return nil, false, err 224 } 225 226 return s, ok, nil 227 } 228 229 // RemoveDbState invalidates any cached db state in this session, for example, if a database is dropped. 230 func (d *DoltSession) RemoveDbState(_ *sql.Context, dbName string) error { 231 d.mu.Lock() 232 defer d.mu.Unlock() 233 delete(d.dbStates, strings.ToLower(dbName)) 234 // also clear out any db-level caches for this db 235 d.dbCache.Clear() 236 return nil 237 } 238 239 // RemoveBranchState removes the session state for a branch, for example, if a branch is deleted. 240 func (d *DoltSession) RemoveBranchState(ctx *sql.Context, dbName string, branchName string) error { 241 baseName, _ := SplitRevisionDbName(dbName) 242 243 checkedOutState, ok, err := d.lookupDbState(ctx, baseName) 244 if err != nil { 245 return err 246 } 247 if !ok { 248 return sql.ErrDatabaseNotFound.New(baseName) 249 } 250 251 d.mu.Lock() 252 delete(checkedOutState.dbState.heads, strings.ToLower(branchName)) 253 d.mu.Unlock() 254 255 db, ok := d.provider.BaseDatabase(ctx, baseName) 256 if !ok { 257 return sql.ErrDatabaseNotFound.New(baseName) 258 } 259 260 defaultHead, err := DefaultHead(baseName, db) 261 if err != nil { 262 return err 263 } 264 265 checkedOutState.dbState.checkedOutRevSpec = defaultHead 266 267 // also clear out any db-level caches for this db 268 d.dbCache.Clear() 269 return nil 270 } 271 272 // RenameBranchState replaces all references to a renamed branch with its new name 273 func (d *DoltSession) RenameBranchState(ctx *sql.Context, dbName string, oldBranchName, newBranchName string) error { 274 baseName, _ := SplitRevisionDbName(dbName) 275 276 checkedOutState, ok, err := d.lookupDbState(ctx, baseName) 277 if err != nil { 278 return err 279 } 280 if !ok { 281 return sql.ErrDatabaseNotFound.New(baseName) 282 } 283 284 d.mu.Lock() 285 branch, ok := checkedOutState.dbState.heads[strings.ToLower(oldBranchName)] 286 287 if !ok { 288 // nothing to rename 289 d.mu.Unlock() 290 return nil 291 } 292 293 delete(checkedOutState.dbState.heads, strings.ToLower(oldBranchName)) 294 branch.head = strings.ToLower(newBranchName) 295 checkedOutState.dbState.heads[strings.ToLower(newBranchName)] = branch 296 297 d.mu.Unlock() 298 299 // also clear out any db-level caches for this db 300 d.dbCache.Clear() 301 return nil 302 } 303 304 // SetValidateErr sets an error on this session to be returned from every call 305 // to ValidateSession. This is effectively a way to disable a session. 306 // 307 // Used by sql/cluster logic to make sessions on a server which has 308 // transitioned roles terminally error. 309 func (d *DoltSession) SetValidateErr(err error) { 310 d.validateErr = err 311 } 312 313 // ValidateSession validates a working set if there are a valid sessionState with non-nil working set. 314 // If there is no sessionState or its current working set not defined, then no need for validation, 315 // so no error is returned. 316 func (d *DoltSession) ValidateSession(ctx *sql.Context) error { 317 return d.validateErr 318 } 319 320 // StartTransaction refreshes the state of this session and starts a new transaction. 321 func (d *DoltSession) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) { 322 // TODO: this is only necessary to support filter-branch, which needs to set a root directly and not have the 323 // session state altered when a transaction begins 324 if TransactionsDisabled(ctx) { 325 return DisabledTransaction{}, nil 326 } 327 328 // New transaction, clear all session state 329 d.clear() 330 331 // Take a snapshot of the current noms root for every database under management 332 doltDatabases := d.provider.DoltDatabases() 333 txDbs := make([]SqlDatabase, 0, len(doltDatabases)) 334 for _, db := range doltDatabases { 335 // TODO: this nil check is only necessary to support UserSpaceDatabase and clusterDatabase, come up with a better set of 336 // interfaces to capture these capabilities 337 ddb := db.DbData().Ddb 338 if ddb != nil { 339 rrd, ok := db.(RemoteReadReplicaDatabase) 340 if ok && rrd.ValidReplicaState(ctx) { 341 err := rrd.PullFromRemote(ctx) 342 if err != nil && !IgnoreReplicationErrors() { 343 return nil, fmt.Errorf("replication error: %w", err) 344 } else if err != nil { 345 WarnReplicationError(ctx, err) 346 } 347 } 348 349 // TODO: this check is relatively expensive, we should cache this value when it changes instead of looking it 350 // up on each transaction start 351 if _, v, ok := sql.SystemVariables.GetGlobal(ReadReplicaRemote); ok && v != "" { 352 err := ddb.Rebase(ctx) 353 if err != nil && !IgnoreReplicationErrors() { 354 return nil, err 355 } else if err != nil { 356 WarnReplicationError(ctx, err) 357 } 358 } 359 360 txDbs = append(txDbs, db) 361 } 362 } 363 364 tx, err := NewDoltTransaction(ctx, txDbs, tCharacteristic) 365 if err != nil { 366 return nil, err 367 } 368 369 // The engine sets the transaction after this call as well, but since we begin accessing data below, we need to set 370 // this now to avoid seeding the session state with stale data in some cases. The duplication is harmless since the 371 // code below cannot error. Additionally we clear any state that was cached by replication updates in the block above. 372 d.clear() 373 ctx.SetTransaction(tx) 374 375 // Set session vars for every DB in this session using their current branch head 376 for _, db := range doltDatabases { 377 // faulty settings can make it impossible to load particular DB branch states, so we ignore any errors in this 378 // loop and just decline to set the session vars. Throwing an error on transaction start in these cases makes it 379 // impossible for the user to correct any problems. 380 bs, ok, err := d.lookupDbState(ctx, db.Name()) 381 if err != nil || !ok { 382 continue 383 } 384 385 _ = d.setDbSessionVars(ctx, bs, false) 386 } 387 388 return tx, nil 389 } 390 391 // clear clears all DB state for this session 392 func (d *DoltSession) clear() { 393 d.mu.Lock() 394 defer d.mu.Unlock() 395 396 for _, dbState := range d.dbStates { 397 for head := range dbState.heads { 398 delete(dbState.heads, head) 399 } 400 } 401 } 402 403 func (d *DoltSession) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSetRef, dbName string) (*doltdb.WorkingSet, error) { 404 dbData, _ := d.GetDbData(nil, dbName) 405 406 headSpec, _ := doltdb.NewCommitSpec("HEAD") 407 headRef, err := wsRef.ToHeadRef() 408 if err != nil { 409 return nil, err 410 } 411 412 optCmt, err := dbData.Ddb.Resolve(ctx, headSpec, headRef) 413 if err != nil { 414 return nil, err 415 } 416 headCommit, ok := optCmt.ToCommit() 417 if !ok { 418 return nil, doltdb.ErrGhostCommitEncountered 419 } 420 421 headRoot, err := headCommit.GetRootValue(ctx) 422 if err != nil { 423 return nil, err 424 } 425 426 return doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(headRoot).WithStagedRoot(headRoot), nil 427 } 428 429 // CommitTransaction commits the in-progress transaction. Depending on session settings, this may write only a new 430 // working set, or may additionally create a new dolt commit for the current HEAD. If more than one branch head has 431 // changes, the transaction is rejected. 432 func (d *DoltSession) CommitTransaction(ctx *sql.Context, tx sql.Transaction) (err error) { 433 // Any non-error path must set the ctx's transaction to nil even if no work was done, because the engine only clears 434 // out transaction state in some cases. Changes to only branch heads (creating a new branch, reset, etc.) have no 435 // changes to commit visible to the transaction logic, but they still need a new transaction on the next statement. 436 // See comment in |commitBranchState| 437 defer func() { 438 if err == nil { 439 ctx.SetTransaction(nil) 440 } 441 }() 442 443 if TransactionsDisabled(ctx) { 444 return nil 445 } 446 447 dirties := d.dirtyWorkingSets() 448 if len(dirties) == 0 { 449 return nil 450 } 451 452 if len(dirties) > 1 { 453 return ErrDirtyWorkingSets 454 } 455 456 performDoltCommitVar, err := d.Session.GetSessionVariable(ctx, DoltCommitOnTransactionCommit) 457 if err != nil { 458 return err 459 } 460 461 peformDoltCommitInt, ok := performDoltCommitVar.(int8) 462 if !ok { 463 return fmt.Errorf(fmt.Sprintf("Unexpected type for var %s: %T", DoltCommitOnTransactionCommit, performDoltCommitVar)) 464 } 465 466 dirtyBranchState := dirties[0] 467 if peformDoltCommitInt == 1 { 468 // if the dirty working set doesn't belong to the currently checked out branch, that's an error 469 err = d.validateDoltCommit(ctx, dirtyBranchState) 470 if err != nil { 471 return err 472 } 473 474 message := "Transaction commit" 475 doltCommitMessageVar, err := d.Session.GetSessionVariable(ctx, DoltCommitOnTransactionCommitMessage) 476 if err != nil { 477 return err 478 } 479 480 doltCommitMessageString, ok := doltCommitMessageVar.(string) 481 if !ok && doltCommitMessageVar != nil { 482 return fmt.Errorf(fmt.Sprintf("Unexpected type for var %s: %T", DoltCommitOnTransactionCommitMessage, doltCommitMessageVar)) 483 } 484 485 trimmedString := strings.TrimSpace(doltCommitMessageString) 486 if strings.TrimSpace(doltCommitMessageString) != "" { 487 message = trimmedString 488 } 489 490 var pendingCommit *doltdb.PendingCommit 491 pendingCommit, err = d.PendingCommitAllStaged(ctx, dirtyBranchState, actions.CommitStagedProps{ 492 Message: message, 493 Date: ctx.QueryTime(), 494 AllowEmpty: false, 495 Force: false, 496 Name: d.Username(), 497 Email: d.Email(), 498 }) 499 if err != nil { 500 return err 501 } 502 503 // Nothing to stage, so fall back to CommitWorkingSet logic instead 504 if pendingCommit == nil { 505 return d.commitWorkingSet(ctx, dirtyBranchState, tx) 506 } 507 508 _, err = d.DoltCommit(ctx, ctx.GetCurrentDatabase(), tx, pendingCommit) 509 return err 510 } else { 511 return d.commitWorkingSet(ctx, dirtyBranchState, tx) 512 } 513 } 514 515 func (d *DoltSession) validateDoltCommit(ctx *sql.Context, dirtyBranchState *branchState) error { 516 currDb := ctx.GetCurrentDatabase() 517 if currDb == "" { 518 return fmt.Errorf("cannot dolt_commit with no database selected") 519 } 520 currDbBaseName, rev := SplitRevisionDbName(currDb) 521 dirtyDbBaseName := dirtyBranchState.dbState.dbName 522 523 if strings.ToLower(currDbBaseName) != strings.ToLower(dirtyDbBaseName) { 524 return fmt.Errorf("no changes to dolt_commit on database %s", currDbBaseName) 525 } 526 527 d.mu.Lock() 528 dbState, ok := d.dbStates[strings.ToLower(currDbBaseName)] 529 d.mu.Unlock() 530 531 if !ok { 532 return fmt.Errorf("no database state found for %s", currDbBaseName) 533 } 534 535 if rev == "" { 536 rev = dbState.checkedOutRevSpec 537 } 538 539 if strings.ToLower(rev) != strings.ToLower(dirtyBranchState.head) { 540 return fmt.Errorf("no changes to dolt_commit on branch %s", rev) 541 } 542 543 return nil 544 } 545 546 var ErrDirtyWorkingSets = errors.New("Cannot commit changes on more than one branch / database") 547 548 // dirtyWorkingSets returns all dirty working sets for this session 549 func (d *DoltSession) dirtyWorkingSets() []*branchState { 550 var dirtyStates []*branchState 551 for _, state := range d.dbStates { 552 for _, branchState := range state.heads { 553 if branchState.dirty { 554 dirtyStates = append(dirtyStates, branchState) 555 } 556 } 557 } 558 559 return dirtyStates 560 } 561 562 // CommitWorkingSet commits the working set for the transaction given, without creating a new dolt commit. 563 // Clients should typically use CommitTransaction, which performs additional checks, instead of this method. 564 func (d *DoltSession) CommitWorkingSet(ctx *sql.Context, dbName string, tx sql.Transaction) error { 565 commitFunc := func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) { 566 ws, err := dtx.Commit(ctx, workingSet, dbName) 567 return ws, nil, err 568 } 569 570 _, err := d.commitCurrentHead(ctx, dbName, tx, commitFunc) 571 return err 572 } 573 574 // commitWorkingSet commits the working set for the branch state given, without creating a new dolt commit. 575 func (d *DoltSession) commitWorkingSet(ctx *sql.Context, branchState *branchState, tx sql.Transaction) error { 576 commitFunc := func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) { 577 ws, err := dtx.Commit(ctx, workingSet, branchState.RevisionDbName()) 578 return ws, nil, err 579 } 580 581 _, err := d.commitBranchState(ctx, branchState, tx, commitFunc) 582 return err 583 } 584 585 // DoltCommit commits the working set and a new dolt commit with the properties given. 586 // Clients should typically use CommitTransaction, which performs additional checks, instead of this method. 587 func (d *DoltSession) DoltCommit( 588 ctx *sql.Context, 589 dbName string, 590 tx sql.Transaction, 591 commit *doltdb.PendingCommit, 592 ) (*doltdb.Commit, error) { 593 commitFunc := func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) { 594 ws, commit, err := dtx.DoltCommit( 595 ctx, 596 workingSet.WithWorkingRoot(commit.Roots.Working).WithStagedRoot(commit.Roots.Staged), 597 commit, 598 dbName) 599 if err != nil { 600 return nil, nil, err 601 } 602 603 return ws, commit, err 604 } 605 606 return d.commitCurrentHead(ctx, dbName, tx, commitFunc) 607 } 608 609 // doCommitFunc is a function to write to the database, which involves updating the working set and potentially 610 // updating HEAD with a new commit 611 type doCommitFunc func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) 612 613 // commitBranchState performs a commit for the branch state given, using the doCommitFunc provided 614 func (d *DoltSession) commitBranchState( 615 ctx *sql.Context, 616 branchState *branchState, 617 tx sql.Transaction, 618 commitFunc doCommitFunc, 619 ) (*doltdb.Commit, error) { 620 dtx, ok := tx.(*DoltTransaction) 621 if !ok { 622 return nil, fmt.Errorf("expected a DoltTransaction") 623 } 624 625 _, newCommit, err := commitFunc(ctx, dtx, branchState.WorkingSet()) 626 if err != nil { 627 return nil, err 628 } 629 630 // Anything that commits a transaction needs its current transaction state cleared so that the next statement starts 631 // a new transaction. This should in principle be done by the engine, but it currently only understands explicit 632 // COMMIT statements. Any other statements that commit a transaction, including stored procedures, needs to do this 633 // themselves. 634 ctx.SetTransaction(nil) 635 return newCommit, nil 636 } 637 638 // commitCurrentHead commits the current HEAD for the database given, using the doCommitFunc provided 639 func (d *DoltSession) commitCurrentHead(ctx *sql.Context, dbName string, tx sql.Transaction, commitFunc doCommitFunc) (*doltdb.Commit, error) { 640 branchState, ok, err := d.lookupDbState(ctx, dbName) 641 if err != nil { 642 return nil, err 643 } else if !ok { 644 return nil, sql.ErrDatabaseNotFound.New(dbName) 645 } 646 647 return d.commitBranchState(ctx, branchState, tx, commitFunc) 648 } 649 650 // PendingCommitAllStaged returns a pending commit with all tables staged. Returns nil if there are no changes to stage. 651 func (d *DoltSession) PendingCommitAllStaged(ctx *sql.Context, branchState *branchState, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) { 652 roots := branchState.roots() 653 654 var err error 655 roots, err = actions.StageAllTables(ctx, roots, true) 656 if err != nil { 657 return nil, err 658 } 659 660 return d.newPendingCommit(ctx, branchState, roots, props) 661 } 662 663 // NewPendingCommit returns a new |doltdb.PendingCommit| for the database named, using the roots given, adding any 664 // merge parent from an in progress merge as appropriate. The session working set is not updated with these new roots, 665 // but they are set in the returned |doltdb.PendingCommit|. If there are no changes staged, this method returns nil. 666 func (d *DoltSession) NewPendingCommit(ctx *sql.Context, dbName string, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) { 667 branchState, ok, err := d.lookupDbState(ctx, dbName) 668 if err != nil { 669 return nil, err 670 } 671 if !ok { 672 return nil, fmt.Errorf("session state for database %s not found", dbName) 673 } 674 675 return d.newPendingCommit(ctx, branchState, roots, props) 676 } 677 678 // newPendingCommit returns a new |doltdb.PendingCommit| for the database and head named by |branchState| 679 // See NewPendingCommit 680 func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchState, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) { 681 headCommit := branchState.headCommit 682 headHash, _ := headCommit.HashOf() 683 684 if branchState.WorkingSet() == nil { 685 return nil, doltdb.ErrOperationNotSupportedInDetachedHead 686 } 687 688 var mergeParentCommits []*doltdb.Commit 689 if branchState.WorkingSet().MergeCommitParents() { 690 mergeParentCommits = []*doltdb.Commit{branchState.WorkingSet().MergeState().Commit()} 691 } else if props.Amend { 692 numParentsHeadForAmend := headCommit.NumParents() 693 for i := 0; i < numParentsHeadForAmend; i++ { 694 optCmt, err := headCommit.GetParent(ctx, i) 695 if err != nil { 696 return nil, err 697 } 698 parentCommit, ok := optCmt.ToCommit() 699 if !ok { 700 return nil, doltdb.ErrGhostCommitEncountered 701 } 702 703 mergeParentCommits = append(mergeParentCommits, parentCommit) 704 } 705 706 // TODO: This is not the correct way to write this commit as an amend. While this commit is running 707 // the branch head moves backwards and concurrency control here is not principled. 708 newRoots, err := actions.ResetSoftToRef(ctx, branchState.dbData, "HEAD~1") 709 if err != nil { 710 return nil, err 711 } 712 713 err = d.SetWorkingSet(ctx, ctx.GetCurrentDatabase(), branchState.WorkingSet().WithStagedRoot(newRoots.Staged)) 714 if err != nil { 715 return nil, err 716 } 717 718 roots.Head = newRoots.Head 719 } 720 721 pendingCommit, err := actions.GetCommitStaged(ctx, roots, branchState.WorkingSet(), mergeParentCommits, branchState.dbData.Ddb, props) 722 if err != nil { 723 if props.Amend { 724 _, err = actions.ResetSoftToRef(ctx, branchState.dbData, headHash.String()) 725 if err != nil { 726 return nil, err 727 } 728 } 729 if _, ok := err.(actions.NothingStaged); err != nil && !ok { 730 return nil, err 731 } 732 } 733 734 return pendingCommit, nil 735 } 736 737 // Rollback rolls the given transaction back 738 func (d *DoltSession) Rollback(ctx *sql.Context, tx sql.Transaction) error { 739 // Nothing to do here, we just throw away all our work and let a new transaction begin next statement 740 d.clear() 741 return nil 742 } 743 744 // CreateSavepoint creates a new savepoint for this transaction with the name given. A previously created savepoint 745 // with the same name will be overwritten. 746 func (d *DoltSession) CreateSavepoint(ctx *sql.Context, tx sql.Transaction, savepointName string) error { 747 if TransactionsDisabled(ctx) { 748 return nil 749 } 750 751 dtx, ok := tx.(*DoltTransaction) 752 if !ok { 753 return fmt.Errorf("expected a DoltTransaction") 754 } 755 756 roots := make(map[string]doltdb.RootValue) 757 for _, db := range d.provider.DoltDatabases() { 758 branchState, ok, err := d.lookupDbState(ctx, db.Name()) 759 if err != nil { 760 return err 761 } 762 if !ok { 763 return fmt.Errorf("session state for database %s not found", db.Name()) 764 } 765 baseName, _ := SplitRevisionDbName(db.Name()) 766 roots[strings.ToLower(baseName)] = branchState.WorkingSet().WorkingRoot() 767 } 768 769 dtx.CreateSavepoint(savepointName, roots) 770 return nil 771 } 772 773 // RollbackToSavepoint sets this session's root to the one saved in the savepoint name. It's an error if no savepoint 774 // with that name exists. 775 func (d *DoltSession) RollbackToSavepoint(ctx *sql.Context, tx sql.Transaction, savepointName string) error { 776 if TransactionsDisabled(ctx) { 777 return nil 778 } 779 780 dtx, ok := tx.(*DoltTransaction) 781 if !ok { 782 return fmt.Errorf("expected a DoltTransaction") 783 } 784 785 roots := dtx.RollbackToSavepoint(savepointName) 786 if roots == nil { 787 return sql.ErrSavepointDoesNotExist.New(savepointName) 788 } 789 790 for dbName, root := range roots { 791 err := d.SetWorkingRoot(ctx, dbName, root) 792 if err != nil { 793 return err 794 } 795 } 796 797 return nil 798 } 799 800 // ReleaseSavepoint removes the savepoint name from the transaction. It's an error if no savepoint with that name 801 // exists. 802 func (d *DoltSession) ReleaseSavepoint(ctx *sql.Context, tx sql.Transaction, savepointName string) error { 803 if TransactionsDisabled(ctx) { 804 return nil 805 } 806 807 dtx, ok := tx.(*DoltTransaction) 808 if !ok { 809 return fmt.Errorf("expected a DoltTransaction") 810 } 811 812 existed := dtx.ClearSavepoint(savepointName) 813 if !existed { 814 return sql.ErrSavepointDoesNotExist.New(savepointName) 815 } 816 817 return nil 818 } 819 820 // GetDoltDB returns the *DoltDB for a given database by name 821 func (d *DoltSession) GetDoltDB(ctx *sql.Context, dbName string) (*doltdb.DoltDB, bool) { 822 branchState, ok, err := d.lookupDbState(ctx, dbName) 823 if err != nil { 824 return nil, false 825 } 826 if !ok { 827 return nil, false 828 } 829 830 return branchState.dbData.Ddb, true 831 } 832 833 func (d *DoltSession) GetDbData(ctx *sql.Context, dbName string) (env.DbData, bool) { 834 branchState, ok, err := d.lookupDbState(ctx, dbName) 835 if err != nil { 836 return env.DbData{}, false 837 } 838 if !ok { 839 return env.DbData{}, false 840 } 841 842 return branchState.dbData, true 843 } 844 845 // GetRoots returns the current roots for a given database associated with the session 846 func (d *DoltSession) GetRoots(ctx *sql.Context, dbName string) (doltdb.Roots, bool) { 847 branchState, ok, err := d.lookupDbState(ctx, dbName) 848 if err != nil { 849 return doltdb.Roots{}, false 850 } 851 if !ok { 852 return doltdb.Roots{}, false 853 } 854 855 return branchState.roots(), true 856 } 857 858 // ResolveRootForRef returns the root value for the ref given, which refers to either a commit spec or is one of the 859 // special identifiers |WORKING| or |STAGED| 860 // Returns the root value associated with the identifier given, its commit time and its hash string. The hash string 861 // for special identifiers |WORKING| or |STAGED| would be itself, 'WORKING' or 'STAGED', respectively. 862 func (d *DoltSession) ResolveRootForRef(ctx *sql.Context, dbName, refStr string) (doltdb.RootValue, *types.Timestamp, string, error) { 863 if refStr == doltdb.Working || refStr == doltdb.Staged { 864 // TODO: get from working set / staged update time 865 now := types.Timestamp(time.Now()) 866 // TODO: no current database 867 roots, _ := d.GetRoots(ctx, ctx.GetCurrentDatabase()) 868 if refStr == doltdb.Working { 869 return roots.Working, &now, refStr, nil 870 } else if refStr == doltdb.Staged { 871 return roots.Staged, &now, refStr, nil 872 } 873 } 874 875 var root doltdb.RootValue 876 var commitTime *types.Timestamp 877 cs, err := doltdb.NewCommitSpec(refStr) 878 if err != nil { 879 return nil, nil, "", err 880 } 881 882 dbData, ok := d.GetDbData(ctx, dbName) 883 if !ok { 884 return nil, nil, "", sql.ErrDatabaseNotFound.New(dbName) 885 } 886 887 headRef, err := d.CWBHeadRef(ctx, dbName) 888 if err == doltdb.ErrOperationNotSupportedInDetachedHead { 889 // leave head ref nil, we may not need it (commit hash) 890 } else if err != nil { 891 return nil, nil, "", err 892 } 893 894 optCmt, err := dbData.Ddb.Resolve(ctx, cs, headRef) 895 if err != nil { 896 return nil, nil, "", err 897 } 898 cm, ok := optCmt.ToCommit() 899 if !ok { 900 return nil, nil, "", doltdb.ErrGhostCommitRuntimeFailure 901 } 902 903 root, err = cm.GetRootValue(ctx) 904 if err != nil { 905 return nil, nil, "", err 906 } 907 908 meta, err := cm.GetCommitMeta(ctx) 909 if err != nil { 910 return nil, nil, "", err 911 } 912 913 t := meta.Time() 914 commitTime = (*types.Timestamp)(&t) 915 916 commitHash, err := cm.HashOf() 917 if err != nil { 918 return nil, nil, "", err 919 } 920 921 return root, commitTime, commitHash.String(), nil 922 } 923 924 // SetWorkingRoot sets a new root value for the session for the database named. This is the primary mechanism by which data 925 // changes are communicated to the engine and persisted back to disk. All data changes should be followed by a call to 926 // update the session's root value via this method. 927 // The dbName given should generally be a revision-qualified database name. 928 // Data changes contained in the |newRoot| aren't persisted until this session is committed. 929 func (d *DoltSession) SetWorkingRoot(ctx *sql.Context, dbName string, newRoot doltdb.RootValue) error { 930 branchState, _, err := d.lookupDbState(ctx, dbName) 931 if err != nil { 932 return err 933 } 934 935 if branchState.WorkingSet() == nil { 936 return doltdb.ErrOperationNotSupportedInDetachedHead 937 } 938 939 if rootsEqual(branchState.roots().Working, newRoot) { 940 return nil 941 } 942 943 if branchState.readOnly { 944 return fmt.Errorf("cannot set root on read-only session") 945 } 946 branchState.workingSet = branchState.WorkingSet().WithWorkingRoot(newRoot) 947 948 return d.SetWorkingSet(ctx, dbName, branchState.WorkingSet()) 949 } 950 951 // SetRoots sets new roots for the session for the database named. Typically, clients should only set the working root, 952 // via setRoot. This method is for clients that need to update more of the session state, such as the dolt_ functions. 953 // Unlike setting the working root, this method always marks the database state dirty. 954 func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error { 955 sessionState, _, err := d.LookupDbState(ctx, dbName) 956 if err != nil { 957 return err 958 } 959 960 if sessionState.WorkingSet() == nil { 961 return doltdb.ErrOperationNotSupportedInDetachedHead 962 } 963 964 workingSet := sessionState.WorkingSet().WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged) 965 return d.SetWorkingSet(ctx, dbName, workingSet) 966 } 967 968 func (d *DoltSession) SetFileSystem(fs filesys.Filesys) { 969 d.fs = fs 970 } 971 972 func (d *DoltSession) GetFileSystem() filesys.Filesys { 973 return d.fs 974 } 975 976 // SetWorkingSet sets the working set for this session. 977 func (d *DoltSession) SetWorkingSet(ctx *sql.Context, dbName string, ws *doltdb.WorkingSet) error { 978 if ws == nil { 979 panic("attempted to set a nil working set for the session") 980 } 981 982 branchState, _, err := d.lookupDbState(ctx, dbName) 983 if err != nil { 984 return err 985 } 986 if ws.Ref() != branchState.WorkingSet().Ref() { 987 return fmt.Errorf("must switch working sets with SwitchWorkingSet") 988 } 989 branchState.workingSet = ws 990 991 err = d.setDbSessionVars(ctx, branchState, true) 992 if err != nil { 993 return err 994 } 995 996 if writeSess := branchState.WriteSession(); writeSess != nil { 997 err = writeSess.SetWorkingSet(ctx, ws) 998 if err != nil { 999 return err 1000 } 1001 } 1002 1003 branchState.dirty = true 1004 return nil 1005 } 1006 1007 // SwitchWorkingSet switches to a new working set for this session. Unlike SetWorkingSet, this method expresses no 1008 // intention to eventually persist any uncommitted changes. Rather, this method only changes the in memory state of 1009 // this session. It's equivalent to starting a new session with the working set reference provided. If the current 1010 // session is dirty, this method returns an error. Clients can only switch branches with a clean working set, and so 1011 // must either commit or rollback any changes before attempting to switch working sets. 1012 func (d *DoltSession) SwitchWorkingSet( 1013 ctx *sql.Context, 1014 dbName string, 1015 wsRef ref.WorkingSetRef, 1016 ) error { 1017 headRef, err := wsRef.ToHeadRef() 1018 if err != nil { 1019 return err 1020 } 1021 1022 d.mu.Lock() 1023 1024 baseName, _ := SplitRevisionDbName(dbName) 1025 dbState, ok := d.dbStates[strings.ToLower(baseName)] 1026 if !ok { 1027 d.mu.Unlock() 1028 return sql.ErrDatabaseNotFound.New(dbName) 1029 } 1030 dbState.checkedOutRevSpec = headRef.GetPath() 1031 1032 d.mu.Unlock() 1033 1034 // bootstrap the db state as necessary 1035 branchState, ok, err := d.lookupDbState(ctx, baseName+DbRevisionDelimiter+headRef.GetPath()) 1036 if err != nil { 1037 return err 1038 } 1039 1040 if !ok { 1041 return sql.ErrDatabaseNotFound.New(dbName) 1042 } 1043 1044 ctx.SetCurrentDatabase(baseName) 1045 1046 return d.setDbSessionVars(ctx, branchState, false) 1047 } 1048 1049 func (d *DoltSession) WorkingSet(ctx *sql.Context, dbName string) (*doltdb.WorkingSet, error) { 1050 // TODO: need to make sure we use a revision qualified DB name here 1051 sessionState, _, err := d.LookupDbState(ctx, dbName) 1052 if err != nil { 1053 return nil, err 1054 } 1055 if sessionState.WorkingSet() == nil { 1056 return nil, doltdb.ErrOperationNotSupportedInDetachedHead 1057 } 1058 return sessionState.WorkingSet(), nil 1059 } 1060 1061 // GetHeadCommit returns the parent commit of the current session. 1062 func (d *DoltSession) GetHeadCommit(ctx *sql.Context, dbName string) (*doltdb.Commit, error) { 1063 branchState, ok, err := d.lookupDbState(ctx, dbName) 1064 if err != nil { 1065 return nil, err 1066 } 1067 if !ok { 1068 return nil, sql.ErrDatabaseNotFound.New(dbName) 1069 } 1070 1071 return branchState.headCommit, nil 1072 } 1073 1074 // SetSessionVariable is defined on sql.Session. We intercept it here to interpret the special semantics of the system 1075 // vars that we define. Otherwise we pass it on to the base implementation. 1076 func (d *DoltSession) SetSessionVariable(ctx *sql.Context, key string, value interface{}) error { 1077 if ok, db := IsHeadRefKey(key); ok { 1078 v, ok := value.(string) 1079 if !ok { 1080 return doltdb.ErrInvalidBranchOrHash 1081 } 1082 return d.setHeadRefSessionVar(ctx, db, v) 1083 } 1084 if IsReadOnlyVersionKey(key) { 1085 return sql.ErrSystemVariableReadOnly.New(key) 1086 } 1087 1088 if strings.ToLower(key) == "foreign_key_checks" { 1089 return d.setForeignKeyChecksSessionVar(ctx, key, value) 1090 } 1091 1092 return d.Session.SetSessionVariable(ctx, key, value) 1093 } 1094 1095 func (d *DoltSession) setHeadRefSessionVar(ctx *sql.Context, db, value string) error { 1096 headRef, err := ref.Parse(value) 1097 if err != nil { 1098 return err 1099 } 1100 1101 ws, err := ref.WorkingSetRefForHead(headRef) 1102 if err != nil { 1103 return err 1104 } 1105 err = d.SwitchWorkingSet(ctx, db, ws) 1106 if errors.Is(err, doltdb.ErrWorkingSetNotFound) { 1107 return fmt.Errorf("%w; %s: '%s'", doltdb.ErrBranchNotFound, err, value) 1108 } 1109 return err 1110 } 1111 1112 func (d *DoltSession) setForeignKeyChecksSessionVar(ctx *sql.Context, key string, value interface{}) error { 1113 d.mu.Lock() 1114 defer d.mu.Unlock() 1115 1116 convertedVal, _, err := sqltypes.Int64.Convert(value) 1117 if err != nil { 1118 return err 1119 } 1120 intVal := int64(0) 1121 if convertedVal != nil { 1122 intVal = convertedVal.(int64) 1123 } 1124 1125 if intVal == 0 { 1126 for _, dbState := range d.dbStates { 1127 for _, branchState := range dbState.heads { 1128 if ws := branchState.WriteSession(); ws != nil { 1129 opts := ws.GetOptions() 1130 opts.ForeignKeyChecksDisabled = true 1131 ws.SetOptions(opts) 1132 } 1133 } 1134 } 1135 } else if intVal == 1 { 1136 for _, dbState := range d.dbStates { 1137 for _, branchState := range dbState.heads { 1138 if ws := branchState.WriteSession(); ws != nil { 1139 opts := ws.GetOptions() 1140 opts.ForeignKeyChecksDisabled = false 1141 ws.SetOptions(opts) 1142 } 1143 } 1144 } 1145 } else { 1146 return sql.ErrInvalidSystemVariableValue.New("foreign_key_checks", intVal) 1147 } 1148 1149 return d.Session.SetSessionVariable(ctx, key, value) 1150 } 1151 1152 // addDB adds the database given to this session. This establishes a starting root value for this session, as well as 1153 // other state tracking metadata. 1154 func (d *DoltSession) addDB(ctx *sql.Context, db SqlDatabase) error { 1155 revisionQualifiedName := strings.ToLower(db.RevisionQualifiedName()) 1156 baseName, _ := SplitRevisionDbName(revisionQualifiedName) 1157 1158 DefineSystemVariablesForDB(baseName) 1159 1160 tx, usingDoltTransaction := d.GetTransaction().(*DoltTransaction) 1161 1162 d.mu.Lock() 1163 defer d.mu.Unlock() 1164 sessionState, sessionStateExists := d.dbStates[baseName] 1165 1166 // Before computing initial state for the DB, check to see if we have it in the cache 1167 var dbState InitialDbState 1168 var dbStateCached bool 1169 if usingDoltTransaction { 1170 nomsRoot, ok := tx.GetInitialRoot(baseName) 1171 if ok && sessionStateExists { 1172 dbState, dbStateCached = d.dbCache.GetCachedInitialDbState(doltdb.DataCacheKey{Hash: nomsRoot}, revisionQualifiedName) 1173 } 1174 } 1175 1176 if !dbStateCached { 1177 var err error 1178 dbState, err = db.InitialDBState(ctx) 1179 if err != nil { 1180 return err 1181 } 1182 } 1183 1184 if !sessionStateExists { 1185 sessionState = newEmptyDatabaseSessionState() 1186 d.dbStates[baseName] = sessionState 1187 1188 var err error 1189 sessionState.tmpFileDir, err = dbState.DbData.Rsw.TempTableFilesDir() 1190 if err != nil { 1191 if errors.Is(err, env.ErrDoltRepositoryNotFound) { 1192 return env.ErrFailedToAccessDB.New(dbState.Db.Name()) 1193 } 1194 return err 1195 } 1196 1197 sessionState.dbName = baseName 1198 1199 baseDb, ok := d.provider.BaseDatabase(ctx, baseName) 1200 if !ok { 1201 return fmt.Errorf("unable to find database %s, this is a bug", baseName) 1202 } 1203 1204 // The checkedOutRevSpec should be the checked out branch of the database if available, or the revision 1205 // string otherwise 1206 sessionState.checkedOutRevSpec, err = DefaultHead(baseName, baseDb) 1207 if err != nil { 1208 return err 1209 } 1210 } 1211 1212 if !dbStateCached && usingDoltTransaction { 1213 nomsRoot, ok := tx.GetInitialRoot(baseName) 1214 if ok { 1215 d.dbCache.CacheInitialDbState(doltdb.DataCacheKey{Hash: nomsRoot}, revisionQualifiedName, dbState) 1216 } 1217 } 1218 1219 branchState := sessionState.NewEmptyBranchState(db.Revision(), db.RevisionType()) 1220 1221 // TODO: get rid of all repo state reader / writer stuff. Until we do, swap out the reader with one of our own, and 1222 // the writer with one that errors out 1223 // TODO: this no longer gets called at session creation time, so the error handling below never occurs when a 1224 // database is deleted out from under a running server 1225 branchState.dbData = dbState.DbData 1226 adapter := NewSessionStateAdapter(d, db.Name(), dbState.Remotes, dbState.Branches, dbState.Backups) 1227 branchState.dbData.Rsr = adapter 1228 branchState.dbData.Rsw = adapter 1229 branchState.readOnly = dbState.ReadOnly 1230 1231 // TODO: figure out how to cast this to dsqle.SqlDatabase without creating import cycles 1232 // Or better yet, get rid of EditOptions from the database, it's a session setting 1233 nbf := types.Format_Default 1234 if branchState.dbData.Ddb != nil { 1235 nbf = branchState.dbData.Ddb.Format() 1236 } 1237 editOpts := db.(interface{ EditOptions() editor.Options }).EditOptions() 1238 1239 if dbState.Err != nil { 1240 sessionState.Err = dbState.Err 1241 } else if dbState.WorkingSet != nil { 1242 branchState.workingSet = dbState.WorkingSet 1243 1244 // TODO: this is pretty clunky, there is a silly dependency between InitialDbState and globalstate.StateProvider 1245 // that's hard to express with the current types 1246 stateProvider, ok := db.(globalstate.GlobalStateProvider) 1247 if !ok { 1248 return fmt.Errorf("database does not contain global state store") 1249 } 1250 sessionState.globalState = stateProvider.GetGlobalState() 1251 1252 tracker, err := sessionState.globalState.AutoIncrementTracker(ctx) 1253 if err != nil { 1254 return err 1255 } 1256 branchState.writeSession = writer.NewWriteSession(nbf, branchState.WorkingSet(), tracker, editOpts) 1257 } 1258 1259 // WorkingSet is nil in the case of a read only, detached head DB 1260 if dbState.HeadCommit != nil { 1261 headRoot, err := dbState.HeadCommit.GetRootValue(ctx) 1262 if err != nil { 1263 return err 1264 } 1265 branchState.headRoot = headRoot 1266 } else if dbState.HeadRoot != nil { 1267 branchState.headRoot = dbState.HeadRoot 1268 } 1269 1270 branchState.headCommit = dbState.HeadCommit 1271 return nil 1272 } 1273 1274 func (d *DoltSession) DatabaseCache(ctx *sql.Context) *DatabaseCache { 1275 return d.dbCache 1276 } 1277 1278 func (d *DoltSession) AddTemporaryTable(ctx *sql.Context, db string, tbl sql.Table) { 1279 d.tempTables[strings.ToLower(db)] = append(d.tempTables[strings.ToLower(db)], tbl) 1280 } 1281 1282 func (d *DoltSession) DropTemporaryTable(ctx *sql.Context, db, name string) { 1283 tables := d.tempTables[strings.ToLower(db)] 1284 for i, tbl := range d.tempTables[strings.ToLower(db)] { 1285 if strings.ToLower(tbl.Name()) == strings.ToLower(name) { 1286 tables = append(tables[:i], tables[i+1:]...) 1287 break 1288 } 1289 } 1290 d.tempTables[strings.ToLower(db)] = tables 1291 } 1292 1293 func (d *DoltSession) GetTemporaryTable(ctx *sql.Context, db, name string) (sql.Table, bool) { 1294 for _, tbl := range d.tempTables[strings.ToLower(db)] { 1295 if strings.ToLower(tbl.Name()) == strings.ToLower(name) { 1296 return tbl, true 1297 } 1298 } 1299 return nil, false 1300 } 1301 1302 // GetAllTemporaryTables returns all temp tables for this session. 1303 func (d *DoltSession) GetAllTemporaryTables(ctx *sql.Context, db string) ([]sql.Table, error) { 1304 return d.tempTables[strings.ToLower(db)], nil 1305 } 1306 1307 // CWBHeadRef returns the branch ref for this session HEAD for the database named 1308 func (d *DoltSession) CWBHeadRef(ctx *sql.Context, dbName string) (ref.DoltRef, error) { 1309 branchState, ok, err := d.lookupDbState(ctx, dbName) 1310 if err != nil { 1311 return nil, err 1312 } 1313 if !ok { 1314 return nil, sql.ErrDatabaseNotFound.New(dbName) 1315 } 1316 1317 if branchState.revisionType != RevisionTypeBranch { 1318 return nil, doltdb.ErrOperationNotSupportedInDetachedHead 1319 } 1320 1321 return ref.NewBranchRef(branchState.head), nil 1322 } 1323 1324 // CurrentHead returns the current head for the db named, which must be unqualified. Used for bootstrap resolving the 1325 // correct session head when a database name from the client is unqualified. 1326 func (d *DoltSession) CurrentHead(ctx *sql.Context, dbName string) (string, bool, error) { 1327 baseName := strings.ToLower(dbName) 1328 1329 d.mu.Lock() 1330 dbState, ok := d.dbStates[baseName] 1331 d.mu.Unlock() 1332 1333 if ok { 1334 return dbState.checkedOutRevSpec, true, nil 1335 } 1336 1337 return "", false, nil 1338 } 1339 1340 func (d *DoltSession) Username() string { 1341 return d.username 1342 } 1343 1344 func (d *DoltSession) Email() string { 1345 return d.email 1346 } 1347 1348 // setDbSessionVars updates the three session vars that track the value of the session root hashes 1349 func (d *DoltSession) setDbSessionVars(ctx *sql.Context, state *branchState, force bool) error { 1350 // This check is important even when we are forcing an update, because it updates the idea of staleness 1351 varsStale := d.dbSessionVarsStale(ctx, state) 1352 if !varsStale && !force { 1353 return nil 1354 } 1355 1356 baseName := state.dbState.dbName 1357 1358 // Different DBs have different requirements for what state is set, so we are maximally permissive on what's expected 1359 // in the state object here 1360 if state.WorkingSet() != nil { 1361 headRef, err := state.WorkingSet().Ref().ToHeadRef() 1362 if err != nil { 1363 return err 1364 } 1365 1366 err = d.Session.SetSessionVariable(ctx, HeadRefKey(baseName), headRef.String()) 1367 if err != nil { 1368 return err 1369 } 1370 } 1371 1372 roots := state.roots() 1373 1374 if roots.Working != nil { 1375 h, err := roots.Working.HashOf() 1376 if err != nil { 1377 return err 1378 } 1379 err = d.Session.SetSessionVariable(ctx, WorkingKey(baseName), h.String()) 1380 if err != nil { 1381 return err 1382 } 1383 } 1384 1385 if roots.Staged != nil { 1386 h, err := roots.Staged.HashOf() 1387 if err != nil { 1388 return err 1389 } 1390 err = d.Session.SetSessionVariable(ctx, StagedKey(baseName), h.String()) 1391 if err != nil { 1392 return err 1393 } 1394 } 1395 1396 if state.headCommit != nil { 1397 h, err := state.headCommit.HashOf() 1398 if err != nil { 1399 return err 1400 } 1401 err = d.Session.SetSessionVariable(ctx, HeadKey(baseName), h.String()) 1402 if err != nil { 1403 return err 1404 } 1405 } 1406 1407 return nil 1408 } 1409 1410 // dbSessionVarsStale returns whether the session vars for the database with the state provided need to be updated in 1411 // the session 1412 func (d *DoltSession) dbSessionVarsStale(ctx *sql.Context, state *branchState) bool { 1413 dtx, ok := ctx.GetTransaction().(*DoltTransaction) 1414 if !ok { 1415 return true 1416 } 1417 1418 return d.dbCache.CacheSessionVars(state, dtx) 1419 } 1420 1421 func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession { 1422 d.globalsConf = conf 1423 return &d 1424 } 1425 1426 // PersistGlobal implements sql.PersistableSession 1427 func (d *DoltSession) PersistGlobal(sysVarName string, value interface{}) error { 1428 if d.globalsConf == nil { 1429 return ErrSessionNotPersistable 1430 } 1431 1432 sysVar, _, err := validatePersistableSysVar(sysVarName) 1433 if err != nil { 1434 return err 1435 } 1436 1437 d.mu.Lock() 1438 defer d.mu.Unlock() 1439 return setPersistedValue(d.globalsConf, sysVar.GetName(), value) 1440 } 1441 1442 // RemovePersistedGlobal implements sql.PersistableSession 1443 func (d *DoltSession) RemovePersistedGlobal(sysVarName string) error { 1444 if d.globalsConf == nil { 1445 return ErrSessionNotPersistable 1446 } 1447 1448 sysVar, _, err := validatePersistableSysVar(sysVarName) 1449 if err != nil { 1450 return err 1451 } 1452 1453 d.mu.Lock() 1454 defer d.mu.Unlock() 1455 return d.globalsConf.Unset([]string{sysVar.GetName()}) 1456 } 1457 1458 // RemoveAllPersistedGlobals implements sql.PersistableSession 1459 func (d *DoltSession) RemoveAllPersistedGlobals() error { 1460 if d.globalsConf == nil { 1461 return ErrSessionNotPersistable 1462 } 1463 1464 allVars := make([]string, d.globalsConf.Size()) 1465 i := 0 1466 d.globalsConf.Iter(func(k, v string) bool { 1467 allVars[i] = k 1468 i++ 1469 return false 1470 }) 1471 1472 d.mu.Lock() 1473 defer d.mu.Unlock() 1474 return d.globalsConf.Unset(allVars) 1475 } 1476 1477 // RemoveAllPersistedGlobals implements sql.PersistableSession 1478 func (d *DoltSession) GetPersistedValue(k string) (interface{}, error) { 1479 if d.globalsConf == nil { 1480 return nil, ErrSessionNotPersistable 1481 } 1482 1483 return getPersistedValue(d.globalsConf, k) 1484 } 1485 1486 // SystemVariablesInConfig returns a list of System Variables associated with the session 1487 func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) { 1488 if d.globalsConf == nil { 1489 return nil, ErrSessionNotPersistable 1490 } 1491 sysVars, _, err := SystemVariablesInConfig(d.globalsConf) 1492 if err != nil { 1493 return nil, err 1494 } 1495 return sysVars, nil 1496 } 1497 1498 // GetBranch implements the interface branch_control.Context. 1499 func (d *DoltSession) GetBranch() (string, error) { 1500 // TODO: creating a new SQL context here is expensive 1501 ctx := sql.NewContext(context.Background(), sql.WithSession(d)) 1502 currentDb := d.Session.GetCurrentDatabase() 1503 1504 // no branch if there's no current db 1505 if currentDb == "" { 1506 return "", nil 1507 } 1508 1509 branchState, _, err := d.LookupDbState(ctx, currentDb) 1510 if err != nil { 1511 return "", err 1512 } 1513 1514 if branchState.WorkingSet() != nil { 1515 branchRef, err := branchState.WorkingSet().Ref().ToHeadRef() 1516 if err != nil { 1517 return "", err 1518 } 1519 return branchRef.GetPath(), nil 1520 } 1521 // A nil working set probably means that we're not on a branch (like we may be on a commit), so we return an empty string 1522 return "", nil 1523 } 1524 1525 // GetUser implements the interface branch_control.Context. 1526 func (d *DoltSession) GetUser() string { 1527 return d.Session.Client().User 1528 } 1529 1530 // GetHost implements the interface branch_control.Context. 1531 func (d *DoltSession) GetHost() string { 1532 return d.Session.Client().Address 1533 } 1534 1535 // GetController implements the interface branch_control.Context. 1536 func (d *DoltSession) GetController() *branch_control.Controller { 1537 return d.branchController 1538 } 1539 1540 // validatePersistedSysVar checks whether a system variable exists and is dynamic 1541 func validatePersistableSysVar(name string) (sql.SystemVariable, interface{}, error) { 1542 sysVar, val, ok := sql.SystemVariables.GetGlobal(name) 1543 if !ok { 1544 return nil, nil, sql.ErrUnknownSystemVariable.New(name) 1545 } 1546 if sysVar.IsReadOnly() { 1547 return nil, nil, sql.ErrSystemVariableReadOnly.New(name) 1548 } 1549 return sysVar, val, nil 1550 } 1551 1552 // getPersistedValue reads and converts a config value to the associated MysqlSystemVariable type 1553 func getPersistedValue(conf config.ReadableConfig, k string) (interface{}, error) { 1554 v, err := conf.GetString(k) 1555 if err != nil { 1556 return nil, err 1557 } 1558 1559 _, value, err := validatePersistableSysVar(k) 1560 if err != nil { 1561 return nil, err 1562 } 1563 1564 var res interface{} 1565 switch value.(type) { 1566 case int8: 1567 var tmp int64 1568 tmp, err = strconv.ParseInt(v, 10, 8) 1569 res = int8(tmp) 1570 case int, int16, int32, int64: 1571 res, err = strconv.ParseInt(v, 10, 64) 1572 case uint, uint8, uint16, uint32, uint64: 1573 res, err = strconv.ParseUint(v, 10, 64) 1574 case float32, float64: 1575 res, err = strconv.ParseFloat(v, 64) 1576 case bool: 1577 return nil, sql.ErrInvalidType.New(value) 1578 case string: 1579 return v, nil 1580 default: 1581 return nil, sql.ErrInvalidType.New(value) 1582 } 1583 1584 if err != nil { 1585 return nil, err 1586 } 1587 1588 return res, nil 1589 } 1590 1591 // setPersistedValue casts and persists a key value pair assuming thread safety 1592 func setPersistedValue(conf config.WritableConfig, key string, value interface{}) error { 1593 switch v := value.(type) { 1594 case int: 1595 return config.SetInt(conf, key, int64(v)) 1596 case int8: 1597 return config.SetInt(conf, key, int64(v)) 1598 case int16: 1599 return config.SetInt(conf, key, int64(v)) 1600 case int32: 1601 return config.SetInt(conf, key, int64(v)) 1602 case int64: 1603 return config.SetInt(conf, key, v) 1604 case uint: 1605 return config.SetUint(conf, key, uint64(v)) 1606 case uint8: 1607 return config.SetUint(conf, key, uint64(v)) 1608 case uint16: 1609 return config.SetUint(conf, key, uint64(v)) 1610 case uint32: 1611 return config.SetUint(conf, key, uint64(v)) 1612 case uint64: 1613 return config.SetUint(conf, key, v) 1614 case float32: 1615 return config.SetFloat(conf, key, float64(v)) 1616 case float64: 1617 return config.SetFloat(conf, key, v) 1618 case decimal.Decimal: 1619 f64, _ := v.Float64() 1620 return config.SetFloat(conf, key, f64) 1621 case string: 1622 return config.SetString(conf, key, v) 1623 case bool: 1624 if v { 1625 return config.SetInt(conf, key, 1) 1626 } else { 1627 return config.SetInt(conf, key, 0) 1628 } 1629 default: 1630 return sql.ErrInvalidType.New(v) 1631 } 1632 } 1633 1634 // SystemVariablesInConfig returns system variables from the persisted config 1635 // and a list of persisted keys that have no corresponding definition in 1636 // |sql.SystemVariables|. 1637 func SystemVariablesInConfig(conf config.ReadableConfig) ([]sql.SystemVariable, []string, error) { 1638 allVars := make([]sql.SystemVariable, conf.Size()) 1639 var missingKeys []string 1640 i := 0 1641 var err error 1642 var def interface{} 1643 conf.Iter(func(k, v string) bool { 1644 def, err = getPersistedValue(conf, k) 1645 if err != nil { 1646 if sql.ErrUnknownSystemVariable.Is(err) { 1647 err = nil 1648 missingKeys = append(missingKeys, k) 1649 return false 1650 } 1651 err = fmt.Errorf("key: '%s'; %w", k, err) 1652 return true 1653 } 1654 // getPersistedVal already checked for errors 1655 sysVar, _, _ := sql.SystemVariables.GetGlobal(k) 1656 sysVar.SetDefault(def) 1657 allVars[i] = sysVar 1658 i++ 1659 return false 1660 }) 1661 if err != nil { 1662 return nil, nil, err 1663 } 1664 return allVars, missingKeys, nil 1665 } 1666 1667 var initMu = sync.Mutex{} 1668 1669 func InitPersistedSystemVars(dEnv *env.DoltEnv) error { 1670 initMu.Lock() 1671 defer initMu.Unlock() 1672 1673 var globals config.ReadWriteConfig 1674 if localConf, ok := dEnv.Config.GetConfig(env.LocalConfig); ok { 1675 globals = config.NewPrefixConfig(localConf, env.SqlServerGlobalsPrefix) 1676 } else if globalConf, ok := dEnv.Config.GetConfig(env.GlobalConfig); ok { 1677 globals = config.NewPrefixConfig(globalConf, env.SqlServerGlobalsPrefix) 1678 } else { 1679 cli.Println("warning: no local or global Dolt configuration found; session is not persistable") 1680 globals = config.NewMapConfig(make(map[string]string)) 1681 } 1682 1683 persistedGlobalVars, missingKeys, err := SystemVariablesInConfig(globals) 1684 if err != nil { 1685 return err 1686 } 1687 for _, k := range missingKeys { 1688 cli.Printf("warning: persisted system variable %s was not loaded since its definition does not exist.\n", k) 1689 } 1690 sql.SystemVariables.AddSystemVariables(persistedGlobalVars) 1691 return nil 1692 } 1693 1694 // TransactionRoot returns the noms root for the given database in the current transaction 1695 func TransactionRoot(ctx *sql.Context, db SqlDatabase) (hash.Hash, error) { 1696 tx, ok := ctx.GetTransaction().(*DoltTransaction) 1697 // We don't have a real transaction in some cases (esp. PREPARE), in which case we need to use the tip of the data 1698 if !ok { 1699 return db.DbData().Ddb.NomsRoot(ctx) 1700 } 1701 1702 nomsRoot, ok := tx.GetInitialRoot(db.Name()) 1703 if !ok { 1704 return hash.Hash{}, fmt.Errorf("could not resolve initial root for database %s", db.Name()) 1705 } 1706 1707 return nomsRoot, nil 1708 } 1709 1710 // DefaultHead returns the head for the database given when one isn't specified 1711 func DefaultHead(baseName string, db SqlDatabase) (string, error) { 1712 head := "" 1713 1714 // First check the global variable for the default branch 1715 _, val, ok := sql.SystemVariables.GetGlobal(DefaultBranchKey(baseName)) 1716 if ok { 1717 head = val.(string) 1718 branchRef, err := ref.Parse(head) 1719 if err == nil { 1720 head = branchRef.GetPath() 1721 } else { 1722 head = "" 1723 // continue to below 1724 } 1725 } 1726 1727 // Fall back to the database's initially checked out branch 1728 if head == "" { 1729 rsr := db.DbData().Rsr 1730 if rsr != nil { 1731 headRef, err := rsr.CWBHeadRef() 1732 if err != nil { 1733 return "", err 1734 } 1735 head = headRef.GetPath() 1736 } 1737 } 1738 1739 if head == "" { 1740 head = db.Revision() 1741 } 1742 1743 return head, nil 1744 }