github.com/SmoothieNoIce/migrate@v3.5.4+incompatible/database/postgres/postgres_test.go (about)

     1  package postgres
     2  
     3  // error codes https://github.com/lib/pq/blob/master/error.go
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"database/sql"
     9  	sqldriver "database/sql/driver"
    10  	"fmt"
    11  	"io"
    12  	"strconv"
    13  	"strings"
    14  	"testing"
    15  
    16  	dt "github.com/golang-migrate/migrate/database/testing"
    17  	mt "github.com/golang-migrate/migrate/testing"
    18  )
    19  
    20  var versions = []mt.Version{
    21  	{Image: "postgres:10"},
    22  	{Image: "postgres:9.6"},
    23  	{Image: "postgres:9.5"},
    24  	{Image: "postgres:9.4"},
    25  	{Image: "postgres:9.3"},
    26  }
    27  
    28  func pgConnectionString(host string, port uint) string {
    29  	return fmt.Sprintf("postgres://postgres@%s:%v/postgres?sslmode=disable", host, port)
    30  }
    31  
    32  func isReady(i mt.Instance) bool {
    33  	db, err := sql.Open("postgres", pgConnectionString(i.Host(), i.Port()))
    34  	if err != nil {
    35  		return false
    36  	}
    37  	defer db.Close()
    38  	if err = db.Ping(); err != nil {
    39  		switch err {
    40  		case sqldriver.ErrBadConn, io.EOF:
    41  			return false
    42  		default:
    43  			fmt.Println(err)
    44  		}
    45  		return false
    46  	}
    47  
    48  	return true
    49  }
    50  
    51  func Test(t *testing.T) {
    52  	mt.ParallelTest(t, versions, isReady,
    53  		func(t *testing.T, i mt.Instance) {
    54  			p := &Postgres{}
    55  			addr := pgConnectionString(i.Host(), i.Port())
    56  			d, err := p.Open(addr)
    57  			if err != nil {
    58  				t.Fatalf("%v", err)
    59  			}
    60  			defer d.Close()
    61  			dt.Test(t, d, []byte("SELECT 1"))
    62  		})
    63  }
    64  
    65  func TestMultiStatement(t *testing.T) {
    66  	mt.ParallelTest(t, versions, isReady,
    67  		func(t *testing.T, i mt.Instance) {
    68  			p := &Postgres{}
    69  			addr := pgConnectionString(i.Host(), i.Port())
    70  			d, err := p.Open(addr)
    71  			if err != nil {
    72  				t.Fatalf("%v", err)
    73  			}
    74  			defer d.Close()
    75  			if err := d.Run(bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"))); err != nil {
    76  				t.Fatalf("expected err to be nil, got %v", err)
    77  			}
    78  
    79  			// make sure second table exists
    80  			var exists bool
    81  			if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
    82  				t.Fatal(err)
    83  			}
    84  			if !exists {
    85  				t.Fatalf("expected table bar to exist")
    86  			}
    87  		})
    88  }
    89  
    90  func TestErrorParsing(t *testing.T) {
    91  	mt.ParallelTest(t, versions, isReady,
    92  		func(t *testing.T, i mt.Instance) {
    93  			p := &Postgres{}
    94  			addr := pgConnectionString(i.Host(), i.Port())
    95  			d, err := p.Open(addr)
    96  			if err != nil {
    97  				t.Fatalf("%v", err)
    98  			}
    99  			defer d.Close()
   100  
   101  			wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
   102  				`(foo text); CREATE TABLEE bar (bar text); (details: pq: syntax error at or near "TABLEE")`
   103  			if err := d.Run(bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);"))); err == nil {
   104  				t.Fatal("expected err but got nil")
   105  			} else if err.Error() != wantErr {
   106  				t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
   107  			}
   108  		})
   109  }
   110  
   111  func TestFilterCustomQuery(t *testing.T) {
   112  	mt.ParallelTest(t, versions, isReady,
   113  		func(t *testing.T, i mt.Instance) {
   114  			p := &Postgres{}
   115  			addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&x-custom=foobar", i.Host(), i.Port())
   116  			d, err := p.Open(addr)
   117  			if err != nil {
   118  				t.Fatalf("%v", err)
   119  			}
   120  			defer d.Close()
   121  		})
   122  }
   123  
   124  func TestWithSchema(t *testing.T) {
   125  	mt.ParallelTest(t, versions, isReady,
   126  		func(t *testing.T, i mt.Instance) {
   127  			p := &Postgres{}
   128  			addr := pgConnectionString(i.Host(), i.Port())
   129  			d, err := p.Open(addr)
   130  			if err != nil {
   131  				t.Fatalf("%v", err)
   132  			}
   133  			defer d.Close()
   134  
   135  			// create foobar schema
   136  			if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA foobar AUTHORIZATION postgres"))); err != nil {
   137  				t.Fatal(err)
   138  			}
   139  			if err := d.SetVersion(1, false); err != nil {
   140  				t.Fatal(err)
   141  			}
   142  
   143  			// re-connect using that schema
   144  			d2, err := p.Open(fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&search_path=foobar", i.Host(), i.Port()))
   145  			if err != nil {
   146  				t.Fatalf("%v", err)
   147  			}
   148  			defer d2.Close()
   149  
   150  			version, _, err := d2.Version()
   151  			if err != nil {
   152  				t.Fatal(err)
   153  			}
   154  			if version != -1 {
   155  				t.Fatal("expected NilVersion")
   156  			}
   157  
   158  			// now update version and compare
   159  			if err := d2.SetVersion(2, false); err != nil {
   160  				t.Fatal(err)
   161  			}
   162  			version, _, err = d2.Version()
   163  			if err != nil {
   164  				t.Fatal(err)
   165  			}
   166  			if version != 2 {
   167  				t.Fatal("expected version 2")
   168  			}
   169  
   170  			// meanwhile, the public schema still has the other version
   171  			version, _, err = d.Version()
   172  			if err != nil {
   173  				t.Fatal(err)
   174  			}
   175  			if version != 1 {
   176  				t.Fatal("expected version 2")
   177  			}
   178  		})
   179  }
   180  
   181  func TestWithInstance(t *testing.T) {
   182  
   183  }
   184  
   185  func TestPostgres_Lock(t *testing.T) {
   186  	mt.ParallelTest(t, versions, isReady,
   187  		func(t *testing.T, i mt.Instance) {
   188  			p := &Postgres{}
   189  			addr := pgConnectionString(i.Host(), i.Port())
   190  			d, err := p.Open(addr)
   191  			if err != nil {
   192  				t.Fatalf("%v", err)
   193  			}
   194  
   195  			dt.Test(t, d, []byte("SELECT 1"))
   196  
   197  			ps := d.(*Postgres)
   198  
   199  			err = ps.Lock()
   200  			if err != nil {
   201  				t.Fatal(err)
   202  			}
   203  
   204  			err = ps.Unlock()
   205  			if err != nil {
   206  				t.Fatal(err)
   207  			}
   208  
   209  			err = ps.Lock()
   210  			if err != nil {
   211  				t.Fatal(err)
   212  			}
   213  
   214  			err = ps.Unlock()
   215  			if err != nil {
   216  				t.Fatal(err)
   217  			}
   218  		})
   219  }
   220  
   221  func Test_computeLineFromPos(t *testing.T) {
   222  	testcases := []struct {
   223  		pos      int
   224  		wantLine uint
   225  		wantCol  uint
   226  		input    string
   227  		wantOk   bool
   228  	}{
   229  		{
   230  			15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
   231  		},
   232  		{
   233  			16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
   234  		},
   235  		{
   236  			25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
   237  		},
   238  		{
   239  			27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
   240  		},
   241  		{
   242  			10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
   243  		},
   244  		{
   245  			11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
   246  		},
   247  		{
   248  			17, 2, 8, "SELECT *\nFROM foo", true, // last character
   249  		},
   250  		{
   251  			18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
   252  		},
   253  	}
   254  	for i, tc := range testcases {
   255  		t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
   256  			run := func(crlf bool, nonASCII bool) {
   257  				var name string
   258  				if crlf {
   259  					name = "crlf"
   260  				} else {
   261  					name = "lf"
   262  				}
   263  				if nonASCII {
   264  					name += "-nonascii"
   265  				} else {
   266  					name += "-ascii"
   267  				}
   268  				t.Run(name, func(t *testing.T) {
   269  					input := tc.input
   270  					if crlf {
   271  						input = strings.Replace(input, "\n", "\r\n", -1)
   272  					}
   273  					if nonASCII {
   274  						input = strings.Replace(input, "FROM", "FRÖM", -1)
   275  					}
   276  					gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
   277  
   278  					if tc.wantOk {
   279  						t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
   280  					}
   281  
   282  					if gotOK != tc.wantOk {
   283  						t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
   284  					}
   285  					if gotLine != tc.wantLine {
   286  						t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
   287  					}
   288  					if gotCol != tc.wantCol {
   289  						t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
   290  					}
   291  				})
   292  			}
   293  			run(false, false)
   294  			run(true, false)
   295  			run(false, true)
   296  			run(true, true)
   297  		})
   298  	}
   299  
   300  }