github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/schema/alterschema/modifycolumn.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package alterschema
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/encoding"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
    28  	"github.com/dolthub/dolt/go/store/types"
    29  )
    30  
    31  // ModifyColumn modifies the column with the name given, replacing it with the new definition provided. A column with
    32  // the name given must exist in the schema of the table.
    33  func ModifyColumn(
    34  	ctx context.Context,
    35  	tbl *doltdb.Table,
    36  	existingCol schema.Column,
    37  	newCol schema.Column,
    38  	order *ColumnOrder,
    39  ) (*doltdb.Table, error) {
    40  	sch, err := tbl.GetSchema(ctx)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	if strings.ToLower(existingCol.Name) == strings.ToLower(newCol.Name) {
    46  		newCol.Name = existingCol.Name
    47  	}
    48  	if err := validateModifyColumn(ctx, tbl, existingCol, newCol); err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	// Modify statements won't include key info, so fill it in from the old column
    53  	if existingCol.IsPartOfPK {
    54  		newCol.IsPartOfPK = true
    55  		foundNotNullConstraint := false
    56  		for _, constraint := range newCol.Constraints {
    57  			if _, ok := constraint.(schema.NotNullConstraint); ok {
    58  				foundNotNullConstraint = true
    59  				break
    60  			}
    61  		}
    62  		if !foundNotNullConstraint {
    63  			newCol.Constraints = append(newCol.Constraints, schema.NotNullConstraint{})
    64  		}
    65  	}
    66  
    67  	newSchema, err := replaceColumnInSchema(sch, existingCol, newCol, order)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	updatedTable, err := updateTableWithModifiedColumn(ctx, tbl, sch, newSchema, existingCol, newCol)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	return updatedTable, nil
    78  }
    79  
    80  // validateModifyColumn returns an error if the column as specified cannot be added to the schema given.
    81  func validateModifyColumn(ctx context.Context, tbl *doltdb.Table, existingCol schema.Column, modifiedCol schema.Column) error {
    82  	sch, err := tbl.GetSchema(ctx)
    83  	if err != nil {
    84  		return err
    85  	}
    86  
    87  	if existingCol.Name != modifiedCol.Name {
    88  		cols := sch.GetAllCols()
    89  		err = cols.Iter(func(currColTag uint64, currCol schema.Column) (stop bool, err error) {
    90  			if currColTag == modifiedCol.Tag {
    91  				return false, nil
    92  			} else if strings.ToLower(currCol.Name) == strings.ToLower(modifiedCol.Name) {
    93  				return true, fmt.Errorf("A column with the name %s already exists.", modifiedCol.Name)
    94  			}
    95  
    96  			return false, nil
    97  		})
    98  		if err != nil {
    99  			return err
   100  		}
   101  	}
   102  
   103  	return nil
   104  }
   105  
   106  // updateTableWithModifiedColumn updates the existing table with the new schema. If the column type has changed, then
   107  // the data is updated.
   108  func updateTableWithModifiedColumn(ctx context.Context, tbl *doltdb.Table, oldSch, newSch schema.Schema, oldCol, modifiedCol schema.Column) (*doltdb.Table, error) {
   109  	vrw := tbl.ValueReadWriter()
   110  	newSchemaVal, err := encoding.MarshalSchemaAsNomsValue(ctx, vrw, newSch)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	rowData, err := tbl.GetRowData(ctx)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	if !oldCol.TypeInfo.Equals(modifiedCol.TypeInfo) {
   121  		if schema.IsKeyless(newSch) {
   122  			return nil, fmt.Errorf("keyless table column type alteration is not yet supported")
   123  		}
   124  		rowData, err = updateRowDataWithNewType(ctx, rowData, tbl.ValueReadWriter(), oldSch, newSch, oldCol, modifiedCol)
   125  		if err != nil {
   126  			return nil, err
   127  		}
   128  	} else if !modifiedCol.IsNullable() {
   129  		err = rowData.Iter(ctx, func(key, value types.Value) (stop bool, err error) {
   130  			r, err := row.FromNoms(newSch, key.(types.Tuple), value.(types.Tuple))
   131  			if err != nil {
   132  				return false, err
   133  			}
   134  			val, ok := r.GetColVal(modifiedCol.Tag)
   135  			if !ok || val == nil || val == types.NullValue {
   136  				return true, fmt.Errorf("cannot change column to NOT NULL when one or more values is NULL")
   137  			}
   138  			return false, nil
   139  		})
   140  		if err != nil {
   141  			return nil, err
   142  		}
   143  	}
   144  
   145  	indexData, err := tbl.GetIndexData(ctx)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  	var autoVal types.Value
   150  	if schema.HasAutoIncrement(newSch) {
   151  		autoVal, err = tbl.GetAutoIncrementValue(ctx)
   152  		if err != nil {
   153  			return nil, err
   154  		}
   155  	}
   156  	updatedTable, err := doltdb.NewTable(ctx, vrw, newSchemaVal, rowData, indexData, autoVal)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	if !oldCol.TypeInfo.Equals(modifiedCol.TypeInfo) {
   162  		// If we're modifying the primary key then all indexes are affected. Otherwise we just want to update the
   163  		// touched ones.
   164  		if modifiedCol.IsPartOfPK {
   165  			for _, index := range newSch.Indexes().AllIndexes() {
   166  				indexRowData, err := editor.RebuildIndex(ctx, updatedTable, index.Name())
   167  				if err != nil {
   168  					return nil, err
   169  				}
   170  				updatedTable, err = updatedTable.SetIndexRowData(ctx, index.Name(), indexRowData)
   171  				if err != nil {
   172  					return nil, err
   173  				}
   174  			}
   175  		} else {
   176  			for _, index := range newSch.Indexes().IndexesWithTag(modifiedCol.Tag) {
   177  				indexRowData, err := editor.RebuildIndex(ctx, updatedTable, index.Name())
   178  				if err != nil {
   179  					return nil, err
   180  				}
   181  				updatedTable, err = updatedTable.SetIndexRowData(ctx, index.Name(), indexRowData)
   182  				if err != nil {
   183  					return nil, err
   184  				}
   185  			}
   186  		}
   187  	}
   188  
   189  	return updatedTable, nil
   190  }
   191  
   192  // updateRowDataWithNewType returns a new map of row data containing the updated rows from the changed schema column type.
   193  func updateRowDataWithNewType(
   194  	ctx context.Context,
   195  	rowData types.Map,
   196  	vrw types.ValueReadWriter,
   197  	oldSch, newSch schema.Schema,
   198  	oldCol, newCol schema.Column,
   199  ) (types.Map, error) {
   200  	// If there are no rows then we can immediately return. All type conversions are valid for tables without rows, but
   201  	// when rows are present then it is no longer true. GetTypeConverter assumes that there are rows present, so it
   202  	// will return a failure on a type conversion that should work for the empty table.
   203  	if rowData.Len() == 0 {
   204  		return rowData, nil
   205  	}
   206  	convFunc, _, err := typeinfo.GetTypeConverter(ctx, oldCol.TypeInfo, newCol.TypeInfo)
   207  	if err != nil {
   208  		return types.EmptyMap, err
   209  	}
   210  
   211  	if !newCol.IsNullable() {
   212  		originalConvFunc := convFunc
   213  		convFunc = func(ctx context.Context, vrw types.ValueReadWriter, v types.Value) (types.Value, error) {
   214  			if v == nil || v == types.NullValue {
   215  				return nil, fmt.Errorf("cannot change column to NOT NULL when one or more values is NULL")
   216  			}
   217  			return originalConvFunc(ctx, vrw, v)
   218  		}
   219  	}
   220  
   221  	var lastKey types.Value
   222  	mapEditor := rowData.Edit()
   223  	err = rowData.Iter(ctx, func(key, value types.Value) (stop bool, err error) {
   224  		r, err := row.FromNoms(oldSch, key.(types.Tuple), value.(types.Tuple))
   225  		if err != nil {
   226  			return true, err
   227  		}
   228  		taggedVals, err := r.TaggedValues()
   229  		if err != nil {
   230  			return true, err
   231  		}
   232  		// We skip the "ok" check as nil is returned if the value does not exist, and we still want to check nil.
   233  		// The underscore is important, otherwise a missing value would result in a panic.
   234  		val, _ := taggedVals[oldCol.Tag]
   235  		delete(taggedVals, oldCol.Tag) // If there was no value then delete is a no-op so this is safe
   236  		newVal, err := convFunc(ctx, vrw, val)
   237  		if err != nil {
   238  			return true, err
   239  		}
   240  		// convFunc returns types.NullValue rather than nil so it's always safe to compare
   241  		if newVal.Equals(val) {
   242  			newRowKey, err := r.NomsMapKey(newSch).Value(ctx)
   243  			if err != nil {
   244  				return true, err
   245  			}
   246  			if newCol.IsPartOfPK && newRowKey.Equals(lastKey) {
   247  				return true, fmt.Errorf("pk violation when altering column type and rewriting values")
   248  			}
   249  			lastKey = newRowKey
   250  			return false, nil
   251  		} else if newVal != types.NullValue {
   252  			taggedVals[newCol.Tag] = newVal
   253  		}
   254  		r, err = row.New(rowData.Format(), newSch, taggedVals)
   255  		if err != nil {
   256  			return true, err
   257  		}
   258  
   259  		newRowKey, err := r.NomsMapKey(newSch).Value(ctx)
   260  		if err != nil {
   261  			return true, err
   262  		}
   263  		if newCol.IsPartOfPK {
   264  			mapEditor.Remove(key)
   265  			if newRowKey.Equals(lastKey) {
   266  				return true, fmt.Errorf("pk violation when altering column type and rewriting values")
   267  			}
   268  		}
   269  		lastKey = newRowKey
   270  		mapEditor.Set(newRowKey, r.NomsMapValue(newSch))
   271  		return false, nil
   272  	})
   273  	if err != nil {
   274  		return types.EmptyMap, err
   275  	}
   276  	return mapEditor.Map(ctx)
   277  }
   278  
   279  // replaceColumnInSchema replaces the column with the name given with its new definition, optionally reordering it.
   280  func replaceColumnInSchema(sch schema.Schema, oldCol schema.Column, newCol schema.Column, order *ColumnOrder) (schema.Schema, error) {
   281  	// If no order is specified, insert in the same place as the existing column
   282  	if order == nil {
   283  		prevColumn := ""
   284  		sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   285  			if col.Name == oldCol.Name {
   286  				if prevColumn == "" {
   287  					order = &ColumnOrder{First: true}
   288  				}
   289  				return true, nil
   290  			} else {
   291  				prevColumn = col.Name
   292  			}
   293  			return false, nil
   294  		})
   295  
   296  		if order == nil {
   297  			if prevColumn != "" {
   298  				order = &ColumnOrder{After: prevColumn}
   299  			} else {
   300  				return nil, fmt.Errorf("Couldn't find column %s", oldCol.Name)
   301  			}
   302  		}
   303  	}
   304  
   305  	var newCols []schema.Column
   306  	if order.First {
   307  		newCols = append(newCols, newCol)
   308  	}
   309  	sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   310  		if col.Name != oldCol.Name {
   311  			newCols = append(newCols, col)
   312  		}
   313  
   314  		if order.After == col.Name {
   315  			newCols = append(newCols, newCol)
   316  		}
   317  
   318  		return false, nil
   319  	})
   320  
   321  	collection := schema.NewColCollection(newCols...)
   322  
   323  	err := schema.ValidateForInsert(collection)
   324  	if err != nil {
   325  		return nil, err
   326  	}
   327  
   328  	newSch, err := schema.SchemaFromCols(collection)
   329  	if err != nil {
   330  		return nil, err
   331  	}
   332  	for _, index := range sch.Indexes().AllIndexes() {
   333  		tags := index.IndexedColumnTags()
   334  		for i := range tags {
   335  			if tags[i] == oldCol.Tag {
   336  				tags[i] = newCol.Tag
   337  			}
   338  		}
   339  		_, err = newSch.Indexes().AddIndexByColTags(index.Name(), tags, schema.IndexProperties{
   340  			IsUnique:      index.IsUnique(),
   341  			IsUserDefined: index.IsUserDefined(),
   342  			Comment:       index.Comment(),
   343  		})
   344  		if err != nil {
   345  			return nil, err
   346  		}
   347  	}
   348  	return newSch, nil
   349  }