github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/merge/merge_schema.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package merge
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    24  )
    25  
    26  type conflictKind byte
    27  
    28  const (
    29  	TagCollision conflictKind = iota
    30  	NameCollision
    31  )
    32  
    33  type SchemaConflict struct {
    34  	TableName    string
    35  	ColConflicts []ColConflict
    36  	IdxConflicts []IdxConflict
    37  }
    38  
    39  var EmptySchConflicts = SchemaConflict{}
    40  
    41  func (sc SchemaConflict) Count() int {
    42  	return len(sc.ColConflicts) + len(sc.IdxConflicts)
    43  }
    44  
    45  func (sc SchemaConflict) AsError() error {
    46  	var b strings.Builder
    47  	b.WriteString(fmt.Sprintf("schema conflicts for table %s:\n", sc.TableName))
    48  	for _, c := range sc.ColConflicts {
    49  		b.WriteString(fmt.Sprintf("\t%s\n", c.String()))
    50  	}
    51  	for _, c := range sc.IdxConflicts {
    52  		b.WriteString(fmt.Sprintf("\t%s\n", c.String()))
    53  	}
    54  	return fmt.Errorf(b.String())
    55  }
    56  
    57  type ColConflict struct {
    58  	Kind         conflictKind
    59  	Ours, Theirs schema.Column
    60  }
    61  
    62  func (c ColConflict) String() string {
    63  	switch c.Kind {
    64  	case NameCollision:
    65  		return fmt.Sprintf("two columns with the name '%s'", c.Ours.Name)
    66  	case TagCollision:
    67  		return fmt.Sprintf("different column definitions for our column %s and their column %s", c.Ours.Name, c.Theirs.Name)
    68  	}
    69  	return ""
    70  }
    71  
    72  type IdxConflict struct {
    73  	Kind         conflictKind
    74  	Ours, Theirs schema.Index
    75  }
    76  
    77  func (c IdxConflict) String() string {
    78  	return ""
    79  }
    80  
    81  type FKConflict struct {
    82  	Kind         conflictKind
    83  	Ours, Theirs doltdb.ForeignKey
    84  }
    85  
    86  // SchemaMerge performs a three-way merge of ourSch, theirSch, and ancSch.
    87  func SchemaMerge(ourSch, theirSch, ancSch schema.Schema, tblName string) (sch schema.Schema, sc SchemaConflict, err error) {
    88  	// (sch - ancSch) ∪ (mergeSch - ancSch) ∪ (sch ∩ mergeSch)
    89  
    90  	sc = SchemaConflict{
    91  		TableName: tblName,
    92  	}
    93  
    94  	var mergedCC *schema.ColCollection
    95  	mergedCC, sc.ColConflicts, err = mergeColumns(ourSch.GetAllCols(), theirSch.GetAllCols(), ancSch.GetAllCols())
    96  	if err != nil {
    97  		return nil, EmptySchConflicts, err
    98  	}
    99  	if len(sc.ColConflicts) > 0 {
   100  		return nil, sc, nil
   101  	}
   102  
   103  	var mergedIdxs schema.IndexCollection
   104  	mergedIdxs, sc.IdxConflicts = mergeIndexes(mergedCC, ourSch, theirSch, ancSch)
   105  	if len(sc.IdxConflicts) > 0 {
   106  		return nil, sc, nil
   107  	}
   108  
   109  	sch, err = schema.SchemaFromCols(mergedCC)
   110  	if err != nil {
   111  		return nil, sc, err
   112  	}
   113  	_ = mergedIdxs.Iter(func(index schema.Index) (stop bool, err error) {
   114  		sch.Indexes().AddIndex(index)
   115  		return false, nil
   116  	})
   117  
   118  	return sch, sc, nil
   119  }
   120  
   121  // ForeignKeysMerge performs a three-way merge of (ourRoot, theirRoot, ancRoot) and using mergeRoot to validate FKs.
   122  func ForeignKeysMerge(ctx context.Context, mergedRoot, ourRoot, theirRoot, ancRoot *doltdb.RootValue) (*doltdb.ForeignKeyCollection, []FKConflict, error) {
   123  	ours, err := ourRoot.GetForeignKeyCollection(ctx)
   124  	if err != nil {
   125  		return nil, nil, err
   126  	}
   127  
   128  	theirs, err := theirRoot.GetForeignKeyCollection(ctx)
   129  	if err != nil {
   130  		return nil, nil, err
   131  	}
   132  
   133  	anc, err := ancRoot.GetForeignKeyCollection(ctx)
   134  	if err != nil {
   135  		return nil, nil, err
   136  	}
   137  
   138  	common, conflicts, err := foreignKeysInCommon(ours, theirs, anc)
   139  	if err != nil {
   140  		return nil, nil, err
   141  	}
   142  
   143  	ourNewFKs, err := fkCollSetDifference(ours, anc)
   144  	if err != nil {
   145  		return nil, nil, err
   146  	}
   147  
   148  	theirNewFKs, err := fkCollSetDifference(theirs, anc)
   149  	if err != nil {
   150  		return nil, nil, err
   151  	}
   152  
   153  	// check for conflicts between foreign keys added on each branch since the ancestor
   154  	_ = ourNewFKs.Iter(func(ourFK doltdb.ForeignKey) (stop bool, err error) {
   155  		theirFK, ok := theirNewFKs.GetByTags(ourFK.TableColumns, ourFK.ReferencedTableColumns)
   156  		if ok && !ourFK.DeepEquals(theirFK) {
   157  			// Foreign Keys are defined over the same tags,
   158  			// but are not exactly equal
   159  			conflicts = append(conflicts, FKConflict{
   160  				Kind:   TagCollision,
   161  				Ours:   ourFK,
   162  				Theirs: theirFK,
   163  			})
   164  		}
   165  
   166  		theirFK, ok = theirNewFKs.GetByNameCaseInsensitive(ourFK.Name)
   167  		if ok && !ourFK.EqualDefs(theirFK) {
   168  			// Two different Foreign Keys have the same name
   169  			conflicts = append(conflicts, FKConflict{
   170  				Kind:   NameCollision,
   171  				Ours:   ourFK,
   172  				Theirs: theirFK,
   173  			})
   174  		}
   175  		return false, err
   176  	})
   177  
   178  	err = ourNewFKs.Iter(func(ourFK doltdb.ForeignKey) (stop bool, err error) {
   179  		return false, common.AddKeys(ourFK)
   180  	})
   181  	if err != nil {
   182  		return nil, nil, err
   183  	}
   184  
   185  	err = theirNewFKs.Iter(func(theirFK doltdb.ForeignKey) (stop bool, err error) {
   186  		return false, common.AddKeys(theirFK)
   187  	})
   188  	if err != nil {
   189  		return nil, nil, err
   190  	}
   191  
   192  	common, err = pruneInvalidForeignKeys(ctx, common, mergedRoot)
   193  	if err != nil {
   194  		return nil, nil, err
   195  	}
   196  
   197  	return common, conflicts, err
   198  }
   199  
   200  func mergeColumns(ourCC, theirCC, ancCC *schema.ColCollection) (merged *schema.ColCollection, conflicts []ColConflict, err error) {
   201  	var common *schema.ColCollection
   202  	common, conflicts = columnsInCommon(ourCC, theirCC, ancCC)
   203  
   204  	ourNewCols := schema.ColCollectionSetDifference(ourCC, ancCC)
   205  	theirNewCols := schema.ColCollectionSetDifference(theirCC, ancCC)
   206  
   207  	// check for name conflicts between columns added on each branch since the ancestor
   208  	_ = ourNewCols.Iter(func(tag uint64, ourCol schema.Column) (stop bool, err error) {
   209  		theirCol, ok := theirNewCols.GetByNameCaseInsensitive(ourCol.Name)
   210  		if ok && ourCol.Tag != theirCol.Tag {
   211  			conflicts = append(conflicts, ColConflict{
   212  				Kind:   NameCollision,
   213  				Ours:   ourCol,
   214  				Theirs: theirCol,
   215  			})
   216  		}
   217  		return false, nil
   218  	})
   219  
   220  	if len(conflicts) > 0 {
   221  		return nil, conflicts, nil
   222  	}
   223  
   224  	// order of args here is important for correct column ordering in sch schema
   225  	// to be before any column in the intersection
   226  	// TODO: sch column ordering doesn't respect sql "MODIFY ... AFTER ..." statements
   227  	merged, err = schema.ColCollUnion(common, ourNewCols, theirNewCols)
   228  	if err != nil {
   229  		return nil, nil, err
   230  	}
   231  
   232  	return merged, conflicts, nil
   233  }
   234  
   235  func columnsInCommon(ourCC, theirCC, ancCC *schema.ColCollection) (common *schema.ColCollection, conflicts []ColConflict) {
   236  	common = schema.NewColCollection()
   237  	_ = ourCC.Iter(func(tag uint64, ourCol schema.Column) (stop bool, err error) {
   238  		theirCol, ok := theirCC.GetByTag(ourCol.Tag)
   239  		if !ok {
   240  			return false, nil
   241  		}
   242  
   243  		if ourCol.Equals(theirCol) {
   244  			common = common.Append(ourCol)
   245  			return false, nil
   246  		}
   247  
   248  		ancCol, ok := ancCC.GetByTag(ourCol.Tag)
   249  		if !ok {
   250  			// col added on our branch and their branch with different def
   251  			conflicts = append(conflicts, ColConflict{
   252  				Kind:   TagCollision,
   253  				Ours:   ourCol,
   254  				Theirs: theirCol,
   255  			})
   256  			return false, nil
   257  		}
   258  
   259  		if ancCol.Equals(theirCol) {
   260  			// col modified on our branch
   261  			col, ok := common.GetByNameCaseInsensitive(ourCol.Name)
   262  			if ok {
   263  				conflicts = append(conflicts, ColConflict{
   264  					Kind:   NameCollision,
   265  					Ours:   ourCol,
   266  					Theirs: col,
   267  				})
   268  			} else {
   269  				common = common.Append(ourCol)
   270  			}
   271  			return false, nil
   272  		}
   273  
   274  		if ancCol.Equals(ourCol) {
   275  			// col modified on their branch
   276  			col, ok := common.GetByNameCaseInsensitive(theirCol.Name)
   277  			if ok {
   278  				conflicts = append(conflicts, ColConflict{
   279  					Kind:   NameCollision,
   280  					Ours:   col,
   281  					Theirs: theirCol,
   282  				})
   283  			} else {
   284  				common = common.Append(theirCol)
   285  			}
   286  			return false, nil
   287  		}
   288  
   289  		// col modified on our branch and their branch with different def
   290  		conflicts = append(conflicts, ColConflict{
   291  			Kind:   TagCollision,
   292  			Ours:   ourCol,
   293  			Theirs: theirCol,
   294  		})
   295  		return false, nil
   296  	})
   297  
   298  	return common, conflicts
   299  }
   300  
   301  // assumes indexes are unique over their column sets
   302  func mergeIndexes(mergedCC *schema.ColCollection, ourSch, theirSch, ancSch schema.Schema) (merged schema.IndexCollection, conflicts []IdxConflict) {
   303  	merged, conflicts = indexesInCommon(mergedCC, ourSch.Indexes(), theirSch.Indexes(), ancSch.Indexes())
   304  
   305  	ourNewIdxs := indexCollSetDifference(ourSch.Indexes(), ancSch.Indexes(), mergedCC)
   306  	theirNewIdxs := indexCollSetDifference(theirSch.Indexes(), ancSch.Indexes(), mergedCC)
   307  
   308  	// check for conflicts between indexes added on each branch since the ancestor
   309  	_ = ourNewIdxs.Iter(func(ourIdx schema.Index) (stop bool, err error) {
   310  		theirIdx, ok := theirNewIdxs.GetByNameCaseInsensitive(ourIdx.Name())
   311  		if ok {
   312  			conflicts = append(conflicts, IdxConflict{
   313  				Kind:   NameCollision,
   314  				Ours:   ourIdx,
   315  				Theirs: theirIdx,
   316  			})
   317  		}
   318  		return false, nil
   319  	})
   320  
   321  	merged.AddIndex(ourNewIdxs.AllIndexes()...)
   322  	merged.AddIndex(theirNewIdxs.AllIndexes()...)
   323  
   324  	return merged, conflicts
   325  }
   326  
   327  func indexesInCommon(mergedCC *schema.ColCollection, ours, theirs, anc schema.IndexCollection) (common schema.IndexCollection, conflicts []IdxConflict) {
   328  	common = schema.NewIndexCollection(mergedCC)
   329  	_ = ours.Iter(func(ourIdx schema.Index) (stop bool, err error) {
   330  		idxTags := ourIdx.IndexedColumnTags()
   331  		for _, t := range idxTags {
   332  			// if column doesn't exist anymore, drop index
   333  			// however, it shouldn't be possible for an index
   334  			// over a dropped column to exist in the intersection
   335  			if _, ok := mergedCC.GetByTag(t); !ok {
   336  				return false, nil
   337  			}
   338  		}
   339  
   340  		theirIdx, ok := theirs.GetIndexByTags(idxTags...)
   341  		if !ok {
   342  			return false, nil
   343  		}
   344  
   345  		if ourIdx.Equals(theirIdx) {
   346  			common.AddIndex(ourIdx)
   347  			return false, nil
   348  		}
   349  
   350  		ancIdx, ok := anc.GetIndexByTags(idxTags...)
   351  
   352  		if !ok {
   353  			// index added on our branch and their branch with different defs, conflict
   354  			conflicts = append(conflicts, IdxConflict{
   355  				Kind:   TagCollision,
   356  				Ours:   ourIdx,
   357  				Theirs: theirIdx,
   358  			})
   359  			return false, nil
   360  		}
   361  
   362  		if ancIdx.Equals(theirIdx) {
   363  			// index modified on our branch
   364  			idx, ok := common.GetByNameCaseInsensitive(ourIdx.Name())
   365  			if ok {
   366  				conflicts = append(conflicts, IdxConflict{
   367  					Kind:   NameCollision,
   368  					Ours:   ourIdx,
   369  					Theirs: idx,
   370  				})
   371  			} else {
   372  				common.AddIndex(ourIdx)
   373  			}
   374  			return false, nil
   375  		}
   376  
   377  		if ancIdx.Equals(ourIdx) {
   378  			// index modified on their branch
   379  			idx, ok := common.GetByNameCaseInsensitive(theirIdx.Name())
   380  			if ok {
   381  				conflicts = append(conflicts, IdxConflict{
   382  					Kind:   NameCollision,
   383  					Ours:   idx,
   384  					Theirs: theirIdx,
   385  				})
   386  			} else {
   387  				common.AddIndex(theirIdx)
   388  			}
   389  			return false, nil
   390  		}
   391  
   392  		// index modified on our branch and their branch, conflict
   393  		conflicts = append(conflicts, IdxConflict{
   394  			Kind:   TagCollision,
   395  			Ours:   ourIdx,
   396  			Theirs: theirIdx,
   397  		})
   398  		return false, nil
   399  	})
   400  	return common, conflicts
   401  }
   402  
   403  func indexCollSetDifference(left, right schema.IndexCollection, cc *schema.ColCollection) (d schema.IndexCollection) {
   404  	d = schema.NewIndexCollection(cc)
   405  	_ = left.Iter(func(idx schema.Index) (stop bool, err error) {
   406  		idxTags := idx.IndexedColumnTags()
   407  		for _, t := range idxTags {
   408  			// if column doesn't exist anymore, drop index
   409  			if _, ok := cc.GetByTag(t); !ok {
   410  				return false, nil
   411  			}
   412  		}
   413  
   414  		_, ok := right.GetIndexByTags(idxTags...)
   415  		if !ok {
   416  			d.AddIndex(idx)
   417  		}
   418  		return false, nil
   419  	})
   420  	return d
   421  }
   422  
   423  func foreignKeysInCommon(ourFKs, theirFKs, ancFKs *doltdb.ForeignKeyCollection) (common *doltdb.ForeignKeyCollection, conflicts []FKConflict, err error) {
   424  	common, _ = doltdb.NewForeignKeyCollection()
   425  	err = ourFKs.Iter(func(ours doltdb.ForeignKey) (stop bool, err error) {
   426  		theirs, ok := theirFKs.GetByTags(ours.TableColumns, ours.ReferencedTableColumns)
   427  		if !ok {
   428  			return false, nil
   429  		}
   430  
   431  		if theirs.EqualDefs(ours) {
   432  			err = common.AddKeys(ours)
   433  			return false, err
   434  		}
   435  
   436  		anc, ok := ancFKs.GetByTags(ours.TableColumns, ours.ReferencedTableColumns)
   437  		if !ok {
   438  			// FKs added on both branch with different defs
   439  			conflicts = append(conflicts, FKConflict{
   440  				Kind:   TagCollision,
   441  				Ours:   ours,
   442  				Theirs: theirs,
   443  			})
   444  		}
   445  
   446  		if theirs.EqualDefs(anc) {
   447  			// FK modified on our branch since the ancestor
   448  			fk, ok := common.GetByNameCaseInsensitive(ours.Name)
   449  			if ok {
   450  				conflicts = append(conflicts, FKConflict{
   451  					Kind:   NameCollision,
   452  					Ours:   ours,
   453  					Theirs: fk,
   454  				})
   455  			} else {
   456  				err = common.AddKeys(ours)
   457  			}
   458  			return false, err
   459  		}
   460  
   461  		if ours.EqualDefs(anc) {
   462  			// FK modified on their branch since the ancestor
   463  			fk, ok := common.GetByNameCaseInsensitive(theirs.Name)
   464  			if ok {
   465  				conflicts = append(conflicts, FKConflict{
   466  					Kind:   NameCollision,
   467  					Ours:   fk,
   468  					Theirs: theirs,
   469  				})
   470  			} else {
   471  				err = common.AddKeys(theirs)
   472  			}
   473  			return false, err
   474  		}
   475  
   476  		// FKs modified on both branch with different defs
   477  		conflicts = append(conflicts, FKConflict{
   478  			Kind:   TagCollision,
   479  			Ours:   ours,
   480  			Theirs: theirs,
   481  		})
   482  		return false, nil
   483  	})
   484  
   485  	if err != nil {
   486  		return nil, nil, err
   487  	}
   488  
   489  	return common, conflicts, nil
   490  }
   491  
   492  func fkCollSetDifference(left, right *doltdb.ForeignKeyCollection) (d *doltdb.ForeignKeyCollection, err error) {
   493  	d, _ = doltdb.NewForeignKeyCollection()
   494  	err = left.Iter(func(fk doltdb.ForeignKey) (stop bool, err error) {
   495  		_, ok := right.GetByTags(fk.TableColumns, fk.ReferencedTableColumns)
   496  		if !ok {
   497  			err = d.AddKeys(fk)
   498  		}
   499  		return false, err
   500  	})
   501  
   502  	if err != nil {
   503  		return nil, err
   504  	}
   505  
   506  	return d, nil
   507  }
   508  
   509  // pruneInvalidForeignKeys removes from a ForeignKeyCollection any ForeignKey whose parent/child table/columns have been removed.
   510  func pruneInvalidForeignKeys(ctx context.Context, fkColl *doltdb.ForeignKeyCollection, mergedRoot *doltdb.RootValue) (pruned *doltdb.ForeignKeyCollection, err error) {
   511  	pruned, _ = doltdb.NewForeignKeyCollection()
   512  	err = fkColl.Iter(func(fk doltdb.ForeignKey) (stop bool, err error) {
   513  		parentTbl, ok, err := mergedRoot.GetTable(ctx, fk.ReferencedTableName)
   514  		if err != nil || !ok {
   515  			return false, err
   516  		}
   517  		parentSch, err := parentTbl.GetSchema(ctx)
   518  		if err != nil {
   519  			return false, err
   520  		}
   521  		for _, tag := range fk.ReferencedTableColumns {
   522  			if _, ok := parentSch.GetAllCols().GetByTag(tag); !ok {
   523  				return false, nil
   524  			}
   525  		}
   526  
   527  		childTbl, ok, err := mergedRoot.GetTable(ctx, fk.TableName)
   528  		if err != nil || !ok {
   529  			return false, err
   530  		}
   531  		childSch, err := childTbl.GetSchema(ctx)
   532  		if err != nil {
   533  			return false, err
   534  		}
   535  		for _, tag := range fk.TableColumns {
   536  			if _, ok := childSch.GetAllCols().GetByTag(tag); !ok {
   537  				return false, nil
   538  			}
   539  		}
   540  
   541  		err = pruned.AddKeys(fk)
   542  		return false, err
   543  	})
   544  
   545  	if err != nil {
   546  		return nil, err
   547  	}
   548  
   549  	return pruned, nil
   550  }