github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/table_copier.go (about)

     1  package plan
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/dolthub/go-mysql-server/sql"
     7  	"github.com/dolthub/go-mysql-server/sql/mysql_db"
     8  	"github.com/dolthub/go-mysql-server/sql/types"
     9  )
    10  
    11  // TableCopier is a supporting node that allows for the optimization of copying tables. It should be used in two cases.
    12  // 1) CREATE TABLE SELECT *
    13  // 2) INSERT INTO SELECT * where the inserted table is empty. // TODO: Implement this optimization
    14  type TableCopier struct {
    15  	Source      sql.Node
    16  	Destination sql.Node
    17  	db          sql.Database
    18  	options     CopierProps
    19  }
    20  
    21  var _ sql.Databaser = (*TableCopier)(nil)
    22  var _ sql.Node = (*TableCopier)(nil)
    23  var _ sql.CollationCoercible = (*TableCopier)(nil)
    24  
    25  type CopierProps struct {
    26  	replace bool
    27  	ignore  bool
    28  }
    29  
    30  func NewTableCopier(db sql.Database, createTableNode sql.Node, source sql.Node, prop CopierProps) *TableCopier {
    31  	return &TableCopier{
    32  		Source:      source,
    33  		Destination: createTableNode,
    34  		db:          db,
    35  		options:     prop,
    36  	}
    37  }
    38  
    39  func (tc *TableCopier) WithDatabase(db sql.Database) (sql.Node, error) {
    40  	ntc := *tc
    41  	ntc.db = db
    42  	return &ntc, nil
    43  }
    44  
    45  func (tc *TableCopier) IsReadOnly() bool {
    46  	return false
    47  }
    48  
    49  func (tc *TableCopier) Database() sql.Database {
    50  	return tc.db
    51  }
    52  
    53  func (tc *TableCopier) ProcessCreateTable(ctx *sql.Context, b sql.NodeExecBuilder, row sql.Row) (sql.RowIter, error) {
    54  	ct := tc.Destination.(*CreateTable)
    55  
    56  	_, err := b.Build(ctx, ct, row)
    57  	if err != nil {
    58  		return sql.RowsToRowIter(), err
    59  	}
    60  
    61  	table, tableExists, err := tc.db.GetTableInsensitive(ctx, ct.Name())
    62  	if err != nil {
    63  		return sql.RowsToRowIter(), err
    64  	}
    65  
    66  	if !tableExists {
    67  		return sql.RowsToRowIter(), fmt.Errorf("error: Newly created table does not exist")
    68  	}
    69  
    70  	if tc.createTableSelectCanBeCopied(table) {
    71  		return tc.CopyTableOver(ctx, tc.Source.Schema()[0].Source, table.Name())
    72  	}
    73  
    74  	// TODO: Improve parsing for CREATE TABLE SELECT to allow for IGNORE/REPLACE and custom specs
    75  	ii := NewInsertInto(tc.db, NewResolvedTable(table, tc.db, nil), tc.Source, tc.options.replace, nil, nil, tc.options.ignore)
    76  
    77  	// Wrap the insert into a row update accumulator
    78  	roa := NewRowUpdateAccumulator(ii, UpdateTypeInsert)
    79  
    80  	return b.Build(ctx, roa, row)
    81  }
    82  
    83  // createTableSelectCanBeCopied determines whether the newly created table's data can just be copied from the Source table
    84  func (tc *TableCopier) createTableSelectCanBeCopied(tableNode sql.Table) bool {
    85  	// The differences in LIMIT between integrators prevent us from using a copy
    86  	if _, ok := tc.Source.(*Limit); ok {
    87  		return false
    88  	}
    89  
    90  	// If the DB does not implement the TableCopierDatabase interface we cannot copy over the table.
    91  	if privDb, ok := tc.db.(mysql_db.PrivilegedDatabase); ok {
    92  		if _, ok := privDb.Unwrap().(sql.TableCopierDatabase); !ok {
    93  			return false
    94  		}
    95  	} else if _, ok := tc.db.(sql.TableCopierDatabase); !ok {
    96  		return false
    97  	}
    98  
    99  	// If there isn't a match in schema we cannot do a direct copy.
   100  	sourceSchema := tc.Source.Schema()
   101  	tableNodeSchema := tableNode.Schema()
   102  
   103  	if len(sourceSchema) != len(tableNodeSchema) {
   104  		return false
   105  	}
   106  
   107  	for i, sn := range sourceSchema {
   108  		if sn.Name != tableNodeSchema[i].Name {
   109  			return false
   110  		}
   111  	}
   112  
   113  	return true
   114  }
   115  
   116  // CopyTableOver is used when we can guarantee the Destination table will have the same data as the source table.
   117  func (tc *TableCopier) CopyTableOver(ctx *sql.Context, sourceTable string, destinationTable string) (sql.RowIter, error) {
   118  	db, ok := tc.db.(sql.TableCopierDatabase)
   119  	if !ok {
   120  		return sql.RowsToRowIter(), sql.ErrTableCopyingNotSupported.New()
   121  	}
   122  
   123  	rowsUpdated, err := db.CopyTableData(ctx, sourceTable, destinationTable)
   124  	if err != nil {
   125  		return sql.RowsToRowIter(), err
   126  	}
   127  
   128  	return sql.RowsToRowIter([]sql.Row{{types.OkResult{RowsAffected: rowsUpdated, InsertID: 0, Info: nil}}}...), nil
   129  }
   130  
   131  func (tc *TableCopier) Schema() sql.Schema {
   132  	return tc.Destination.Schema()
   133  }
   134  
   135  func (tc *TableCopier) Children() []sql.Node {
   136  	return nil
   137  }
   138  
   139  func (tc *TableCopier) WithChildren(...sql.Node) (sql.Node, error) {
   140  	return tc, nil
   141  }
   142  
   143  // CheckPrivileges implements the interface sql.Node.
   144  func (tc *TableCopier) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   145  	//TODO: add a new branch when the INSERT optimization is added
   146  	subject := sql.PrivilegeCheckSubject{Database: tc.db.Name()}
   147  	return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Create)) &&
   148  		tc.Source.CheckPrivileges(ctx, opChecker)
   149  }
   150  
   151  // CollationCoercibility implements the interface sql.CollationCoercible.
   152  func (*TableCopier) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   153  	return sql.Collation_binary, 7
   154  }
   155  
   156  func (tc *TableCopier) Resolved() bool {
   157  	return tc.Source.Resolved()
   158  }
   159  
   160  func (tc *TableCopier) String() string {
   161  	return fmt.Sprintf("TABLE_COPY SRC: %s into DST: %s", tc.Source, tc.Destination)
   162  }