github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_verify_constraints.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dprocedures 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 23 "github.com/dolthub/dolt/go/cmd/dolt/cli" 24 "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" 25 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 26 "github.com/dolthub/dolt/go/libraries/doltcore/merge" 27 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 28 "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" 29 "github.com/dolthub/dolt/go/libraries/utils/argparser" 30 "github.com/dolthub/dolt/go/libraries/utils/set" 31 ) 32 33 // doltVerifyConstraints is the stored procedure version for the CLI command `dolt constraints verify`. 34 func doltVerifyConstraints(ctx *sql.Context, args ...string) (sql.RowIter, error) { 35 res, err := doDoltConstraintsVerify(ctx, args) 36 if err != nil { 37 return nil, err 38 } 39 return rowToIter(int64(res)), nil 40 } 41 42 func doDoltConstraintsVerify(ctx *sql.Context, args []string) (int, error) { 43 if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { 44 return 1, err 45 } 46 47 dbName := ctx.GetCurrentDatabase() 48 dSess := dsess.DSessFromSess(ctx.Session) 49 workingSet, err := dSess.WorkingSet(ctx, dbName) 50 if err != nil { 51 return 1, err 52 } 53 workingRoot := workingSet.WorkingRoot() 54 headCommit, err := dSess.GetHeadCommit(ctx, dbName) 55 if err != nil { 56 return 1, err 57 } 58 59 apr, err := cli.CreateVerifyConstraintsArgParser("doltVerifyConstraints").Parse(args) 60 if err != nil { 61 return 1, err 62 } 63 64 verifyAll := apr.Contains(cli.AllFlag) 65 outputOnly := apr.Contains(cli.OutputOnlyFlag) 66 67 var comparingRoot doltdb.RootValue 68 if verifyAll { 69 comparingRoot, err = doltdb.EmptyRootValue(ctx, workingRoot.VRW(), workingRoot.NodeStore()) 70 if err != nil { 71 return 1, err 72 } 73 } else { 74 comparingRoot, err = headCommit.GetRootValue(ctx) 75 if err != nil { 76 return 1, err 77 } 78 } 79 80 tableSet, err := parseTablesToCheck(ctx, workingRoot, apr) 81 if err != nil { 82 return 1, err 83 } 84 85 // Check for all non-FK constraint violations 86 newRoot, tablesWithViolations, err := calculateViolations(ctx, workingRoot, comparingRoot, tableSet) 87 if err != nil { 88 return 1, err 89 } 90 91 if !outputOnly { 92 err = dSess.SetWorkingRoot(ctx, dbName, newRoot) 93 if err != nil { 94 return 1, err 95 } 96 } 97 98 if tablesWithViolations.Size() == 0 { 99 // no violations were found 100 return 0, nil 101 } 102 103 // TODO: We only return 1 or 0 to indicate if there were any constraint violations or not. This isn't 104 // super useful to customers, and not how the CLI command works. It would be better to return 105 // results that indicate the total number of violations found for the specified tables, and 106 // potentially also a human readable message. 107 return 1, nil 108 } 109 110 // calculateViolations calculates all constraint violations between |workingRoot| and |comparingRoot| for the 111 // tables in |tableSet|. Returns the new root with the violations, and a set of table names that have violations. 112 // Note that constraint violations detected for ALL existing tables will be stored in the dolt_constraint_violations 113 // tables, but the returned set of table names will be a subset of |tableSet|. 114 func calculateViolations(ctx *sql.Context, workingRoot, comparingRoot doltdb.RootValue, tableSet *set.StrSet) (doltdb.RootValue, *set.StrSet, error) { 115 var recordViolationsForTables map[string]struct{} = nil 116 if tableSet.Size() > 0 { 117 recordViolationsForTables = make(map[string]struct{}) 118 for _, table := range tableSet.AsSlice() { 119 table = strings.ToLower(table) 120 recordViolationsForTables[table] = struct{}{} 121 } 122 } 123 124 mergeOpts := merge.MergeOpts{ 125 IsCherryPick: false, 126 KeepSchemaConflicts: true, 127 ReverifyAllConstraints: true, 128 RecordViolationsForTables: recordViolationsForTables, 129 } 130 mergeResults, err := merge.MergeRoots(ctx, comparingRoot, workingRoot, comparingRoot, workingRoot, comparingRoot, 131 editor.Options{}, mergeOpts) 132 if err != nil { 133 return nil, nil, fmt.Errorf("error calculating constraint violations: %w", err) 134 } 135 136 tablesWithViolations := set.NewStrSet(nil) 137 for _, tableName := range tableSet.AsSlice() { 138 table, ok, err := mergeResults.Root.GetTable(ctx, doltdb.TableName{Name: tableName}) 139 if err != nil { 140 return nil, nil, err 141 } 142 if !ok { 143 return nil, nil, fmt.Errorf("table %s not found", tableName) 144 } 145 artifacts, err := table.GetArtifacts(ctx) 146 if err != nil { 147 return nil, nil, err 148 } 149 constraintViolationCount, err := artifacts.ConstraintViolationCount(ctx) 150 if err != nil { 151 return nil, nil, err 152 } 153 if constraintViolationCount > 0 { 154 tablesWithViolations.Add(tableName) 155 } 156 } 157 158 return mergeResults.Root, tablesWithViolations, nil 159 } 160 161 // parseTablesToCheck returns a set of table names to check for constraint violations. If no tables are specified, then 162 // all tables in the root are returned. 163 func parseTablesToCheck(ctx *sql.Context, workingRoot doltdb.RootValue, apr *argparser.ArgParseResults) (*set.StrSet, error) { 164 tableSet := set.NewStrSet(nil) 165 for _, val := range apr.Args { 166 _, tableName, ok, err := doltdb.GetTableInsensitive(ctx, workingRoot, val) 167 if err != nil { 168 return nil, err 169 } 170 if !ok { 171 return nil, sql.ErrTableNotFound.New(tableName) 172 } 173 tableSet.Add(tableName) 174 } 175 176 // If no tables were explicitly specified, then check all tables 177 if tableSet.Size() == 0 { 178 names, err := workingRoot.GetTableNames(ctx, doltdb.DefaultSchemaName) 179 if err != nil { 180 return nil, err 181 } 182 tableSet.Add(names...) 183 } 184 185 return tableSet, nil 186 }