github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_verify_constraints.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dprocedures
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  
    23  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/merge"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
    29  	"github.com/dolthub/dolt/go/libraries/utils/argparser"
    30  	"github.com/dolthub/dolt/go/libraries/utils/set"
    31  )
    32  
    33  // doltVerifyConstraints is the stored procedure version for the CLI command `dolt constraints verify`.
    34  func doltVerifyConstraints(ctx *sql.Context, args ...string) (sql.RowIter, error) {
    35  	res, err := doDoltConstraintsVerify(ctx, args)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	return rowToIter(int64(res)), nil
    40  }
    41  
    42  func doDoltConstraintsVerify(ctx *sql.Context, args []string) (int, error) {
    43  	if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
    44  		return 1, err
    45  	}
    46  
    47  	dbName := ctx.GetCurrentDatabase()
    48  	dSess := dsess.DSessFromSess(ctx.Session)
    49  	workingSet, err := dSess.WorkingSet(ctx, dbName)
    50  	if err != nil {
    51  		return 1, err
    52  	}
    53  	workingRoot := workingSet.WorkingRoot()
    54  	headCommit, err := dSess.GetHeadCommit(ctx, dbName)
    55  	if err != nil {
    56  		return 1, err
    57  	}
    58  
    59  	apr, err := cli.CreateVerifyConstraintsArgParser("doltVerifyConstraints").Parse(args)
    60  	if err != nil {
    61  		return 1, err
    62  	}
    63  
    64  	verifyAll := apr.Contains(cli.AllFlag)
    65  	outputOnly := apr.Contains(cli.OutputOnlyFlag)
    66  
    67  	var comparingRoot doltdb.RootValue
    68  	if verifyAll {
    69  		comparingRoot, err = doltdb.EmptyRootValue(ctx, workingRoot.VRW(), workingRoot.NodeStore())
    70  		if err != nil {
    71  			return 1, err
    72  		}
    73  	} else {
    74  		comparingRoot, err = headCommit.GetRootValue(ctx)
    75  		if err != nil {
    76  			return 1, err
    77  		}
    78  	}
    79  
    80  	tableSet, err := parseTablesToCheck(ctx, workingRoot, apr)
    81  	if err != nil {
    82  		return 1, err
    83  	}
    84  
    85  	// Check for all non-FK constraint violations
    86  	newRoot, tablesWithViolations, err := calculateViolations(ctx, workingRoot, comparingRoot, tableSet)
    87  	if err != nil {
    88  		return 1, err
    89  	}
    90  
    91  	if !outputOnly {
    92  		err = dSess.SetWorkingRoot(ctx, dbName, newRoot)
    93  		if err != nil {
    94  			return 1, err
    95  		}
    96  	}
    97  
    98  	if tablesWithViolations.Size() == 0 {
    99  		// no violations were found
   100  		return 0, nil
   101  	}
   102  
   103  	// TODO: We only return 1 or 0 to indicate if there were any constraint violations or not. This isn't
   104  	//       super useful to customers, and not how the CLI command works. It would be better to return
   105  	//       results that indicate the total number of violations found for the specified tables, and
   106  	//       potentially also a human readable message.
   107  	return 1, nil
   108  }
   109  
   110  // calculateViolations calculates all constraint violations between |workingRoot| and |comparingRoot| for the
   111  // tables in |tableSet|. Returns the new root with the violations, and a set of table names that have violations.
   112  // Note that constraint violations detected for ALL existing tables will be stored in the dolt_constraint_violations
   113  // tables, but the returned set of table names will be a subset of |tableSet|.
   114  func calculateViolations(ctx *sql.Context, workingRoot, comparingRoot doltdb.RootValue, tableSet *set.StrSet) (doltdb.RootValue, *set.StrSet, error) {
   115  	var recordViolationsForTables map[string]struct{} = nil
   116  	if tableSet.Size() > 0 {
   117  		recordViolationsForTables = make(map[string]struct{})
   118  		for _, table := range tableSet.AsSlice() {
   119  			table = strings.ToLower(table)
   120  			recordViolationsForTables[table] = struct{}{}
   121  		}
   122  	}
   123  
   124  	mergeOpts := merge.MergeOpts{
   125  		IsCherryPick:              false,
   126  		KeepSchemaConflicts:       true,
   127  		ReverifyAllConstraints:    true,
   128  		RecordViolationsForTables: recordViolationsForTables,
   129  	}
   130  	mergeResults, err := merge.MergeRoots(ctx, comparingRoot, workingRoot, comparingRoot, workingRoot, comparingRoot,
   131  		editor.Options{}, mergeOpts)
   132  	if err != nil {
   133  		return nil, nil, fmt.Errorf("error calculating constraint violations: %w", err)
   134  	}
   135  
   136  	tablesWithViolations := set.NewStrSet(nil)
   137  	for _, tableName := range tableSet.AsSlice() {
   138  		table, ok, err := mergeResults.Root.GetTable(ctx, doltdb.TableName{Name: tableName})
   139  		if err != nil {
   140  			return nil, nil, err
   141  		}
   142  		if !ok {
   143  			return nil, nil, fmt.Errorf("table %s not found", tableName)
   144  		}
   145  		artifacts, err := table.GetArtifacts(ctx)
   146  		if err != nil {
   147  			return nil, nil, err
   148  		}
   149  		constraintViolationCount, err := artifacts.ConstraintViolationCount(ctx)
   150  		if err != nil {
   151  			return nil, nil, err
   152  		}
   153  		if constraintViolationCount > 0 {
   154  			tablesWithViolations.Add(tableName)
   155  		}
   156  	}
   157  
   158  	return mergeResults.Root, tablesWithViolations, nil
   159  }
   160  
   161  // parseTablesToCheck returns a set of table names to check for constraint violations. If no tables are specified, then
   162  // all tables in the root are returned.
   163  func parseTablesToCheck(ctx *sql.Context, workingRoot doltdb.RootValue, apr *argparser.ArgParseResults) (*set.StrSet, error) {
   164  	tableSet := set.NewStrSet(nil)
   165  	for _, val := range apr.Args {
   166  		_, tableName, ok, err := doltdb.GetTableInsensitive(ctx, workingRoot, val)
   167  		if err != nil {
   168  			return nil, err
   169  		}
   170  		if !ok {
   171  			return nil, sql.ErrTableNotFound.New(tableName)
   172  		}
   173  		tableSet.Add(tableName)
   174  	}
   175  
   176  	// If no tables were explicitly specified, then check all tables
   177  	if tableSet.Size() == 0 {
   178  		names, err := workingRoot.GetTableNames(ctx, doltdb.DefaultSchemaName)
   179  		if err != nil {
   180  			return nil, err
   181  		}
   182  		tableSet.Add(names...)
   183  	}
   184  
   185  	return tableSet, nil
   186  }