github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cr2pg/main.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // cr2pg is a program that reads CockroachDB-formatted SQL files on stdin,
    12  // modifies them to be Postgres compatible, and outputs them to stdout.
    13  package main
    14  
    15  import (
    16  	"bufio"
    17  	"context"
    18  	"io"
    19  	"log"
    20  	"os"
    21  
    22  	"github.com/cockroachdb/cockroach/pkg/cmd/cr2pg/sqlstream"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    24  	"github.com/cockroachdb/errors"
    25  	"golang.org/x/sync/errgroup"
    26  )
    27  
    28  func main() {
    29  	ctx := context.Background()
    30  	g, ctx := errgroup.WithContext(ctx)
    31  	done := ctx.Done()
    32  
    33  	readStmts := make(chan tree.Statement, 100)
    34  	writeStmts := make(chan tree.Statement, 100)
    35  	// Divide up work between parsing, filtering, and writing.
    36  	g.Go(func() error {
    37  		defer close(readStmts)
    38  		stream := sqlstream.NewStream(os.Stdin)
    39  		for {
    40  			stmt, err := stream.Next()
    41  			if err == io.EOF {
    42  				break
    43  			}
    44  			if err != nil {
    45  				return err
    46  			}
    47  			select {
    48  			case readStmts <- stmt:
    49  			case <-done:
    50  				return nil
    51  			}
    52  		}
    53  		return nil
    54  	})
    55  	g.Go(func() error {
    56  		defer close(writeStmts)
    57  		newstmts := make([]tree.Statement, 8)
    58  		for stmt := range readStmts {
    59  			newstmts = newstmts[:1]
    60  			newstmts[0] = stmt
    61  			switch stmt := stmt.(type) {
    62  			case *tree.CreateTable:
    63  				stmt.Interleave = nil
    64  				stmt.PartitionBy = nil
    65  				var newdefs tree.TableDefs
    66  				for _, def := range stmt.Defs {
    67  					switch def := def.(type) {
    68  					case *tree.FamilyTableDef:
    69  						// skip
    70  					case *tree.IndexTableDef:
    71  						// Postgres doesn't support
    72  						// indexes in CREATE TABLE,
    73  						// so split them out to their
    74  						// own statement.
    75  						newstmts = append(newstmts, &tree.CreateIndex{
    76  							Name:     def.Name,
    77  							Table:    stmt.Table,
    78  							Inverted: def.Inverted,
    79  							Columns:  def.Columns,
    80  							Storing:  def.Storing,
    81  						})
    82  					case *tree.UniqueConstraintTableDef:
    83  						if def.PrimaryKey {
    84  							// Postgres doesn't support descending PKs.
    85  							for i, col := range def.Columns {
    86  								if col.Direction != tree.Ascending {
    87  									return errors.New("PK directions not supported by postgres")
    88  								}
    89  								def.Columns[i].Direction = tree.DefaultDirection
    90  							}
    91  							// Unset Name here because
    92  							// constaint names cannot
    93  							// be shared among tables,
    94  							// so multiple PK constraints
    95  							// named "primary" is an error.
    96  							def.Name = ""
    97  							newdefs = append(newdefs, def)
    98  							break
    99  						}
   100  						newstmts = append(newstmts, &tree.CreateIndex{
   101  							Name:     def.Name,
   102  							Table:    stmt.Table,
   103  							Unique:   true,
   104  							Inverted: def.Inverted,
   105  							Columns:  def.Columns,
   106  							Storing:  def.Storing,
   107  						})
   108  					default:
   109  						newdefs = append(newdefs, def)
   110  					}
   111  				}
   112  				stmt.Defs = newdefs
   113  			}
   114  			for _, stmt := range newstmts {
   115  				select {
   116  				case writeStmts <- stmt:
   117  				case <-done:
   118  					return nil
   119  				}
   120  			}
   121  		}
   122  		return nil
   123  	})
   124  	g.Go(func() error {
   125  		w := bufio.NewWriterSize(os.Stdout, 1024*1024)
   126  		fmtctx := tree.NewFmtCtx(tree.FmtSimple)
   127  		for stmt := range writeStmts {
   128  			stmt.Format(fmtctx)
   129  			_, _ = w.WriteString(fmtctx.CloseAndGetString())
   130  			_, _ = w.WriteString(";\n\n")
   131  		}
   132  		w.Flush()
   133  		return nil
   134  	})
   135  	if err := g.Wait(); err != nil {
   136  		log.Fatal(err)
   137  	}
   138  }