github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/mvdata/engine_table_writer.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mvdata 16 17 import ( 18 "context" 19 "fmt" 20 "io" 21 "sync/atomic" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/analyzer" 25 "github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors" 26 "github.com/dolthub/go-mysql-server/sql/plan" 27 "github.com/dolthub/go-mysql-server/sql/planbuilder" 28 "github.com/dolthub/go-mysql-server/sql/rowexec" 29 "github.com/dolthub/go-mysql-server/sql/transform" 30 31 "github.com/dolthub/dolt/go/cmd/dolt/commands/engine" 32 "github.com/dolthub/dolt/go/libraries/doltcore/env" 33 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 34 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" 35 "github.com/dolthub/dolt/go/libraries/doltcore/table/typed/noms" 36 "github.com/dolthub/dolt/go/store/types" 37 ) 38 39 const ( 40 // tableWriterStatUpdateRate is the number of writes that will process before the updated stats are displayed. 41 tableWriterStatUpdateRate = 64 * 1024 42 ) 43 44 // SqlEngineTableWriter is a utility for importing a set of rows through the sql engine. 45 type SqlEngineTableWriter struct { 46 se *engine.SqlEngine 47 sqlCtx *sql.Context 48 49 tableName string 50 database string 51 contOnErr bool 52 force bool 53 disableFks bool 54 55 statsCB noms.StatsCB 56 stats types.AppliedEditStats 57 statOps int32 58 59 importOption TableImportOp 60 tableSchema sql.PrimaryKeySchema 61 rowOperationSchema sql.PrimaryKeySchema 62 } 63 64 func NewSqlEngineTableWriter(ctx context.Context, dEnv *env.DoltEnv, createTableSchema, rowOperationSchema schema.Schema, options *MoverOptions, statsCB noms.StatsCB) (*SqlEngineTableWriter, error) { 65 // TODO: Assert that dEnv.DoltDB.AccessMode() != ReadOnly? 66 67 mrEnv, err := env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv) 68 if err != nil { 69 return nil, err 70 } 71 72 // Simplest path would have our import path be a layer over load data 73 config := &engine.SqlEngineConfig{ 74 ServerUser: "root", 75 Autocommit: false, // We set autocommit == false to ensure to improve performance. Bulk import should not commit on each row. 76 Bulk: true, 77 } 78 se, err := engine.NewSqlEngine( 79 ctx, 80 mrEnv, 81 config, 82 ) 83 if err != nil { 84 return nil, err 85 } 86 defer se.Close() 87 88 dbName := mrEnv.GetFirstDatabase() 89 90 if se.GetUnderlyingEngine().IsReadOnly() { 91 // SqlEngineTableWriter does not respect read only mode 92 return nil, analyzererrors.ErrReadOnlyDatabase.New(dbName) 93 } 94 95 sqlCtx, err := se.NewLocalContext(ctx) 96 if err != nil { 97 return nil, err 98 } 99 sqlCtx.SetCurrentDatabase(dbName) 100 101 doltCreateTableSchema, err := sqlutil.FromDoltSchema("", options.TableToWriteTo, createTableSchema) 102 if err != nil { 103 return nil, err 104 } 105 106 doltRowOperationSchema, err := sqlutil.FromDoltSchema("", options.TableToWriteTo, rowOperationSchema) 107 if err != nil { 108 return nil, err 109 } 110 111 return &SqlEngineTableWriter{ 112 se: se, 113 sqlCtx: sqlCtx, 114 contOnErr: options.ContinueOnErr, 115 force: options.Force, 116 disableFks: options.DisableFks, 117 118 database: dbName, 119 tableName: options.TableToWriteTo, 120 121 statsCB: statsCB, 122 123 importOption: options.Operation, 124 tableSchema: doltCreateTableSchema, 125 rowOperationSchema: doltRowOperationSchema, 126 }, nil 127 } 128 129 func (s *SqlEngineTableWriter) WriteRows(ctx context.Context, inputChannel chan sql.Row, badRowCb func(row sql.Row, rowSchema sql.PrimaryKeySchema, tableName string, lineNumber int, err error) bool) (err error) { 130 err = s.forceDropTableIfNeeded() 131 if err != nil { 132 return err 133 } 134 135 _, _, err = s.se.Query(s.sqlCtx, "START TRANSACTION") 136 if err != nil { 137 return err 138 } 139 140 if s.disableFks { 141 _, _, err = s.se.Query(s.sqlCtx, "SET FOREIGN_KEY_CHECKS = 0") 142 if err != nil { 143 return err 144 } 145 } 146 147 err = s.createOrEmptyTableIfNeeded() 148 if err != nil { 149 return err 150 } 151 152 updateStats := func(row sql.Row) { 153 if row == nil { 154 return 155 } 156 157 // If the length of the row does not match the schema then we have an update operation. 158 if len(row) != len(s.tableSchema.Schema) { 159 oldRow := row[:len(row)/2] 160 newRow := row[len(row)/2:] 161 162 if ok, err := oldRow.Equals(newRow, s.tableSchema.Schema); err == nil { 163 if ok { 164 s.stats.SameVal++ 165 } else { 166 s.stats.Modifications++ 167 } 168 } 169 } else { 170 s.stats.Additions++ 171 } 172 } 173 174 insertOrUpdateOperation, err := s.getInsertNode(inputChannel, false) 175 if err != nil { 176 return err 177 } 178 179 iter, err := rowexec.DefaultBuilder.Build(s.sqlCtx, insertOrUpdateOperation, nil) 180 if err != nil { 181 return err 182 } 183 184 defer func() { 185 rerr := iter.Close(s.sqlCtx) 186 if err == nil { 187 err = rerr 188 } 189 }() 190 191 line := 1 192 193 for { 194 if s.statsCB != nil && atomic.LoadInt32(&s.statOps) >= tableWriterStatUpdateRate { 195 atomic.StoreInt32(&s.statOps, 0) 196 s.statsCB(s.stats) 197 } 198 199 row, err := iter.Next(s.sqlCtx) 200 line += 1 201 202 // All other errors are handled by the errorHandler 203 if err == nil { 204 _ = atomic.AddInt32(&s.statOps, 1) 205 updateStats(row) 206 } else if err == io.EOF { 207 atomic.LoadInt32(&s.statOps) 208 atomic.StoreInt32(&s.statOps, 0) 209 if s.statsCB != nil { 210 s.statsCB(s.stats) 211 } 212 213 return err 214 } else { 215 var offendingRow sql.Row 216 switch n := err.(type) { 217 case sql.WrappedInsertError: 218 offendingRow = n.OffendingRow 219 case sql.IgnorableError: 220 offendingRow = n.OffendingRow 221 } 222 223 quit := badRowCb(offendingRow, s.tableSchema, s.tableName, line, err) 224 if quit { 225 return err 226 } 227 } 228 } 229 } 230 231 func (s *SqlEngineTableWriter) Commit(ctx context.Context) error { 232 _, _, err := s.se.Query(s.sqlCtx, "COMMIT") 233 return err 234 } 235 236 func (s *SqlEngineTableWriter) RowOperationSchema() sql.PrimaryKeySchema { 237 return s.rowOperationSchema 238 } 239 240 func (s *SqlEngineTableWriter) TableSchema() sql.PrimaryKeySchema { 241 return s.tableSchema 242 } 243 244 // forceDropTableIfNeeded drop the given table in case the -f parameter is passed. 245 func (s *SqlEngineTableWriter) forceDropTableIfNeeded() error { 246 if s.force { 247 _, _, err := s.se.Query(s.sqlCtx, fmt.Sprintf("DROP TABLE IF EXISTS `%s`", s.tableName)) 248 return err 249 } 250 251 return nil 252 } 253 254 // createOrEmptyTableIfNeeded either creates or truncates the table given a -c or -r parameter. 255 func (s *SqlEngineTableWriter) createOrEmptyTableIfNeeded() error { 256 switch s.importOption { 257 case CreateOp: 258 return s.createTable() 259 case ReplaceOp: 260 _, _, err := s.se.Query(s.sqlCtx, fmt.Sprintf("TRUNCATE TABLE `%s`", s.tableName)) 261 return err 262 default: 263 return nil 264 } 265 } 266 267 // createTable creates a table. 268 func (s *SqlEngineTableWriter) createTable() error { 269 // TODO don't use internal interfaces to do this, we had to have a sql.Schema somewhere 270 // upstream to make the dolt schema 271 sqlCols := make([]string, len(s.tableSchema.Schema)) 272 for i, c := range s.tableSchema.Schema { 273 sqlCols[i] = sql.GenerateCreateTableColumnDefinition(c, c.Default.String(), c.OnUpdate.String(), sql.Collation_Default) 274 } 275 var pks string 276 var sep string 277 for _, i := range s.tableSchema.PkOrdinals { 278 pks += sep + sql.QuoteIdentifier(s.tableSchema.Schema[i].Name) 279 sep = ", " 280 } 281 if len(sep) > 0 { 282 sqlCols = append(sqlCols, fmt.Sprintf("PRIMARY KEY (%s)", pks)) 283 } 284 285 createTable := sql.GenerateCreateTableStatement(s.tableName, sqlCols, "", sql.CharacterSet_utf8mb4.String(), sql.Collation_Default.String(), "") 286 _, iter, err := s.se.Query(s.sqlCtx, createTable) 287 if err != nil { 288 return err 289 } 290 _, err = sql.RowIterToRows(s.sqlCtx, iter) 291 return err 292 } 293 294 // createInsertImportNode creates the relevant/analyzed insert node given the import option. This insert node is wrapped 295 // with an error handler. 296 func (s *SqlEngineTableWriter) getInsertNode(inputChannel chan sql.Row, replace bool) (sql.Node, error) { 297 update := s.importOption == UpdateOp 298 colNames := "" 299 values := "" 300 duplicate := "" 301 if update { 302 duplicate += " ON DUPLICATE KEY UPDATE " 303 } 304 sep := "" 305 for _, col := range s.rowOperationSchema.Schema { 306 colNames += fmt.Sprintf("%s%s", sep, sql.QuoteIdentifier(col.Name)) 307 values += fmt.Sprintf("%s1", sep) 308 if update { 309 duplicate += fmt.Sprintf("%s`%s` = VALUES(`%s`)", sep, col.Name, col.Name) 310 } 311 sep = ", " 312 } 313 314 sqlEngine := s.se.GetUnderlyingEngine() 315 binder := planbuilder.New(s.sqlCtx, sqlEngine.Analyzer.Catalog, sqlEngine.Parser) 316 insert := fmt.Sprintf("insert into `%s` (%s) VALUES (%s)%s", s.tableName, colNames, values, duplicate) 317 parsed, _, _, err := binder.Parse(insert, false) 318 if err != nil { 319 return nil, fmt.Errorf("error constructing import query '%s': %w", insert, err) 320 } 321 parsedIns, ok := parsed.(*plan.InsertInto) 322 if !ok { 323 return nil, fmt.Errorf("import setup expected *plan.InsertInto root, found %T", parsed) 324 } 325 schema := make(sql.Schema, len(s.rowOperationSchema.Schema)) 326 for i, c := range s.rowOperationSchema.Schema { 327 newC := c.Copy() 328 newC.Source = planbuilder.OnDupValuesPrefix 329 schema[i] = newC 330 } 331 332 switch n := parsedIns.Source.(type) { 333 case *plan.Values: 334 parsedIns.Source = NewChannelRowSource(schema, inputChannel) 335 case *plan.Project: 336 n.Child = NewChannelRowSource(schema, inputChannel) 337 } 338 339 parsedIns.Ignore = s.contOnErr 340 parsedIns.IsReplace = replace 341 analyzed, err := s.se.Analyze(s.sqlCtx, parsedIns) 342 if err != nil { 343 return nil, err 344 } 345 346 analyzed = analyzer.StripPassthroughNodes(analyzed) 347 348 // Get the first insert (wrapped with the error handler) 349 transform.Inspect(analyzed, func(node sql.Node) bool { 350 switch n := node.(type) { 351 case *plan.InsertInto: 352 analyzed = n 353 return false 354 default: 355 return true 356 } 357 }) 358 359 return analyzed, nil 360 }