github.com/meetsoni15/go-migrate/v4@v4.15.3-0.20221220054613-2c40bd0c4ee9/database/sqlserver/sqlserver_test.go (about)

     1  package sqlserver
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	sqldriver "database/sql/driver"
     7  	"fmt"
     8  	"log"
     9  	"runtime"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/dhui/dktest"
    15  	"github.com/docker/go-connections/nat"
    16  	"github.com/meetsoni15/go-migrate/v4"
    17  
    18  	dt "github.com/meetsoni15/go-migrate/v4/database/testing"
    19  	"github.com/meetsoni15/go-migrate/v4/dktesting"
    20  
    21  	_ "github.com/meetsoni15/go-migrate/v4/source/file"
    22  )
    23  
    24  const defaultPort = 1433
    25  const saPassword = "Root1234"
    26  
    27  var (
    28  	sqlEdgeOpts = dktest.Options{
    29  		Env: map[string]string{"ACCEPT_EULA": "Y", "MSSQL_SA_PASSWORD": saPassword},
    30  		PortBindings: map[nat.Port][]nat.PortBinding{
    31  			nat.Port(fmt.Sprintf("%d/tcp", defaultPort)): {
    32  				nat.PortBinding{
    33  					HostIP:   "0.0.0.0",
    34  					HostPort: "0/tcp",
    35  				},
    36  			},
    37  		},
    38  		PortRequired: true, ReadyFunc: isReady, PullTimeout: 2 * time.Minute,
    39  	}
    40  	sqlServerOpts = dktest.Options{
    41  		Env:          map[string]string{"ACCEPT_EULA": "Y", "MSSQL_SA_PASSWORD": saPassword, "MSSQL_PID": "Express"},
    42  		PortRequired: true, ReadyFunc: isReady, PullTimeout: 2 * time.Minute,
    43  	}
    44  	// Container versions: https://mcr.microsoft.com/v2/mssql/server/tags/list
    45  	specs = []dktesting.ContainerSpec{
    46  		{ImageName: "mcr.microsoft.com/azure-sql-edge:latest", Options: sqlEdgeOpts},
    47  		{ImageName: "mcr.microsoft.com/mssql/server:2017-latest", Options: sqlServerOpts},
    48  		{ImageName: "mcr.microsoft.com/mssql/server:2019-latest", Options: sqlServerOpts},
    49  	}
    50  )
    51  
    52  func msConnectionString(host, port string) string {
    53  	return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port)
    54  }
    55  
    56  func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string {
    57  	return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master&useMsi=%t", saPassword, host, port, useMsi)
    58  }
    59  
    60  func msConnectionStringMsi(host, port string, useMsi bool) string {
    61  	return fmt.Sprintf("sqlserver://sa@%v:%v?database=master&useMsi=%t", host, port, useMsi)
    62  }
    63  
    64  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    65  	ip, port, err := c.Port(defaultPort)
    66  	if err != nil {
    67  		return false
    68  	}
    69  	uri := msConnectionString(ip, port)
    70  	db, err := sql.Open("sqlserver", uri)
    71  	if err != nil {
    72  		return false
    73  	}
    74  	defer func() {
    75  		if err := db.Close(); err != nil {
    76  			log.Println("close error:", err)
    77  		}
    78  	}()
    79  	if err = db.PingContext(ctx); err != nil {
    80  		switch err {
    81  		case sqldriver.ErrBadConn:
    82  			return false
    83  		default:
    84  			fmt.Println(err)
    85  		}
    86  		return false
    87  	}
    88  
    89  	return true
    90  }
    91  
    92  func SkipIfUnsupportedArch(t *testing.T, c dktest.ContainerInfo) {
    93  	if strings.Contains(c.ImageName, "mssql") && !strings.HasPrefix(runtime.GOARCH, "amd") {
    94  		t.Skip(fmt.Sprintf("Image %s is not supported on arch %s", c.ImageName, runtime.GOARCH))
    95  	}
    96  }
    97  
    98  func Test(t *testing.T) {
    99  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   100  		SkipIfUnsupportedArch(t, c)
   101  		ip, port, err := c.Port(defaultPort)
   102  		if err != nil {
   103  			t.Fatal(err)
   104  		}
   105  
   106  		addr := msConnectionString(ip, port)
   107  		p := &SQLServer{}
   108  		d, err := p.Open(addr)
   109  		if err != nil {
   110  			t.Fatalf("%v", err)
   111  		}
   112  
   113  		defer func() {
   114  			if err := d.Close(); err != nil {
   115  				t.Error(err)
   116  			}
   117  		}()
   118  
   119  		dt.Test(t, d, []byte("SELECT 1"))
   120  	})
   121  }
   122  
   123  func TestMigrate(t *testing.T) {
   124  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   125  		SkipIfUnsupportedArch(t, c)
   126  		ip, port, err := c.Port(defaultPort)
   127  		if err != nil {
   128  			t.Fatal(err)
   129  		}
   130  
   131  		addr := msConnectionString(ip, port)
   132  		p := &SQLServer{}
   133  		d, err := p.Open(addr)
   134  		if err != nil {
   135  			t.Fatalf("%v", err)
   136  		}
   137  
   138  		defer func() {
   139  			if err := d.Close(); err != nil {
   140  				t.Error(err)
   141  			}
   142  		}()
   143  
   144  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "master", d)
   145  		if err != nil {
   146  			t.Fatal(err)
   147  		}
   148  		dt.TestMigrate(t, m)
   149  	})
   150  }
   151  
   152  func TestMultiStatement(t *testing.T) {
   153  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   154  		SkipIfUnsupportedArch(t, c)
   155  		ip, port, err := c.Port(defaultPort)
   156  		if err != nil {
   157  			t.Fatal(err)
   158  		}
   159  
   160  		addr := msConnectionString(ip, port)
   161  		ms := &SQLServer{}
   162  		d, err := ms.Open(addr)
   163  		if err != nil {
   164  			t.Fatal(err)
   165  		}
   166  		defer func() {
   167  			if err := d.Close(); err != nil {
   168  				t.Error(err)
   169  			}
   170  		}()
   171  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
   172  			t.Fatalf("expected err to be nil, got %v", err)
   173  		}
   174  
   175  		// make sure second table exists
   176  		var exists int
   177  		if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil {
   178  			t.Fatal(err)
   179  		}
   180  		if exists != 1 {
   181  			t.Fatalf("expected table bar to exist")
   182  		}
   183  	})
   184  }
   185  
   186  func TestBatchedStatement(t *testing.T) {
   187  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   188  		ip, port, err := c.FirstPort()
   189  		if err != nil {
   190  			t.Fatal(err)
   191  		}
   192  
   193  		addr := msConnectionString(ip, port)
   194  		ms := &SQLServer{}
   195  		d, err := ms.Open(addr)
   196  		if err != nil {
   197  			t.Fatal(err)
   198  		}
   199  		defer func() {
   200  			if err := d.Close(); err != nil {
   201  				t.Error(err)
   202  			}
   203  		}()
   204  		if err := d.Run(strings.NewReader(`CREATE PROCEDURE uspA
   205  AS
   206  BEGIN
   207      SELECT 1;
   208  END;
   209  GO
   210  CREATE PROCEDURE uspB
   211  AS
   212  BEGIN
   213      SELECT 2;
   214  END`)); err != nil {
   215  			t.Fatalf("expected err to be nil, got %v", err)
   216  		}
   217  
   218  		// make sure second proc exists
   219  		var exists int
   220  		if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "Select COUNT(1) from sysobjects where type = 'P' and category = 0 and [NAME] = 'uspB'").Scan(&exists); err != nil {
   221  			t.Fatal(err)
   222  		}
   223  		if exists != 1 {
   224  			t.Fatalf("expected proc uspB to exist")
   225  		}
   226  	})
   227  }
   228  
   229  func TestErrorParsing(t *testing.T) {
   230  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   231  		SkipIfUnsupportedArch(t, c)
   232  		ip, port, err := c.Port(defaultPort)
   233  		if err != nil {
   234  			t.Fatal(err)
   235  		}
   236  
   237  		addr := msConnectionString(ip, port)
   238  
   239  		p := &SQLServer{}
   240  		d, err := p.Open(addr)
   241  		if err != nil {
   242  			t.Fatal(err)
   243  		}
   244  		defer func() {
   245  			if err := d.Close(); err != nil {
   246  				t.Error(err)
   247  			}
   248  		}()
   249  
   250  		wantErr := `migration failed: Unknown object type 'TABLEE' used in a CREATE, DROP, or ALTER statement. in line 1:` +
   251  			` CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text); (details: mssql: Unknown object type ` +
   252  			`'TABLEE' used in a CREATE, DROP, or ALTER statement.)`
   253  		if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
   254  			t.Fatal("expected err but got nil")
   255  		} else if err.Error() != wantErr {
   256  			t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
   257  		}
   258  	})
   259  }
   260  
   261  func TestLockWorks(t *testing.T) {
   262  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   263  		SkipIfUnsupportedArch(t, c)
   264  		ip, port, err := c.Port(defaultPort)
   265  		if err != nil {
   266  			t.Fatal(err)
   267  		}
   268  
   269  		addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port)
   270  		p := &SQLServer{}
   271  		d, err := p.Open(addr)
   272  		if err != nil {
   273  			t.Fatalf("%v", err)
   274  		}
   275  		dt.Test(t, d, []byte("SELECT 1"))
   276  
   277  		ms := d.(*SQLServer)
   278  
   279  		err = ms.Lock()
   280  		if err != nil {
   281  			t.Fatal(err)
   282  		}
   283  		err = ms.Unlock()
   284  		if err != nil {
   285  			t.Fatal(err)
   286  		}
   287  
   288  		// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
   289  		err = ms.Lock()
   290  		if err != nil {
   291  			t.Fatal(err)
   292  		}
   293  		err = ms.Unlock()
   294  		if err != nil {
   295  			t.Fatal(err)
   296  		}
   297  	})
   298  }
   299  
   300  func TestMsiTrue(t *testing.T) {
   301  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   302  		SkipIfUnsupportedArch(t, c)
   303  		ip, port, err := c.Port(defaultPort)
   304  		if err != nil {
   305  			t.Fatal(err)
   306  		}
   307  
   308  		addr := msConnectionStringMsi(ip, port, true)
   309  		p := &SQLServer{}
   310  		_, err = p.Open(addr)
   311  		if err == nil {
   312  			t.Fatal("MSI should fail when not running in an Azure context.")
   313  		}
   314  	})
   315  }
   316  
   317  func TestOpenWithPasswordAndMSI(t *testing.T) {
   318  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   319  		SkipIfUnsupportedArch(t, c)
   320  		ip, port, err := c.Port(defaultPort)
   321  		if err != nil {
   322  			t.Fatal(err)
   323  		}
   324  
   325  		addr := msConnectionStringMsiWithPassword(ip, port, true)
   326  		p := &SQLServer{}
   327  		_, err = p.Open(addr)
   328  		if err == nil {
   329  			t.Fatal("Open should fail when both password and useMsi=true are passed.")
   330  		}
   331  
   332  		addr = msConnectionStringMsiWithPassword(ip, port, false)
   333  		p = &SQLServer{}
   334  		d, err := p.Open(addr)
   335  		if err != nil {
   336  			t.Fatal(err)
   337  		}
   338  
   339  		defer func() {
   340  			if err := d.Close(); err != nil {
   341  				t.Error(err)
   342  			}
   343  		}()
   344  
   345  		dt.Test(t, d, []byte("SELECT 1"))
   346  	})
   347  }
   348  
   349  func TestMsiFalse(t *testing.T) {
   350  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   351  		SkipIfUnsupportedArch(t, c)
   352  		ip, port, err := c.Port(defaultPort)
   353  		if err != nil {
   354  			t.Fatal(err)
   355  		}
   356  
   357  		addr := msConnectionStringMsi(ip, port, false)
   358  		p := &SQLServer{}
   359  		_, err = p.Open(addr)
   360  		if err == nil {
   361  			t.Fatal("Open should fail since no password was passed and useMsi is false.")
   362  		}
   363  	})
   364  }