github.com/olitvin/migrate/v4@v4.14.3-0.20210330111251-992b37ee04c8/database/postgres/postgres_test.go (about)

     1  package postgres
     2  
     3  // error codes https://github.com/lib/pq/blob/master/error.go
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	sqldriver "database/sql/driver"
     9  	"errors"
    10  	"fmt"
    11  	"github.com/olitvin/migrate/v4"
    12  	"io"
    13  	"log"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  
    19  	"github.com/dhui/dktest"
    20  
    21  	"github.com/olitvin/migrate/v4/database"
    22  	dt "github.com/olitvin/migrate/v4/database/testing"
    23  	"github.com/olitvin/migrate/v4/dktesting"
    24  	_ "github.com/olitvin/migrate/v4/source/file"
    25  )
    26  
    27  const (
    28  	pgPassword = "postgres"
    29  )
    30  
    31  var (
    32  	opts = dktest.Options{
    33  		Env:          map[string]string{"POSTGRES_PASSWORD": pgPassword},
    34  		PortRequired: true, ReadyFunc: isReady}
    35  	// Supported versions: https://www.postgresql.org/support/versioning/
    36  	specs = []dktesting.ContainerSpec{
    37  		{ImageName: "postgres:9.5", Options: opts},
    38  		{ImageName: "postgres:9.6", Options: opts},
    39  		{ImageName: "postgres:10", Options: opts},
    40  		{ImageName: "postgres:11", Options: opts},
    41  		{ImageName: "postgres:12", Options: opts},
    42  	}
    43  )
    44  
    45  func pgConnectionString(host, port string, options ...string) string {
    46  	options = append(options, "sslmode=disable")
    47  	return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
    48  }
    49  
    50  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    51  	ip, port, err := c.FirstPort()
    52  	if err != nil {
    53  		return false
    54  	}
    55  
    56  	db, err := sql.Open("postgres", pgConnectionString(ip, port))
    57  	if err != nil {
    58  		return false
    59  	}
    60  	defer func() {
    61  		if err := db.Close(); err != nil {
    62  			log.Println("close error:", err)
    63  		}
    64  	}()
    65  	if err = db.PingContext(ctx); err != nil {
    66  		switch err {
    67  		case sqldriver.ErrBadConn, io.EOF:
    68  			return false
    69  		default:
    70  			log.Println(err)
    71  		}
    72  		return false
    73  	}
    74  
    75  	return true
    76  }
    77  
    78  func mustRun(t *testing.T, d database.Driver, statements []string) {
    79  	for _, statement := range statements {
    80  		if err := d.Run(strings.NewReader(statement)); err != nil {
    81  			t.Fatal(err)
    82  		}
    83  	}
    84  }
    85  
    86  func Test(t *testing.T) {
    87  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    88  		ip, port, err := c.FirstPort()
    89  		if err != nil {
    90  			t.Fatal(err)
    91  		}
    92  
    93  		addr := pgConnectionString(ip, port)
    94  		p := &Postgres{}
    95  		d, err := p.Open(addr)
    96  		if err != nil {
    97  			t.Fatal(err)
    98  		}
    99  		defer func() {
   100  			if err := d.Close(); err != nil {
   101  				t.Error(err)
   102  			}
   103  		}()
   104  		dt.Test(t, d, []byte("SELECT 1"))
   105  	})
   106  }
   107  
   108  func TestMigrate(t *testing.T) {
   109  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   110  		ip, port, err := c.FirstPort()
   111  		if err != nil {
   112  			t.Fatal(err)
   113  		}
   114  
   115  		addr := pgConnectionString(ip, port)
   116  		p := &Postgres{}
   117  		d, err := p.Open(addr)
   118  		if err != nil {
   119  			t.Fatal(err)
   120  		}
   121  		defer func() {
   122  			if err := d.Close(); err != nil {
   123  				t.Error(err)
   124  			}
   125  		}()
   126  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "postgres", d)
   127  		if err != nil {
   128  			t.Fatal(err)
   129  		}
   130  		dt.TestMigrate(t, m)
   131  	})
   132  }
   133  
   134  func TestMultipleStatements(t *testing.T) {
   135  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   136  		ip, port, err := c.FirstPort()
   137  		if err != nil {
   138  			t.Fatal(err)
   139  		}
   140  
   141  		addr := pgConnectionString(ip, port)
   142  		p := &Postgres{}
   143  		d, err := p.Open(addr)
   144  		if err != nil {
   145  			t.Fatal(err)
   146  		}
   147  		defer func() {
   148  			if err := d.Close(); err != nil {
   149  				t.Error(err)
   150  			}
   151  		}()
   152  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
   153  			t.Fatalf("expected err to be nil, got %v", err)
   154  		}
   155  
   156  		// make sure second table exists
   157  		var exists bool
   158  		if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
   159  			t.Fatal(err)
   160  		}
   161  		if !exists {
   162  			t.Fatalf("expected table bar to exist")
   163  		}
   164  	})
   165  }
   166  
   167  func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
   168  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   169  		ip, port, err := c.FirstPort()
   170  		if err != nil {
   171  			t.Fatal(err)
   172  		}
   173  
   174  		addr := pgConnectionString(ip, port, "x-multi-statement=true")
   175  		p := &Postgres{}
   176  		d, err := p.Open(addr)
   177  		if err != nil {
   178  			t.Fatal(err)
   179  		}
   180  		defer func() {
   181  			if err := d.Close(); err != nil {
   182  				t.Error(err)
   183  			}
   184  		}()
   185  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
   186  			t.Fatalf("expected err to be nil, got %v", err)
   187  		}
   188  
   189  		// make sure created index exists
   190  		var exists bool
   191  		if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
   192  			t.Fatal(err)
   193  		}
   194  		if !exists {
   195  			t.Fatalf("expected table bar to exist")
   196  		}
   197  	})
   198  }
   199  
   200  func TestErrorParsing(t *testing.T) {
   201  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   202  		ip, port, err := c.FirstPort()
   203  		if err != nil {
   204  			t.Fatal(err)
   205  		}
   206  
   207  		addr := pgConnectionString(ip, port)
   208  		p := &Postgres{}
   209  		d, err := p.Open(addr)
   210  		if err != nil {
   211  			t.Fatal(err)
   212  		}
   213  		defer func() {
   214  			if err := d.Close(); err != nil {
   215  				t.Error(err)
   216  			}
   217  		}()
   218  
   219  		wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
   220  			`(foo text); CREATE TABLEE bar (bar text); (details: pq: syntax error at or near "TABLEE")`
   221  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
   222  			t.Fatal("expected err but got nil")
   223  		} else if err.Error() != wantErr {
   224  			t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
   225  		}
   226  	})
   227  }
   228  
   229  func TestFilterCustomQuery(t *testing.T) {
   230  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   231  		ip, port, err := c.FirstPort()
   232  		if err != nil {
   233  			t.Fatal(err)
   234  		}
   235  
   236  		addr := fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-custom=foobar",
   237  			pgPassword, ip, port)
   238  		p := &Postgres{}
   239  		d, err := p.Open(addr)
   240  		if err != nil {
   241  			t.Fatal(err)
   242  		}
   243  		defer func() {
   244  			if err := d.Close(); err != nil {
   245  				t.Error(err)
   246  			}
   247  		}()
   248  	})
   249  }
   250  
   251  func TestWithSchema(t *testing.T) {
   252  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   253  		ip, port, err := c.FirstPort()
   254  		if err != nil {
   255  			t.Fatal(err)
   256  		}
   257  
   258  		addr := pgConnectionString(ip, port)
   259  		p := &Postgres{}
   260  		d, err := p.Open(addr)
   261  		if err != nil {
   262  			t.Fatal(err)
   263  		}
   264  		defer func() {
   265  			if err := d.Close(); err != nil {
   266  				t.Fatal(err)
   267  			}
   268  		}()
   269  
   270  		// create foobar schema
   271  		if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
   272  			t.Fatal(err)
   273  		}
   274  		if err := d.SetVersion(1, false); err != nil {
   275  			t.Fatal(err)
   276  		}
   277  
   278  		// re-connect using that schema
   279  		d2, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar",
   280  			pgPassword, ip, port))
   281  		if err != nil {
   282  			t.Fatal(err)
   283  		}
   284  		defer func() {
   285  			if err := d2.Close(); err != nil {
   286  				t.Fatal(err)
   287  			}
   288  		}()
   289  
   290  		version, _, err := d2.Version()
   291  		if err != nil {
   292  			t.Fatal(err)
   293  		}
   294  		if version != database.NilVersion {
   295  			t.Fatal("expected NilVersion")
   296  		}
   297  
   298  		// now update version and compare
   299  		if err := d2.SetVersion(2, false); err != nil {
   300  			t.Fatal(err)
   301  		}
   302  		version, _, err = d2.Version()
   303  		if err != nil {
   304  			t.Fatal(err)
   305  		}
   306  		if version != 2 {
   307  			t.Fatal("expected version 2")
   308  		}
   309  
   310  		// meanwhile, the public schema still has the other version
   311  		version, _, err = d.Version()
   312  		if err != nil {
   313  			t.Fatal(err)
   314  		}
   315  		if version != 1 {
   316  			t.Fatal("expected version 2")
   317  		}
   318  	})
   319  }
   320  
   321  func TestFailToCreateTableWithoutPermissions(t *testing.T) {
   322  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   323  		ip, port, err := c.FirstPort()
   324  		if err != nil {
   325  			t.Fatal(err)
   326  		}
   327  
   328  		addr := pgConnectionString(ip, port)
   329  
   330  		// Check that opening the postgres connection returns NilVersion
   331  		p := &Postgres{}
   332  
   333  		d, err := p.Open(addr)
   334  
   335  		if err != nil {
   336  			t.Fatal(err)
   337  		}
   338  
   339  		defer func() {
   340  			if err := d.Close(); err != nil {
   341  				t.Error(err)
   342  			}
   343  		}()
   344  
   345  		// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
   346  		// since this is a test environment and we're not expecting to the pgPassword to be malicious
   347  		mustRun(t, d, []string{
   348  			"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
   349  			"CREATE SCHEMA barfoo AUTHORIZATION postgres",
   350  			"GRANT USAGE ON SCHEMA barfoo TO not_owner",
   351  			"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
   352  			"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
   353  		})
   354  
   355  		// re-connect using that schema
   356  		d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
   357  			pgPassword, ip, port))
   358  
   359  		defer func() {
   360  			if d2 == nil {
   361  				return
   362  			}
   363  			if err := d2.Close(); err != nil {
   364  				t.Fatal(err)
   365  			}
   366  		}()
   367  
   368  		var e *database.Error
   369  		if !errors.As(err, &e) || err == nil {
   370  			t.Fatal("Unexpected error, want permission denied error. Got: ", err)
   371  		}
   372  
   373  		if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
   374  			t.Fatal(e)
   375  		}
   376  	})
   377  }
   378  
   379  func TestCheckBeforeCreateTable(t *testing.T) {
   380  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   381  		ip, port, err := c.FirstPort()
   382  		if err != nil {
   383  			t.Fatal(err)
   384  		}
   385  
   386  		addr := pgConnectionString(ip, port)
   387  
   388  		// Check that opening the postgres connection returns NilVersion
   389  		p := &Postgres{}
   390  
   391  		d, err := p.Open(addr)
   392  
   393  		if err != nil {
   394  			t.Fatal(err)
   395  		}
   396  
   397  		defer func() {
   398  			if err := d.Close(); err != nil {
   399  				t.Error(err)
   400  			}
   401  		}()
   402  
   403  		// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
   404  		// since this is a test environment and we're not expecting to the pgPassword to be malicious
   405  		mustRun(t, d, []string{
   406  			"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
   407  			"CREATE SCHEMA barfoo AUTHORIZATION postgres",
   408  			"GRANT USAGE ON SCHEMA barfoo TO not_owner",
   409  			"GRANT CREATE ON SCHEMA barfoo TO not_owner",
   410  		})
   411  
   412  		// re-connect using that schema
   413  		d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
   414  			pgPassword, ip, port))
   415  
   416  		if err != nil {
   417  			t.Fatal(err)
   418  		}
   419  
   420  		if err := d2.Close(); err != nil {
   421  			t.Fatal(err)
   422  		}
   423  
   424  		// revoke privileges
   425  		mustRun(t, d, []string{
   426  			"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
   427  			"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
   428  		})
   429  
   430  		// re-connect using that schema
   431  		d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
   432  			pgPassword, ip, port))
   433  
   434  		if err != nil {
   435  			t.Fatal(err)
   436  		}
   437  
   438  		version, _, err := d3.Version()
   439  
   440  		if err != nil {
   441  			t.Fatal(err)
   442  		}
   443  
   444  		if version != database.NilVersion {
   445  			t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
   446  		}
   447  
   448  		defer func() {
   449  			if err := d3.Close(); err != nil {
   450  				t.Fatal(err)
   451  			}
   452  		}()
   453  	})
   454  }
   455  
   456  func TestParallelSchema(t *testing.T) {
   457  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   458  		ip, port, err := c.FirstPort()
   459  		if err != nil {
   460  			t.Fatal(err)
   461  		}
   462  
   463  		addr := pgConnectionString(ip, port)
   464  		p := &Postgres{}
   465  		d, err := p.Open(addr)
   466  		if err != nil {
   467  			t.Fatal(err)
   468  		}
   469  		defer func() {
   470  			if err := d.Close(); err != nil {
   471  				t.Error(err)
   472  			}
   473  		}()
   474  
   475  		// create foo and bar schemas
   476  		if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
   477  			t.Fatal(err)
   478  		}
   479  		if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
   480  			t.Fatal(err)
   481  		}
   482  
   483  		// re-connect using that schemas
   484  		dfoo, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foo",
   485  			pgPassword, ip, port))
   486  		if err != nil {
   487  			t.Fatal(err)
   488  		}
   489  		defer func() {
   490  			if err := dfoo.Close(); err != nil {
   491  				t.Error(err)
   492  			}
   493  		}()
   494  
   495  		dbar, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=bar",
   496  			pgPassword, ip, port))
   497  		if err != nil {
   498  			t.Fatal(err)
   499  		}
   500  		defer func() {
   501  			if err := dbar.Close(); err != nil {
   502  				t.Error(err)
   503  			}
   504  		}()
   505  
   506  		if err := dfoo.Lock(); err != nil {
   507  			t.Fatal(err)
   508  		}
   509  
   510  		if err := dbar.Lock(); err != nil {
   511  			t.Fatal(err)
   512  		}
   513  
   514  		if err := dbar.Unlock(); err != nil {
   515  			t.Fatal(err)
   516  		}
   517  
   518  		if err := dfoo.Unlock(); err != nil {
   519  			t.Fatal(err)
   520  		}
   521  	})
   522  }
   523  
   524  func TestPostgres_Lock(t *testing.T) {
   525  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   526  		ip, port, err := c.FirstPort()
   527  		if err != nil {
   528  			t.Fatal(err)
   529  		}
   530  
   531  		addr := pgConnectionString(ip, port)
   532  		p := &Postgres{}
   533  		d, err := p.Open(addr)
   534  		if err != nil {
   535  			t.Fatal(err)
   536  		}
   537  
   538  		dt.Test(t, d, []byte("SELECT 1"))
   539  
   540  		ps := d.(*Postgres)
   541  
   542  		err = ps.Lock()
   543  		if err != nil {
   544  			t.Fatal(err)
   545  		}
   546  
   547  		err = ps.Unlock()
   548  		if err != nil {
   549  			t.Fatal(err)
   550  		}
   551  
   552  		err = ps.Lock()
   553  		if err != nil {
   554  			t.Fatal(err)
   555  		}
   556  
   557  		err = ps.Unlock()
   558  		if err != nil {
   559  			t.Fatal(err)
   560  		}
   561  	})
   562  }
   563  
   564  func TestWithInstance_Concurrent(t *testing.T) {
   565  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   566  		ip, port, err := c.FirstPort()
   567  		if err != nil {
   568  			t.Fatal(err)
   569  		}
   570  
   571  		// The number of concurrent processes running WithInstance
   572  		const concurrency = 30
   573  
   574  		// We can instantiate a single database handle because it is
   575  		// actually a connection pool, and so, each of the below go
   576  		// routines will have a high probability of using a separate
   577  		// connection, which is something we want to exercise.
   578  		db, err := sql.Open("postgres", pgConnectionString(ip, port))
   579  		if err != nil {
   580  			t.Fatal(err)
   581  		}
   582  		defer func() {
   583  			if err := db.Close(); err != nil {
   584  				t.Error(err)
   585  			}
   586  		}()
   587  
   588  		db.SetMaxIdleConns(concurrency)
   589  		db.SetMaxOpenConns(concurrency)
   590  
   591  		var wg sync.WaitGroup
   592  		defer wg.Wait()
   593  
   594  		wg.Add(concurrency)
   595  		for i := 0; i < concurrency; i++ {
   596  			go func(i int) {
   597  				defer wg.Done()
   598  				_, err := WithInstance(db, &Config{})
   599  				if err != nil {
   600  					t.Errorf("process %d error: %s", i, err)
   601  				}
   602  			}(i)
   603  		}
   604  	})
   605  }
   606  func Test_computeLineFromPos(t *testing.T) {
   607  	testcases := []struct {
   608  		pos      int
   609  		wantLine uint
   610  		wantCol  uint
   611  		input    string
   612  		wantOk   bool
   613  	}{
   614  		{
   615  			15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
   616  		},
   617  		{
   618  			16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
   619  		},
   620  		{
   621  			25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
   622  		},
   623  		{
   624  			27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
   625  		},
   626  		{
   627  			10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
   628  		},
   629  		{
   630  			11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
   631  		},
   632  		{
   633  			17, 2, 8, "SELECT *\nFROM foo", true, // last character
   634  		},
   635  		{
   636  			18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
   637  		},
   638  	}
   639  	for i, tc := range testcases {
   640  		t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
   641  			run := func(crlf bool, nonASCII bool) {
   642  				var name string
   643  				if crlf {
   644  					name = "crlf"
   645  				} else {
   646  					name = "lf"
   647  				}
   648  				if nonASCII {
   649  					name += "-nonascii"
   650  				} else {
   651  					name += "-ascii"
   652  				}
   653  				t.Run(name, func(t *testing.T) {
   654  					input := tc.input
   655  					if crlf {
   656  						input = strings.Replace(input, "\n", "\r\n", -1)
   657  					}
   658  					if nonASCII {
   659  						input = strings.Replace(input, "FROM", "FRÖM", -1)
   660  					}
   661  					gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
   662  
   663  					if tc.wantOk {
   664  						t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
   665  					}
   666  
   667  					if gotOK != tc.wantOk {
   668  						t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
   669  					}
   670  					if gotLine != tc.wantLine {
   671  						t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
   672  					}
   673  					if gotCol != tc.wantCol {
   674  						t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
   675  					}
   676  				})
   677  			}
   678  			run(false, false)
   679  			run(true, false)
   680  			run(false, true)
   681  			run(true, true)
   682  		})
   683  	}
   684  
   685  }