github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/schema/alterschema/addcolumn.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package alterschema
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    27  	"github.com/dolthub/dolt/go/store/types"
    28  )
    29  
    30  // Nullable represents whether a column can have a null value.
    31  type Nullable bool
    32  
    33  const (
    34  	NotNull Nullable = false
    35  	Null    Nullable = true
    36  )
    37  
    38  // Clone of sql.ColumnOrder to avoid a dependency on sql here
    39  type ColumnOrder struct {
    40  	First bool
    41  	After string
    42  }
    43  
    44  // Adds a new column to the schema given and returns the new table value. Non-null column additions rewrite the entire
    45  // table, since we must write a value for each row. If the column is not nullable, a default value must be provided.
    46  //
    47  // Returns an error if the column added conflicts with the existing schema in tag or name.
    48  func AddColumnToTable(ctx context.Context, root *doltdb.RootValue, tbl *doltdb.Table, tblName string, tag uint64, newColName string, typeInfo typeinfo.TypeInfo, nullable Nullable, defaultVal, comment string, order *ColumnOrder) (*doltdb.Table, error) {
    49  	sch, err := tbl.GetSchema(ctx)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	if schema.IsKeyless(sch) {
    55  		return nil, ErrKeylessAltTbl
    56  	}
    57  
    58  	if err := validateNewColumn(ctx, root, tbl, tblName, tag, newColName, typeInfo, nullable, defaultVal); err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	newSchema, err := addColumnToSchema(sch, tag, newColName, typeInfo, nullable, order, defaultVal, comment)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return updateTableWithNewSchema(ctx, tblName, tbl, tag, newSchema, defaultVal)
    68  }
    69  
    70  // updateTableWithNewSchema updates the existing table with a new schema and new values for the new column as necessary,
    71  // and returns the new table.
    72  func updateTableWithNewSchema(ctx context.Context, tblName string, tbl *doltdb.Table, tag uint64, newSchema schema.Schema, defaultVal string) (*doltdb.Table, error) {
    73  	var err error
    74  	tbl, err = tbl.UpdateSchema(ctx, newSchema)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	tbl, err = applyDefaultValue(ctx, tblName, tbl, tag, newSchema)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	return tbl, nil
    85  }
    86  
    87  // addColumnToSchema creates a new schema with a column as specified by the params.
    88  func addColumnToSchema(sch schema.Schema, tag uint64, newColName string, typeInfo typeinfo.TypeInfo, nullable Nullable, order *ColumnOrder, defaultVal, comment string) (schema.Schema, error) {
    89  	newCol, err := createColumn(nullable, newColName, tag, typeInfo, defaultVal, comment)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	var newCols []schema.Column
    95  	if order != nil && order.First {
    96  		newCols = append(newCols, newCol)
    97  	}
    98  	sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
    99  		newCols = append(newCols, col)
   100  		if order != nil && order.After == col.Name {
   101  			newCols = append(newCols, newCol)
   102  		}
   103  		return false, nil
   104  	})
   105  	if order == nil {
   106  		newCols = append(newCols, newCol)
   107  	}
   108  
   109  	collection := schema.NewColCollection(newCols...)
   110  
   111  	err = schema.ValidateForInsert(collection)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	newSch, err := schema.SchemaFromCols(collection)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	newSch.Indexes().AddIndex(sch.Indexes().AllIndexes()...)
   121  
   122  	return newSch, nil
   123  }
   124  
   125  func createColumn(nullable Nullable, newColName string, tag uint64, typeInfo typeinfo.TypeInfo, defaultVal, comment string) (schema.Column, error) {
   126  	if nullable {
   127  		return schema.NewColumnWithTypeInfo(newColName, tag, typeInfo, false, defaultVal, false, comment)
   128  	} else {
   129  		return schema.NewColumnWithTypeInfo(newColName, tag, typeInfo, false, defaultVal, false, comment, schema.NotNullConstraint{})
   130  	}
   131  }
   132  
   133  // ValidateNewColumn returns an error if the column as specified cannot be added to the schema given.
   134  func validateNewColumn(ctx context.Context, root *doltdb.RootValue, tbl *doltdb.Table, tblName string, tag uint64, newColName string, typeInfo typeinfo.TypeInfo, nullable Nullable, defaultVal string) error {
   135  	if typeInfo == nil {
   136  		return fmt.Errorf(`typeinfo may not be nil`)
   137  	}
   138  
   139  	sch, err := tbl.GetSchema(ctx)
   140  
   141  	if err != nil {
   142  		return err
   143  	}
   144  
   145  	cols := sch.GetAllCols()
   146  	err = cols.Iter(func(currColTag uint64, currCol schema.Column) (stop bool, err error) {
   147  		if currColTag == tag {
   148  			return false, schema.ErrTagPrevUsed(tag, newColName, tblName)
   149  		} else if strings.ToLower(currCol.Name) == strings.ToLower(newColName) {
   150  			return true, fmt.Errorf("A column with the name %s already exists in table %s.", newColName, tblName)
   151  		}
   152  
   153  		return false, nil
   154  	})
   155  
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	_, tblName, found, err := root.GetTableByColTag(ctx, tag)
   161  	if err != nil {
   162  		return err
   163  	}
   164  	if found {
   165  		return schema.ErrTagPrevUsed(tag, newColName, tblName)
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  func applyDefaultValue(ctx context.Context, tblName string, tbl *doltdb.Table, tag uint64, newSchema schema.Schema) (*doltdb.Table, error) {
   172  	rowData, err := tbl.GetRowData(ctx)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  
   177  	me := rowData.Edit()
   178  
   179  	newSqlSchema, err := sqlutil.FromDoltSchema(tblName, newSchema)
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  
   184  	columnIndex := -1
   185  	for i, colTag := range newSchema.GetAllCols().Tags {
   186  		if colTag == tag {
   187  			columnIndex = i
   188  			break
   189  		}
   190  	}
   191  	if columnIndex == -1 {
   192  		return nil, fmt.Errorf("could not find tag `%d` in new schema", tag)
   193  	}
   194  
   195  	err = rowData.Iter(ctx, func(k, v types.Value) (stop bool, err error) {
   196  		oldRow, err := row.FromNoms(newSchema, k.(types.Tuple), v.(types.Tuple))
   197  		if err != nil {
   198  			return true, err
   199  		}
   200  		newRow, err := sqlutil.ApplyDefaults(ctx, tbl.ValueReadWriter(), newSchema, newSqlSchema, []int{columnIndex}, oldRow)
   201  		if err != nil {
   202  			return true, err
   203  		}
   204  		me.Set(newRow.NomsMapKey(newSchema), newRow.NomsMapValue(newSchema))
   205  		return false, nil
   206  	})
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	newRowData, err := me.Map(ctx)
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  
   216  	return tbl.UpdateRows(ctx, newRowData)
   217  }