github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/sqlutil/convert.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlutil
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    26  	"github.com/dolthub/dolt/go/store/types"
    27  )
    28  
    29  // TODO: Many callers only care about field names and types, not the table or db names.
    30  // Those callers may be passing in "" for these values, or may be passing in incorrect values
    31  // that are currently unused.
    32  func FromDoltSchema(dbName, tableName string, sch schema.Schema) (sql.PrimaryKeySchema, error) {
    33  	cols := make(sql.Schema, sch.GetAllCols().Size())
    34  
    35  	var i int
    36  	_ = sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
    37  		sqlType := col.TypeInfo.ToSqlType()
    38  		var extra string
    39  		if col.AutoIncrement {
    40  			extra = "auto_increment"
    41  		}
    42  
    43  		var deflt, generated, onUpdate *sql.ColumnDefaultValue
    44  		if col.Default != "" {
    45  			deflt = sql.NewUnresolvedColumnDefaultValue(col.Default)
    46  		}
    47  		if col.Generated != "" {
    48  			generated = sql.NewUnresolvedColumnDefaultValue(col.Generated)
    49  		}
    50  		if col.OnUpdate != "" {
    51  			onUpdate = sql.NewUnresolvedColumnDefaultValue(col.OnUpdate)
    52  		}
    53  
    54  		cols[i] = &sql.Column{
    55  			Name:           col.Name,
    56  			Type:           sqlType,
    57  			Default:        deflt,
    58  			Generated:      generated,
    59  			OnUpdate:       onUpdate,
    60  			Nullable:       col.IsNullable(),
    61  			DatabaseSource: dbName,
    62  			Source:         tableName,
    63  			PrimaryKey:     col.IsPartOfPK,
    64  			AutoIncrement:  col.AutoIncrement,
    65  			Comment:        col.Comment,
    66  			Virtual:        col.Virtual,
    67  			Extra:          extra,
    68  		}
    69  		i++
    70  		return false, nil
    71  	})
    72  
    73  	return sql.NewPrimaryKeySchema(cols, sch.GetPkOrdinals()...), nil
    74  }
    75  
    76  // ToDoltSchema returns a dolt Schema from the sql schema given, suitable for use in creating a table.
    77  func ToDoltSchema(
    78  	ctx context.Context,
    79  	root doltdb.RootValue,
    80  	tableName string,
    81  	sqlSchema sql.PrimaryKeySchema,
    82  	headRoot doltdb.RootValue,
    83  	collation sql.CollationID,
    84  ) (schema.Schema, error) {
    85  	var cols []schema.Column
    86  	var err error
    87  
    88  	// generate tags for all columns
    89  	var names []string
    90  	var kinds []types.NomsKind
    91  	for _, col := range sqlSchema.Schema {
    92  		names = append(names, col.Name)
    93  		ti, err := typeinfo.FromSqlType(col.Type)
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  		kinds = append(kinds, ti.NomsKind())
    98  	}
    99  
   100  	tags, err := doltdb.GenerateTagsForNewColumns(ctx, root, tableName, names, kinds, headRoot)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	if len(tags) != len(sqlSchema.Schema) {
   106  		return nil, fmt.Errorf("number of tags should equal number of columns")
   107  	}
   108  
   109  	for i, col := range sqlSchema.Schema {
   110  		convertedCol, err := ToDoltCol(tags[i], col)
   111  		if err != nil {
   112  			return nil, err
   113  		}
   114  		cols = append(cols, convertedCol)
   115  	}
   116  
   117  	colColl := schema.NewColCollection(cols...)
   118  
   119  	err = schema.ValidateForInsert(colColl)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	sch, err := schema.NewSchema(colColl,
   125  		sqlSchema.PkOrdinals,
   126  		schema.Collation(collation),
   127  		nil,
   128  		nil,
   129  	)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	return sch, nil
   135  }
   136  
   137  // ToDoltCol returns the dolt column corresponding to the SQL column given
   138  func ToDoltCol(tag uint64, col *sql.Column) (schema.Column, error) {
   139  	var constraints []schema.ColConstraint
   140  	if !col.Nullable || col.PrimaryKey {
   141  		constraints = append(constraints, schema.NotNullConstraint{})
   142  	}
   143  	typeInfo, err := typeinfo.FromSqlType(col.Type)
   144  	if err != nil {
   145  		return schema.Column{}, err
   146  	}
   147  
   148  	var defaultVal, generatedVal, onUpdateVal string
   149  	if col.Default != nil {
   150  		defaultVal = col.Default.String()
   151  	} else {
   152  		generatedVal = col.Generated.String()
   153  	}
   154  
   155  	if col.OnUpdate != nil {
   156  		onUpdateVal = col.OnUpdate.String()
   157  	}
   158  
   159  	c := schema.Column{
   160  		Name:          col.Name,
   161  		Tag:           tag,
   162  		Kind:          typeInfo.NomsKind(),
   163  		IsPartOfPK:    col.PrimaryKey,
   164  		TypeInfo:      typeInfo,
   165  		Default:       defaultVal,
   166  		Generated:     generatedVal,
   167  		OnUpdate:      onUpdateVal,
   168  		Virtual:       col.Virtual,
   169  		AutoIncrement: col.AutoIncrement,
   170  		Comment:       col.Comment,
   171  		Constraints:   constraints,
   172  	}
   173  
   174  	err = schema.ValidateColumn(c)
   175  	if err != nil {
   176  		return schema.Column{}, err
   177  	}
   178  
   179  	return c, nil
   180  }