github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dolt_patch_table_function.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sqle 16 17 import ( 18 "bytes" 19 "fmt" 20 "io" 21 "sort" 22 "strings" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/rowexec" 27 sqltypes "github.com/dolthub/go-mysql-server/sql/types" 28 "github.com/dolthub/vitess/go/mysql" 29 "golang.org/x/exp/slices" 30 31 "github.com/dolthub/dolt/go/cmd/dolt/errhand" 32 "github.com/dolthub/dolt/go/libraries/doltcore/diff" 33 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 34 "github.com/dolthub/dolt/go/libraries/doltcore/env" 35 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 36 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 37 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables" 38 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" 39 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt" 40 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" 41 "github.com/dolthub/dolt/go/store/types" 42 ) 43 44 var _ sql.TableFunction = (*PatchTableFunction)(nil) 45 var _ sql.ExecSourceRel = (*PatchTableFunction)(nil) 46 var _ sql.IndexAddressable = (*PatchTableFunction)(nil) 47 var _ sql.IndexedTable = (*PatchTableFunction)(nil) 48 var _ sql.TableNode = (*PatchTableFunction)(nil) 49 50 const ( 51 diffTypeSchema = "schema" 52 diffTypeData = "data" 53 ) 54 55 var schemaChangePartitionKey = []byte(diffTypeSchema) 56 var dataChangePartitionKey = []byte(diffTypeData) 57 var schemaAndDataChangePartitionKey = []byte("all") 58 59 const ( 60 orderColumnName = "statement_order" 61 fromColumnName = "from_commit_hash" 62 toColumnName = "to_commit_hash" 63 tableNameColumnName = "table_name" 64 diffTypeColumnName = "diff_type" 65 statementColumnName = "statement" 66 patchTableDefaultRowCount = 100 67 ) 68 69 type PatchTableFunction struct { 70 ctx *sql.Context 71 72 fromCommitExpr sql.Expression 73 toCommitExpr sql.Expression 74 dotCommitExpr sql.Expression 75 tableNameExpr sql.Expression 76 database sql.Database 77 } 78 79 func (p *PatchTableFunction) DataLength(ctx *sql.Context) (uint64, error) { 80 numBytesPerRow := schema.SchemaAvgLength(p.Schema()) 81 numRows, _, err := p.RowCount(ctx) 82 if err != nil { 83 return 0, err 84 } 85 return numBytesPerRow * numRows, nil 86 } 87 88 func (p *PatchTableFunction) RowCount(_ *sql.Context) (uint64, bool, error) { 89 return patchTableDefaultRowCount, false, nil 90 } 91 92 func (p *PatchTableFunction) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 93 return sql.Collation_binary, 7 94 } 95 96 type Partition struct { 97 key []byte 98 } 99 100 func (p *Partition) Key() []byte { return p.key } 101 102 // UnderlyingTable implements the plan.TableNode interface 103 func (p *PatchTableFunction) UnderlyingTable() sql.Table { 104 return p 105 } 106 107 // Collation implements the sql.Table interface. 108 func (p *PatchTableFunction) Collation() sql.CollationID { 109 return sql.Collation_Default 110 } 111 112 // Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition. 113 func (p *PatchTableFunction) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { 114 return dtables.NewSliceOfPartitionsItr([]sql.Partition{ 115 &Partition{key: schemaAndDataChangePartitionKey}, 116 }), nil 117 } 118 119 // PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition. 120 // This table has a partition for just schema changes, one for just data changes, and one for both. 121 func (p *PatchTableFunction) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { 122 fromCommitVal, toCommitVal, dotCommitVal, tableName, err := p.evaluateArguments() 123 if err != nil { 124 return nil, err 125 } 126 127 sqledb, ok := p.database.(dsess.SqlDatabase) 128 if !ok { 129 return nil, fmt.Errorf("unable to get dolt database") 130 } 131 132 fromRefDetails, toRefDetails, err := loadDetailsForRefs(ctx, fromCommitVal, toCommitVal, dotCommitVal, sqledb) 133 if err != nil { 134 return nil, err 135 } 136 137 tableDeltas, err := diff.GetTableDeltas(ctx, fromRefDetails.root, toRefDetails.root) 138 if err != nil { 139 return nil, err 140 } 141 142 sort.Slice(tableDeltas, func(i, j int) bool { 143 return strings.Compare(tableDeltas[i].ToName, tableDeltas[j].ToName) < 0 144 }) 145 146 // If tableNameExpr defined, return a single table patch result 147 if p.tableNameExpr != nil { 148 fromTblExists, err := fromRefDetails.root.HasTable(ctx, tableName) 149 if err != nil { 150 return nil, err 151 } 152 toTblExists, err := toRefDetails.root.HasTable(ctx, tableName) 153 if err != nil { 154 return nil, err 155 } 156 if !fromTblExists && !toTblExists { 157 return nil, sql.ErrTableNotFound.New(tableName) 158 } 159 160 delta := findMatchingDelta(tableDeltas, tableName) 161 tableDeltas = []diff.TableDelta{delta} 162 } 163 164 includeSchemaDiff := bytes.Equal(partition.Key(), schemaAndDataChangePartitionKey) || bytes.Equal(partition.Key(), schemaChangePartitionKey) 165 includeDataDiff := bytes.Equal(partition.Key(), schemaAndDataChangePartitionKey) || bytes.Equal(partition.Key(), dataChangePartitionKey) 166 167 patches, err := getPatchNodes(ctx, sqledb.DbData(), tableDeltas, fromRefDetails, toRefDetails, includeSchemaDiff, includeDataDiff) 168 if err != nil { 169 return nil, err 170 } 171 172 return newPatchTableFunctionRowIter(patches, fromRefDetails.hashStr, toRefDetails.hashStr), nil 173 } 174 175 // LookupPartitions is a sql.IndexedTable interface function that takes an index lookup and returns the set of corresponding partitions. 176 func (p *PatchTableFunction) LookupPartitions(context *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) { 177 if lookup.Index.ID() == diffTypeColumnName { 178 diffTypes, ok := index.LookupToPointSelectStr(lookup) 179 if !ok { 180 return nil, fmt.Errorf("failed to parse commit lookup ranges: %s", sql.DebugString(lookup.Ranges)) 181 } 182 183 includeSchemaDiff := slices.Contains(diffTypes, diffTypeSchema) 184 includeDataDiff := slices.Contains(diffTypes, diffTypeData) 185 186 if includeSchemaDiff && includeDataDiff { 187 return dtables.NewSliceOfPartitionsItr([]sql.Partition{ 188 &Partition{key: schemaAndDataChangePartitionKey}, 189 }), nil 190 } 191 192 if includeSchemaDiff { 193 return dtables.NewSliceOfPartitionsItr([]sql.Partition{ 194 &Partition{key: schemaChangePartitionKey}, 195 }), nil 196 } 197 198 if includeDataDiff { 199 return dtables.NewSliceOfPartitionsItr([]sql.Partition{ 200 &Partition{key: dataChangePartitionKey}, 201 }), nil 202 } 203 204 return dtables.NewSliceOfPartitionsItr([]sql.Partition{}), nil 205 } 206 207 return dtables.NewSliceOfPartitionsItr([]sql.Partition{ 208 &Partition{key: schemaAndDataChangePartitionKey}, 209 }), nil 210 } 211 212 func (p *PatchTableFunction) IndexedAccess(lookup sql.IndexLookup) sql.IndexedTable { 213 return p 214 } 215 216 func (p *PatchTableFunction) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { 217 return []sql.Index{ 218 index.MockIndex(p.database.Name(), p.Name(), diffTypeColumnName, types.StringKind, false), 219 }, nil 220 } 221 222 func (p *PatchTableFunction) PreciseMatch() bool { 223 return true 224 } 225 226 var patchTableSchema = sql.Schema{ 227 &sql.Column{Name: orderColumnName, Type: sqltypes.Uint64, PrimaryKey: true, Nullable: false}, 228 &sql.Column{Name: fromColumnName, Type: sqltypes.LongText, Nullable: false}, 229 &sql.Column{Name: toColumnName, Type: sqltypes.LongText, Nullable: false}, 230 &sql.Column{Name: tableNameColumnName, Type: sqltypes.LongText, Nullable: false}, 231 &sql.Column{Name: diffTypeColumnName, Type: sqltypes.LongText, Nullable: false}, 232 &sql.Column{Name: statementColumnName, Type: sqltypes.LongText, Nullable: false}, 233 } 234 235 // NewInstance creates a new instance of TableFunction interface 236 func (p *PatchTableFunction) NewInstance(ctx *sql.Context, db sql.Database, exprs []sql.Expression) (sql.Node, error) { 237 newInstance := &PatchTableFunction{ 238 ctx: ctx, 239 database: db, 240 } 241 242 node, err := newInstance.WithExpressions(exprs...) 243 if err != nil { 244 return nil, err 245 } 246 247 return node, nil 248 } 249 250 // Resolved implements the sql.Resolvable interface 251 func (p *PatchTableFunction) Resolved() bool { 252 if p.tableNameExpr != nil { 253 return p.commitsResolved() && p.tableNameExpr.Resolved() 254 } 255 return p.commitsResolved() 256 } 257 258 func (p *PatchTableFunction) IsReadOnly() bool { 259 return true 260 } 261 262 func (p *PatchTableFunction) commitsResolved() bool { 263 if p.dotCommitExpr != nil { 264 return p.dotCommitExpr.Resolved() 265 } 266 return p.fromCommitExpr.Resolved() && p.toCommitExpr.Resolved() 267 } 268 269 // String implements the Stringer interface 270 func (p *PatchTableFunction) String() string { 271 if p.dotCommitExpr != nil { 272 if p.tableNameExpr != nil { 273 return fmt.Sprintf("DOLT_PATCH(%s, %s)", p.dotCommitExpr.String(), p.tableNameExpr.String()) 274 } 275 return fmt.Sprintf("DOLT_PATCH(%s)", p.dotCommitExpr.String()) 276 } 277 if p.tableNameExpr != nil { 278 return fmt.Sprintf("DOLT_PATCH(%s, %s, %s)", p.fromCommitExpr.String(), p.toCommitExpr.String(), p.tableNameExpr.String()) 279 } 280 if p.fromCommitExpr != nil && p.toCommitExpr != nil { 281 return fmt.Sprintf("DOLT_PATCH(%s, %s)", p.fromCommitExpr.String(), p.toCommitExpr.String()) 282 } 283 return fmt.Sprintf("DOLT_PATCH(<INVALID>)") 284 } 285 286 // Schema implements the sql.Node interface. 287 func (p *PatchTableFunction) Schema() sql.Schema { 288 return patchTableSchema 289 } 290 291 // Children implements the sql.Node interface. 292 func (p *PatchTableFunction) Children() []sql.Node { 293 return nil 294 } 295 296 // WithChildren implements the sql.Node interface. 297 func (p *PatchTableFunction) WithChildren(children ...sql.Node) (sql.Node, error) { 298 if len(children) != 0 { 299 return nil, fmt.Errorf("unexpected children") 300 } 301 return p, nil 302 } 303 304 // CheckPrivileges implements the interface sql.Node. 305 func (p *PatchTableFunction) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 306 if p.tableNameExpr != nil { 307 if !sqltypes.IsText(p.tableNameExpr.Type()) { 308 return false 309 } 310 311 tableNameVal, err := p.tableNameExpr.Eval(p.ctx, nil) 312 if err != nil { 313 return false 314 } 315 tableName, ok := tableNameVal.(string) 316 if !ok { 317 return false 318 } 319 320 subject := sql.PrivilegeCheckSubject{Database: p.database.Name(), Table: tableName} 321 return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select)) 322 } 323 324 tblNames, err := p.database.GetTableNames(ctx) 325 if err != nil { 326 return false 327 } 328 329 operations := make([]sql.PrivilegedOperation, 0, len(tblNames)) 330 for _, tblName := range tblNames { 331 subject := sql.PrivilegeCheckSubject{Database: p.database.Name(), Table: tblName} 332 operations = append(operations, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select)) 333 } 334 335 return opChecker.UserHasPrivileges(ctx, operations...) 336 } 337 338 // Expressions implements the sql.Expressioner interface. 339 func (p *PatchTableFunction) Expressions() []sql.Expression { 340 exprs := []sql.Expression{} 341 if p.dotCommitExpr != nil { 342 exprs = append(exprs, p.dotCommitExpr) 343 } else { 344 exprs = append(exprs, p.fromCommitExpr, p.toCommitExpr) 345 } 346 if p.tableNameExpr != nil { 347 exprs = append(exprs, p.tableNameExpr) 348 } 349 return exprs 350 } 351 352 // WithExpressions implements the sql.Expressioner interface. 353 func (p *PatchTableFunction) WithExpressions(expr ...sql.Expression) (sql.Node, error) { 354 if len(expr) < 1 { 355 return nil, sql.ErrInvalidArgumentNumber.New(p.Name(), "1 to 3", len(expr)) 356 } 357 358 for _, expr := range expr { 359 if !expr.Resolved() { 360 return nil, ErrInvalidNonLiteralArgument.New(p.Name(), expr.String()) 361 } 362 // prepared statements resolve functions beforehand, so above check fails 363 if _, ok := expr.(sql.FunctionExpression); ok { 364 return nil, ErrInvalidNonLiteralArgument.New(p.Name(), expr.String()) 365 } 366 } 367 368 newPtf := *p 369 if strings.Contains(expr[0].String(), "..") { 370 if len(expr) < 1 || len(expr) > 2 { 371 return nil, sql.ErrInvalidArgumentNumber.New(newPtf.Name(), "1 or 2", len(expr)) 372 } 373 newPtf.dotCommitExpr = expr[0] 374 if len(expr) == 2 { 375 newPtf.tableNameExpr = expr[1] 376 } 377 } else { 378 if len(expr) < 2 || len(expr) > 3 { 379 return nil, sql.ErrInvalidArgumentNumber.New(newPtf.Name(), "2 or 3", len(expr)) 380 } 381 newPtf.fromCommitExpr = expr[0] 382 newPtf.toCommitExpr = expr[1] 383 if len(expr) == 3 { 384 newPtf.tableNameExpr = expr[2] 385 } 386 } 387 388 // validate the expressions 389 if newPtf.dotCommitExpr != nil { 390 if !sqltypes.IsText(newPtf.dotCommitExpr.Type()) && !expression.IsBindVar(newPtf.dotCommitExpr) { 391 return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.dotCommitExpr.String()) 392 } 393 } else { 394 if !sqltypes.IsText(newPtf.fromCommitExpr.Type()) && !expression.IsBindVar(newPtf.fromCommitExpr) { 395 return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.fromCommitExpr.String()) 396 } 397 if !sqltypes.IsText(newPtf.toCommitExpr.Type()) && !expression.IsBindVar(newPtf.toCommitExpr) { 398 return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.toCommitExpr.String()) 399 } 400 } 401 402 if newPtf.tableNameExpr != nil { 403 if !sqltypes.IsText(newPtf.tableNameExpr.Type()) && !expression.IsBindVar(newPtf.tableNameExpr) { 404 return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.tableNameExpr.String()) 405 } 406 } 407 408 return &newPtf, nil 409 } 410 411 // Database implements the sql.Databaser interface 412 func (p *PatchTableFunction) Database() sql.Database { 413 return p.database 414 } 415 416 // WithDatabase implements the sql.Databaser interface 417 func (p *PatchTableFunction) WithDatabase(database sql.Database) (sql.Node, error) { 418 np := *p 419 np.database = database 420 return &np, nil 421 } 422 423 // Name implements the sql.TableFunction interface 424 func (p *PatchTableFunction) Name() string { 425 return p.String() 426 } 427 428 // RowIter implements the sql.ExecSourceRel interface 429 func (p *PatchTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { 430 partitions, err := p.Partitions(ctx) 431 if err != nil { 432 return nil, err 433 } 434 435 return sql.NewTableRowIter(ctx, p, partitions), nil 436 } 437 438 // evaluateArguments returns fromCommitVal, toCommitVal, dotCommitVal, and tableName. 439 // It evaluates the argument expressions to turn them into values this PatchTableFunction 440 // can use. Note that this method only evals the expressions, and doesn't validate the values. 441 func (p *PatchTableFunction) evaluateArguments() (interface{}, interface{}, interface{}, string, error) { 442 var tableName string 443 if p.tableNameExpr != nil { 444 tableNameVal, err := p.tableNameExpr.Eval(p.ctx, nil) 445 if err != nil { 446 return nil, nil, nil, "", err 447 } 448 tn, ok := tableNameVal.(string) 449 if !ok { 450 return nil, nil, nil, "", ErrInvalidTableName.New(p.tableNameExpr.String()) 451 } 452 tableName = tn 453 } 454 455 if p.dotCommitExpr != nil { 456 dotCommitVal, err := p.dotCommitExpr.Eval(p.ctx, nil) 457 if err != nil { 458 return nil, nil, nil, "", err 459 } 460 461 return nil, nil, dotCommitVal, tableName, nil 462 } 463 464 fromCommitVal, err := p.fromCommitExpr.Eval(p.ctx, nil) 465 if err != nil { 466 return nil, nil, nil, "", err 467 } 468 469 toCommitVal, err := p.toCommitExpr.Eval(p.ctx, nil) 470 if err != nil { 471 return nil, nil, nil, "", err 472 } 473 474 return fromCommitVal, toCommitVal, nil, tableName, nil 475 } 476 477 type patchNode struct { 478 tblName string 479 schemaPatchStmts []string 480 dataPatchStmts []string 481 } 482 483 func getPatchNodes(ctx *sql.Context, dbData env.DbData, tableDeltas []diff.TableDelta, fromRefDetails, toRefDetails *refDetails, includeSchemaDiff, includeDataDiff bool) (patches []*patchNode, err error) { 484 for _, td := range tableDeltas { 485 if td.FromTable == nil && td.ToTable == nil { 486 // no diff 487 if !strings.HasPrefix(td.FromName, diff.DBPrefix) || !strings.HasPrefix(td.ToName, diff.DBPrefix) { 488 continue 489 } 490 491 // db collation diff 492 dbName := strings.TrimPrefix(td.ToName, diff.DBPrefix) 493 fromColl, cerr := fromRefDetails.root.GetCollation(ctx) 494 if cerr != nil { 495 return nil, cerr 496 } 497 toColl, cerr := toRefDetails.root.GetCollation(ctx) 498 if cerr != nil { 499 return nil, cerr 500 } 501 alterDBCollStmt := sqlfmt.AlterDatabaseCollateStmt(dbName, fromColl, toColl) 502 patches = append(patches, &patchNode{ 503 tblName: td.FromName, 504 schemaPatchStmts: []string{alterDBCollStmt}, 505 dataPatchStmts: []string{}, 506 }) 507 } 508 509 tblName := td.ToName 510 if td.IsDrop() { 511 tblName = td.FromName 512 } 513 514 // Get SCHEMA DIFF 515 var schemaStmts []string 516 if includeSchemaDiff { 517 schemaStmts, err = getSchemaSqlPatch(ctx, toRefDetails.root, td) 518 if err != nil { 519 return nil, err 520 } 521 } 522 523 // Get DATA DIFF 524 var dataStmts []string 525 if includeDataDiff && canGetDataDiff(ctx, td) { 526 dataStmts, err = getUserTableDataSqlPatch(ctx, dbData, td, fromRefDetails, toRefDetails) 527 if err != nil { 528 return nil, err 529 } 530 } 531 532 patches = append(patches, &patchNode{tblName: tblName, schemaPatchStmts: schemaStmts, dataPatchStmts: dataStmts}) 533 } 534 535 return patches, nil 536 } 537 538 func getSchemaSqlPatch(ctx *sql.Context, toRoot doltdb.RootValue, td diff.TableDelta) ([]string, error) { 539 toSchemas, err := doltdb.GetAllSchemas(ctx, toRoot) 540 if err != nil { 541 return nil, fmt.Errorf("could not read schemas from toRoot, cause: %s", err.Error()) 542 } 543 544 fromSch, toSch, err := td.GetSchemas(ctx) 545 if err != nil { 546 return nil, fmt.Errorf("cannot retrieve schema for table %s, cause: %s", td.ToName, err.Error()) 547 } 548 549 var ddlStatements []string 550 if td.IsDrop() { 551 ddlStatements = append(ddlStatements, sqlfmt.DropTableStmt(td.FromName)) 552 } else if td.IsAdd() { 553 stmt, err := sqlfmt.GenerateCreateTableStatement(td.ToName, td.ToSch, td.ToFks, td.ToFksParentSch) 554 if err != nil { 555 return nil, errhand.VerboseErrorFromError(err) 556 } 557 ddlStatements = append(ddlStatements, stmt) 558 } else { 559 stmts, err := GetNonCreateNonDropTableSqlSchemaDiff(td, toSchemas, fromSch, toSch) 560 if err != nil { 561 return nil, err 562 } 563 ddlStatements = append(ddlStatements, stmts...) 564 } 565 566 return ddlStatements, nil 567 } 568 569 func canGetDataDiff(ctx *sql.Context, td diff.TableDelta) bool { 570 if td.IsDrop() { 571 return false // don't output DELETE FROM statements after DROP TABLE 572 } 573 574 // not diffable 575 if !schema.ArePrimaryKeySetsDiffable(td.Format(), td.FromSch, td.ToSch) { 576 ctx.Session.Warn(&sql.Warning{ 577 Level: "Warning", 578 Code: mysql.ERNotSupportedYet, 579 Message: fmt.Sprintf("Primary key sets differ between revisions for table '%s', skipping data diff", td.ToName), 580 }) 581 return false 582 } 583 584 return true 585 } 586 587 func getUserTableDataSqlPatch(ctx *sql.Context, dbData env.DbData, td diff.TableDelta, fromRefDetails, toRefDetails *refDetails) ([]string, error) { 588 // ToTable is used as target table as it cannot be nil at this point 589 diffSch, projections, ri, err := getDiffQuery(ctx, dbData, td, fromRefDetails, toRefDetails) 590 if err != nil { 591 return nil, err 592 } 593 594 targetPkSch, err := sqlutil.FromDoltSchema("", td.ToName, td.ToSch) 595 if err != nil { 596 return nil, err 597 } 598 599 return getDataSqlPatchResults(ctx, diffSch, targetPkSch.Schema, projections, ri, td.ToName, td.ToSch) 600 } 601 602 func getDataSqlPatchResults(ctx *sql.Context, diffQuerySch, targetSch sql.Schema, projections []sql.Expression, iter sql.RowIter, tn string, tsch schema.Schema) ([]string, error) { 603 ds, err := diff.NewDiffSplitter(diffQuerySch, targetSch) 604 if err != nil { 605 return nil, err 606 } 607 608 var res []string 609 for { 610 r, err := iter.Next(ctx) 611 if err == io.EOF { 612 return res, nil 613 } else if err != nil { 614 return nil, err 615 } 616 617 r, err = rowexec.ProjectRow(ctx, projections, r) 618 if err != nil { 619 return nil, err 620 } 621 622 oldRow, newRow, err := ds.SplitDiffResultRow(r) 623 if err != nil { 624 return nil, err 625 } 626 627 var stmt string 628 if oldRow.Row != nil { 629 stmt, err = diff.GetDataDiffStatement(tn, tsch, oldRow.Row, oldRow.RowDiff, oldRow.ColDiffs) 630 if err != nil { 631 return nil, err 632 } 633 } 634 635 if newRow.Row != nil { 636 stmt, err = diff.GetDataDiffStatement(tn, tsch, newRow.Row, newRow.RowDiff, newRow.ColDiffs) 637 if err != nil { 638 return nil, err 639 } 640 } 641 642 if stmt != "" { 643 res = append(res, stmt) 644 } 645 } 646 } 647 648 // GetNonCreateNonDropTableSqlSchemaDiff returns any schema diff in SQL statements that is NEITHER 'CREATE TABLE' NOR 'DROP TABLE' statements. 649 func GetNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas map[string]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) { 650 if td.IsAdd() || td.IsDrop() { 651 // use add and drop specific methods 652 return nil, nil 653 } 654 655 var ddlStatements []string 656 if td.FromName != td.ToName { 657 ddlStatements = append(ddlStatements, sqlfmt.RenameTableStmt(td.FromName, td.ToName)) 658 } 659 660 eq := schema.SchemasAreEqual(fromSch, toSch) 661 if eq && !td.HasFKChanges() { 662 return ddlStatements, nil 663 } 664 665 colDiffs, unionTags := diff.DiffSchColumns(fromSch, toSch) 666 for _, tag := range unionTags { 667 cd := colDiffs[tag] 668 switch cd.DiffType { 669 case diff.SchDiffNone: 670 case diff.SchDiffAdded: 671 ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddColStmt(td.ToName, sqlfmt.GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation())))) 672 case diff.SchDiffRemoved: 673 ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropColStmt(td.ToName, cd.Old.Name)) 674 case diff.SchDiffModified: 675 // Ignore any primary key set changes here 676 if cd.Old.IsPartOfPK != cd.New.IsPartOfPK { 677 continue 678 } 679 if cd.Old.Name != cd.New.Name { 680 ddlStatements = append(ddlStatements, sqlfmt.AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name)) 681 } 682 if cd.Old.TypeInfo != cd.New.TypeInfo { 683 ddlStatements = append(ddlStatements, sqlfmt.AlterTableModifyColStmt(td.ToName, 684 sqlfmt.GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation())))) 685 } 686 } 687 } 688 689 // Print changes between a primary key set change. It contains an ALTER TABLE DROP and an ALTER TABLE ADD 690 if !schema.ColCollsAreEqual(fromSch.GetPKCols(), toSch.GetPKCols()) { 691 ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropPks(td.ToName)) 692 if toSch.GetPKCols().Size() > 0 { 693 ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols().GetColumnNames())) 694 } 695 } 696 697 for _, idxDiff := range diff.DiffSchIndexes(fromSch, toSch) { 698 switch idxDiff.DiffType { 699 case diff.SchDiffNone: 700 case diff.SchDiffAdded: 701 ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) 702 case diff.SchDiffRemoved: 703 ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) 704 case diff.SchDiffModified: 705 ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) 706 ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) 707 } 708 } 709 710 for _, fkDiff := range diff.DiffForeignKeys(td.FromFks, td.ToFks) { 711 switch fkDiff.DiffType { 712 case diff.SchDiffNone: 713 case diff.SchDiffAdded: 714 parentSch := toSchemas[fkDiff.To.ReferencedTableName] 715 ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) 716 case diff.SchDiffRemoved: 717 from := fkDiff.From 718 ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name)) 719 case diff.SchDiffModified: 720 from := fkDiff.From 721 ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name)) 722 723 parentSch := toSchemas[fkDiff.To.ReferencedTableName] 724 ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) 725 } 726 } 727 728 // Handle charset/collation changes 729 toCollation := toSch.GetCollation() 730 fromCollation := fromSch.GetCollation() 731 if toCollation != fromCollation { 732 ddlStatements = append(ddlStatements, sqlfmt.AlterTableCollateStmt(td.ToName, fromCollation, toCollation)) 733 } 734 735 return ddlStatements, nil 736 } 737 738 // getDiffQuery returns diff schema for specified columns and array of sql.Expression as projection to be used 739 // on diff table function row iter. This function attempts to imitate running a query 740 // fmt.Sprintf("select %s, %s from dolt_diff('%s', '%s', '%s')", columnsWithDiff, "diff_type", fromRef, toRef, tableName) 741 // on sql engine, which returns the schema and rowIter of the final data diff result. 742 func getDiffQuery(ctx *sql.Context, dbData env.DbData, td diff.TableDelta, fromRefDetails, toRefDetails *refDetails) (sql.Schema, []sql.Expression, sql.RowIter, error) { 743 diffTableSchema, j, err := dtables.GetDiffTableSchemaAndJoiner(td.ToTable.Format(), td.FromSch, td.ToSch) 744 if err != nil { 745 return nil, nil, nil, err 746 } 747 diffPKSch, err := sqlutil.FromDoltSchema("", "", diffTableSchema) 748 if err != nil { 749 return nil, nil, nil, err 750 } 751 752 columnsWithDiff := getColumnNamesWithDiff(td.FromSch, td.ToSch) 753 diffQuerySqlSch, projections := getDiffQuerySqlSchemaAndProjections(diffPKSch.Schema, columnsWithDiff) 754 755 dp := dtables.NewDiffPartition(td.ToTable, td.FromTable, toRefDetails.hashStr, fromRefDetails.hashStr, toRefDetails.commitTime, fromRefDetails.commitTime, td.ToSch, td.FromSch) 756 ri := dtables.NewDiffPartitionRowIter(*dp, dbData.Ddb, j) 757 758 return diffQuerySqlSch, projections, ri, nil 759 } 760 761 func getColumnNamesWithDiff(fromSch, toSch schema.Schema) []string { 762 var cols []string 763 764 if fromSch != nil { 765 _ = fromSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 766 cols = append(cols, fmt.Sprintf("from_%s", col.Name)) 767 return false, nil 768 }) 769 } 770 if toSch != nil { 771 _ = toSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 772 cols = append(cols, fmt.Sprintf("to_%s", col.Name)) 773 return false, nil 774 }) 775 } 776 return cols 777 } 778 779 // getDiffQuerySqlSchemaAndProjections returns the schema of columns with data diff and "diff_type". This is used for diff splitter. 780 // When extracting the diff schema, the ordering must follow the ordering of given columns 781 func getDiffQuerySqlSchemaAndProjections(diffTableSch sql.Schema, columns []string) (sql.Schema, []sql.Expression) { 782 type column struct { 783 sqlCol *sql.Column 784 idx int 785 } 786 787 columns = append(columns, diffTypeColumnName) 788 colMap := make(map[string]*column) 789 for _, c := range columns { 790 colMap[c] = nil 791 } 792 793 var cols = make([]*sql.Column, len(columns)) 794 var getFieldCols = make([]sql.Expression, len(columns)) 795 796 for i, c := range diffTableSch { 797 if _, ok := colMap[c.Name]; ok { 798 colMap[c.Name] = &column{c, i} 799 } 800 } 801 802 for i, c := range columns { 803 col := colMap[c].sqlCol 804 cols[i] = col 805 getFieldCols[i] = expression.NewGetField(colMap[c].idx, col.Type, col.Name, col.Nullable) 806 } 807 808 return cols, getFieldCols 809 } 810 811 //------------------------------------ 812 // patchTableFunctionRowIter 813 //------------------------------------ 814 815 var _ sql.RowIter = (*patchTableFunctionRowIter)(nil) 816 817 type patchTableFunctionRowIter struct { 818 patches []*patchNode 819 patchIdx int 820 statementIdx int 821 fromRef string 822 toRef string 823 currentPatch *patchNode 824 currentRowIter *sql.RowIter 825 } 826 827 // newPatchTableFunctionRowIter iterates over each patch nodes given returning 828 // each statement in each patch node as a single row including from_commit_hash, 829 // to_commit_hash and table_name prepended to diff_type and statement for each patch statement. 830 func newPatchTableFunctionRowIter(patchNodes []*patchNode, fromRef, toRef string) sql.RowIter { 831 return &patchTableFunctionRowIter{ 832 patches: patchNodes, 833 patchIdx: 0, 834 statementIdx: 0, 835 fromRef: fromRef, 836 toRef: toRef, 837 } 838 } 839 840 func (itr *patchTableFunctionRowIter) Next(ctx *sql.Context) (sql.Row, error) { 841 for { 842 if itr.patchIdx >= len(itr.patches) { 843 return nil, io.EOF 844 } 845 if itr.currentPatch == nil { 846 itr.currentPatch = itr.patches[itr.patchIdx] 847 } 848 if itr.currentRowIter == nil { 849 ri := newPatchStatementsRowIter(itr.currentPatch.schemaPatchStmts, itr.currentPatch.dataPatchStmts) 850 itr.currentRowIter = &ri 851 } 852 853 row, err := (*itr.currentRowIter).Next(ctx) 854 if err == io.EOF { 855 itr.currentPatch = nil 856 itr.currentRowIter = nil 857 itr.patchIdx++ 858 continue 859 } else if err != nil { 860 return nil, err 861 } else { 862 itr.statementIdx++ 863 r := sql.Row{itr.statementIdx, itr.fromRef, itr.toRef, itr.currentPatch.tblName} 864 return r.Append(row), nil 865 } 866 } 867 } 868 869 func (itr *patchTableFunctionRowIter) Close(_ *sql.Context) error { 870 return nil 871 } 872 873 //------------------------------------ 874 // patchStatementsRowIter 875 //------------------------------------ 876 877 var _ sql.RowIter = (*patchStatementsRowIter)(nil) 878 879 type patchStatementsRowIter struct { 880 stmts []string 881 ddlLen int 882 idx int 883 } 884 885 // newPatchStatementsRowIter iterates over each patch statements returning row of diff_type of either 'schema' or 'data' with the statement. 886 func newPatchStatementsRowIter(ddlStmts, dataStmts []string) sql.RowIter { 887 return &patchStatementsRowIter{ 888 stmts: append(ddlStmts, dataStmts...), 889 ddlLen: len(ddlStmts), 890 idx: 0, 891 } 892 } 893 894 func (p *patchStatementsRowIter) Next(ctx *sql.Context) (sql.Row, error) { 895 defer func() { 896 p.idx++ 897 }() 898 899 if p.idx >= len(p.stmts) { 900 return nil, io.EOF 901 } 902 903 if p.stmts == nil { 904 return nil, io.EOF 905 } 906 907 stmt := p.stmts[p.idx] 908 diffType := diffTypeSchema 909 if p.idx >= p.ddlLen { 910 diffType = diffTypeData 911 } 912 913 return sql.Row{diffType, stmt}, nil 914 } 915 916 func (p *patchStatementsRowIter) Close(_ *sql.Context) error { 917 return nil 918 }