github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dolt_patch_table_function.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io"
    21  	"sort"
    22  	"strings"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/rowexec"
    27  	sqltypes "github.com/dolthub/go-mysql-server/sql/types"
    28  	"github.com/dolthub/vitess/go/mysql"
    29  	"golang.org/x/exp/slices"
    30  
    31  	"github.com/dolthub/dolt/go/cmd/dolt/errhand"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/diff"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    34  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    35  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    36  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    37  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables"
    38  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
    39  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt"
    40  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    41  	"github.com/dolthub/dolt/go/store/types"
    42  )
    43  
    44  var _ sql.TableFunction = (*PatchTableFunction)(nil)
    45  var _ sql.ExecSourceRel = (*PatchTableFunction)(nil)
    46  var _ sql.IndexAddressable = (*PatchTableFunction)(nil)
    47  var _ sql.IndexedTable = (*PatchTableFunction)(nil)
    48  var _ sql.TableNode = (*PatchTableFunction)(nil)
    49  
    50  const (
    51  	diffTypeSchema = "schema"
    52  	diffTypeData   = "data"
    53  )
    54  
    55  var schemaChangePartitionKey = []byte(diffTypeSchema)
    56  var dataChangePartitionKey = []byte(diffTypeData)
    57  var schemaAndDataChangePartitionKey = []byte("all")
    58  
    59  const (
    60  	orderColumnName           = "statement_order"
    61  	fromColumnName            = "from_commit_hash"
    62  	toColumnName              = "to_commit_hash"
    63  	tableNameColumnName       = "table_name"
    64  	diffTypeColumnName        = "diff_type"
    65  	statementColumnName       = "statement"
    66  	patchTableDefaultRowCount = 100
    67  )
    68  
    69  type PatchTableFunction struct {
    70  	ctx *sql.Context
    71  
    72  	fromCommitExpr sql.Expression
    73  	toCommitExpr   sql.Expression
    74  	dotCommitExpr  sql.Expression
    75  	tableNameExpr  sql.Expression
    76  	database       sql.Database
    77  }
    78  
    79  func (p *PatchTableFunction) DataLength(ctx *sql.Context) (uint64, error) {
    80  	numBytesPerRow := schema.SchemaAvgLength(p.Schema())
    81  	numRows, _, err := p.RowCount(ctx)
    82  	if err != nil {
    83  		return 0, err
    84  	}
    85  	return numBytesPerRow * numRows, nil
    86  }
    87  
    88  func (p *PatchTableFunction) RowCount(_ *sql.Context) (uint64, bool, error) {
    89  	return patchTableDefaultRowCount, false, nil
    90  }
    91  
    92  func (p *PatchTableFunction) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    93  	return sql.Collation_binary, 7
    94  }
    95  
    96  type Partition struct {
    97  	key []byte
    98  }
    99  
   100  func (p *Partition) Key() []byte { return p.key }
   101  
   102  // UnderlyingTable implements the plan.TableNode interface
   103  func (p *PatchTableFunction) UnderlyingTable() sql.Table {
   104  	return p
   105  }
   106  
   107  // Collation implements the sql.Table interface.
   108  func (p *PatchTableFunction) Collation() sql.CollationID {
   109  	return sql.Collation_Default
   110  }
   111  
   112  // Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition.
   113  func (p *PatchTableFunction) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
   114  	return dtables.NewSliceOfPartitionsItr([]sql.Partition{
   115  		&Partition{key: schemaAndDataChangePartitionKey},
   116  	}), nil
   117  }
   118  
   119  // PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition.
   120  // This table has a partition for just schema changes, one for just data changes, and one for both.
   121  func (p *PatchTableFunction) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
   122  	fromCommitVal, toCommitVal, dotCommitVal, tableName, err := p.evaluateArguments()
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	sqledb, ok := p.database.(dsess.SqlDatabase)
   128  	if !ok {
   129  		return nil, fmt.Errorf("unable to get dolt database")
   130  	}
   131  
   132  	fromRefDetails, toRefDetails, err := loadDetailsForRefs(ctx, fromCommitVal, toCommitVal, dotCommitVal, sqledb)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	tableDeltas, err := diff.GetTableDeltas(ctx, fromRefDetails.root, toRefDetails.root)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	sort.Slice(tableDeltas, func(i, j int) bool {
   143  		return strings.Compare(tableDeltas[i].ToName, tableDeltas[j].ToName) < 0
   144  	})
   145  
   146  	// If tableNameExpr defined, return a single table patch result
   147  	if p.tableNameExpr != nil {
   148  		fromTblExists, err := fromRefDetails.root.HasTable(ctx, tableName)
   149  		if err != nil {
   150  			return nil, err
   151  		}
   152  		toTblExists, err := toRefDetails.root.HasTable(ctx, tableName)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  		if !fromTblExists && !toTblExists {
   157  			return nil, sql.ErrTableNotFound.New(tableName)
   158  		}
   159  
   160  		delta := findMatchingDelta(tableDeltas, tableName)
   161  		tableDeltas = []diff.TableDelta{delta}
   162  	}
   163  
   164  	includeSchemaDiff := bytes.Equal(partition.Key(), schemaAndDataChangePartitionKey) || bytes.Equal(partition.Key(), schemaChangePartitionKey)
   165  	includeDataDiff := bytes.Equal(partition.Key(), schemaAndDataChangePartitionKey) || bytes.Equal(partition.Key(), dataChangePartitionKey)
   166  
   167  	patches, err := getPatchNodes(ctx, sqledb.DbData(), tableDeltas, fromRefDetails, toRefDetails, includeSchemaDiff, includeDataDiff)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	return newPatchTableFunctionRowIter(patches, fromRefDetails.hashStr, toRefDetails.hashStr), nil
   173  }
   174  
   175  // LookupPartitions is a sql.IndexedTable interface function that takes an index lookup and returns the set of corresponding partitions.
   176  func (p *PatchTableFunction) LookupPartitions(context *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
   177  	if lookup.Index.ID() == diffTypeColumnName {
   178  		diffTypes, ok := index.LookupToPointSelectStr(lookup)
   179  		if !ok {
   180  			return nil, fmt.Errorf("failed to parse commit lookup ranges: %s", sql.DebugString(lookup.Ranges))
   181  		}
   182  
   183  		includeSchemaDiff := slices.Contains(diffTypes, diffTypeSchema)
   184  		includeDataDiff := slices.Contains(diffTypes, diffTypeData)
   185  
   186  		if includeSchemaDiff && includeDataDiff {
   187  			return dtables.NewSliceOfPartitionsItr([]sql.Partition{
   188  				&Partition{key: schemaAndDataChangePartitionKey},
   189  			}), nil
   190  		}
   191  
   192  		if includeSchemaDiff {
   193  			return dtables.NewSliceOfPartitionsItr([]sql.Partition{
   194  				&Partition{key: schemaChangePartitionKey},
   195  			}), nil
   196  		}
   197  
   198  		if includeDataDiff {
   199  			return dtables.NewSliceOfPartitionsItr([]sql.Partition{
   200  				&Partition{key: dataChangePartitionKey},
   201  			}), nil
   202  		}
   203  
   204  		return dtables.NewSliceOfPartitionsItr([]sql.Partition{}), nil
   205  	}
   206  
   207  	return dtables.NewSliceOfPartitionsItr([]sql.Partition{
   208  		&Partition{key: schemaAndDataChangePartitionKey},
   209  	}), nil
   210  }
   211  
   212  func (p *PatchTableFunction) IndexedAccess(lookup sql.IndexLookup) sql.IndexedTable {
   213  	return p
   214  }
   215  
   216  func (p *PatchTableFunction) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
   217  	return []sql.Index{
   218  		index.MockIndex(p.database.Name(), p.Name(), diffTypeColumnName, types.StringKind, false),
   219  	}, nil
   220  }
   221  
   222  func (p *PatchTableFunction) PreciseMatch() bool {
   223  	return true
   224  }
   225  
   226  var patchTableSchema = sql.Schema{
   227  	&sql.Column{Name: orderColumnName, Type: sqltypes.Uint64, PrimaryKey: true, Nullable: false},
   228  	&sql.Column{Name: fromColumnName, Type: sqltypes.LongText, Nullable: false},
   229  	&sql.Column{Name: toColumnName, Type: sqltypes.LongText, Nullable: false},
   230  	&sql.Column{Name: tableNameColumnName, Type: sqltypes.LongText, Nullable: false},
   231  	&sql.Column{Name: diffTypeColumnName, Type: sqltypes.LongText, Nullable: false},
   232  	&sql.Column{Name: statementColumnName, Type: sqltypes.LongText, Nullable: false},
   233  }
   234  
   235  // NewInstance creates a new instance of TableFunction interface
   236  func (p *PatchTableFunction) NewInstance(ctx *sql.Context, db sql.Database, exprs []sql.Expression) (sql.Node, error) {
   237  	newInstance := &PatchTableFunction{
   238  		ctx:      ctx,
   239  		database: db,
   240  	}
   241  
   242  	node, err := newInstance.WithExpressions(exprs...)
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  
   247  	return node, nil
   248  }
   249  
   250  // Resolved implements the sql.Resolvable interface
   251  func (p *PatchTableFunction) Resolved() bool {
   252  	if p.tableNameExpr != nil {
   253  		return p.commitsResolved() && p.tableNameExpr.Resolved()
   254  	}
   255  	return p.commitsResolved()
   256  }
   257  
   258  func (p *PatchTableFunction) IsReadOnly() bool {
   259  	return true
   260  }
   261  
   262  func (p *PatchTableFunction) commitsResolved() bool {
   263  	if p.dotCommitExpr != nil {
   264  		return p.dotCommitExpr.Resolved()
   265  	}
   266  	return p.fromCommitExpr.Resolved() && p.toCommitExpr.Resolved()
   267  }
   268  
   269  // String implements the Stringer interface
   270  func (p *PatchTableFunction) String() string {
   271  	if p.dotCommitExpr != nil {
   272  		if p.tableNameExpr != nil {
   273  			return fmt.Sprintf("DOLT_PATCH(%s, %s)", p.dotCommitExpr.String(), p.tableNameExpr.String())
   274  		}
   275  		return fmt.Sprintf("DOLT_PATCH(%s)", p.dotCommitExpr.String())
   276  	}
   277  	if p.tableNameExpr != nil {
   278  		return fmt.Sprintf("DOLT_PATCH(%s, %s, %s)", p.fromCommitExpr.String(), p.toCommitExpr.String(), p.tableNameExpr.String())
   279  	}
   280  	if p.fromCommitExpr != nil && p.toCommitExpr != nil {
   281  		return fmt.Sprintf("DOLT_PATCH(%s, %s)", p.fromCommitExpr.String(), p.toCommitExpr.String())
   282  	}
   283  	return fmt.Sprintf("DOLT_PATCH(<INVALID>)")
   284  }
   285  
   286  // Schema implements the sql.Node interface.
   287  func (p *PatchTableFunction) Schema() sql.Schema {
   288  	return patchTableSchema
   289  }
   290  
   291  // Children implements the sql.Node interface.
   292  func (p *PatchTableFunction) Children() []sql.Node {
   293  	return nil
   294  }
   295  
   296  // WithChildren implements the sql.Node interface.
   297  func (p *PatchTableFunction) WithChildren(children ...sql.Node) (sql.Node, error) {
   298  	if len(children) != 0 {
   299  		return nil, fmt.Errorf("unexpected children")
   300  	}
   301  	return p, nil
   302  }
   303  
   304  // CheckPrivileges implements the interface sql.Node.
   305  func (p *PatchTableFunction) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   306  	if p.tableNameExpr != nil {
   307  		if !sqltypes.IsText(p.tableNameExpr.Type()) {
   308  			return false
   309  		}
   310  
   311  		tableNameVal, err := p.tableNameExpr.Eval(p.ctx, nil)
   312  		if err != nil {
   313  			return false
   314  		}
   315  		tableName, ok := tableNameVal.(string)
   316  		if !ok {
   317  			return false
   318  		}
   319  
   320  		subject := sql.PrivilegeCheckSubject{Database: p.database.Name(), Table: tableName}
   321  		return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select))
   322  	}
   323  
   324  	tblNames, err := p.database.GetTableNames(ctx)
   325  	if err != nil {
   326  		return false
   327  	}
   328  
   329  	operations := make([]sql.PrivilegedOperation, 0, len(tblNames))
   330  	for _, tblName := range tblNames {
   331  		subject := sql.PrivilegeCheckSubject{Database: p.database.Name(), Table: tblName}
   332  		operations = append(operations, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select))
   333  	}
   334  
   335  	return opChecker.UserHasPrivileges(ctx, operations...)
   336  }
   337  
   338  // Expressions implements the sql.Expressioner interface.
   339  func (p *PatchTableFunction) Expressions() []sql.Expression {
   340  	exprs := []sql.Expression{}
   341  	if p.dotCommitExpr != nil {
   342  		exprs = append(exprs, p.dotCommitExpr)
   343  	} else {
   344  		exprs = append(exprs, p.fromCommitExpr, p.toCommitExpr)
   345  	}
   346  	if p.tableNameExpr != nil {
   347  		exprs = append(exprs, p.tableNameExpr)
   348  	}
   349  	return exprs
   350  }
   351  
   352  // WithExpressions implements the sql.Expressioner interface.
   353  func (p *PatchTableFunction) WithExpressions(expr ...sql.Expression) (sql.Node, error) {
   354  	if len(expr) < 1 {
   355  		return nil, sql.ErrInvalidArgumentNumber.New(p.Name(), "1 to 3", len(expr))
   356  	}
   357  
   358  	for _, expr := range expr {
   359  		if !expr.Resolved() {
   360  			return nil, ErrInvalidNonLiteralArgument.New(p.Name(), expr.String())
   361  		}
   362  		// prepared statements resolve functions beforehand, so above check fails
   363  		if _, ok := expr.(sql.FunctionExpression); ok {
   364  			return nil, ErrInvalidNonLiteralArgument.New(p.Name(), expr.String())
   365  		}
   366  	}
   367  
   368  	newPtf := *p
   369  	if strings.Contains(expr[0].String(), "..") {
   370  		if len(expr) < 1 || len(expr) > 2 {
   371  			return nil, sql.ErrInvalidArgumentNumber.New(newPtf.Name(), "1 or 2", len(expr))
   372  		}
   373  		newPtf.dotCommitExpr = expr[0]
   374  		if len(expr) == 2 {
   375  			newPtf.tableNameExpr = expr[1]
   376  		}
   377  	} else {
   378  		if len(expr) < 2 || len(expr) > 3 {
   379  			return nil, sql.ErrInvalidArgumentNumber.New(newPtf.Name(), "2 or 3", len(expr))
   380  		}
   381  		newPtf.fromCommitExpr = expr[0]
   382  		newPtf.toCommitExpr = expr[1]
   383  		if len(expr) == 3 {
   384  			newPtf.tableNameExpr = expr[2]
   385  		}
   386  	}
   387  
   388  	// validate the expressions
   389  	if newPtf.dotCommitExpr != nil {
   390  		if !sqltypes.IsText(newPtf.dotCommitExpr.Type()) && !expression.IsBindVar(newPtf.dotCommitExpr) {
   391  			return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.dotCommitExpr.String())
   392  		}
   393  	} else {
   394  		if !sqltypes.IsText(newPtf.fromCommitExpr.Type()) && !expression.IsBindVar(newPtf.fromCommitExpr) {
   395  			return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.fromCommitExpr.String())
   396  		}
   397  		if !sqltypes.IsText(newPtf.toCommitExpr.Type()) && !expression.IsBindVar(newPtf.toCommitExpr) {
   398  			return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.toCommitExpr.String())
   399  		}
   400  	}
   401  
   402  	if newPtf.tableNameExpr != nil {
   403  		if !sqltypes.IsText(newPtf.tableNameExpr.Type()) && !expression.IsBindVar(newPtf.tableNameExpr) {
   404  			return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.tableNameExpr.String())
   405  		}
   406  	}
   407  
   408  	return &newPtf, nil
   409  }
   410  
   411  // Database implements the sql.Databaser interface
   412  func (p *PatchTableFunction) Database() sql.Database {
   413  	return p.database
   414  }
   415  
   416  // WithDatabase implements the sql.Databaser interface
   417  func (p *PatchTableFunction) WithDatabase(database sql.Database) (sql.Node, error) {
   418  	np := *p
   419  	np.database = database
   420  	return &np, nil
   421  }
   422  
   423  // Name implements the sql.TableFunction interface
   424  func (p *PatchTableFunction) Name() string {
   425  	return p.String()
   426  }
   427  
   428  // RowIter implements the sql.ExecSourceRel interface
   429  func (p *PatchTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
   430  	partitions, err := p.Partitions(ctx)
   431  	if err != nil {
   432  		return nil, err
   433  	}
   434  
   435  	return sql.NewTableRowIter(ctx, p, partitions), nil
   436  }
   437  
   438  // evaluateArguments returns fromCommitVal, toCommitVal, dotCommitVal, and tableName.
   439  // It evaluates the argument expressions to turn them into values this PatchTableFunction
   440  // can use. Note that this method only evals the expressions, and doesn't validate the values.
   441  func (p *PatchTableFunction) evaluateArguments() (interface{}, interface{}, interface{}, string, error) {
   442  	var tableName string
   443  	if p.tableNameExpr != nil {
   444  		tableNameVal, err := p.tableNameExpr.Eval(p.ctx, nil)
   445  		if err != nil {
   446  			return nil, nil, nil, "", err
   447  		}
   448  		tn, ok := tableNameVal.(string)
   449  		if !ok {
   450  			return nil, nil, nil, "", ErrInvalidTableName.New(p.tableNameExpr.String())
   451  		}
   452  		tableName = tn
   453  	}
   454  
   455  	if p.dotCommitExpr != nil {
   456  		dotCommitVal, err := p.dotCommitExpr.Eval(p.ctx, nil)
   457  		if err != nil {
   458  			return nil, nil, nil, "", err
   459  		}
   460  
   461  		return nil, nil, dotCommitVal, tableName, nil
   462  	}
   463  
   464  	fromCommitVal, err := p.fromCommitExpr.Eval(p.ctx, nil)
   465  	if err != nil {
   466  		return nil, nil, nil, "", err
   467  	}
   468  
   469  	toCommitVal, err := p.toCommitExpr.Eval(p.ctx, nil)
   470  	if err != nil {
   471  		return nil, nil, nil, "", err
   472  	}
   473  
   474  	return fromCommitVal, toCommitVal, nil, tableName, nil
   475  }
   476  
   477  type patchNode struct {
   478  	tblName          string
   479  	schemaPatchStmts []string
   480  	dataPatchStmts   []string
   481  }
   482  
   483  func getPatchNodes(ctx *sql.Context, dbData env.DbData, tableDeltas []diff.TableDelta, fromRefDetails, toRefDetails *refDetails, includeSchemaDiff, includeDataDiff bool) (patches []*patchNode, err error) {
   484  	for _, td := range tableDeltas {
   485  		if td.FromTable == nil && td.ToTable == nil {
   486  			// no diff
   487  			if !strings.HasPrefix(td.FromName, diff.DBPrefix) || !strings.HasPrefix(td.ToName, diff.DBPrefix) {
   488  				continue
   489  			}
   490  
   491  			// db collation diff
   492  			dbName := strings.TrimPrefix(td.ToName, diff.DBPrefix)
   493  			fromColl, cerr := fromRefDetails.root.GetCollation(ctx)
   494  			if cerr != nil {
   495  				return nil, cerr
   496  			}
   497  			toColl, cerr := toRefDetails.root.GetCollation(ctx)
   498  			if cerr != nil {
   499  				return nil, cerr
   500  			}
   501  			alterDBCollStmt := sqlfmt.AlterDatabaseCollateStmt(dbName, fromColl, toColl)
   502  			patches = append(patches, &patchNode{
   503  				tblName:          td.FromName,
   504  				schemaPatchStmts: []string{alterDBCollStmt},
   505  				dataPatchStmts:   []string{},
   506  			})
   507  		}
   508  
   509  		tblName := td.ToName
   510  		if td.IsDrop() {
   511  			tblName = td.FromName
   512  		}
   513  
   514  		// Get SCHEMA DIFF
   515  		var schemaStmts []string
   516  		if includeSchemaDiff {
   517  			schemaStmts, err = getSchemaSqlPatch(ctx, toRefDetails.root, td)
   518  			if err != nil {
   519  				return nil, err
   520  			}
   521  		}
   522  
   523  		// Get DATA DIFF
   524  		var dataStmts []string
   525  		if includeDataDiff && canGetDataDiff(ctx, td) {
   526  			dataStmts, err = getUserTableDataSqlPatch(ctx, dbData, td, fromRefDetails, toRefDetails)
   527  			if err != nil {
   528  				return nil, err
   529  			}
   530  		}
   531  
   532  		patches = append(patches, &patchNode{tblName: tblName, schemaPatchStmts: schemaStmts, dataPatchStmts: dataStmts})
   533  	}
   534  
   535  	return patches, nil
   536  }
   537  
   538  func getSchemaSqlPatch(ctx *sql.Context, toRoot doltdb.RootValue, td diff.TableDelta) ([]string, error) {
   539  	toSchemas, err := doltdb.GetAllSchemas(ctx, toRoot)
   540  	if err != nil {
   541  		return nil, fmt.Errorf("could not read schemas from toRoot, cause: %s", err.Error())
   542  	}
   543  
   544  	fromSch, toSch, err := td.GetSchemas(ctx)
   545  	if err != nil {
   546  		return nil, fmt.Errorf("cannot retrieve schema for table %s, cause: %s", td.ToName, err.Error())
   547  	}
   548  
   549  	var ddlStatements []string
   550  	if td.IsDrop() {
   551  		ddlStatements = append(ddlStatements, sqlfmt.DropTableStmt(td.FromName))
   552  	} else if td.IsAdd() {
   553  		stmt, err := sqlfmt.GenerateCreateTableStatement(td.ToName, td.ToSch, td.ToFks, td.ToFksParentSch)
   554  		if err != nil {
   555  			return nil, errhand.VerboseErrorFromError(err)
   556  		}
   557  		ddlStatements = append(ddlStatements, stmt)
   558  	} else {
   559  		stmts, err := GetNonCreateNonDropTableSqlSchemaDiff(td, toSchemas, fromSch, toSch)
   560  		if err != nil {
   561  			return nil, err
   562  		}
   563  		ddlStatements = append(ddlStatements, stmts...)
   564  	}
   565  
   566  	return ddlStatements, nil
   567  }
   568  
   569  func canGetDataDiff(ctx *sql.Context, td diff.TableDelta) bool {
   570  	if td.IsDrop() {
   571  		return false // don't output DELETE FROM statements after DROP TABLE
   572  	}
   573  
   574  	// not diffable
   575  	if !schema.ArePrimaryKeySetsDiffable(td.Format(), td.FromSch, td.ToSch) {
   576  		ctx.Session.Warn(&sql.Warning{
   577  			Level:   "Warning",
   578  			Code:    mysql.ERNotSupportedYet,
   579  			Message: fmt.Sprintf("Primary key sets differ between revisions for table '%s', skipping data diff", td.ToName),
   580  		})
   581  		return false
   582  	}
   583  
   584  	return true
   585  }
   586  
   587  func getUserTableDataSqlPatch(ctx *sql.Context, dbData env.DbData, td diff.TableDelta, fromRefDetails, toRefDetails *refDetails) ([]string, error) {
   588  	// ToTable is used as target table as it cannot be nil at this point
   589  	diffSch, projections, ri, err := getDiffQuery(ctx, dbData, td, fromRefDetails, toRefDetails)
   590  	if err != nil {
   591  		return nil, err
   592  	}
   593  
   594  	targetPkSch, err := sqlutil.FromDoltSchema("", td.ToName, td.ToSch)
   595  	if err != nil {
   596  		return nil, err
   597  	}
   598  
   599  	return getDataSqlPatchResults(ctx, diffSch, targetPkSch.Schema, projections, ri, td.ToName, td.ToSch)
   600  }
   601  
   602  func getDataSqlPatchResults(ctx *sql.Context, diffQuerySch, targetSch sql.Schema, projections []sql.Expression, iter sql.RowIter, tn string, tsch schema.Schema) ([]string, error) {
   603  	ds, err := diff.NewDiffSplitter(diffQuerySch, targetSch)
   604  	if err != nil {
   605  		return nil, err
   606  	}
   607  
   608  	var res []string
   609  	for {
   610  		r, err := iter.Next(ctx)
   611  		if err == io.EOF {
   612  			return res, nil
   613  		} else if err != nil {
   614  			return nil, err
   615  		}
   616  
   617  		r, err = rowexec.ProjectRow(ctx, projections, r)
   618  		if err != nil {
   619  			return nil, err
   620  		}
   621  
   622  		oldRow, newRow, err := ds.SplitDiffResultRow(r)
   623  		if err != nil {
   624  			return nil, err
   625  		}
   626  
   627  		var stmt string
   628  		if oldRow.Row != nil {
   629  			stmt, err = diff.GetDataDiffStatement(tn, tsch, oldRow.Row, oldRow.RowDiff, oldRow.ColDiffs)
   630  			if err != nil {
   631  				return nil, err
   632  			}
   633  		}
   634  
   635  		if newRow.Row != nil {
   636  			stmt, err = diff.GetDataDiffStatement(tn, tsch, newRow.Row, newRow.RowDiff, newRow.ColDiffs)
   637  			if err != nil {
   638  				return nil, err
   639  			}
   640  		}
   641  
   642  		if stmt != "" {
   643  			res = append(res, stmt)
   644  		}
   645  	}
   646  }
   647  
   648  // GetNonCreateNonDropTableSqlSchemaDiff returns any schema diff in SQL statements that is NEITHER 'CREATE TABLE' NOR 'DROP TABLE' statements.
   649  func GetNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas map[string]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) {
   650  	if td.IsAdd() || td.IsDrop() {
   651  		// use add and drop specific methods
   652  		return nil, nil
   653  	}
   654  
   655  	var ddlStatements []string
   656  	if td.FromName != td.ToName {
   657  		ddlStatements = append(ddlStatements, sqlfmt.RenameTableStmt(td.FromName, td.ToName))
   658  	}
   659  
   660  	eq := schema.SchemasAreEqual(fromSch, toSch)
   661  	if eq && !td.HasFKChanges() {
   662  		return ddlStatements, nil
   663  	}
   664  
   665  	colDiffs, unionTags := diff.DiffSchColumns(fromSch, toSch)
   666  	for _, tag := range unionTags {
   667  		cd := colDiffs[tag]
   668  		switch cd.DiffType {
   669  		case diff.SchDiffNone:
   670  		case diff.SchDiffAdded:
   671  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddColStmt(td.ToName, sqlfmt.GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation()))))
   672  		case diff.SchDiffRemoved:
   673  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropColStmt(td.ToName, cd.Old.Name))
   674  		case diff.SchDiffModified:
   675  			// Ignore any primary key set changes here
   676  			if cd.Old.IsPartOfPK != cd.New.IsPartOfPK {
   677  				continue
   678  			}
   679  			if cd.Old.Name != cd.New.Name {
   680  				ddlStatements = append(ddlStatements, sqlfmt.AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name))
   681  			}
   682  			if cd.Old.TypeInfo != cd.New.TypeInfo {
   683  				ddlStatements = append(ddlStatements, sqlfmt.AlterTableModifyColStmt(td.ToName,
   684  					sqlfmt.GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation()))))
   685  			}
   686  		}
   687  	}
   688  
   689  	// Print changes between a primary key set change. It contains an ALTER TABLE DROP and an ALTER TABLE ADD
   690  	if !schema.ColCollsAreEqual(fromSch.GetPKCols(), toSch.GetPKCols()) {
   691  		ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropPks(td.ToName))
   692  		if toSch.GetPKCols().Size() > 0 {
   693  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols().GetColumnNames()))
   694  		}
   695  	}
   696  
   697  	for _, idxDiff := range diff.DiffSchIndexes(fromSch, toSch) {
   698  		switch idxDiff.DiffType {
   699  		case diff.SchDiffNone:
   700  		case diff.SchDiffAdded:
   701  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To))
   702  		case diff.SchDiffRemoved:
   703  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From))
   704  		case diff.SchDiffModified:
   705  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From))
   706  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To))
   707  		}
   708  	}
   709  
   710  	for _, fkDiff := range diff.DiffForeignKeys(td.FromFks, td.ToFks) {
   711  		switch fkDiff.DiffType {
   712  		case diff.SchDiffNone:
   713  		case diff.SchDiffAdded:
   714  			parentSch := toSchemas[fkDiff.To.ReferencedTableName]
   715  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch))
   716  		case diff.SchDiffRemoved:
   717  			from := fkDiff.From
   718  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name))
   719  		case diff.SchDiffModified:
   720  			from := fkDiff.From
   721  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name))
   722  
   723  			parentSch := toSchemas[fkDiff.To.ReferencedTableName]
   724  			ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch))
   725  		}
   726  	}
   727  
   728  	// Handle charset/collation changes
   729  	toCollation := toSch.GetCollation()
   730  	fromCollation := fromSch.GetCollation()
   731  	if toCollation != fromCollation {
   732  		ddlStatements = append(ddlStatements, sqlfmt.AlterTableCollateStmt(td.ToName, fromCollation, toCollation))
   733  	}
   734  
   735  	return ddlStatements, nil
   736  }
   737  
   738  // getDiffQuery returns diff schema for specified columns and array of sql.Expression as projection to be used
   739  // on diff table function row iter. This function attempts to imitate running a query
   740  // fmt.Sprintf("select %s, %s from dolt_diff('%s', '%s', '%s')", columnsWithDiff, "diff_type", fromRef, toRef, tableName)
   741  // on sql engine, which returns the schema and rowIter of the final data diff result.
   742  func getDiffQuery(ctx *sql.Context, dbData env.DbData, td diff.TableDelta, fromRefDetails, toRefDetails *refDetails) (sql.Schema, []sql.Expression, sql.RowIter, error) {
   743  	diffTableSchema, j, err := dtables.GetDiffTableSchemaAndJoiner(td.ToTable.Format(), td.FromSch, td.ToSch)
   744  	if err != nil {
   745  		return nil, nil, nil, err
   746  	}
   747  	diffPKSch, err := sqlutil.FromDoltSchema("", "", diffTableSchema)
   748  	if err != nil {
   749  		return nil, nil, nil, err
   750  	}
   751  
   752  	columnsWithDiff := getColumnNamesWithDiff(td.FromSch, td.ToSch)
   753  	diffQuerySqlSch, projections := getDiffQuerySqlSchemaAndProjections(diffPKSch.Schema, columnsWithDiff)
   754  
   755  	dp := dtables.NewDiffPartition(td.ToTable, td.FromTable, toRefDetails.hashStr, fromRefDetails.hashStr, toRefDetails.commitTime, fromRefDetails.commitTime, td.ToSch, td.FromSch)
   756  	ri := dtables.NewDiffPartitionRowIter(*dp, dbData.Ddb, j)
   757  
   758  	return diffQuerySqlSch, projections, ri, nil
   759  }
   760  
   761  func getColumnNamesWithDiff(fromSch, toSch schema.Schema) []string {
   762  	var cols []string
   763  
   764  	if fromSch != nil {
   765  		_ = fromSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   766  			cols = append(cols, fmt.Sprintf("from_%s", col.Name))
   767  			return false, nil
   768  		})
   769  	}
   770  	if toSch != nil {
   771  		_ = toSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   772  			cols = append(cols, fmt.Sprintf("to_%s", col.Name))
   773  			return false, nil
   774  		})
   775  	}
   776  	return cols
   777  }
   778  
   779  // getDiffQuerySqlSchemaAndProjections returns the schema of columns with data diff and "diff_type". This is used for diff splitter.
   780  // When extracting the diff schema, the ordering must follow the ordering of given columns
   781  func getDiffQuerySqlSchemaAndProjections(diffTableSch sql.Schema, columns []string) (sql.Schema, []sql.Expression) {
   782  	type column struct {
   783  		sqlCol *sql.Column
   784  		idx    int
   785  	}
   786  
   787  	columns = append(columns, diffTypeColumnName)
   788  	colMap := make(map[string]*column)
   789  	for _, c := range columns {
   790  		colMap[c] = nil
   791  	}
   792  
   793  	var cols = make([]*sql.Column, len(columns))
   794  	var getFieldCols = make([]sql.Expression, len(columns))
   795  
   796  	for i, c := range diffTableSch {
   797  		if _, ok := colMap[c.Name]; ok {
   798  			colMap[c.Name] = &column{c, i}
   799  		}
   800  	}
   801  
   802  	for i, c := range columns {
   803  		col := colMap[c].sqlCol
   804  		cols[i] = col
   805  		getFieldCols[i] = expression.NewGetField(colMap[c].idx, col.Type, col.Name, col.Nullable)
   806  	}
   807  
   808  	return cols, getFieldCols
   809  }
   810  
   811  //------------------------------------
   812  // patchTableFunctionRowIter
   813  //------------------------------------
   814  
   815  var _ sql.RowIter = (*patchTableFunctionRowIter)(nil)
   816  
   817  type patchTableFunctionRowIter struct {
   818  	patches        []*patchNode
   819  	patchIdx       int
   820  	statementIdx   int
   821  	fromRef        string
   822  	toRef          string
   823  	currentPatch   *patchNode
   824  	currentRowIter *sql.RowIter
   825  }
   826  
   827  // newPatchTableFunctionRowIter iterates over each patch nodes given returning
   828  // each statement in each patch node as a single row including from_commit_hash,
   829  // to_commit_hash and table_name prepended to diff_type and statement for each patch statement.
   830  func newPatchTableFunctionRowIter(patchNodes []*patchNode, fromRef, toRef string) sql.RowIter {
   831  	return &patchTableFunctionRowIter{
   832  		patches:      patchNodes,
   833  		patchIdx:     0,
   834  		statementIdx: 0,
   835  		fromRef:      fromRef,
   836  		toRef:        toRef,
   837  	}
   838  }
   839  
   840  func (itr *patchTableFunctionRowIter) Next(ctx *sql.Context) (sql.Row, error) {
   841  	for {
   842  		if itr.patchIdx >= len(itr.patches) {
   843  			return nil, io.EOF
   844  		}
   845  		if itr.currentPatch == nil {
   846  			itr.currentPatch = itr.patches[itr.patchIdx]
   847  		}
   848  		if itr.currentRowIter == nil {
   849  			ri := newPatchStatementsRowIter(itr.currentPatch.schemaPatchStmts, itr.currentPatch.dataPatchStmts)
   850  			itr.currentRowIter = &ri
   851  		}
   852  
   853  		row, err := (*itr.currentRowIter).Next(ctx)
   854  		if err == io.EOF {
   855  			itr.currentPatch = nil
   856  			itr.currentRowIter = nil
   857  			itr.patchIdx++
   858  			continue
   859  		} else if err != nil {
   860  			return nil, err
   861  		} else {
   862  			itr.statementIdx++
   863  			r := sql.Row{itr.statementIdx, itr.fromRef, itr.toRef, itr.currentPatch.tblName}
   864  			return r.Append(row), nil
   865  		}
   866  	}
   867  }
   868  
   869  func (itr *patchTableFunctionRowIter) Close(_ *sql.Context) error {
   870  	return nil
   871  }
   872  
   873  //------------------------------------
   874  // patchStatementsRowIter
   875  //------------------------------------
   876  
   877  var _ sql.RowIter = (*patchStatementsRowIter)(nil)
   878  
   879  type patchStatementsRowIter struct {
   880  	stmts  []string
   881  	ddlLen int
   882  	idx    int
   883  }
   884  
   885  // newPatchStatementsRowIter iterates over each patch statements returning row of diff_type of either 'schema' or 'data' with the statement.
   886  func newPatchStatementsRowIter(ddlStmts, dataStmts []string) sql.RowIter {
   887  	return &patchStatementsRowIter{
   888  		stmts:  append(ddlStmts, dataStmts...),
   889  		ddlLen: len(ddlStmts),
   890  		idx:    0,
   891  	}
   892  }
   893  
   894  func (p *patchStatementsRowIter) Next(ctx *sql.Context) (sql.Row, error) {
   895  	defer func() {
   896  		p.idx++
   897  	}()
   898  
   899  	if p.idx >= len(p.stmts) {
   900  		return nil, io.EOF
   901  	}
   902  
   903  	if p.stmts == nil {
   904  		return nil, io.EOF
   905  	}
   906  
   907  	stmt := p.stmts[p.idx]
   908  	diffType := diffTypeSchema
   909  	if p.idx >= p.ddlLen {
   910  		diffType = diffTypeData
   911  	}
   912  
   913  	return sql.Row{diffType, stmt}, nil
   914  }
   915  
   916  func (p *patchStatementsRowIter) Close(_ *sql.Context) error {
   917  	return nil
   918  }