github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/ccl/importccl/pg_testdata_helpers_test.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Licensed as a CockroachDB Enterprise file under the Cockroach Community
     4  // License (the "License"); you may not use this file except in compliance with
     5  // the License. You may obtain a copy of the License at
     6  //
     7  //     https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
     8  
     9  package importccl
    10  
    11  import (
    12  	gosql "database/sql"
    13  	"fmt"
    14  	"io/ioutil"
    15  	"math/rand"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"strconv"
    20  	"strings"
    21  	"testing"
    22  	"unicode/utf8"
    23  
    24  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    25  	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
    26  	"github.com/cockroachdb/cockroach/pkg/util/envutil"
    27  	_ "github.com/lib/pq"
    28  )
    29  
    30  var rewritePostgresTestData = envutil.EnvOrDefaultBool("COCKROACH_REWRITE_POSTGRES_TESTDATA", false)
    31  
    32  var simplePostgresTestRows = func() []simpleTestRow {
    33  	badChars := []rune{'a', ';', '\n', ',', '"', '\\', '\r', '<', '\t', '✅', 'π', rune(10), rune(2425), rune(5183), utf8.RuneError}
    34  	r := rand.New(rand.NewSource(1))
    35  	testRows := []simpleTestRow{
    36  		{i: 0, s: `str`},
    37  		{i: 1, s: ``},
    38  		{i: 2, s: ` `},
    39  		{i: 3, s: `,`},
    40  		{i: 4, s: "\n"},
    41  		{i: 5, s: `\n`},
    42  		{i: 6, s: "\r\n"},
    43  		{i: 7, s: "\r"},
    44  		{i: 9, s: `"`},
    45  
    46  		{i: 10, s: injectNull},
    47  		{i: 11, s: `\N`},
    48  		{i: 12, s: `NULL`},
    49  
    50  		// Unicode
    51  		{i: 13, s: `¢`},
    52  		{i: 14, s: ` ¢ `},
    53  		{i: 15, s: `✅`},
    54  		{i: 16, s: `","\n,™¢`},
    55  		{i: 19, s: `✅¢©ƒƒƒƒåß∂√œ∫∑∆πœ∑˚¬≤µµç∫ø∆œ∑∆¬œ∫œ∑´´†¥¨ˆˆπ‘“æ…¬…¬˚ß∆å˚˙ƒ∆©˙©∂˙≥≤Ω˜˜µ√∫∫Ω¥∑`},
    56  		{i: 20, s: `a quote " or two quotes "" and a quote-comma ", , and then a quote and newline "` + "\n"},
    57  		{i: 21, s: `"a slash \, a double slash \\, a slash+quote \",  \` + "\n"},
    58  	}
    59  
    60  	for i := 0; i < 10; i++ {
    61  		buf := make([]byte, 200)
    62  		r.Seed(int64(i))
    63  		r.Read(buf)
    64  		testRows = append(testRows, simpleTestRow{i: i + 100, s: randStr(r, badChars, 1000), b: buf})
    65  	}
    66  	return testRows
    67  }()
    68  
    69  func getSimplePostgresDumpTestdata(t *testing.T) ([]simpleTestRow, string) {
    70  	dest := filepath.Join(`testdata`, `pgdump`, `simple.sql`)
    71  	if rewritePostgresTestData {
    72  		genSimplePostgresTestdata(t, func() { pgdump(t, dest, "simple") })
    73  	}
    74  	return simplePostgresTestRows, dest
    75  }
    76  
    77  func getSecondPostgresDumpTestdata(t *testing.T) (int, string) {
    78  	dest := filepath.Join(`testdata`, `pgdump`, `second.sql`)
    79  	if rewritePostgresTestData {
    80  		genSecondPostgresTestdata(t, func() { pgdump(t, dest, "second") })
    81  	}
    82  	return secondTableRows, dest
    83  }
    84  
    85  func getMultiTablePostgresDumpTestdata(t *testing.T) string {
    86  	dest := filepath.Join(`testdata`, `pgdump`, `db.sql`)
    87  	if rewritePostgresTestData {
    88  		genSequencePostgresTestdata(t, func() {
    89  			genSecondPostgresTestdata(t, func() {
    90  				genSimplePostgresTestdata(t, func() { pgdump(t, dest) })
    91  			})
    92  		})
    93  	}
    94  	return dest
    95  }
    96  
    97  type pgCopyDumpCfg struct {
    98  	name     string
    99  	filename string
   100  	opts     roachpb.PgCopyOptions
   101  }
   102  
   103  func getPgCopyTestdata(t *testing.T) ([]simpleTestRow, []pgCopyDumpCfg) {
   104  	configs := []pgCopyDumpCfg{
   105  		{
   106  			name: "default",
   107  			opts: roachpb.PgCopyOptions{
   108  				Delimiter: '\t',
   109  				Null:      `\N`,
   110  			},
   111  		},
   112  		{
   113  			name: "comma-null-header",
   114  			opts: roachpb.PgCopyOptions{
   115  				Delimiter: ',',
   116  				Null:      "null",
   117  			},
   118  		},
   119  	}
   120  
   121  	for i := range configs {
   122  		configs[i].filename = filepath.Join(`testdata`, `pgcopy`, configs[i].name, `test.txt`)
   123  	}
   124  
   125  	if rewritePostgresTestData {
   126  		genSimplePostgresTestdata(t, func() {
   127  			if err := os.RemoveAll(filepath.Join(`testdata`, `pgcopy`)); err != nil {
   128  				t.Fatal(err)
   129  			}
   130  			for _, cfg := range configs {
   131  				dest := filepath.Dir(cfg.filename)
   132  				if err := os.MkdirAll(dest, 0777); err != nil {
   133  					t.Fatal(err)
   134  				}
   135  
   136  				var sb strings.Builder
   137  				sb.WriteString(`COPY simple TO STDOUT WITH (FORMAT 'text'`)
   138  				if cfg.opts.Delimiter != copyDefaultDelimiter {
   139  					fmt.Fprintf(&sb, `, DELIMITER %q`, cfg.opts.Delimiter)
   140  				}
   141  				if cfg.opts.Null != copyDefaultNull {
   142  					fmt.Fprintf(&sb, `, NULL "%s"`, cfg.opts.Null)
   143  				}
   144  				sb.WriteString(`)`)
   145  				flags := []string{`-U`, `postgres`, `-h`, `127.0.0.1`, `test`, `-c`, sb.String()}
   146  				if res, err := exec.Command(
   147  					`psql`, flags...,
   148  				).CombinedOutput(); err != nil {
   149  					t.Fatal(err, string(res))
   150  				} else if err := ioutil.WriteFile(cfg.filename, res, 0666); err != nil {
   151  					t.Fatal(err)
   152  				}
   153  			}
   154  		})
   155  	}
   156  
   157  	return simplePostgresTestRows, configs
   158  }
   159  
   160  func genSimplePostgresTestdata(t *testing.T, dump func()) {
   161  	defer genPostgresTestdata(t,
   162  		"simple",
   163  		`i INT PRIMARY KEY, s text, b bytea`,
   164  		func(db *gosql.DB) {
   165  			// Postgres doesn't support creating non-unique indexes in CREATE TABLE;
   166  			// do it afterward.
   167  			if _, err := db.Exec(`
   168  				CREATE UNIQUE INDEX ON simple (b, s);
   169  				CREATE INDEX ON simple (s);
   170  			`); err != nil {
   171  				t.Fatal(err)
   172  			}
   173  			for _, tc := range simplePostgresTestRows {
   174  				s := &tc.s
   175  				if *s == injectNull {
   176  					s = nil
   177  				}
   178  				if _, err := db.Exec(
   179  					`INSERT INTO simple VALUES ($1, $2, NULLIF($3, ''::bytea))`, tc.i, s, tc.b,
   180  				); err != nil {
   181  					t.Fatal(err)
   182  				}
   183  			}
   184  		},
   185  	)()
   186  	dump()
   187  }
   188  
   189  func genSecondPostgresTestdata(t *testing.T, dump func()) {
   190  	defer genPostgresTestdata(t,
   191  		"second",
   192  		`i INT PRIMARY KEY, s TEXT`,
   193  		func(db *gosql.DB) {
   194  			for i := 0; i < secondTableRows; i++ {
   195  				if _, err := db.Exec(`INSERT INTO second VALUES ($1, $2)`, i, strconv.Itoa(i)); err != nil {
   196  					t.Fatal(err)
   197  				}
   198  			}
   199  		},
   200  	)()
   201  	dump()
   202  }
   203  
   204  func genSequencePostgresTestdata(t *testing.T, dump func()) {
   205  	defer genPostgresTestdata(t,
   206  		"seqtable",
   207  		`a INT, b INT`,
   208  		func(sqlDB *gosql.DB) {
   209  			db := sqlutils.MakeSQLRunner(sqlDB)
   210  			db.Exec(t, `DROP SEQUENCE IF EXISTS a_seq`)
   211  			db.Exec(t, `CREATE SEQUENCE a_seq`)
   212  			db.Exec(t, `ALTER TABLE seqtable ALTER COLUMN a SET DEFAULT nextval('a_seq'::REGCLASS)`)
   213  			for i := 0; i < secondTableRows; i++ {
   214  				db.Exec(t, `INSERT INTO seqtable (b) VALUES ($1 * 10)`, i)
   215  			}
   216  		},
   217  	)()
   218  	dump()
   219  }
   220  
   221  // genPostgresTestdata connects to the a local postgres, creates the passed
   222  // table and calls the passed `load` func to populate it and returns a
   223  // cleanup func.
   224  func genPostgresTestdata(t *testing.T, name, schema string, load func(*gosql.DB)) func() {
   225  	db, err := gosql.Open("postgres", "postgres://postgres@localhost/test?sslmode=disable")
   226  	if err != nil {
   227  		t.Fatal(err)
   228  	}
   229  	defer db.Close()
   230  
   231  	if _, err := db.Exec(
   232  		fmt.Sprintf(`DROP TABLE IF EXISTS %s`, name),
   233  	); err != nil {
   234  		t.Fatal(err)
   235  	}
   236  	if _, err := db.Exec(
   237  		fmt.Sprintf(`CREATE TABLE %s (%s)`, name, schema),
   238  	); err != nil {
   239  		t.Fatal(err)
   240  	}
   241  	load(db)
   242  	return func() {
   243  		db, err := gosql.Open("postgres", "postgres://postgres@localhost/test?sslmode=disable")
   244  		if err != nil {
   245  			t.Fatal(err)
   246  		}
   247  		defer db.Close()
   248  		if _, err := db.Exec(
   249  			fmt.Sprintf(`DROP TABLE IF EXISTS %s`, name),
   250  		); err != nil {
   251  			t.Fatal(err)
   252  		}
   253  	}
   254  }
   255  
   256  func pgdump(t *testing.T, dest string, tables ...string) {
   257  	if err := os.MkdirAll(filepath.Dir(dest), 0777); err != nil {
   258  		t.Fatal(err)
   259  	}
   260  
   261  	args := []string{`-U`, `postgres`, `-h`, `127.0.0.1`, `-d`, `test`}
   262  	for _, table := range tables {
   263  		args = append(args, `-t`, table)
   264  	}
   265  	out, err := exec.Command(`pg_dump`, args...).CombinedOutput()
   266  	if err != nil {
   267  		t.Fatalf("%s: %s", err, out)
   268  	}
   269  	if err := ioutil.WriteFile(dest, out, 0666); err != nil {
   270  		t.Fatal(err)
   271  	}
   272  }