github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/ccl/importccl/read_import_pgdump.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Licensed as a CockroachDB Enterprise file under the Cockroach Community 4 // License (the "License"); you may not use this file except in compliance with 5 // the License. You may obtain a copy of the License at 6 // 7 // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt 8 9 package importccl 10 11 import ( 12 "bufio" 13 "context" 14 "io" 15 "regexp" 16 "strings" 17 18 "github.com/cockroachdb/cockroach/pkg/keys" 19 "github.com/cockroachdb/cockroach/pkg/roachpb" 20 "github.com/cockroachdb/cockroach/pkg/sql" 21 "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" 22 "github.com/cockroachdb/cockroach/pkg/sql/parser" 23 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 24 "github.com/cockroachdb/cockroach/pkg/sql/row" 25 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 26 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 27 "github.com/cockroachdb/cockroach/pkg/storage/cloud" 28 "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" 29 "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" 30 "github.com/cockroachdb/cockroach/pkg/util/hlc" 31 "github.com/cockroachdb/cockroach/pkg/util/log" 32 "github.com/cockroachdb/errors" 33 "github.com/lib/pq/oid" 34 ) 35 36 type postgreStream struct { 37 s *bufio.Scanner 38 copy *postgreStreamCopy 39 } 40 41 // newPostgreStream returns a struct that can stream statements from an 42 // io.Reader. 43 func newPostgreStream(r io.Reader, max int) *postgreStream { 44 s := bufio.NewScanner(r) 45 s.Buffer(nil, max) 46 p := &postgreStream{s: s} 47 s.Split(p.split) 48 return p 49 } 50 51 func (p *postgreStream) split(data []byte, atEOF bool) (advance int, token []byte, err error) { 52 if p.copy == nil { 53 return splitSQLSemicolon(data, atEOF) 54 } 55 return bufio.ScanLines(data, atEOF) 56 } 57 58 // splitSQLSemicolon is a bufio.SplitFunc that splits on SQL semicolon tokens. 59 func splitSQLSemicolon(data []byte, atEOF bool) (advance int, token []byte, err error) { 60 if atEOF && len(data) == 0 { 61 return 0, nil, nil 62 } 63 64 if pos, ok := parser.SplitFirstStatement(string(data)); ok { 65 return pos, data[:pos], nil 66 } 67 // If we're at EOF, we have a final, non-terminated line. Return it. 68 if atEOF { 69 return len(data), data, nil 70 } 71 // Request more data. 72 return 0, nil, nil 73 } 74 75 // Next returns the next statement. The type of statement can be one of 76 // tree.Statement, copyData, or errCopyDone. A nil statement and io.EOF are 77 // returned when there are no more statements. 78 func (p *postgreStream) Next() (interface{}, error) { 79 if p.copy != nil { 80 row, err := p.copy.Next() 81 if errors.Is(err, errCopyDone) { 82 p.copy = nil 83 return errCopyDone, nil 84 } 85 return row, err 86 } 87 88 for p.s.Scan() { 89 t := p.s.Text() 90 stmts, err := parser.Parse(t) 91 if err != nil { 92 // Something non-parseable may be something we don't yet parse but still 93 // want to ignore. 94 if isIgnoredStatement(t) { 95 continue 96 } 97 return nil, err 98 } 99 switch len(stmts) { 100 case 0: 101 // Got whitespace or comments; try again. 102 case 1: 103 // If the statement is COPY ... FROM STDIN, set p.copy so the next call to 104 // this function will read copy data. We still return this COPY statement 105 // for this invocation. 106 if cf, ok := stmts[0].AST.(*tree.CopyFrom); ok && cf.Stdin { 107 // Set p.copy which reconfigures the scanner's split func. 108 p.copy = newPostgreStreamCopy(p.s, copyDefaultDelimiter, copyDefaultNull) 109 110 // We expect a single newline character following the COPY statement before 111 // the copy data starts. 112 if !p.s.Scan() { 113 return nil, errors.Errorf("expected empty line") 114 } 115 if err := p.s.Err(); err != nil { 116 return nil, err 117 } 118 if len(p.s.Bytes()) != 0 { 119 return nil, errors.Errorf("expected empty line") 120 } 121 } 122 return stmts[0].AST, nil 123 default: 124 return nil, errors.Errorf("unexpected: got %d statements", len(stmts)) 125 } 126 } 127 if err := p.s.Err(); err != nil { 128 if errors.Is(err, bufio.ErrTooLong) { 129 err = errors.HandledWithMessage(err, "line too long") 130 } 131 return nil, err 132 } 133 return nil, io.EOF 134 } 135 136 var ( 137 ignoreComments = regexp.MustCompile(`^\s*(--.*)`) 138 ignoreStatements = []*regexp.Regexp{ 139 regexp.MustCompile("(?i)^alter function"), 140 regexp.MustCompile("(?i)^alter sequence .* owned by"), 141 regexp.MustCompile("(?i)^alter table .* owner to"), 142 regexp.MustCompile("(?i)^comment on"), 143 regexp.MustCompile("(?i)^create extension"), 144 regexp.MustCompile("(?i)^create function"), 145 regexp.MustCompile("(?i)^create trigger"), 146 regexp.MustCompile("(?i)^grant .* on sequence"), 147 regexp.MustCompile("(?i)^revoke .* on sequence"), 148 } 149 ) 150 151 func isIgnoredStatement(s string) bool { 152 // Look for the first line with no whitespace or comments. 153 for { 154 m := ignoreComments.FindStringIndex(s) 155 if m == nil { 156 break 157 } 158 s = s[m[1]:] 159 } 160 s = strings.TrimSpace(s) 161 for _, re := range ignoreStatements { 162 if re.MatchString(s) { 163 return true 164 } 165 } 166 return false 167 } 168 169 type regclassRewriter struct{} 170 171 var _ tree.Visitor = regclassRewriter{} 172 173 func (regclassRewriter) VisitPre(expr tree.Expr) (recurse bool, newExpr tree.Expr) { 174 switch t := expr.(type) { 175 case *tree.FuncExpr: 176 switch t.Func.String() { 177 case "nextval": 178 if len(t.Exprs) > 0 { 179 switch e := t.Exprs[0].(type) { 180 case *tree.CastExpr: 181 if typ, ok := tree.GetStaticallyKnownType(e.Type); ok && typ.Oid() == oid.T_regclass { 182 // tree.Visitor says we should make a copy, but since copyNode is unexported 183 // and there's no planner here, I think it's safe to directly modify the 184 // statement here. 185 t.Exprs[0] = e.Expr 186 } 187 } 188 } 189 } 190 } 191 return true, expr 192 } 193 194 func (regclassRewriter) VisitPost(expr tree.Expr) tree.Expr { return expr } 195 196 // removeDefaultRegclass removes `::regclass` casts from sequence operations 197 // (i.e., nextval) in DEFAULT column expressions. 198 func removeDefaultRegclass(create *tree.CreateTable) { 199 for _, def := range create.Defs { 200 switch def := def.(type) { 201 case *tree.ColumnTableDef: 202 if def.DefaultExpr.Expr != nil { 203 def.DefaultExpr.Expr, _ = tree.WalkExpr(regclassRewriter{}, def.DefaultExpr.Expr) 204 } 205 } 206 } 207 } 208 209 // readPostgresCreateTable returns table descriptors for all tables or the 210 // matching table from SQL statements. 211 func readPostgresCreateTable( 212 ctx context.Context, 213 input io.Reader, 214 evalCtx *tree.EvalContext, 215 p sql.PlanHookState, 216 match string, 217 parentID sqlbase.ID, 218 walltime int64, 219 fks fkHandler, 220 max int, 221 ) ([]*sqlbase.TableDescriptor, error) { 222 // Modify the CreateTable stmt with the various index additions. We do this 223 // instead of creating a full table descriptor first and adding indexes 224 // later because MakeSimpleTableDescriptor calls the sql package which calls 225 // AllocateIDs which adds the hidden rowid and default primary key. This means 226 // we'd have to delete the index and row and modify the column family. This 227 // is much easier and probably safer too. 228 createTbl := make(map[string]*tree.CreateTable) 229 createSeq := make(map[string]*tree.CreateSequence) 230 tableFKs := make(map[string][]*tree.ForeignKeyConstraintTableDef) 231 ps := newPostgreStream(input, max) 232 params := p.RunParams(ctx) 233 for { 234 stmt, err := ps.Next() 235 if err == io.EOF { 236 ret := make([]*sqlbase.TableDescriptor, 0, len(createTbl)) 237 for name, seq := range createSeq { 238 id := sqlbase.ID(int(defaultCSVTableID) + len(ret)) 239 desc, err := sql.MakeSequenceTableDesc( 240 name, 241 seq.Options, 242 parentID, 243 keys.PublicSchemaID, 244 id, 245 hlc.Timestamp{WallTime: walltime}, 246 sqlbase.NewDefaultPrivilegeDescriptor(), 247 false, /* temporary */ 248 ¶ms, 249 ) 250 if err != nil { 251 return nil, err 252 } 253 fks.resolver[desc.Name] = &desc 254 ret = append(ret, desc.TableDesc()) 255 } 256 backrefs := make(map[sqlbase.ID]*sqlbase.MutableTableDescriptor) 257 for _, create := range createTbl { 258 if create == nil { 259 continue 260 } 261 removeDefaultRegclass(create) 262 id := sqlbase.ID(int(defaultCSVTableID) + len(ret)) 263 desc, err := MakeSimpleTableDescriptor(evalCtx.Ctx(), p.ExecCfg().Settings, create, parentID, id, fks, walltime) 264 if err != nil { 265 return nil, err 266 } 267 fks.resolver[desc.Name] = desc 268 backrefs[desc.ID] = desc 269 ret = append(ret, desc.TableDesc()) 270 } 271 for name, constraints := range tableFKs { 272 desc := fks.resolver[name] 273 if desc == nil { 274 continue 275 } 276 for _, constraint := range constraints { 277 if err := sql.ResolveFK( 278 evalCtx.Ctx(), nil /* txn */, fks.resolver, desc, constraint, backrefs, sql.NewTable, tree.ValidationDefault, evalCtx, 279 ); err != nil { 280 return nil, err 281 } 282 } 283 if err := fixDescriptorFKState(desc.TableDesc()); err != nil { 284 return nil, err 285 } 286 } 287 if match != "" && len(ret) != 1 { 288 found := make([]string, 0, len(createTbl)) 289 for name := range createTbl { 290 found = append(found, name) 291 } 292 return nil, errors.Errorf("table %q not found in file (found tables: %s)", match, strings.Join(found, ", ")) 293 } 294 if len(ret) == 0 { 295 return nil, errors.Errorf("no table definition found") 296 } 297 return ret, nil 298 } 299 if err != nil { 300 return nil, errors.Wrap(err, "postgres parse error") 301 } 302 switch stmt := stmt.(type) { 303 case *tree.CreateTable: 304 name, err := getTableName(&stmt.Table) 305 if err != nil { 306 return nil, err 307 } 308 if match != "" && match != name { 309 createTbl[name] = nil 310 } else { 311 createTbl[name] = stmt 312 } 313 case *tree.CreateIndex: 314 name, err := getTableName(&stmt.Table) 315 if err != nil { 316 return nil, err 317 } 318 create := createTbl[name] 319 if create == nil { 320 break 321 } 322 var idx tree.TableDef = &tree.IndexTableDef{ 323 Name: stmt.Name, 324 Columns: stmt.Columns, 325 Storing: stmt.Storing, 326 Inverted: stmt.Inverted, 327 Interleave: stmt.Interleave, 328 PartitionBy: stmt.PartitionBy, 329 } 330 if stmt.Unique { 331 idx = &tree.UniqueConstraintTableDef{IndexTableDef: *idx.(*tree.IndexTableDef)} 332 } 333 create.Defs = append(create.Defs, idx) 334 case *tree.AlterTable: 335 name, err := getTableName2(stmt.Table) 336 if err != nil { 337 return nil, err 338 } 339 create := createTbl[name] 340 if create == nil { 341 break 342 } 343 for _, cmd := range stmt.Cmds { 344 switch cmd := cmd.(type) { 345 case *tree.AlterTableAddConstraint: 346 switch con := cmd.ConstraintDef.(type) { 347 case *tree.ForeignKeyConstraintTableDef: 348 if !fks.skip { 349 tableFKs[name] = append(tableFKs[name], con) 350 } 351 default: 352 create.Defs = append(create.Defs, cmd.ConstraintDef) 353 } 354 case *tree.AlterTableSetDefault: 355 for i, def := range create.Defs { 356 def, ok := def.(*tree.ColumnTableDef) 357 if !ok || def.Name != cmd.Column { 358 continue 359 } 360 def.DefaultExpr.Expr = cmd.Default 361 create.Defs[i] = def 362 } 363 case *tree.AlterTableValidateConstraint: 364 // ignore 365 default: 366 return nil, errors.Errorf("unsupported statement: %s", stmt) 367 } 368 } 369 case *tree.CreateSequence: 370 name, err := getTableName(&stmt.Name) 371 if err != nil { 372 return nil, err 373 } 374 if match == "" || match == name { 375 createSeq[name] = stmt 376 } 377 } 378 } 379 } 380 381 func getTableName(tn *tree.TableName) (string, error) { 382 if sc := tn.Schema(); sc != "" && sc != "public" { 383 return "", unimplemented.NewWithIssueDetailf( 384 26443, 385 "import non-public schema", 386 "non-public schemas unsupported: %s", sc, 387 ) 388 } 389 return tn.Table(), nil 390 } 391 392 // getTableName variant for UnresolvedObjectName. 393 func getTableName2(u *tree.UnresolvedObjectName) (string, error) { 394 if u.NumParts >= 2 && u.Parts[1] != "public" { 395 return "", unimplemented.NewWithIssueDetailf( 396 26443, 397 "import non-public schema", 398 "non-public schemas unsupported: %s", u.Parts[1], 399 ) 400 } 401 return u.Parts[0], nil 402 } 403 404 type pgDumpReader struct { 405 tables map[string]*row.DatumRowConverter 406 descs map[string]*execinfrapb.ReadImportDataSpec_ImportTable 407 kvCh chan row.KVBatch 408 opts roachpb.PgDumpOptions 409 } 410 411 var _ inputConverter = &pgDumpReader{} 412 413 // newPgDumpReader creates a new inputConverter for pg_dump files. 414 func newPgDumpReader( 415 ctx context.Context, 416 kvCh chan row.KVBatch, 417 opts roachpb.PgDumpOptions, 418 descs map[string]*execinfrapb.ReadImportDataSpec_ImportTable, 419 evalCtx *tree.EvalContext, 420 ) (*pgDumpReader, error) { 421 converters := make(map[string]*row.DatumRowConverter, len(descs)) 422 for name, table := range descs { 423 if table.Desc.IsTable() { 424 conv, err := row.NewDatumRowConverter(ctx, table.Desc, nil /* targetColNames */, evalCtx, kvCh) 425 if err != nil { 426 return nil, err 427 } 428 converters[name] = conv 429 } 430 } 431 return &pgDumpReader{ 432 kvCh: kvCh, 433 tables: converters, 434 descs: descs, 435 opts: opts, 436 }, nil 437 } 438 439 func (m *pgDumpReader) start(ctx ctxgroup.Group) { 440 } 441 442 func (m *pgDumpReader) readFiles( 443 ctx context.Context, 444 dataFiles map[int32]string, 445 resumePos map[int32]int64, 446 format roachpb.IOFileFormat, 447 makeExternalStorage cloud.ExternalStorageFactory, 448 ) error { 449 return readInputFiles(ctx, dataFiles, resumePos, format, m.readFile, makeExternalStorage) 450 } 451 452 func (m *pgDumpReader) readFile( 453 ctx context.Context, input *fileReader, inputIdx int32, resumePos int64, rejected chan string, 454 ) error { 455 var inserts, count int64 456 ps := newPostgreStream(input, int(m.opts.MaxRowSize)) 457 semaCtx := tree.MakeSemaContext() 458 for _, conv := range m.tables { 459 conv.KvBatch.Source = inputIdx 460 conv.FractionFn = input.ReadFraction 461 conv.CompletedRowFn = func() int64 { 462 return count 463 } 464 } 465 466 for { 467 stmt, err := ps.Next() 468 if err == io.EOF { 469 break 470 } 471 if err != nil { 472 return errors.Wrap(err, "postgres parse error") 473 } 474 switch i := stmt.(type) { 475 case *tree.Insert: 476 n, ok := i.Table.(*tree.TableName) 477 if !ok { 478 return errors.Errorf("unexpected: %T", i.Table) 479 } 480 name, err := getTableName(n) 481 if err != nil { 482 return errors.Wrapf(err, "%s", i) 483 } 484 conv, ok := m.tables[name] 485 if !ok { 486 // not importing this table. 487 continue 488 } 489 if ok && conv == nil { 490 return errors.Errorf("missing schema info for requested table %q", name) 491 } 492 values, ok := i.Rows.Select.(*tree.ValuesClause) 493 if !ok { 494 return errors.Errorf("unsupported: %s", i.Rows.Select) 495 } 496 inserts++ 497 startingCount := count 498 for _, tuple := range values.Rows { 499 count++ 500 if count <= resumePos { 501 continue 502 } 503 if expected, got := len(conv.VisibleCols), len(tuple); expected != got { 504 return errors.Errorf("expected %d values, got %d: %v", expected, got, tuple) 505 } 506 for i, expr := range tuple { 507 typed, err := expr.TypeCheck(ctx, &semaCtx, conv.VisibleColTypes[i]) 508 if err != nil { 509 return errors.Wrapf(err, "reading row %d (%d in insert statement %d)", 510 count, count-startingCount, inserts) 511 } 512 converted, err := typed.Eval(conv.EvalCtx) 513 if err != nil { 514 return errors.Wrapf(err, "reading row %d (%d in insert statement %d)", 515 count, count-startingCount, inserts) 516 } 517 conv.Datums[i] = converted 518 } 519 if err := conv.Row(ctx, inputIdx, count); err != nil { 520 return err 521 } 522 } 523 case *tree.CopyFrom: 524 if !i.Stdin { 525 return errors.New("expected STDIN option on COPY FROM") 526 } 527 name, err := getTableName(&i.Table) 528 if err != nil { 529 return errors.Wrapf(err, "%s", i) 530 } 531 conv, importing := m.tables[name] 532 if importing && conv == nil { 533 return errors.Errorf("missing schema info for requested table %q", name) 534 } 535 if conv != nil { 536 if expected, got := len(conv.VisibleCols), len(i.Columns); expected != got { 537 return errors.Errorf("expected %d columns, got %d", expected, got) 538 } 539 for colI, col := range i.Columns { 540 if string(col) != conv.VisibleCols[colI].Name { 541 return errors.Errorf("COPY columns do not match table columns for table %s", name) 542 } 543 } 544 } 545 for { 546 row, err := ps.Next() 547 // We expect an explicit copyDone here. io.EOF is unexpected. 548 if err == io.EOF { 549 return makeRowErr("", count, pgcode.ProtocolViolation, 550 "unexpected EOF") 551 } 552 if row == errCopyDone { 553 break 554 } 555 count++ 556 if err != nil { 557 return wrapRowErr(err, "", count, pgcode.Uncategorized, "") 558 } 559 if !importing { 560 continue 561 } 562 if count <= resumePos { 563 continue 564 } 565 switch row := row.(type) { 566 case copyData: 567 if expected, got := len(conv.VisibleCols), len(row); expected != got { 568 return makeRowErr("", count, pgcode.Syntax, 569 "expected %d values, got %d", expected, got) 570 } 571 for i, s := range row { 572 if s == nil { 573 conv.Datums[i] = tree.DNull 574 } else { 575 conv.Datums[i], err = sqlbase.ParseDatumStringAs(conv.VisibleColTypes[i], *s, conv.EvalCtx) 576 if err != nil { 577 col := conv.VisibleCols[i] 578 return wrapRowErr(err, "", count, pgcode.Syntax, 579 "parse %q as %s", col.Name, col.Type.SQLString()) 580 } 581 } 582 } 583 if err := conv.Row(ctx, inputIdx, count); err != nil { 584 return err 585 } 586 default: 587 return makeRowErr("", count, pgcode.Uncategorized, 588 "unexpected: %v", row) 589 } 590 } 591 case *tree.Select: 592 // Look for something of the form "SELECT pg_catalog.setval(...)". Any error 593 // or unexpected value silently breaks out of this branch. We are silent 594 // instead of returning an error because we expect input to be well-formatted 595 // by pg_dump, and thus if it isn't, we don't try to figure out what to do. 596 sc, ok := i.Select.(*tree.SelectClause) 597 if !ok { 598 break 599 } 600 if len(sc.Exprs) != 1 { 601 break 602 } 603 fn, ok := sc.Exprs[0].Expr.(*tree.FuncExpr) 604 if !ok || len(fn.Exprs) < 2 { 605 break 606 } 607 if name := strings.ToLower(fn.Func.String()); name != "setval" && name != "pg_catalog.setval" { 608 break 609 } 610 seqname, ok := fn.Exprs[0].(*tree.StrVal) 611 if !ok { 612 break 613 } 614 seqval, ok := fn.Exprs[1].(*tree.NumVal) 615 if !ok { 616 break 617 } 618 val, err := seqval.AsInt64() 619 if err != nil { 620 break 621 } 622 isCalled := false 623 if len(fn.Exprs) > 2 { 624 called, ok := fn.Exprs[2].(*tree.DBool) 625 if !ok { 626 break 627 } 628 isCalled = bool(*called) 629 } 630 name, err := parser.ParseTableName(seqname.RawString()) 631 if err != nil { 632 break 633 } 634 seq := m.descs[name.Parts[0]] 635 if seq == nil { 636 break 637 } 638 key, val, err := sql.MakeSequenceKeyVal(keys.TODOSQLCodec, seq.Desc, val, isCalled) 639 if err != nil { 640 return wrapRowErr(err, "", count, pgcode.Uncategorized, "") 641 } 642 kv := roachpb.KeyValue{Key: key} 643 kv.Value.SetInt(val) 644 m.kvCh <- row.KVBatch{ 645 Source: inputIdx, KVs: []roachpb.KeyValue{kv}, Progress: input.ReadFraction(), 646 } 647 default: 648 if log.V(3) { 649 log.Infof(ctx, "ignoring %T stmt: %v", i, i) 650 } 651 continue 652 } 653 } 654 for _, conv := range m.tables { 655 if err := conv.SendBatch(ctx); err != nil { 656 return err 657 } 658 } 659 return nil 660 }