github.com/nagyist/migrate/v4@v4.14.6/database/postgres/postgres_test.go (about)

     1  package postgres
     2  
     3  // error codes https://github.com/lib/pq/blob/master/error.go
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	sqldriver "database/sql/driver"
     9  	"fmt"
    10  	"log"
    11  
    12  	"io"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  
    18  	"github.com/nagyist/migrate/v4"
    19  
    20  	"github.com/dhui/dktest"
    21  
    22  	"github.com/nagyist/migrate/v4/database"
    23  	dt "github.com/nagyist/migrate/v4/database/testing"
    24  	"github.com/nagyist/migrate/v4/dktesting"
    25  	_ "github.com/nagyist/migrate/v4/source/file"
    26  )
    27  
    28  const (
    29  	pgPassword = "postgres"
    30  )
    31  
    32  var (
    33  	opts = dktest.Options{
    34  		Env:          map[string]string{"POSTGRES_PASSWORD": pgPassword},
    35  		PortRequired: true, ReadyFunc: isReady}
    36  	// Supported versions: https://www.postgresql.org/support/versioning/
    37  	specs = []dktesting.ContainerSpec{
    38  		{ImageName: "postgres:9.5", Options: opts},
    39  		{ImageName: "postgres:9.6", Options: opts},
    40  		{ImageName: "postgres:10", Options: opts},
    41  		{ImageName: "postgres:11", Options: opts},
    42  		{ImageName: "postgres:12", Options: opts},
    43  	}
    44  )
    45  
    46  func pgConnectionString(host, port string) string {
    47  	return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?sslmode=disable", pgPassword, host, port)
    48  }
    49  
    50  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    51  	ip, port, err := c.FirstPort()
    52  	if err != nil {
    53  		return false
    54  	}
    55  
    56  	db, err := sql.Open("postgres", pgConnectionString(ip, port))
    57  	if err != nil {
    58  		return false
    59  	}
    60  	defer func() {
    61  		if err := db.Close(); err != nil {
    62  			log.Println("close error:", err)
    63  		}
    64  	}()
    65  	if err = db.PingContext(ctx); err != nil {
    66  		switch err {
    67  		case sqldriver.ErrBadConn, io.EOF:
    68  			return false
    69  		default:
    70  			log.Println(err)
    71  		}
    72  		return false
    73  	}
    74  
    75  	return true
    76  }
    77  
    78  func Test(t *testing.T) {
    79  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    80  		ip, port, err := c.FirstPort()
    81  		if err != nil {
    82  			t.Fatal(err)
    83  		}
    84  
    85  		addr := pgConnectionString(ip, port)
    86  		p := &Postgres{}
    87  		d, err := p.Open(addr)
    88  		if err != nil {
    89  			t.Fatal(err)
    90  		}
    91  		defer func() {
    92  			if err := d.Close(); err != nil {
    93  				t.Error(err)
    94  			}
    95  		}()
    96  		dt.Test(t, d, []byte("SELECT 1"))
    97  	})
    98  }
    99  
   100  func TestMigrate(t *testing.T) {
   101  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   102  		ip, port, err := c.FirstPort()
   103  		if err != nil {
   104  			t.Fatal(err)
   105  		}
   106  
   107  		addr := pgConnectionString(ip, port)
   108  		p := &Postgres{}
   109  		d, err := p.Open(addr)
   110  		if err != nil {
   111  			t.Fatal(err)
   112  		}
   113  		defer func() {
   114  			if err := d.Close(); err != nil {
   115  				t.Error(err)
   116  			}
   117  		}()
   118  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "postgres", d)
   119  		if err != nil {
   120  			t.Fatal(err)
   121  		}
   122  		dt.TestMigrate(t, m)
   123  	})
   124  }
   125  
   126  func TestMultiStatement(t *testing.T) {
   127  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   128  		ip, port, err := c.FirstPort()
   129  		if err != nil {
   130  			t.Fatal(err)
   131  		}
   132  
   133  		addr := pgConnectionString(ip, port)
   134  		p := &Postgres{}
   135  		d, err := p.Open(addr)
   136  		if err != nil {
   137  			t.Fatal(err)
   138  		}
   139  		defer func() {
   140  			if err := d.Close(); err != nil {
   141  				t.Error(err)
   142  			}
   143  		}()
   144  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
   145  			t.Fatalf("expected err to be nil, got %v", err)
   146  		}
   147  
   148  		// make sure second table exists
   149  		var exists bool
   150  		if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
   151  			t.Fatal(err)
   152  		}
   153  		if !exists {
   154  			t.Fatalf("expected table bar to exist")
   155  		}
   156  	})
   157  }
   158  
   159  func TestErrorParsing(t *testing.T) {
   160  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   161  		ip, port, err := c.FirstPort()
   162  		if err != nil {
   163  			t.Fatal(err)
   164  		}
   165  
   166  		addr := pgConnectionString(ip, port)
   167  		p := &Postgres{}
   168  		d, err := p.Open(addr)
   169  		if err != nil {
   170  			t.Fatal(err)
   171  		}
   172  		defer func() {
   173  			if err := d.Close(); err != nil {
   174  				t.Error(err)
   175  			}
   176  		}()
   177  
   178  		wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
   179  			`(foo text); CREATE TABLEE bar (bar text); (details: pq: syntax error at or near "TABLEE")`
   180  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
   181  			t.Fatal("expected err but got nil")
   182  		} else if err.Error() != wantErr {
   183  			t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
   184  		}
   185  	})
   186  }
   187  
   188  func TestFilterCustomQuery(t *testing.T) {
   189  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   190  		ip, port, err := c.FirstPort()
   191  		if err != nil {
   192  			t.Fatal(err)
   193  		}
   194  
   195  		addr := fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-custom=foobar",
   196  			pgPassword, ip, port)
   197  		p := &Postgres{}
   198  		d, err := p.Open(addr)
   199  		if err != nil {
   200  			t.Fatal(err)
   201  		}
   202  		defer func() {
   203  			if err := d.Close(); err != nil {
   204  				t.Error(err)
   205  			}
   206  		}()
   207  	})
   208  }
   209  
   210  func TestWithSchema(t *testing.T) {
   211  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   212  		ip, port, err := c.FirstPort()
   213  		if err != nil {
   214  			t.Fatal(err)
   215  		}
   216  
   217  		addr := pgConnectionString(ip, port)
   218  		p := &Postgres{}
   219  		d, err := p.Open(addr)
   220  		if err != nil {
   221  			t.Fatal(err)
   222  		}
   223  		defer func() {
   224  			if err := d.Close(); err != nil {
   225  				t.Fatal(err)
   226  			}
   227  		}()
   228  
   229  		// create foobar schema
   230  		if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
   231  			t.Fatal(err)
   232  		}
   233  		if err := d.SetVersion(1, false); err != nil {
   234  			t.Fatal(err)
   235  		}
   236  
   237  		// re-connect using that schema
   238  		d2, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar",
   239  			pgPassword, ip, port))
   240  		if err != nil {
   241  			t.Fatal(err)
   242  		}
   243  		defer func() {
   244  			if err := d2.Close(); err != nil {
   245  				t.Fatal(err)
   246  			}
   247  		}()
   248  
   249  		version, _, err := d2.Version()
   250  		if err != nil {
   251  			t.Fatal(err)
   252  		}
   253  		if version != database.NilVersion {
   254  			t.Fatal("expected NilVersion")
   255  		}
   256  
   257  		// now update version and compare
   258  		if err := d2.SetVersion(2, false); err != nil {
   259  			t.Fatal(err)
   260  		}
   261  		version, _, err = d2.Version()
   262  		if err != nil {
   263  			t.Fatal(err)
   264  		}
   265  		if version != 2 {
   266  			t.Fatal("expected version 2")
   267  		}
   268  
   269  		// meanwhile, the public schema still has the other version
   270  		version, _, err = d.Version()
   271  		if err != nil {
   272  			t.Fatal(err)
   273  		}
   274  		if version != 1 {
   275  			t.Fatal("expected version 2")
   276  		}
   277  	})
   278  }
   279  
   280  func TestParallelSchema(t *testing.T) {
   281  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   282  		ip, port, err := c.FirstPort()
   283  		if err != nil {
   284  			t.Fatal(err)
   285  		}
   286  
   287  		addr := pgConnectionString(ip, port)
   288  		p := &Postgres{}
   289  		d, err := p.Open(addr)
   290  		if err != nil {
   291  			t.Fatal(err)
   292  		}
   293  		defer func() {
   294  			if err := d.Close(); err != nil {
   295  				t.Error(err)
   296  			}
   297  		}()
   298  
   299  		// create foo and bar schemas
   300  		if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
   301  			t.Fatal(err)
   302  		}
   303  		if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
   304  			t.Fatal(err)
   305  		}
   306  
   307  		// re-connect using that schemas
   308  		dfoo, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foo",
   309  			pgPassword, ip, port))
   310  		if err != nil {
   311  			t.Fatal(err)
   312  		}
   313  		defer func() {
   314  			if err := dfoo.Close(); err != nil {
   315  				t.Error(err)
   316  			}
   317  		}()
   318  
   319  		dbar, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=bar",
   320  			pgPassword, ip, port))
   321  		if err != nil {
   322  			t.Fatal(err)
   323  		}
   324  		defer func() {
   325  			if err := dbar.Close(); err != nil {
   326  				t.Error(err)
   327  			}
   328  		}()
   329  
   330  		if err := dfoo.Lock(); err != nil {
   331  			t.Fatal(err)
   332  		}
   333  
   334  		if err := dbar.Lock(); err != nil {
   335  			t.Fatal(err)
   336  		}
   337  
   338  		if err := dbar.Unlock(); err != nil {
   339  			t.Fatal(err)
   340  		}
   341  
   342  		if err := dfoo.Unlock(); err != nil {
   343  			t.Fatal(err)
   344  		}
   345  	})
   346  }
   347  
   348  func TestWithInstance(t *testing.T) {
   349  
   350  }
   351  
   352  func TestPostgres_Lock(t *testing.T) {
   353  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   354  		ip, port, err := c.FirstPort()
   355  		if err != nil {
   356  			t.Fatal(err)
   357  		}
   358  
   359  		addr := pgConnectionString(ip, port)
   360  		p := &Postgres{}
   361  		d, err := p.Open(addr)
   362  		if err != nil {
   363  			t.Fatal(err)
   364  		}
   365  
   366  		dt.Test(t, d, []byte("SELECT 1"))
   367  
   368  		ps := d.(*Postgres)
   369  
   370  		err = ps.Lock()
   371  		if err != nil {
   372  			t.Fatal(err)
   373  		}
   374  
   375  		err = ps.Unlock()
   376  		if err != nil {
   377  			t.Fatal(err)
   378  		}
   379  
   380  		err = ps.Lock()
   381  		if err != nil {
   382  			t.Fatal(err)
   383  		}
   384  
   385  		err = ps.Unlock()
   386  		if err != nil {
   387  			t.Fatal(err)
   388  		}
   389  	})
   390  }
   391  
   392  func TestWithInstance_Concurrent(t *testing.T) {
   393  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   394  		ip, port, err := c.FirstPort()
   395  		if err != nil {
   396  			t.Fatal(err)
   397  		}
   398  
   399  		// The number of concurrent processes running WithInstance
   400  		const concurrency = 30
   401  
   402  		// We can instantiate a single database handle because it is
   403  		// actually a connection pool, and so, each of the below go
   404  		// routines will have a high probability of using a separate
   405  		// connection, which is something we want to exercise.
   406  		db, err := sql.Open("postgres", pgConnectionString(ip, port))
   407  		if err != nil {
   408  			t.Fatal(err)
   409  		}
   410  		defer func() {
   411  			if err := db.Close(); err != nil {
   412  				t.Error(err)
   413  			}
   414  		}()
   415  
   416  		db.SetMaxIdleConns(concurrency)
   417  		db.SetMaxOpenConns(concurrency)
   418  
   419  		var wg sync.WaitGroup
   420  		defer wg.Wait()
   421  
   422  		wg.Add(concurrency)
   423  		for i := 0; i < concurrency; i++ {
   424  			go func(i int) {
   425  				defer wg.Done()
   426  				_, err := WithInstance(db, &Config{})
   427  				if err != nil {
   428  					t.Errorf("process %d error: %s", i, err)
   429  				}
   430  			}(i)
   431  		}
   432  	})
   433  }
   434  func Test_computeLineFromPos(t *testing.T) {
   435  	testcases := []struct {
   436  		pos      int
   437  		wantLine uint
   438  		wantCol  uint
   439  		input    string
   440  		wantOk   bool
   441  	}{
   442  		{
   443  			15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
   444  		},
   445  		{
   446  			16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
   447  		},
   448  		{
   449  			25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
   450  		},
   451  		{
   452  			27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
   453  		},
   454  		{
   455  			10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
   456  		},
   457  		{
   458  			11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
   459  		},
   460  		{
   461  			17, 2, 8, "SELECT *\nFROM foo", true, // last character
   462  		},
   463  		{
   464  			18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
   465  		},
   466  	}
   467  	for i, tc := range testcases {
   468  		t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
   469  			run := func(crlf bool, nonASCII bool) {
   470  				var name string
   471  				if crlf {
   472  					name = "crlf"
   473  				} else {
   474  					name = "lf"
   475  				}
   476  				if nonASCII {
   477  					name += "-nonascii"
   478  				} else {
   479  					name += "-ascii"
   480  				}
   481  				t.Run(name, func(t *testing.T) {
   482  					input := tc.input
   483  					if crlf {
   484  						input = strings.Replace(input, "\n", "\r\n", -1)
   485  					}
   486  					if nonASCII {
   487  						input = strings.Replace(input, "FROM", "FRÖM", -1)
   488  					}
   489  					gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
   490  
   491  					if tc.wantOk {
   492  						t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
   493  					}
   494  
   495  					if gotOK != tc.wantOk {
   496  						t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
   497  					}
   498  					if gotLine != tc.wantLine {
   499  						t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
   500  					}
   501  					if gotCol != tc.wantCol {
   502  						t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
   503  					}
   504  				})
   505  			}
   506  			run(false, false)
   507  			run(true, false)
   508  			run(false, true)
   509  			run(true, true)
   510  		})
   511  	}
   512  }
   513  
   514  func Test_quoteIdentifier(t *testing.T) {
   515  	testcases := []struct {
   516  		migrationsTableHasSchema bool
   517  		name                     string
   518  		want                     string
   519  	}{
   520  		{
   521  			true,
   522  			"schema_name.table_name",
   523  			"\"schema_name\".\"table_name\"",
   524  		},
   525  		{
   526  			true,
   527  			"schema_name.table.name",
   528  			"\"schema_name\".\"table.name\"",
   529  		},
   530  		{
   531  			false,
   532  			"table_name",
   533  			"\"table_name\"",
   534  		},
   535  		{
   536  			false,
   537  			"table.name",
   538  			"\"table.name\"",
   539  		},
   540  	}
   541  	p := &Postgres{
   542  		config: &Config{
   543  			MigrationsTableHasSchema: true,
   544  		},
   545  	}
   546  
   547  	for _, tc := range testcases {
   548  		p.config.MigrationsTableHasSchema = tc.migrationsTableHasSchema
   549  		got := p.quoteIdentifier(tc.name)
   550  		if tc.want != got {
   551  			t.Fatalf("expected %s but got %s", tc.want, got)
   552  		}
   553  	}
   554  }