github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/fkconstrain/validate.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package fkconstrain 16 17 import ( 18 "context" 19 "time" 20 21 "github.com/dolthub/dolt/go/libraries/doltcore/diff" 22 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 23 "github.com/dolthub/dolt/go/libraries/doltcore/row" 24 nomsdiff "github.com/dolthub/dolt/go/store/diff" 25 "github.com/dolthub/dolt/go/store/types" 26 ) 27 28 func Validate(ctx context.Context, parentCommitRoot, root *doltdb.RootValue) error { 29 tblNames, err := root.GetTableNames(ctx) 30 31 if err != nil { 32 return err 33 } 34 35 tblNameToValidationInfo, err := getFKValidationInfo(ctx, root, tblNames) 36 37 if err != nil { 38 return err 39 } 40 41 for _, tblName := range tblNames { 42 // skip tables that don't have foreign key constraints 43 validationInfo, ok := tblNameToValidationInfo[tblName] 44 45 if !ok || len(validationInfo.allChecks) == 0 { 46 continue 47 } 48 49 // get iterator that, when possible, only iterates over changes 50 diffItr, err := getDiffItr(ctx, parentCommitRoot, root, tblName) 51 52 if err != nil { 53 return err 54 } 55 56 err = validateFKForDiffs(ctx, diffItr, validationInfo) 57 58 if err != nil { 59 return err 60 } 61 } 62 63 return nil 64 } 65 66 func validateFKForDiffs(ctx context.Context, itr diff.RowDiffer, info fkValidationInfo) error { 67 for { 68 diffs, ok, err := itr.GetDiffs(1, time.Minute) 69 70 if err != nil { 71 return err 72 } 73 74 if !ok { 75 return nil 76 } 77 78 d := diffs[0] 79 switch d.ChangeType { 80 case types.DiffChangeRemoved: 81 if len(info.referencedFK) == 0 { 82 break 83 } 84 85 tv, err := row.TaggedValuesFromTupleKeyAndValue(d.KeyValue.(types.Tuple), d.OldValue.(types.Tuple)) 86 87 if err != nil { 88 return err 89 } 90 91 // when a row is removed we need to check that no rows were referencing that value 92 for _, check := range info.referencedFK { 93 err = check.Check(ctx, tv, nil) 94 95 if err != nil { 96 return err 97 } 98 } 99 100 case types.DiffChangeAdded: 101 if len(info.declaredFK) == 0 { 102 break 103 } 104 105 tv, err := row.TaggedValuesFromTupleKeyAndValue(d.KeyValue.(types.Tuple), d.NewValue.(types.Tuple)) 106 107 if err != nil { 108 return err 109 } 110 111 // when a row is added we need to check that all the foreign key constraints declared on the table 112 // are satisfied 113 for _, check := range info.declaredFK { 114 err = check.Check(ctx, nil, tv) 115 116 if err != nil { 117 return err 118 } 119 } 120 121 case types.DiffChangeModified: 122 oldTV, newTV, colsChanged, err := parseDiff(d) 123 124 if err != nil { 125 return err 126 } 127 128 for _, check := range info.allChecks { 129 if check.ColsIntersectChanges(colsChanged) { 130 err = check.Check(ctx, oldTV, newTV) 131 132 if err != nil { 133 return err 134 } 135 } 136 } 137 } 138 } 139 } 140 141 func nextTagAndValue(itr *types.TupleIterator) (uint64, types.Value, error) { 142 _, tag, err := itr.NextUint64() 143 144 if err != nil { 145 return 0, nil, err 146 } 147 148 _, val, err := itr.Next() 149 150 if err != nil { 151 return 0, nil, err 152 } 153 154 return tag, val, nil 155 } 156 157 func parseDiff(d *nomsdiff.Difference) (oldTV, newTV row.TaggedValues, changes map[uint64]bool, err error) { 158 const MaxTag uint64 = (1 << 64) - 1 159 160 newTV = make(row.TaggedValues) 161 oldTV = make(row.TaggedValues) 162 changes = make(map[uint64]bool) 163 164 itr, err := d.KeyValue.(types.Tuple).Iterator() 165 166 if err != nil { 167 return nil, nil, nil, err 168 } 169 170 for itr.HasMore() { 171 tag, val, err := nextTagAndValue(itr) 172 173 if err != nil { 174 return nil, nil, nil, err 175 } 176 177 newTV[tag] = val 178 oldTV[tag] = val 179 } 180 181 oldVal := d.OldValue.(types.Tuple) 182 newVal := d.NewValue.(types.Tuple) 183 184 oldItr, err := oldVal.Iterator() 185 186 if err != nil { 187 return nil, nil, nil, err 188 } 189 190 newItr, err := newVal.Iterator() 191 192 if err != nil { 193 return nil, nil, nil, err 194 } 195 196 var currNewTag, currOldTag uint64 197 var currNewVal, currOldVal types.Value 198 for { 199 if currNewVal == nil { 200 if !newItr.HasMore() { 201 currNewTag = MaxTag 202 } else { 203 currNewTag, currNewVal, err = nextTagAndValue(newItr) 204 if err != nil { 205 return nil, nil, nil, err 206 } 207 } 208 } 209 210 if currOldVal == nil { 211 if !oldItr.HasMore() { 212 if currNewTag == MaxTag { 213 break 214 } 215 216 currOldTag = MaxTag 217 } else { 218 currOldTag, currOldVal, err = nextTagAndValue(oldItr) 219 if err != nil { 220 return nil, nil, nil, err 221 } 222 } 223 } 224 225 if currNewTag < currOldTag { 226 newTV[currNewTag] = currNewVal 227 oldTV[currNewTag] = types.NullValue 228 changes[currNewTag] = true 229 currNewVal = nil 230 } else if currOldTag < currNewTag { 231 newTV[currOldTag] = types.NullValue 232 oldTV[currOldTag] = currOldVal 233 changes[currOldTag] = true 234 currOldVal = nil 235 } else { 236 newTV[currNewTag] = currNewVal 237 oldTV[currOldTag] = currOldVal 238 changes[currNewTag] = !currOldVal.Equals(currNewVal) 239 currNewVal, currOldVal = nil, nil 240 } 241 } 242 243 return oldTV, newTV, changes, nil 244 } 245 246 type fkValidationInfo struct { 247 declaredFK []declaredFKCheck 248 referencedFK []referencedFKCheck 249 allChecks []fkCheck 250 } 251 252 func getFKValidationInfo(ctx context.Context, root *doltdb.RootValue, tblNames []string) (map[string]fkValidationInfo, error) { 253 tblToValInfo := make(map[string]fkValidationInfo) 254 255 fkColl, err := root.GetForeignKeyCollection(ctx) 256 if err != nil { 257 return nil, err 258 } 259 260 for _, tblName := range tblNames { 261 declaredFk, referencedByFk := fkColl.KeysForTable(tblName) 262 263 declaredFKChecks := make([]declaredFKCheck, 0, len(declaredFk)) 264 referencedFKChecks := make([]referencedFKCheck, 0, len(referencedByFk)) 265 allFKChecks := make([]fkCheck, 0, len(declaredFKChecks)+len(referencedFKChecks)) 266 267 for _, dfk := range declaredFk { 268 chk, err := newDeclaredFKCheck(ctx, root, dfk) 269 270 if err != nil { 271 return nil, err 272 } 273 274 declaredFKChecks = append(declaredFKChecks, chk) 275 allFKChecks = append(allFKChecks, chk) 276 } 277 278 for _, rfk := range referencedByFk { 279 chk, err := newRefFKCheck(ctx, root, rfk) 280 281 if err != nil { 282 return nil, err 283 } 284 285 referencedFKChecks = append(referencedFKChecks, chk) 286 allFKChecks = append(allFKChecks, chk) 287 } 288 289 tblToValInfo[tblName] = fkValidationInfo{ 290 declaredFK: declaredFKChecks, 291 referencedFK: referencedFKChecks, 292 allChecks: allFKChecks, 293 } 294 } 295 296 return tblToValInfo, nil 297 }