github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/mvdata/engine_table_writer.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mvdata
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"sync/atomic"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/analyzer"
    25  	"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
    26  	"github.com/dolthub/go-mysql-server/sql/plan"
    27  	"github.com/dolthub/go-mysql-server/sql/planbuilder"
    28  	"github.com/dolthub/go-mysql-server/sql/rowexec"
    29  	"github.com/dolthub/go-mysql-server/sql/transform"
    30  
    31  	"github.com/dolthub/dolt/go/cmd/dolt/commands/engine"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    34  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    35  	"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/noms"
    36  	"github.com/dolthub/dolt/go/store/types"
    37  )
    38  
    39  const (
    40  	// tableWriterStatUpdateRate is the number of writes that will process before the updated stats are displayed.
    41  	tableWriterStatUpdateRate = 64 * 1024
    42  )
    43  
    44  // SqlEngineTableWriter is a utility for importing a set of rows through the sql engine.
    45  type SqlEngineTableWriter struct {
    46  	se     *engine.SqlEngine
    47  	sqlCtx *sql.Context
    48  
    49  	tableName  string
    50  	database   string
    51  	contOnErr  bool
    52  	force      bool
    53  	disableFks bool
    54  
    55  	statsCB noms.StatsCB
    56  	stats   types.AppliedEditStats
    57  	statOps int32
    58  
    59  	importOption       TableImportOp
    60  	tableSchema        sql.PrimaryKeySchema
    61  	rowOperationSchema sql.PrimaryKeySchema
    62  }
    63  
    64  func NewSqlEngineTableWriter(ctx context.Context, dEnv *env.DoltEnv, createTableSchema, rowOperationSchema schema.Schema, options *MoverOptions, statsCB noms.StatsCB) (*SqlEngineTableWriter, error) {
    65  	// TODO: Assert that dEnv.DoltDB.AccessMode() != ReadOnly?
    66  
    67  	mrEnv, err := env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	// Simplest path would have our import path be a layer over load data
    73  	config := &engine.SqlEngineConfig{
    74  		ServerUser: "root",
    75  		Autocommit: false, // We set autocommit == false to ensure to improve performance. Bulk import should not commit on each row.
    76  		Bulk:       true,
    77  	}
    78  	se, err := engine.NewSqlEngine(
    79  		ctx,
    80  		mrEnv,
    81  		config,
    82  	)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	defer se.Close()
    87  
    88  	dbName := mrEnv.GetFirstDatabase()
    89  
    90  	if se.GetUnderlyingEngine().IsReadOnly() {
    91  		// SqlEngineTableWriter does not respect read only mode
    92  		return nil, analyzererrors.ErrReadOnlyDatabase.New(dbName)
    93  	}
    94  
    95  	sqlCtx, err := se.NewLocalContext(ctx)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	sqlCtx.SetCurrentDatabase(dbName)
   100  
   101  	doltCreateTableSchema, err := sqlutil.FromDoltSchema("", options.TableToWriteTo, createTableSchema)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	doltRowOperationSchema, err := sqlutil.FromDoltSchema("", options.TableToWriteTo, rowOperationSchema)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	return &SqlEngineTableWriter{
   112  		se:         se,
   113  		sqlCtx:     sqlCtx,
   114  		contOnErr:  options.ContinueOnErr,
   115  		force:      options.Force,
   116  		disableFks: options.DisableFks,
   117  
   118  		database:  dbName,
   119  		tableName: options.TableToWriteTo,
   120  
   121  		statsCB: statsCB,
   122  
   123  		importOption:       options.Operation,
   124  		tableSchema:        doltCreateTableSchema,
   125  		rowOperationSchema: doltRowOperationSchema,
   126  	}, nil
   127  }
   128  
   129  func (s *SqlEngineTableWriter) WriteRows(ctx context.Context, inputChannel chan sql.Row, badRowCb func(row sql.Row, rowSchema sql.PrimaryKeySchema, tableName string, lineNumber int, err error) bool) (err error) {
   130  	err = s.forceDropTableIfNeeded()
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	_, _, err = s.se.Query(s.sqlCtx, "START TRANSACTION")
   136  	if err != nil {
   137  		return err
   138  	}
   139  
   140  	if s.disableFks {
   141  		_, _, err = s.se.Query(s.sqlCtx, "SET FOREIGN_KEY_CHECKS = 0")
   142  		if err != nil {
   143  			return err
   144  		}
   145  	}
   146  
   147  	err = s.createOrEmptyTableIfNeeded()
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	updateStats := func(row sql.Row) {
   153  		if row == nil {
   154  			return
   155  		}
   156  
   157  		// If the length of the row does not match the schema then we have an update operation.
   158  		if len(row) != len(s.tableSchema.Schema) {
   159  			oldRow := row[:len(row)/2]
   160  			newRow := row[len(row)/2:]
   161  
   162  			if ok, err := oldRow.Equals(newRow, s.tableSchema.Schema); err == nil {
   163  				if ok {
   164  					s.stats.SameVal++
   165  				} else {
   166  					s.stats.Modifications++
   167  				}
   168  			}
   169  		} else {
   170  			s.stats.Additions++
   171  		}
   172  	}
   173  
   174  	insertOrUpdateOperation, err := s.getInsertNode(inputChannel, false)
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	iter, err := rowexec.DefaultBuilder.Build(s.sqlCtx, insertOrUpdateOperation, nil)
   180  	if err != nil {
   181  		return err
   182  	}
   183  
   184  	defer func() {
   185  		rerr := iter.Close(s.sqlCtx)
   186  		if err == nil {
   187  			err = rerr
   188  		}
   189  	}()
   190  
   191  	line := 1
   192  
   193  	for {
   194  		if s.statsCB != nil && atomic.LoadInt32(&s.statOps) >= tableWriterStatUpdateRate {
   195  			atomic.StoreInt32(&s.statOps, 0)
   196  			s.statsCB(s.stats)
   197  		}
   198  
   199  		row, err := iter.Next(s.sqlCtx)
   200  		line += 1
   201  
   202  		// All other errors are handled by the errorHandler
   203  		if err == nil {
   204  			_ = atomic.AddInt32(&s.statOps, 1)
   205  			updateStats(row)
   206  		} else if err == io.EOF {
   207  			atomic.LoadInt32(&s.statOps)
   208  			atomic.StoreInt32(&s.statOps, 0)
   209  			if s.statsCB != nil {
   210  				s.statsCB(s.stats)
   211  			}
   212  
   213  			return err
   214  		} else {
   215  			var offendingRow sql.Row
   216  			switch n := err.(type) {
   217  			case sql.WrappedInsertError:
   218  				offendingRow = n.OffendingRow
   219  			case sql.IgnorableError:
   220  				offendingRow = n.OffendingRow
   221  			}
   222  
   223  			quit := badRowCb(offendingRow, s.tableSchema, s.tableName, line, err)
   224  			if quit {
   225  				return err
   226  			}
   227  		}
   228  	}
   229  }
   230  
   231  func (s *SqlEngineTableWriter) Commit(ctx context.Context) error {
   232  	_, _, err := s.se.Query(s.sqlCtx, "COMMIT")
   233  	return err
   234  }
   235  
   236  func (s *SqlEngineTableWriter) RowOperationSchema() sql.PrimaryKeySchema {
   237  	return s.rowOperationSchema
   238  }
   239  
   240  func (s *SqlEngineTableWriter) TableSchema() sql.PrimaryKeySchema {
   241  	return s.tableSchema
   242  }
   243  
   244  // forceDropTableIfNeeded drop the given table in case the -f parameter is passed.
   245  func (s *SqlEngineTableWriter) forceDropTableIfNeeded() error {
   246  	if s.force {
   247  		_, _, err := s.se.Query(s.sqlCtx, fmt.Sprintf("DROP TABLE IF EXISTS `%s`", s.tableName))
   248  		return err
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  // createOrEmptyTableIfNeeded either creates or truncates the table given a -c or -r parameter.
   255  func (s *SqlEngineTableWriter) createOrEmptyTableIfNeeded() error {
   256  	switch s.importOption {
   257  	case CreateOp:
   258  		return s.createTable()
   259  	case ReplaceOp:
   260  		_, _, err := s.se.Query(s.sqlCtx, fmt.Sprintf("TRUNCATE TABLE `%s`", s.tableName))
   261  		return err
   262  	default:
   263  		return nil
   264  	}
   265  }
   266  
   267  // createTable creates a table.
   268  func (s *SqlEngineTableWriter) createTable() error {
   269  	// TODO don't use internal interfaces to do this, we had to have a sql.Schema somewhere
   270  	// upstream to make the dolt schema
   271  	sqlCols := make([]string, len(s.tableSchema.Schema))
   272  	for i, c := range s.tableSchema.Schema {
   273  		sqlCols[i] = sql.GenerateCreateTableColumnDefinition(c, c.Default.String(), c.OnUpdate.String(), sql.Collation_Default)
   274  	}
   275  	var pks string
   276  	var sep string
   277  	for _, i := range s.tableSchema.PkOrdinals {
   278  		pks += sep + sql.QuoteIdentifier(s.tableSchema.Schema[i].Name)
   279  		sep = ", "
   280  	}
   281  	if len(sep) > 0 {
   282  		sqlCols = append(sqlCols, fmt.Sprintf("PRIMARY KEY (%s)", pks))
   283  	}
   284  
   285  	createTable := sql.GenerateCreateTableStatement(s.tableName, sqlCols, "", sql.CharacterSet_utf8mb4.String(), sql.Collation_Default.String(), "")
   286  	_, iter, err := s.se.Query(s.sqlCtx, createTable)
   287  	if err != nil {
   288  		return err
   289  	}
   290  	_, err = sql.RowIterToRows(s.sqlCtx, iter)
   291  	return err
   292  }
   293  
   294  // createInsertImportNode creates the relevant/analyzed insert node given the import option. This insert node is wrapped
   295  // with an error handler.
   296  func (s *SqlEngineTableWriter) getInsertNode(inputChannel chan sql.Row, replace bool) (sql.Node, error) {
   297  	update := s.importOption == UpdateOp
   298  	colNames := ""
   299  	values := ""
   300  	duplicate := ""
   301  	if update {
   302  		duplicate += " ON DUPLICATE KEY UPDATE "
   303  	}
   304  	sep := ""
   305  	for _, col := range s.rowOperationSchema.Schema {
   306  		colNames += fmt.Sprintf("%s%s", sep, sql.QuoteIdentifier(col.Name))
   307  		values += fmt.Sprintf("%s1", sep)
   308  		if update {
   309  			duplicate += fmt.Sprintf("%s`%s` = VALUES(`%s`)", sep, col.Name, col.Name)
   310  		}
   311  		sep = ", "
   312  	}
   313  
   314  	sqlEngine := s.se.GetUnderlyingEngine()
   315  	binder := planbuilder.New(s.sqlCtx, sqlEngine.Analyzer.Catalog, sqlEngine.Parser)
   316  	insert := fmt.Sprintf("insert into `%s` (%s) VALUES (%s)%s", s.tableName, colNames, values, duplicate)
   317  	parsed, _, _, err := binder.Parse(insert, false)
   318  	if err != nil {
   319  		return nil, fmt.Errorf("error constructing import query '%s': %w", insert, err)
   320  	}
   321  	parsedIns, ok := parsed.(*plan.InsertInto)
   322  	if !ok {
   323  		return nil, fmt.Errorf("import setup expected *plan.InsertInto root, found %T", parsed)
   324  	}
   325  	schema := make(sql.Schema, len(s.rowOperationSchema.Schema))
   326  	for i, c := range s.rowOperationSchema.Schema {
   327  		newC := c.Copy()
   328  		newC.Source = planbuilder.OnDupValuesPrefix
   329  		schema[i] = newC
   330  	}
   331  
   332  	switch n := parsedIns.Source.(type) {
   333  	case *plan.Values:
   334  		parsedIns.Source = NewChannelRowSource(schema, inputChannel)
   335  	case *plan.Project:
   336  		n.Child = NewChannelRowSource(schema, inputChannel)
   337  	}
   338  
   339  	parsedIns.Ignore = s.contOnErr
   340  	parsedIns.IsReplace = replace
   341  	analyzed, err := s.se.Analyze(s.sqlCtx, parsedIns)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	analyzed = analyzer.StripPassthroughNodes(analyzed)
   347  
   348  	// Get the first insert (wrapped with the error handler)
   349  	transform.Inspect(analyzed, func(node sql.Node) bool {
   350  		switch n := node.(type) {
   351  		case *plan.InsertInto:
   352  			analyzed = n
   353  			return false
   354  		default:
   355  			return true
   356  		}
   357  	})
   358  
   359  	return analyzed, nil
   360  }