github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/resolve_create_select.go (about)

     1  package analyzer
     2  
     3  import (
     4  	"github.com/dolthub/go-mysql-server/sql"
     5  	"github.com/dolthub/go-mysql-server/sql/plan"
     6  	"github.com/dolthub/go-mysql-server/sql/transform"
     7  )
     8  
     9  // todo this should be split into two rules. The first should be in
    10  // planbuilder and only bind the child select, strip/merge schemas.
    11  // a second rule should finalize analysis of the source/dest nodes
    12  // (skipping passthrough rule).
    13  func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    14  	ct, ok := n.(*plan.CreateTable)
    15  	if !ok || ct.Select() == nil {
    16  		return n, transform.SameTree, nil
    17  	}
    18  
    19  	analyzedSelect, err := a.Analyze(ctx, ct.Select(), scope)
    20  	if err != nil {
    21  		return nil, transform.SameTree, err
    22  	}
    23  
    24  	// Get the correct schema of the CREATE TABLE based on the select query
    25  	inputSpec := ct.TableSpec()
    26  
    27  	// We don't want to carry any information about keys, constraints, defaults, etc. from a `create table as select`
    28  	// statement. When the underlying select node is a table, we must remove all such info from its schema. The only
    29  	// exception is NOT NULL constraints, which we leave alone.
    30  	selectSchema := stripSchema(analyzedSelect.Schema())
    31  	mergedSchema := mergeSchemas(inputSpec.Schema.Schema, selectSchema)
    32  	newSch := make(sql.Schema, len(mergedSchema))
    33  
    34  	for i, col := range mergedSchema {
    35  		tempCol := *col
    36  		tempCol.Source = ct.Name()
    37  		newSch[i] = &tempCol
    38  	}
    39  
    40  	pkOrdinals := make([]int, 0)
    41  	for i, col := range newSch {
    42  		if col.PrimaryKey {
    43  			pkOrdinals = append(pkOrdinals, i)
    44  		}
    45  	}
    46  
    47  	newSpec := inputSpec.WithSchema(sql.NewPrimaryKeySchema(newSch, pkOrdinals...))
    48  	// CREATE ... SELECT will always inherit the database collation for the table
    49  	newSpec.Collation = plan.GetDatabaseCollation(ctx, ct.Database())
    50  
    51  	newCreateTable := plan.NewCreateTable(ct.Database(), ct.Name(), ct.IfNotExists(), ct.Temporary(), newSpec)
    52  	analyzedCreate, err := a.Analyze(ctx, newCreateTable, scope)
    53  	if err != nil {
    54  		return nil, transform.SameTree, err
    55  	}
    56  
    57  	return plan.NewTableCopier(ct.Database(), StripPassthroughNodes(analyzedCreate), StripPassthroughNodes(analyzedSelect), plan.CopierProps{}), transform.NewTree, nil
    58  }
    59  
    60  // stripSchema removes all non-type information from a schema, such as the key info, default value, etc.
    61  func stripSchema(schema sql.Schema) sql.Schema {
    62  	sch := schema.Copy()
    63  	for i := range schema {
    64  		sch[i].Default = nil
    65  		sch[i].Generated = nil
    66  		sch[i].AutoIncrement = false
    67  		sch[i].PrimaryKey = false
    68  		sch[i].Source = ""
    69  		sch[i].Comment = ""
    70  	}
    71  	return sch
    72  }
    73  
    74  // mergeSchemas takes in the table spec of the CREATE TABLE and merges it with the schema used by the
    75  // select query. The ultimate structure for the new table will be [CREATE TABLE exclusive columns, columns with the same
    76  // name, SELECT exclusive columns]
    77  func mergeSchemas(inputSchema sql.Schema, selectSchema sql.Schema) sql.Schema {
    78  	if inputSchema == nil {
    79  		return selectSchema
    80  	}
    81  
    82  	// Get the matching columns between the two via name
    83  	matchingColumns := make([]*sql.Column, 0)
    84  	leftExclusive := make([]*sql.Column, 0)
    85  	for _, col := range inputSchema {
    86  		found := false
    87  		for _, col2 := range selectSchema {
    88  			if col.Name == col2.Name {
    89  				matchingColumns = append(matchingColumns, col)
    90  				found = true
    91  			}
    92  		}
    93  
    94  		if !found {
    95  			leftExclusive = append(leftExclusive, col)
    96  		}
    97  	}
    98  
    99  	rightExclusive := make([]*sql.Column, 0)
   100  	for _, col := range selectSchema {
   101  		found := false
   102  		for _, col2 := range inputSchema {
   103  			if col.Name == col2.Name {
   104  				found = true
   105  				break
   106  			}
   107  		}
   108  
   109  		if !found {
   110  			rightExclusive = append(rightExclusive, col)
   111  		}
   112  	}
   113  
   114  	return append(append(leftExclusive, matchingColumns...), rightExclusive...)
   115  }