github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/redshift/redshift_test.go (about)

     1  package redshift
     2  
     3  // error codes https://github.com/lib/pq/blob/master/error.go
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"database/sql"
     9  	sqldriver "database/sql/driver"
    10  	"fmt"
    11  	"log"
    12  
    13  	"github.com/golang-migrate/migrate/v4"
    14  	"io"
    15  	"strconv"
    16  	"strings"
    17  	"testing"
    18  )
    19  
    20  import (
    21  	"github.com/dhui/dktest"
    22  )
    23  
    24  import (
    25  	"github.com/golang-migrate/migrate/v4/database"
    26  	dt "github.com/golang-migrate/migrate/v4/database/testing"
    27  	"github.com/golang-migrate/migrate/v4/dktesting"
    28  	_ "github.com/golang-migrate/migrate/v4/source/file"
    29  )
    30  
    31  var (
    32  	opts  = dktest.Options{PortRequired: true, ReadyFunc: isReady}
    33  	specs = []dktesting.ContainerSpec{
    34  		{ImageName: "postgres:8", Options: opts},
    35  	}
    36  )
    37  
    38  func redshiftConnectionString(host, port string) string {
    39  	return connectionString("redshift", host, port)
    40  }
    41  
    42  func pgConnectionString(host, port string) string {
    43  	return connectionString("postgres", host, port)
    44  }
    45  
    46  func connectionString(schema, host, port string) string {
    47  	return fmt.Sprintf("%s://postgres@%s:%s/postgres?sslmode=disable", schema, host, port)
    48  }
    49  
    50  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    51  	ip, port, err := c.FirstPort()
    52  	if err != nil {
    53  		return false
    54  	}
    55  
    56  	db, err := sql.Open("postgres", pgConnectionString(ip, port))
    57  	if err != nil {
    58  		return false
    59  	}
    60  	defer func() {
    61  		if err := db.Close(); err != nil {
    62  			log.Println("close error:", err)
    63  		}
    64  	}()
    65  	if err = db.PingContext(ctx); err != nil {
    66  		switch err {
    67  		case sqldriver.ErrBadConn, io.EOF:
    68  			return false
    69  		default:
    70  			log.Println(err)
    71  		}
    72  		return false
    73  	}
    74  
    75  	return true
    76  }
    77  
    78  func Test(t *testing.T) {
    79  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    80  		ip, port, err := c.FirstPort()
    81  		if err != nil {
    82  			t.Fatal(err)
    83  		}
    84  
    85  		addr := redshiftConnectionString(ip, port)
    86  		p := &Redshift{}
    87  		d, err := p.Open(addr)
    88  		if err != nil {
    89  			t.Fatal(err)
    90  		}
    91  		defer func() {
    92  			if err := d.Close(); err != nil {
    93  				t.Error(err)
    94  			}
    95  		}()
    96  		dt.Test(t, d, []byte("SELECT 1"))
    97  	})
    98  }
    99  
   100  func TestMigrate(t *testing.T) {
   101  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   102  		ip, port, err := c.FirstPort()
   103  		if err != nil {
   104  			t.Fatal(err)
   105  		}
   106  
   107  		addr := redshiftConnectionString(ip, port)
   108  		p := &Redshift{}
   109  		d, err := p.Open(addr)
   110  		if err != nil {
   111  			t.Fatal(err)
   112  		}
   113  		defer func() {
   114  			if err := d.Close(); err != nil {
   115  				t.Error(err)
   116  			}
   117  		}()
   118  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "postgres", d)
   119  		if err != nil {
   120  			t.Fatal(err)
   121  		}
   122  		dt.TestMigrate(t, m)
   123  	})
   124  }
   125  
   126  func TestMultiStatement(t *testing.T) {
   127  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   128  		ip, port, err := c.FirstPort()
   129  		if err != nil {
   130  			t.Fatal(err)
   131  		}
   132  
   133  		addr := redshiftConnectionString(ip, port)
   134  		p := &Redshift{}
   135  		d, err := p.Open(addr)
   136  		if err != nil {
   137  			t.Fatal(err)
   138  		}
   139  		defer func() {
   140  			if err := d.Close(); err != nil {
   141  				t.Error(err)
   142  			}
   143  		}()
   144  		if err := d.Run(bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"))); err != nil {
   145  			t.Fatalf("expected err to be nil, got %v", err)
   146  		}
   147  
   148  		// make sure second table exists
   149  		var exists bool
   150  		if err := d.(*Redshift).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
   151  			t.Fatal(err)
   152  		}
   153  		if !exists {
   154  			t.Fatalf("expected table bar to exist")
   155  		}
   156  	})
   157  }
   158  
   159  func TestErrorParsing(t *testing.T) {
   160  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   161  		ip, port, err := c.FirstPort()
   162  		if err != nil {
   163  			t.Fatal(err)
   164  		}
   165  
   166  		addr := redshiftConnectionString(ip, port)
   167  		p := &Redshift{}
   168  		d, err := p.Open(addr)
   169  		if err != nil {
   170  			t.Fatal(err)
   171  		}
   172  		defer func() {
   173  			if err := d.Close(); err != nil {
   174  				t.Error(err)
   175  			}
   176  		}()
   177  
   178  		wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
   179  			`(foo text); CREATE TABLEE bar (bar text); (details: pq: syntax error at or near "TABLEE")`
   180  		if err := d.Run(bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);"))); err == nil {
   181  			t.Fatal("expected err but got nil")
   182  		} else if err.Error() != wantErr {
   183  			t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
   184  		}
   185  	})
   186  }
   187  
   188  func TestFilterCustomQuery(t *testing.T) {
   189  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   190  		ip, port, err := c.FirstPort()
   191  		if err != nil {
   192  			t.Fatal(err)
   193  		}
   194  
   195  		addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&x-custom=foobar", ip, port)
   196  		p := &Redshift{}
   197  		d, err := p.Open(addr)
   198  		if err != nil {
   199  			t.Fatal(err)
   200  		}
   201  		defer func() {
   202  			if err := d.Close(); err != nil {
   203  				t.Error(err)
   204  			}
   205  		}()
   206  	})
   207  }
   208  
   209  func TestWithSchema(t *testing.T) {
   210  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   211  		ip, port, err := c.FirstPort()
   212  		if err != nil {
   213  			t.Fatal(err)
   214  		}
   215  
   216  		addr := redshiftConnectionString(ip, port)
   217  		p := &Redshift{}
   218  		d, err := p.Open(addr)
   219  		if err != nil {
   220  			t.Fatal(err)
   221  		}
   222  		defer func() {
   223  			if err := d.Close(); err != nil {
   224  				t.Error(err)
   225  			}
   226  		}()
   227  
   228  		// create foobar schema
   229  		if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA foobar AUTHORIZATION postgres"))); err != nil {
   230  			t.Fatal(err)
   231  		}
   232  		if err := d.SetVersion(1, false); err != nil {
   233  			t.Fatal(err)
   234  		}
   235  
   236  		// re-connect using that schema
   237  		d2, err := p.Open(fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&search_path=foobar", ip, port))
   238  		if err != nil {
   239  			t.Fatal(err)
   240  		}
   241  		defer func() {
   242  			if err := d2.Close(); err != nil {
   243  				t.Error(err)
   244  			}
   245  		}()
   246  
   247  		version, _, err := d2.Version()
   248  		if err != nil {
   249  			t.Fatal(err)
   250  		}
   251  		if version != database.NilVersion {
   252  			t.Fatal("expected NilVersion")
   253  		}
   254  
   255  		// now update version and compare
   256  		if err := d2.SetVersion(2, false); err != nil {
   257  			t.Fatal(err)
   258  		}
   259  		version, _, err = d2.Version()
   260  		if err != nil {
   261  			t.Fatal(err)
   262  		}
   263  		if version != 2 {
   264  			t.Fatal("expected version 2")
   265  		}
   266  
   267  		// meanwhile, the public schema still has the other version
   268  		version, _, err = d.Version()
   269  		if err != nil {
   270  			t.Fatal(err)
   271  		}
   272  		if version != 1 {
   273  			t.Fatal("expected version 2")
   274  		}
   275  	})
   276  }
   277  
   278  func TestWithInstance(t *testing.T) {
   279  
   280  }
   281  
   282  func TestRedshift_Lock(t *testing.T) {
   283  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   284  		ip, port, err := c.FirstPort()
   285  		if err != nil {
   286  			t.Fatal(err)
   287  		}
   288  
   289  		addr := pgConnectionString(ip, port)
   290  		p := &Redshift{}
   291  		d, err := p.Open(addr)
   292  		if err != nil {
   293  			t.Fatal(err)
   294  		}
   295  
   296  		dt.Test(t, d, []byte("SELECT 1"))
   297  
   298  		ps := d.(*Redshift)
   299  
   300  		err = ps.Lock()
   301  		if err != nil {
   302  			t.Fatal(err)
   303  		}
   304  
   305  		err = ps.Unlock()
   306  		if err != nil {
   307  			t.Fatal(err)
   308  		}
   309  
   310  		err = ps.Lock()
   311  		if err != nil {
   312  			t.Fatal(err)
   313  		}
   314  
   315  		err = ps.Unlock()
   316  		if err != nil {
   317  			t.Fatal(err)
   318  		}
   319  	})
   320  }
   321  
   322  func Test_computeLineFromPos(t *testing.T) {
   323  	testcases := []struct {
   324  		pos      int
   325  		wantLine uint
   326  		wantCol  uint
   327  		input    string
   328  		wantOk   bool
   329  	}{
   330  		{
   331  			15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
   332  		},
   333  		{
   334  			16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
   335  		},
   336  		{
   337  			25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
   338  		},
   339  		{
   340  			27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
   341  		},
   342  		{
   343  			10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
   344  		},
   345  		{
   346  			11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
   347  		},
   348  		{
   349  			17, 2, 8, "SELECT *\nFROM foo", true, // last character
   350  		},
   351  		{
   352  			18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
   353  		},
   354  	}
   355  	for i, tc := range testcases {
   356  		t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
   357  			run := func(crlf bool, nonASCII bool) {
   358  				var name string
   359  				if crlf {
   360  					name = "crlf"
   361  				} else {
   362  					name = "lf"
   363  				}
   364  				if nonASCII {
   365  					name += "-nonascii"
   366  				} else {
   367  					name += "-ascii"
   368  				}
   369  				t.Run(name, func(t *testing.T) {
   370  					input := tc.input
   371  					if crlf {
   372  						input = strings.Replace(input, "\n", "\r\n", -1)
   373  					}
   374  					if nonASCII {
   375  						input = strings.Replace(input, "FROM", "FRÖM", -1)
   376  					}
   377  					gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
   378  
   379  					if tc.wantOk {
   380  						t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
   381  					}
   382  
   383  					if gotOK != tc.wantOk {
   384  						t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
   385  					}
   386  					if gotLine != tc.wantLine {
   387  						t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
   388  					}
   389  					if gotCol != tc.wantCol {
   390  						t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
   391  					}
   392  				})
   393  			}
   394  			run(false, false)
   395  			run(true, false)
   396  			run(false, true)
   397  			run(true, true)
   398  		})
   399  	}
   400  
   401  }