github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/backend/remote-state/pg/backend_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package pg
     5  
     6  // Create the test database: createdb terraform_backend_pg_test
     7  // TF_ACC=1 GO111MODULE=on go test -v -mod=vendor -timeout=2m -parallel=4 github.com/terramate-io/tf/backend/remote-state/pg
     8  
     9  import (
    10  	"database/sql"
    11  	"fmt"
    12  	"net/url"
    13  	"os"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/hashicorp/hcl/v2/hcldec"
    18  	"github.com/lib/pq"
    19  	"github.com/terramate-io/tf/backend"
    20  	"github.com/terramate-io/tf/states/remote"
    21  	"github.com/terramate-io/tf/states/statemgr"
    22  	"github.com/terramate-io/tf/tfdiags"
    23  )
    24  
    25  // Function to skip a test unless in ACCeptance test mode.
    26  //
    27  // A running Postgres server identified by env variable
    28  // DATABASE_URL is required for acceptance tests.
    29  func testACC(t *testing.T) string {
    30  	skip := os.Getenv("TF_ACC") == ""
    31  	if skip {
    32  		t.Log("pg backend tests require setting TF_ACC")
    33  		t.Skip()
    34  	}
    35  	databaseUrl, found := os.LookupEnv("DATABASE_URL")
    36  	if !found {
    37  		databaseUrl = "postgres://localhost/terraform_backend_pg_test?sslmode=disable"
    38  		os.Setenv("DATABASE_URL", databaseUrl)
    39  	}
    40  	u, err := url.Parse(databaseUrl)
    41  	if err != nil {
    42  		t.Fatal(err)
    43  	}
    44  	return u.Path[1:]
    45  }
    46  
    47  func TestBackend_impl(t *testing.T) {
    48  	var _ backend.Backend = new(Backend)
    49  }
    50  
    51  func TestBackendConfig(t *testing.T) {
    52  	databaseName := testACC(t)
    53  	connStr := getDatabaseUrl()
    54  
    55  	testCases := []struct {
    56  		Name                     string
    57  		EnvVars                  map[string]string
    58  		Config                   map[string]interface{}
    59  		ExpectConfigurationError string
    60  		ExpectConnectionError    string
    61  	}{
    62  		{
    63  			Name: "valid-config",
    64  			Config: map[string]interface{}{
    65  				"conn_str":    connStr,
    66  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
    67  			},
    68  		},
    69  		{
    70  			Name: "missing-conn_str-defaults-to-localhost",
    71  			EnvVars: map[string]string{
    72  				"PGSSLMODE":  "disable",
    73  				"PGDATABASE": databaseName,
    74  			},
    75  			Config: map[string]interface{}{
    76  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
    77  			},
    78  		},
    79  		{
    80  			Name: "conn-str-env-var",
    81  			EnvVars: map[string]string{
    82  				"PG_CONN_STR": connStr,
    83  			},
    84  			Config: map[string]interface{}{
    85  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
    86  			},
    87  		},
    88  		{
    89  			Name: "setting-credentials-using-env-vars",
    90  			EnvVars: map[string]string{
    91  				"PGUSER":     "baduser",
    92  				"PGPASSWORD": "badpassword",
    93  			},
    94  			Config: map[string]interface{}{
    95  				"conn_str":    connStr,
    96  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
    97  			},
    98  			ExpectConnectionError: `role "baduser" does not exist`,
    99  		},
   100  		{
   101  			Name: "host-in-env-vars",
   102  			EnvVars: map[string]string{
   103  				"PGHOST": "hostthatdoesnotexist",
   104  			},
   105  			Config: map[string]interface{}{
   106  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
   107  			},
   108  			ExpectConnectionError: `no such host`,
   109  		},
   110  		{
   111  			Name: "boolean-env-vars",
   112  			EnvVars: map[string]string{
   113  				"PGSSLMODE":               "disable",
   114  				"PG_SKIP_SCHEMA_CREATION": "f",
   115  				"PG_SKIP_TABLE_CREATION":  "f",
   116  				"PG_SKIP_INDEX_CREATION":  "f",
   117  				"PGDATABASE":              databaseName,
   118  			},
   119  			Config: map[string]interface{}{
   120  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
   121  			},
   122  		},
   123  		{
   124  			Name: "wrong-boolean-env-vars",
   125  			EnvVars: map[string]string{
   126  				"PGSSLMODE":               "disable",
   127  				"PG_SKIP_SCHEMA_CREATION": "foo",
   128  				"PGDATABASE":              databaseName,
   129  			},
   130  			Config: map[string]interface{}{
   131  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
   132  			},
   133  			ExpectConfigurationError: `error getting default for "skip_schema_creation"`,
   134  		},
   135  	}
   136  
   137  	for _, tc := range testCases {
   138  		t.Run(tc.Name, func(t *testing.T) {
   139  			for k, v := range tc.EnvVars {
   140  				t.Setenv(k, v)
   141  			}
   142  
   143  			config := backend.TestWrapConfig(tc.Config)
   144  			schemaName := pq.QuoteIdentifier(tc.Config["schema_name"].(string))
   145  
   146  			dbCleaner, err := sql.Open("postgres", connStr)
   147  			if err != nil {
   148  				t.Fatal(err)
   149  			}
   150  			defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
   151  
   152  			var diags tfdiags.Diagnostics
   153  			b := New().(*Backend)
   154  			schema := b.ConfigSchema()
   155  			spec := schema.DecoderSpec()
   156  			obj, decDiags := hcldec.Decode(config, spec, nil)
   157  			diags = diags.Append(decDiags)
   158  
   159  			newObj, valDiags := b.PrepareConfig(obj)
   160  			diags = diags.Append(valDiags.InConfigBody(config, ""))
   161  
   162  			if tc.ExpectConfigurationError != "" {
   163  				if !diags.HasErrors() {
   164  					t.Fatal("error expected but got none")
   165  				}
   166  				if !strings.Contains(diags.ErrWithWarnings().Error(), tc.ExpectConfigurationError) {
   167  					t.Fatalf("failed to find %q in %s", tc.ExpectConfigurationError, diags.ErrWithWarnings())
   168  				}
   169  				return
   170  			} else if diags.HasErrors() {
   171  				t.Fatal(diags.ErrWithWarnings())
   172  			}
   173  
   174  			obj = newObj
   175  
   176  			confDiags := b.Configure(obj)
   177  			if tc.ExpectConnectionError != "" {
   178  				err := confDiags.InConfigBody(config, "").ErrWithWarnings()
   179  				if err == nil {
   180  					t.Fatal("error expected but got none")
   181  				}
   182  				if !strings.Contains(err.Error(), tc.ExpectConnectionError) {
   183  					t.Fatalf("failed to find %q in %s", tc.ExpectConnectionError, err)
   184  				}
   185  				return
   186  			} else if len(confDiags) != 0 {
   187  				confDiags = confDiags.InConfigBody(config, "")
   188  				t.Fatal(confDiags.ErrWithWarnings())
   189  			}
   190  
   191  			if b == nil {
   192  				t.Fatal("Backend could not be configured")
   193  			}
   194  
   195  			_, err = b.db.Query(fmt.Sprintf("SELECT name, data FROM %s.%s LIMIT 1", schemaName, statesTableName))
   196  			if err != nil {
   197  				t.Fatal(err)
   198  			}
   199  
   200  			_, err = b.StateMgr(backend.DefaultStateName)
   201  			if err != nil {
   202  				t.Fatal(err)
   203  			}
   204  
   205  			s, err := b.StateMgr(backend.DefaultStateName)
   206  			if err != nil {
   207  				t.Fatal(err)
   208  			}
   209  			c := s.(*remote.State).Client.(*RemoteClient)
   210  			if c.Name != backend.DefaultStateName {
   211  				t.Fatal("RemoteClient name is not configured")
   212  			}
   213  
   214  			backend.TestBackendStates(t, b)
   215  		})
   216  	}
   217  
   218  }
   219  
   220  func TestBackendConfigSkipOptions(t *testing.T) {
   221  	testACC(t)
   222  	connStr := getDatabaseUrl()
   223  
   224  	testCases := []struct {
   225  		Name               string
   226  		SkipSchemaCreation bool
   227  		SkipTableCreation  bool
   228  		SkipIndexCreation  bool
   229  		TestIndexIsPresent bool
   230  		Setup              func(t *testing.T, db *sql.DB, schemaName string)
   231  	}{
   232  		{
   233  			Name:               "skip_schema_creation",
   234  			SkipSchemaCreation: true,
   235  			TestIndexIsPresent: true,
   236  			Setup: func(t *testing.T, db *sql.DB, schemaName string) {
   237  				// create the schema as a prerequisites
   238  				_, err := db.Query(fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaName))
   239  				if err != nil {
   240  					t.Fatal(err)
   241  				}
   242  			},
   243  		},
   244  		{
   245  			Name:               "skip_table_creation",
   246  			SkipTableCreation:  true,
   247  			TestIndexIsPresent: true,
   248  			Setup: func(t *testing.T, db *sql.DB, schemaName string) {
   249  				// since the table needs to be already created the schema must be too
   250  				_, err := db.Query(fmt.Sprintf(`CREATE SCHEMA %s`, schemaName))
   251  				if err != nil {
   252  					t.Fatal(err)
   253  				}
   254  				_, err = db.Query(fmt.Sprintf(`CREATE TABLE %s.%s (
   255  					id SERIAL PRIMARY KEY,
   256  					name TEXT,
   257  					data TEXT
   258  					)`, schemaName, statesTableName))
   259  				if err != nil {
   260  					t.Fatal(err)
   261  				}
   262  			},
   263  		},
   264  		{
   265  			Name:               "skip_index_creation",
   266  			SkipIndexCreation:  true,
   267  			TestIndexIsPresent: true,
   268  			Setup: func(t *testing.T, db *sql.DB, schemaName string) {
   269  				// Everything need to exists for the index to be created
   270  				_, err := db.Query(fmt.Sprintf(`CREATE SCHEMA %s`, schemaName))
   271  				if err != nil {
   272  					t.Fatal(err)
   273  				}
   274  				_, err = db.Query(fmt.Sprintf(`CREATE TABLE %s.%s (
   275  					id SERIAL PRIMARY KEY,
   276  					name TEXT,
   277  					data TEXT
   278  					)`, schemaName, statesTableName))
   279  				if err != nil {
   280  					t.Fatal(err)
   281  				}
   282  				_, err = db.Exec(fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s.%s (name)`, statesIndexName, schemaName, statesTableName))
   283  				if err != nil {
   284  					t.Fatal(err)
   285  				}
   286  			},
   287  		},
   288  		{
   289  			Name:              "missing_index",
   290  			SkipIndexCreation: true,
   291  		},
   292  	}
   293  
   294  	for _, tc := range testCases {
   295  		t.Run(tc.Name, func(t *testing.T) {
   296  			schemaName := tc.Name
   297  
   298  			config := backend.TestWrapConfig(map[string]interface{}{
   299  				"conn_str":             connStr,
   300  				"schema_name":          schemaName,
   301  				"skip_schema_creation": tc.SkipSchemaCreation,
   302  				"skip_table_creation":  tc.SkipTableCreation,
   303  				"skip_index_creation":  tc.SkipIndexCreation,
   304  			})
   305  			schemaName = pq.QuoteIdentifier(schemaName)
   306  			db, err := sql.Open("postgres", connStr)
   307  			if err != nil {
   308  				t.Fatal(err)
   309  			}
   310  
   311  			if tc.Setup != nil {
   312  				tc.Setup(t, db, schemaName)
   313  			}
   314  			defer db.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
   315  
   316  			b := backend.TestBackendConfig(t, New(), config).(*Backend)
   317  
   318  			if b == nil {
   319  				t.Fatal("Backend could not be configured")
   320  			}
   321  
   322  			// Make sure everything has been created
   323  
   324  			// This tests that both the schema and the table have been created
   325  			_, err = b.db.Query(fmt.Sprintf("SELECT name, data FROM %s.%s LIMIT 1", schemaName, statesTableName))
   326  			if err != nil {
   327  				t.Fatal(err)
   328  			}
   329  			if tc.TestIndexIsPresent {
   330  				// Make sure that the index exists
   331  				query := `select count(*) from pg_indexes where schemaname=$1 and tablename=$2 and indexname=$3;`
   332  				var count int
   333  				if err := b.db.QueryRow(query, tc.Name, statesTableName, statesIndexName).Scan(&count); err != nil {
   334  					t.Fatal(err)
   335  				}
   336  				if count != 1 {
   337  					t.Fatalf("The index has not been created (%d)", count)
   338  				}
   339  			}
   340  
   341  			_, err = b.StateMgr(backend.DefaultStateName)
   342  			if err != nil {
   343  				t.Fatal(err)
   344  			}
   345  
   346  			s, err := b.StateMgr(backend.DefaultStateName)
   347  			if err != nil {
   348  				t.Fatal(err)
   349  			}
   350  			c := s.(*remote.State).Client.(*RemoteClient)
   351  			if c.Name != backend.DefaultStateName {
   352  				t.Fatal("RemoteClient name is not configured")
   353  			}
   354  
   355  			// Make sure that all workspace must have a unique name
   356  			_, err = db.Exec(fmt.Sprintf(`INSERT INTO %s.%s VALUES (100, 'unique_name_test', '')`, schemaName, statesTableName))
   357  			if err != nil {
   358  				t.Fatal(err)
   359  			}
   360  			_, err = db.Exec(fmt.Sprintf(`INSERT INTO %s.%s VALUES (101, 'unique_name_test', '')`, schemaName, statesTableName))
   361  			if err == nil {
   362  				t.Fatal("Creating two workspaces with the same name did not raise an error")
   363  			}
   364  		})
   365  	}
   366  
   367  }
   368  
   369  func TestBackendStates(t *testing.T) {
   370  	testACC(t)
   371  	connStr := getDatabaseUrl()
   372  
   373  	testCases := []string{
   374  		fmt.Sprintf("terraform_%s", t.Name()),
   375  		fmt.Sprintf("test with spaces: %s", t.Name()),
   376  	}
   377  	for _, schemaName := range testCases {
   378  		t.Run(schemaName, func(t *testing.T) {
   379  			dbCleaner, err := sql.Open("postgres", connStr)
   380  			if err != nil {
   381  				t.Fatal(err)
   382  			}
   383  			defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", pq.QuoteIdentifier(schemaName)))
   384  
   385  			config := backend.TestWrapConfig(map[string]interface{}{
   386  				"conn_str":    connStr,
   387  				"schema_name": schemaName,
   388  			})
   389  			b := backend.TestBackendConfig(t, New(), config).(*Backend)
   390  
   391  			if b == nil {
   392  				t.Fatal("Backend could not be configured")
   393  			}
   394  
   395  			backend.TestBackendStates(t, b)
   396  		})
   397  	}
   398  }
   399  
   400  func TestBackendStateLocks(t *testing.T) {
   401  	testACC(t)
   402  	connStr := getDatabaseUrl()
   403  	schemaName := fmt.Sprintf("terraform_%s", t.Name())
   404  	dbCleaner, err := sql.Open("postgres", connStr)
   405  	if err != nil {
   406  		t.Fatal(err)
   407  	}
   408  	defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
   409  
   410  	config := backend.TestWrapConfig(map[string]interface{}{
   411  		"conn_str":    connStr,
   412  		"schema_name": schemaName,
   413  	})
   414  	b := backend.TestBackendConfig(t, New(), config).(*Backend)
   415  
   416  	if b == nil {
   417  		t.Fatal("Backend could not be configured")
   418  	}
   419  
   420  	bb := backend.TestBackendConfig(t, New(), config).(*Backend)
   421  
   422  	if bb == nil {
   423  		t.Fatal("Backend could not be configured")
   424  	}
   425  
   426  	backend.TestBackendStateLocks(t, b, bb)
   427  }
   428  
   429  func TestBackendConcurrentLock(t *testing.T) {
   430  	testACC(t)
   431  	connStr := getDatabaseUrl()
   432  	dbCleaner, err := sql.Open("postgres", connStr)
   433  	if err != nil {
   434  		t.Fatal(err)
   435  	}
   436  
   437  	getStateMgr := func(schemaName string) (statemgr.Full, *statemgr.LockInfo) {
   438  		defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
   439  		config := backend.TestWrapConfig(map[string]interface{}{
   440  			"conn_str":    connStr,
   441  			"schema_name": schemaName,
   442  		})
   443  		b := backend.TestBackendConfig(t, New(), config).(*Backend)
   444  
   445  		if b == nil {
   446  			t.Fatal("Backend could not be configured")
   447  		}
   448  		stateMgr, err := b.StateMgr(backend.DefaultStateName)
   449  		if err != nil {
   450  			t.Fatalf("Failed to get the state manager: %v", err)
   451  		}
   452  
   453  		info := statemgr.NewLockInfo()
   454  		info.Operation = "test"
   455  		info.Who = schemaName
   456  
   457  		return stateMgr, info
   458  	}
   459  
   460  	s1, i1 := getStateMgr(fmt.Sprintf("terraform_%s_1", t.Name()))
   461  	s2, i2 := getStateMgr(fmt.Sprintf("terraform_%s_2", t.Name()))
   462  
   463  	// First we need to create the workspace as the lock for creating them is
   464  	// global
   465  	lockID1, err := s1.Lock(i1)
   466  	if err != nil {
   467  		t.Fatalf("failed to lock first state: %v", err)
   468  	}
   469  
   470  	if err = s1.PersistState(nil); err != nil {
   471  		t.Fatalf("failed to persist state: %v", err)
   472  	}
   473  
   474  	if err := s1.Unlock(lockID1); err != nil {
   475  		t.Fatalf("failed to unlock first state: %v", err)
   476  	}
   477  
   478  	lockID2, err := s2.Lock(i2)
   479  	if err != nil {
   480  		t.Fatalf("failed to lock second state: %v", err)
   481  	}
   482  
   483  	if err = s2.PersistState(nil); err != nil {
   484  		t.Fatalf("failed to persist state: %v", err)
   485  	}
   486  
   487  	if err := s2.Unlock(lockID2); err != nil {
   488  		t.Fatalf("failed to unlock first state: %v", err)
   489  	}
   490  
   491  	// Now we can test concurrent lock
   492  	lockID1, err = s1.Lock(i1)
   493  	if err != nil {
   494  		t.Fatalf("failed to lock first state: %v", err)
   495  	}
   496  
   497  	lockID2, err = s2.Lock(i2)
   498  	if err != nil {
   499  		t.Fatalf("failed to lock second state: %v", err)
   500  	}
   501  
   502  	if err := s1.Unlock(lockID1); err != nil {
   503  		t.Fatalf("failed to unlock first state: %v", err)
   504  	}
   505  
   506  	if err := s2.Unlock(lockID2); err != nil {
   507  		t.Fatalf("failed to unlock first state: %v", err)
   508  	}
   509  }
   510  
   511  func getDatabaseUrl() string {
   512  	return os.Getenv("DATABASE_URL")
   513  }