github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/sqlserver/sqlserver_test.go (about)

     1  package sqlserver
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	sqldriver "database/sql/driver"
     7  	"fmt"
     8  	"log"
     9  	"runtime"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/dhui/dktest"
    15  	"github.com/docker/go-connections/nat"
    16  	"github.com/golang-migrate/migrate/v4"
    17  
    18  	dt "github.com/golang-migrate/migrate/v4/database/testing"
    19  	"github.com/golang-migrate/migrate/v4/dktesting"
    20  
    21  	_ "github.com/golang-migrate/migrate/v4/source/file"
    22  )
    23  
    24  const defaultPort = 1433
    25  const saPassword = "Root1234"
    26  
    27  var (
    28  	sqlEdgeOpts = dktest.Options{
    29  		Env: map[string]string{"ACCEPT_EULA": "Y", "MSSQL_SA_PASSWORD": saPassword},
    30  		PortBindings: map[nat.Port][]nat.PortBinding{
    31  			nat.Port(fmt.Sprintf("%d/tcp", defaultPort)): {
    32  				nat.PortBinding{
    33  					HostIP:   "0.0.0.0",
    34  					HostPort: "0/tcp",
    35  				},
    36  			},
    37  		},
    38  		PortRequired: true, ReadyFunc: isReady, PullTimeout: 2 * time.Minute,
    39  	}
    40  	sqlServerOpts = dktest.Options{
    41  		Env:          map[string]string{"ACCEPT_EULA": "Y", "MSSQL_SA_PASSWORD": saPassword, "MSSQL_PID": "Express"},
    42  		PortRequired: true, ReadyFunc: isReady, PullTimeout: 2 * time.Minute,
    43  	}
    44  	// Container versions: https://mcr.microsoft.com/v2/mssql/server/tags/list
    45  	specs = []dktesting.ContainerSpec{
    46  		{ImageName: "mcr.microsoft.com/azure-sql-edge:latest", Options: sqlEdgeOpts},
    47  		{ImageName: "mcr.microsoft.com/mssql/server:2017-latest", Options: sqlServerOpts},
    48  		{ImageName: "mcr.microsoft.com/mssql/server:2019-latest", Options: sqlServerOpts},
    49  	}
    50  )
    51  
    52  func msConnectionString(host, port string) string {
    53  	return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port)
    54  }
    55  
    56  func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string {
    57  	return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master&useMsi=%t", saPassword, host, port, useMsi)
    58  }
    59  
    60  func msConnectionStringMsi(host, port string, useMsi bool) string {
    61  	return fmt.Sprintf("sqlserver://sa@%v:%v?database=master&useMsi=%t", host, port, useMsi)
    62  }
    63  
    64  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    65  	ip, port, err := c.Port(defaultPort)
    66  	if err != nil {
    67  		return false
    68  	}
    69  	uri := msConnectionString(ip, port)
    70  	db, err := sql.Open("sqlserver", uri)
    71  	if err != nil {
    72  		return false
    73  	}
    74  	defer func() {
    75  		if err := db.Close(); err != nil {
    76  			log.Println("close error:", err)
    77  		}
    78  	}()
    79  	if err = db.PingContext(ctx); err != nil {
    80  		switch err {
    81  		case sqldriver.ErrBadConn:
    82  			return false
    83  		default:
    84  			fmt.Println(err)
    85  		}
    86  		return false
    87  	}
    88  
    89  	return true
    90  }
    91  
    92  func SkipIfUnsupportedArch(t *testing.T, c dktest.ContainerInfo) {
    93  	if strings.Contains(c.ImageName, "mssql") && !strings.HasPrefix(runtime.GOARCH, "amd") {
    94  		t.Skipf("Image %s is not supported on arch %s", c.ImageName, runtime.GOARCH)
    95  	}
    96  }
    97  
    98  func Test(t *testing.T) {
    99  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   100  		SkipIfUnsupportedArch(t, c)
   101  		ip, port, err := c.Port(defaultPort)
   102  		if err != nil {
   103  			t.Fatal(err)
   104  		}
   105  
   106  		addr := msConnectionString(ip, port)
   107  		p := &SQLServer{}
   108  		d, err := p.Open(addr)
   109  		if err != nil {
   110  			t.Fatalf("%v", err)
   111  		}
   112  
   113  		defer func() {
   114  			if err := d.Close(); err != nil {
   115  				t.Error(err)
   116  			}
   117  		}()
   118  
   119  		dt.Test(t, d, []byte("SELECT 1"))
   120  	})
   121  }
   122  
   123  func TestMigrate(t *testing.T) {
   124  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   125  		SkipIfUnsupportedArch(t, c)
   126  		ip, port, err := c.Port(defaultPort)
   127  		if err != nil {
   128  			t.Fatal(err)
   129  		}
   130  
   131  		addr := msConnectionString(ip, port)
   132  		p := &SQLServer{}
   133  		d, err := p.Open(addr)
   134  		if err != nil {
   135  			t.Fatalf("%v", err)
   136  		}
   137  
   138  		defer func() {
   139  			if err := d.Close(); err != nil {
   140  				t.Error(err)
   141  			}
   142  		}()
   143  
   144  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "master", d)
   145  		if err != nil {
   146  			t.Fatal(err)
   147  		}
   148  		dt.TestMigrate(t, m)
   149  	})
   150  }
   151  
   152  func TestMultiStatement(t *testing.T) {
   153  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   154  		SkipIfUnsupportedArch(t, c)
   155  		ip, port, err := c.Port(defaultPort)
   156  		if err != nil {
   157  			t.Fatal(err)
   158  		}
   159  
   160  		addr := msConnectionString(ip, port)
   161  		ms := &SQLServer{}
   162  		d, err := ms.Open(addr)
   163  		if err != nil {
   164  			t.Fatal(err)
   165  		}
   166  		defer func() {
   167  			if err := d.Close(); err != nil {
   168  				t.Error(err)
   169  			}
   170  		}()
   171  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
   172  			t.Fatalf("expected err to be nil, got %v", err)
   173  		}
   174  
   175  		// make sure second table exists
   176  		var exists int
   177  		if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil {
   178  			t.Fatal(err)
   179  		}
   180  		if exists != 1 {
   181  			t.Fatalf("expected table bar to exist")
   182  		}
   183  	})
   184  }
   185  
   186  func TestErrorParsing(t *testing.T) {
   187  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   188  		SkipIfUnsupportedArch(t, c)
   189  		ip, port, err := c.Port(defaultPort)
   190  		if err != nil {
   191  			t.Fatal(err)
   192  		}
   193  
   194  		addr := msConnectionString(ip, port)
   195  
   196  		p := &SQLServer{}
   197  		d, err := p.Open(addr)
   198  		if err != nil {
   199  			t.Fatal(err)
   200  		}
   201  		defer func() {
   202  			if err := d.Close(); err != nil {
   203  				t.Error(err)
   204  			}
   205  		}()
   206  
   207  		wantErr := `migration failed: Unknown object type 'TABLEE' used in a CREATE, DROP, or ALTER statement. in line 1:` +
   208  			` CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text); (details: mssql: Unknown object type ` +
   209  			`'TABLEE' used in a CREATE, DROP, or ALTER statement.)`
   210  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
   211  			t.Fatal("expected err but got nil")
   212  		} else if err.Error() != wantErr {
   213  			t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
   214  		}
   215  	})
   216  }
   217  
   218  func TestLockWorks(t *testing.T) {
   219  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   220  		SkipIfUnsupportedArch(t, c)
   221  		ip, port, err := c.Port(defaultPort)
   222  		if err != nil {
   223  			t.Fatal(err)
   224  		}
   225  
   226  		addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port)
   227  		p := &SQLServer{}
   228  		d, err := p.Open(addr)
   229  		if err != nil {
   230  			t.Fatalf("%v", err)
   231  		}
   232  		dt.Test(t, d, []byte("SELECT 1"))
   233  
   234  		ms := d.(*SQLServer)
   235  
   236  		err = ms.Lock()
   237  		if err != nil {
   238  			t.Fatal(err)
   239  		}
   240  		err = ms.Unlock()
   241  		if err != nil {
   242  			t.Fatal(err)
   243  		}
   244  
   245  		// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
   246  		err = ms.Lock()
   247  		if err != nil {
   248  			t.Fatal(err)
   249  		}
   250  		err = ms.Unlock()
   251  		if err != nil {
   252  			t.Fatal(err)
   253  		}
   254  	})
   255  }
   256  
   257  func TestMsiTrue(t *testing.T) {
   258  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   259  		SkipIfUnsupportedArch(t, c)
   260  		ip, port, err := c.Port(defaultPort)
   261  		if err != nil {
   262  			t.Fatal(err)
   263  		}
   264  
   265  		addr := msConnectionStringMsi(ip, port, true)
   266  		p := &SQLServer{}
   267  		_, err = p.Open(addr)
   268  		if err == nil {
   269  			t.Fatal("MSI should fail when not running in an Azure context.")
   270  		}
   271  	})
   272  }
   273  
   274  func TestOpenWithPasswordAndMSI(t *testing.T) {
   275  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   276  		SkipIfUnsupportedArch(t, c)
   277  		ip, port, err := c.Port(defaultPort)
   278  		if err != nil {
   279  			t.Fatal(err)
   280  		}
   281  
   282  		addr := msConnectionStringMsiWithPassword(ip, port, true)
   283  		p := &SQLServer{}
   284  		_, err = p.Open(addr)
   285  		if err == nil {
   286  			t.Fatal("Open should fail when both password and useMsi=true are passed.")
   287  		}
   288  
   289  		addr = msConnectionStringMsiWithPassword(ip, port, false)
   290  		p = &SQLServer{}
   291  		d, err := p.Open(addr)
   292  		if err != nil {
   293  			t.Fatal(err)
   294  		}
   295  
   296  		defer func() {
   297  			if err := d.Close(); err != nil {
   298  				t.Error(err)
   299  			}
   300  		}()
   301  
   302  		dt.Test(t, d, []byte("SELECT 1"))
   303  	})
   304  }
   305  
   306  func TestMsiFalse(t *testing.T) {
   307  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   308  		SkipIfUnsupportedArch(t, c)
   309  		ip, port, err := c.Port(defaultPort)
   310  		if err != nil {
   311  			t.Fatal(err)
   312  		}
   313  
   314  		addr := msConnectionStringMsi(ip, port, false)
   315  		p := &SQLServer{}
   316  		_, err = p.Open(addr)
   317  		if err == nil {
   318  			t.Fatal("Open should fail since no password was passed and useMsi is false.")
   319  		}
   320  	})
   321  }