github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/fkconstrain/checks.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package fkconstrain
    16  
    17  import (
    18  	"context"
    19  	"io"
    20  
    21  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    24  	"github.com/dolthub/dolt/go/store/types"
    25  )
    26  
    27  type fkCheck interface {
    28  	ColsIntersectChanges(changes map[uint64]bool) bool
    29  	Check(ctx context.Context, oldTV, newTV row.TaggedValues) error
    30  }
    31  
    32  type check struct {
    33  	nbf                 *types.NomsBinFormat
    34  	fk                  doltdb.ForeignKey
    35  	declaredIndex       schema.Index
    36  	declaredIndexRows   types.Map
    37  	referencedIndex     schema.Index
    38  	referencedIndexRows types.Map
    39  
    40  	colTags           []uint64
    41  	declTagsToRefTags map[uint64]uint64
    42  	refTagsToDeclTags map[uint64]uint64
    43  }
    44  
    45  func newCheck(ctx context.Context, root *doltdb.RootValue, colTags []uint64, fk doltdb.ForeignKey) (check, error) {
    46  	declTable, declSch, err := getTableAndSchema(ctx, root, fk.TableName)
    47  
    48  	if err != nil {
    49  		return check{}, err
    50  	}
    51  
    52  	refTable, refSch, err := getTableAndSchema(ctx, root, fk.ReferencedTableName)
    53  
    54  	if err != nil {
    55  		return check{}, err
    56  	}
    57  
    58  	refIdx := refSch.Indexes().GetByName(fk.ReferencedTableIndex)
    59  	refIdxRowData, err := refTable.GetIndexRowData(ctx, fk.ReferencedTableIndex)
    60  
    61  	if err != nil {
    62  		return check{}, err
    63  	}
    64  
    65  	declIdx := declSch.Indexes().GetByName(fk.TableIndex)
    66  	declIdxRowData, err := declTable.GetIndexRowData(ctx, fk.TableIndex)
    67  
    68  	if err != nil {
    69  		return check{}, err
    70  	}
    71  
    72  	declTagsToRefTags := make(map[uint64]uint64)
    73  	refTagsToDeclTags := make(map[uint64]uint64)
    74  	for i, declTag := range fk.TableColumns {
    75  		refTag := fk.ReferencedTableColumns[i]
    76  		declTagsToRefTags[declTag] = refTag
    77  		refTagsToDeclTags[refTag] = declTag
    78  	}
    79  
    80  	return check{
    81  		nbf:                 root.VRW().Format(),
    82  		fk:                  fk,
    83  		declaredIndex:       declIdx,
    84  		declaredIndexRows:   declIdxRowData,
    85  		referencedIndex:     refIdx,
    86  		referencedIndexRows: refIdxRowData,
    87  		colTags:             colTags,
    88  		declTagsToRefTags:   declTagsToRefTags,
    89  		refTagsToDeclTags:   refTagsToDeclTags,
    90  	}, nil
    91  }
    92  
    93  func (chk check) ColsIntersectChanges(changes map[uint64]bool) bool {
    94  	for _, tag := range chk.colTags {
    95  		if changes[tag] {
    96  			return true
    97  		}
    98  	}
    99  
   100  	return false
   101  }
   102  
   103  func (chk check) NewErrForKey(key types.Tuple) error {
   104  	return &GenericForeignKeyError{
   105  		tableName:           chk.fk.TableName,
   106  		referencedTableName: chk.fk.ReferencedTableName,
   107  		fkName:              chk.fk.Name,
   108  		keyStr:              key.String(),
   109  	}
   110  }
   111  
   112  type declaredFKCheck struct {
   113  	check
   114  }
   115  
   116  func newDeclaredFKCheck(ctx context.Context, root *doltdb.RootValue, fk doltdb.ForeignKey) (declaredFKCheck, error) {
   117  	chk, err := newCheck(ctx, root, fk.TableColumns, fk)
   118  
   119  	if err != nil {
   120  		return declaredFKCheck{}, err
   121  	}
   122  
   123  	return declaredFKCheck{
   124  		check: chk,
   125  	}, nil
   126  }
   127  
   128  // Check checks that the new tagged values coming from the declared table are present in the referenced index
   129  func (declFKC declaredFKCheck) Check(ctx context.Context, _, newTV row.TaggedValues) error {
   130  	indexColTags := declFKC.referencedIndex.IndexedColumnTags()
   131  	keyTupVals := make([]types.Value, len(indexColTags)*2)
   132  	for i, refTag := range indexColTags {
   133  		declTag := declFKC.refTagsToDeclTags[refTag]
   134  		keyTupVals[i*2] = types.Uint(refTag)
   135  
   136  		if val, ok := newTV[declTag]; ok && !types.IsNull(val) {
   137  			keyTupVals[i*2+1] = val
   138  		} else {
   139  			// full key is not present.  skip check
   140  			return nil
   141  		}
   142  	}
   143  
   144  	key, err := types.NewTuple(declFKC.nbf, keyTupVals...)
   145  
   146  	if err != nil {
   147  		return err
   148  	}
   149  
   150  	found, err := indexHasKey(ctx, declFKC.referencedIndexRows, key)
   151  
   152  	if err != nil {
   153  		return err
   154  	}
   155  
   156  	if !found {
   157  		return declFKC.NewErrForKey(key)
   158  	}
   159  
   160  	return nil
   161  }
   162  
   163  type referencedFKCheck struct {
   164  	check
   165  }
   166  
   167  func newRefFKCheck(ctx context.Context, root *doltdb.RootValue, fk doltdb.ForeignKey) (referencedFKCheck, error) {
   168  	chk, err := newCheck(ctx, root, fk.ReferencedTableColumns, fk)
   169  
   170  	if err != nil {
   171  		return referencedFKCheck{}, err
   172  	}
   173  
   174  	return referencedFKCheck{
   175  		check: chk,
   176  	}, nil
   177  }
   178  
   179  // Check checks that either the value coming from the old tagged values is present in a new row in the referenced index
   180  // or the value is no longer referenced by rows in the declared index.
   181  func (refFKC referencedFKCheck) Check(ctx context.Context, oldTV, _ row.TaggedValues) error {
   182  	indexColTags := refFKC.referencedIndex.IndexedColumnTags()
   183  	keyTupVals := make([]types.Value, len(refFKC.fk.ReferencedTableColumns)*2)
   184  	for i, tag := range indexColTags {
   185  		keyTupVals[i*2] = types.Uint(tag)
   186  
   187  		if val, ok := oldTV[tag]; ok && !types.IsNull(val) {
   188  			keyTupVals[i*2+1] = val
   189  		} else {
   190  			// full key is not present.  skip check
   191  			return nil
   192  		}
   193  	}
   194  
   195  	key, err := types.NewTuple(refFKC.nbf, keyTupVals...)
   196  
   197  	if err != nil {
   198  		return err
   199  	}
   200  
   201  	found, err := indexHasKey(ctx, refFKC.referencedIndexRows, key)
   202  
   203  	if err != nil {
   204  		return err
   205  	}
   206  
   207  	if found {
   208  		return nil
   209  	}
   210  
   211  	// If there is not a new value with the old key then make sure no rows in the table point to the old value
   212  	declIndexTags := refFKC.declaredIndex.IndexedColumnTags()
   213  	keyTupVals = make([]types.Value, len(indexColTags)*2)
   214  	for i, declTag := range declIndexTags {
   215  		refTag := refFKC.declTagsToRefTags[declTag]
   216  		keyTupVals[i*2] = types.Uint(declTag)
   217  
   218  		if val, ok := oldTV[refTag]; ok {
   219  			keyTupVals[i*2+1] = val
   220  		} else {
   221  			keyTupVals[i*2+1] = types.NullValue
   222  		}
   223  	}
   224  
   225  	key, err = types.NewTuple(refFKC.nbf, keyTupVals...)
   226  
   227  	if err != nil {
   228  		return err
   229  	}
   230  
   231  	found, err = indexHasKey(ctx, refFKC.declaredIndexRows, key)
   232  
   233  	if err != nil {
   234  		return err
   235  	}
   236  
   237  	if found {
   238  		// found a row referencing a key that no longer exists
   239  		return refFKC.NewErrForKey(key)
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  func indexHasKey(ctx context.Context, indexRows types.Map, key types.Tuple) (bool, error) {
   246  	itr, err := indexRows.IteratorFrom(ctx, key)
   247  
   248  	if err != nil {
   249  		return false, err
   250  	}
   251  
   252  	refKey, _, err := itr.NextTuple(ctx)
   253  
   254  	if err == io.EOF {
   255  		return false, nil
   256  	} else if err != nil {
   257  		return false, err
   258  	}
   259  
   260  	return refKey.StartsWith(key), nil
   261  }