github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/transactions.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sqle 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 23 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 24 "github.com/dolthub/dolt/go/libraries/doltcore/env" 25 "github.com/dolthub/dolt/go/libraries/doltcore/merge" 26 "github.com/dolthub/dolt/go/libraries/doltcore/ref" 27 "github.com/dolthub/dolt/go/store/datas" 28 "github.com/dolthub/dolt/go/store/hash" 29 ) 30 31 const ( 32 maxTxCommitRetries = 5 33 ) 34 35 type DoltTransaction struct { 36 startRoot *doltdb.RootValue 37 workingSet ref.WorkingSetRef 38 dbData env.DbData 39 savepoints []savepoint 40 } 41 42 type savepoint struct { 43 name string 44 root *doltdb.RootValue 45 } 46 47 func NewDoltTransaction(startRoot *doltdb.RootValue, workingSet ref.WorkingSetRef, dbData env.DbData) *DoltTransaction { 48 return &DoltTransaction{ 49 startRoot: startRoot, 50 workingSet: workingSet, 51 dbData: dbData, 52 } 53 } 54 55 func (tx DoltTransaction) String() string { 56 // TODO: return more info (hashes need caching) 57 return "DoltTransaction" 58 } 59 60 // Commit attempts to merge newRoot into the working set 61 // Uses the same algorithm as merge.Merger: 62 // |ws.root| is the root 63 // |newRoot| is the mergeRoot 64 // |tx.startRoot| is ancRoot 65 // if working set == ancRoot, attempt a fast-forward merge 66 func (tx *DoltTransaction) Commit(ctx *sql.Context, newRoot *doltdb.RootValue) (*doltdb.RootValue, error) { 67 for i := 0; i < maxTxCommitRetries; i++ { 68 ws, err := tx.dbData.Ddb.ResolveWorkingSet(ctx, tx.workingSet) 69 if err == doltdb.ErrWorkingSetNotFound { 70 // initial commit 71 err = tx.dbData.Ddb.UpdateWorkingSet(ctx, tx.workingSet, newRoot, hash.Hash{}) 72 if err == datas.ErrOptimisticLockFailed { 73 continue 74 } 75 } 76 77 if err != nil { 78 return nil, err 79 } 80 81 root := ws.RootValue() 82 83 hash, err := ws.Struct().Hash(tx.dbData.Ddb.Format()) 84 if err != nil { 85 return nil, err 86 } 87 88 if rootsEqual(root, tx.startRoot) { 89 // ff merge 90 err = tx.dbData.Ddb.UpdateWorkingSet(ctx, tx.workingSet, newRoot, hash) 91 if err == datas.ErrOptimisticLockFailed { 92 continue 93 } else if err != nil { 94 return nil, err 95 } 96 97 return tx.updateRepoStateFile(ctx, newRoot) 98 } 99 100 mergedRoot, stats, err := merge.MergeRoots(ctx, root, newRoot, tx.startRoot) 101 if err != nil { 102 return nil, err 103 } 104 105 for table, mergeStats := range stats { 106 if mergeStats.Conflicts > 0 { 107 // TODO: surface duplicate key errors as appropriate 108 return nil, fmt.Errorf("conflict in table %s", table) 109 } 110 } 111 112 err = tx.dbData.Ddb.UpdateWorkingSet(ctx, tx.workingSet, mergedRoot, hash) 113 if err == datas.ErrOptimisticLockFailed { 114 continue 115 } else if err != nil { 116 return nil, err 117 } 118 119 // TODO: this is not thread safe, but will not be necessary after migrating all clients away from using the 120 // working set stored in repo_state.json, so should be good enough for now 121 return tx.updateRepoStateFile(ctx, mergedRoot) 122 } 123 124 // TODO: different error type for retries exhausted 125 return nil, datas.ErrOptimisticLockFailed 126 } 127 128 func (tx *DoltTransaction) updateRepoStateFile(ctx *sql.Context, mergedRoot *doltdb.RootValue) (*doltdb.RootValue, error) { 129 hash, err := mergedRoot.HashOf() 130 if err != nil { 131 return nil, err 132 } 133 134 err = tx.dbData.Rsw.SetWorkingHash(ctx, hash) 135 if err != nil { 136 return nil, err 137 } 138 139 return mergedRoot, err 140 } 141 142 // CreateSavepoint creates a new savepoint with the name and root value given. If a savepoint with the name given 143 // already exists, it's overwritten. 144 func (tx *DoltTransaction) CreateSavepoint(name string, root *doltdb.RootValue) { 145 existing := tx.findSavepoint(name) 146 if existing >= 0 { 147 tx.savepoints = append(tx.savepoints[:existing], tx.savepoints[existing+1:]...) 148 } 149 tx.savepoints = append(tx.savepoints, savepoint{name, root}) 150 } 151 152 // findSavepoint returns the index of the savepoint with the name given, or -1 if it doesn't exist 153 func (tx *DoltTransaction) findSavepoint(name string) int { 154 for i, s := range tx.savepoints { 155 if strings.ToLower(s.name) == strings.ToLower(name) { 156 return i 157 } 158 } 159 return -1 160 } 161 162 // RollbackToSavepoint returns the root value associated with the savepoint name given, or nil if no such savepoint can 163 // be found. All savepoints created after the one being rolled back to are no longer accessible. 164 func (tx *DoltTransaction) RollbackToSavepoint(name string) *doltdb.RootValue { 165 existing := tx.findSavepoint(name) 166 if existing >= 0 { 167 // Clear out any savepoints past this one 168 tx.savepoints = tx.savepoints[:existing+1] 169 return tx.savepoints[existing].root 170 } 171 return nil 172 } 173 174 // ClearSavepoint removes the savepoint with the name given and returns the root value recorded there, or nil if no 175 // savepoint exists with that name. 176 func (tx *DoltTransaction) ClearSavepoint(name string) *doltdb.RootValue { 177 existing := tx.findSavepoint(name) 178 var existingRoot *doltdb.RootValue 179 if existing >= 0 { 180 existingRoot = tx.savepoints[existing].root 181 tx.savepoints = append(tx.savepoints[:existing], tx.savepoints[existing+1:]...) 182 } 183 return existingRoot 184 } 185 186 func rootsEqual(left, right *doltdb.RootValue) bool { 187 if left == nil || right == nil { 188 return false 189 } 190 191 lh, err := left.HashOf() 192 if err != nil { 193 return false 194 } 195 196 rh, err := right.HashOf() 197 if err != nil { 198 return false 199 } 200 201 return lh == rh 202 }