github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/fkconstrain/validate.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package fkconstrain
    16  
    17  import (
    18  	"context"
    19  	"time"
    20  
    21  	"github.com/dolthub/dolt/go/libraries/doltcore/diff"
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    24  	nomsdiff "github.com/dolthub/dolt/go/store/diff"
    25  	"github.com/dolthub/dolt/go/store/types"
    26  )
    27  
    28  func Validate(ctx context.Context, parentCommitRoot, root *doltdb.RootValue) error {
    29  	tblNames, err := root.GetTableNames(ctx)
    30  
    31  	if err != nil {
    32  		return err
    33  	}
    34  
    35  	tblNameToValidationInfo, err := getFKValidationInfo(ctx, root, tblNames)
    36  
    37  	if err != nil {
    38  		return err
    39  	}
    40  
    41  	for _, tblName := range tblNames {
    42  		// skip tables that don't have foreign key constraints
    43  		validationInfo, ok := tblNameToValidationInfo[tblName]
    44  
    45  		if !ok || len(validationInfo.allChecks) == 0 {
    46  			continue
    47  		}
    48  
    49  		// get iterator that, when possible, only iterates over changes
    50  		diffItr, err := getDiffItr(ctx, parentCommitRoot, root, tblName)
    51  
    52  		if err != nil {
    53  			return err
    54  		}
    55  
    56  		err = validateFKForDiffs(ctx, diffItr, validationInfo)
    57  
    58  		if err != nil {
    59  			return err
    60  		}
    61  	}
    62  
    63  	return nil
    64  }
    65  
    66  func validateFKForDiffs(ctx context.Context, itr diff.RowDiffer, info fkValidationInfo) error {
    67  	for {
    68  		diffs, ok, err := itr.GetDiffs(1, time.Minute)
    69  
    70  		if err != nil {
    71  			return err
    72  		}
    73  
    74  		if !ok {
    75  			return nil
    76  		}
    77  
    78  		d := diffs[0]
    79  		switch d.ChangeType {
    80  		case types.DiffChangeRemoved:
    81  			if len(info.referencedFK) == 0 {
    82  				break
    83  			}
    84  
    85  			tv, err := row.TaggedValuesFromTupleKeyAndValue(d.KeyValue.(types.Tuple), d.OldValue.(types.Tuple))
    86  
    87  			if err != nil {
    88  				return err
    89  			}
    90  
    91  			// when a row is removed we need to check that no rows were referencing that value
    92  			for _, check := range info.referencedFK {
    93  				err = check.Check(ctx, tv, nil)
    94  
    95  				if err != nil {
    96  					return err
    97  				}
    98  			}
    99  
   100  		case types.DiffChangeAdded:
   101  			if len(info.declaredFK) == 0 {
   102  				break
   103  			}
   104  
   105  			tv, err := row.TaggedValuesFromTupleKeyAndValue(d.KeyValue.(types.Tuple), d.NewValue.(types.Tuple))
   106  
   107  			if err != nil {
   108  				return err
   109  			}
   110  
   111  			// when a row is added we need to check that all the foreign key constraints declared on the table
   112  			// are satisfied
   113  			for _, check := range info.declaredFK {
   114  				err = check.Check(ctx, nil, tv)
   115  
   116  				if err != nil {
   117  					return err
   118  				}
   119  			}
   120  
   121  		case types.DiffChangeModified:
   122  			oldTV, newTV, colsChanged, err := parseDiff(d)
   123  
   124  			if err != nil {
   125  				return err
   126  			}
   127  
   128  			for _, check := range info.allChecks {
   129  				if check.ColsIntersectChanges(colsChanged) {
   130  					err = check.Check(ctx, oldTV, newTV)
   131  
   132  					if err != nil {
   133  						return err
   134  					}
   135  				}
   136  			}
   137  		}
   138  	}
   139  }
   140  
   141  func nextTagAndValue(itr *types.TupleIterator) (uint64, types.Value, error) {
   142  	_, tag, err := itr.NextUint64()
   143  
   144  	if err != nil {
   145  		return 0, nil, err
   146  	}
   147  
   148  	_, val, err := itr.Next()
   149  
   150  	if err != nil {
   151  		return 0, nil, err
   152  	}
   153  
   154  	return tag, val, nil
   155  }
   156  
   157  func parseDiff(d *nomsdiff.Difference) (oldTV, newTV row.TaggedValues, changes map[uint64]bool, err error) {
   158  	const MaxTag uint64 = (1 << 64) - 1
   159  
   160  	newTV = make(row.TaggedValues)
   161  	oldTV = make(row.TaggedValues)
   162  	changes = make(map[uint64]bool)
   163  
   164  	itr, err := d.KeyValue.(types.Tuple).Iterator()
   165  
   166  	if err != nil {
   167  		return nil, nil, nil, err
   168  	}
   169  
   170  	for itr.HasMore() {
   171  		tag, val, err := nextTagAndValue(itr)
   172  
   173  		if err != nil {
   174  			return nil, nil, nil, err
   175  		}
   176  
   177  		newTV[tag] = val
   178  		oldTV[tag] = val
   179  	}
   180  
   181  	oldVal := d.OldValue.(types.Tuple)
   182  	newVal := d.NewValue.(types.Tuple)
   183  
   184  	oldItr, err := oldVal.Iterator()
   185  
   186  	if err != nil {
   187  		return nil, nil, nil, err
   188  	}
   189  
   190  	newItr, err := newVal.Iterator()
   191  
   192  	if err != nil {
   193  		return nil, nil, nil, err
   194  	}
   195  
   196  	var currNewTag, currOldTag uint64
   197  	var currNewVal, currOldVal types.Value
   198  	for {
   199  		if currNewVal == nil {
   200  			if !newItr.HasMore() {
   201  				currNewTag = MaxTag
   202  			} else {
   203  				currNewTag, currNewVal, err = nextTagAndValue(newItr)
   204  				if err != nil {
   205  					return nil, nil, nil, err
   206  				}
   207  			}
   208  		}
   209  
   210  		if currOldVal == nil {
   211  			if !oldItr.HasMore() {
   212  				if currNewTag == MaxTag {
   213  					break
   214  				}
   215  
   216  				currOldTag = MaxTag
   217  			} else {
   218  				currOldTag, currOldVal, err = nextTagAndValue(oldItr)
   219  				if err != nil {
   220  					return nil, nil, nil, err
   221  				}
   222  			}
   223  		}
   224  
   225  		if currNewTag < currOldTag {
   226  			newTV[currNewTag] = currNewVal
   227  			oldTV[currNewTag] = types.NullValue
   228  			changes[currNewTag] = true
   229  			currNewVal = nil
   230  		} else if currOldTag < currNewTag {
   231  			newTV[currOldTag] = types.NullValue
   232  			oldTV[currOldTag] = currOldVal
   233  			changes[currOldTag] = true
   234  			currOldVal = nil
   235  		} else {
   236  			newTV[currNewTag] = currNewVal
   237  			oldTV[currOldTag] = currOldVal
   238  			changes[currNewTag] = !currOldVal.Equals(currNewVal)
   239  			currNewVal, currOldVal = nil, nil
   240  		}
   241  	}
   242  
   243  	return oldTV, newTV, changes, nil
   244  }
   245  
   246  type fkValidationInfo struct {
   247  	declaredFK   []declaredFKCheck
   248  	referencedFK []referencedFKCheck
   249  	allChecks    []fkCheck
   250  }
   251  
   252  func getFKValidationInfo(ctx context.Context, root *doltdb.RootValue, tblNames []string) (map[string]fkValidationInfo, error) {
   253  	tblToValInfo := make(map[string]fkValidationInfo)
   254  
   255  	fkColl, err := root.GetForeignKeyCollection(ctx)
   256  	if err != nil {
   257  		return nil, err
   258  	}
   259  
   260  	for _, tblName := range tblNames {
   261  		declaredFk, referencedByFk := fkColl.KeysForTable(tblName)
   262  
   263  		declaredFKChecks := make([]declaredFKCheck, 0, len(declaredFk))
   264  		referencedFKChecks := make([]referencedFKCheck, 0, len(referencedByFk))
   265  		allFKChecks := make([]fkCheck, 0, len(declaredFKChecks)+len(referencedFKChecks))
   266  
   267  		for _, dfk := range declaredFk {
   268  			chk, err := newDeclaredFKCheck(ctx, root, dfk)
   269  
   270  			if err != nil {
   271  				return nil, err
   272  			}
   273  
   274  			declaredFKChecks = append(declaredFKChecks, chk)
   275  			allFKChecks = append(allFKChecks, chk)
   276  		}
   277  
   278  		for _, rfk := range referencedByFk {
   279  			chk, err := newRefFKCheck(ctx, root, rfk)
   280  
   281  			if err != nil {
   282  				return nil, err
   283  			}
   284  
   285  			referencedFKChecks = append(referencedFKChecks, chk)
   286  			allFKChecks = append(allFKChecks, chk)
   287  		}
   288  
   289  		tblToValInfo[tblName] = fkValidationInfo{
   290  			declaredFK:   declaredFKChecks,
   291  			referencedFK: referencedFKChecks,
   292  			allChecks:    allFKChecks,
   293  		}
   294  	}
   295  
   296  	return tblToValInfo, nil
   297  }