github.com/opentofu/opentofu@v1.7.1/internal/backend/remote-state/pg/backend_test.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package pg
     7  
     8  // Create the test database: createdb terraform_backend_pg_test
     9  // TF_ACC=1 GO111MODULE=on go test -v -mod=vendor -timeout=2m -parallel=4 github.com/opentofu/opentofu/backend/remote-state/pg
    10  
    11  import (
    12  	"database/sql"
    13  	"fmt"
    14  	"net/url"
    15  	"os"
    16  	"strings"
    17  	"testing"
    18  
    19  	"github.com/hashicorp/hcl/v2/hcldec"
    20  	"github.com/lib/pq"
    21  	"github.com/opentofu/opentofu/internal/backend"
    22  	"github.com/opentofu/opentofu/internal/encryption"
    23  	"github.com/opentofu/opentofu/internal/states/remote"
    24  	"github.com/opentofu/opentofu/internal/states/statemgr"
    25  	"github.com/opentofu/opentofu/internal/tfdiags"
    26  )
    27  
    28  // Function to skip a test unless in ACCeptance test mode.
    29  //
    30  // A running Postgres server identified by env variable
    31  // DATABASE_URL is required for acceptance tests.
    32  func testACC(t *testing.T) (connectionURI *url.URL) {
    33  	skip := os.Getenv("TF_ACC") == "" && os.Getenv("TF_PG_TEST") == ""
    34  	if skip {
    35  		t.Log("pg backend tests requires setting TF_ACC or TF_PG_TEST")
    36  		t.Skip()
    37  	}
    38  	databaseUrl, found := os.LookupEnv("DATABASE_URL")
    39  	if !found {
    40  		t.Fatal("pg backend tests require setting DATABASE_URL")
    41  	}
    42  
    43  	u, err := url.Parse(databaseUrl)
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	return u
    48  }
    49  
    50  func TestBackend_impl(t *testing.T) {
    51  	var _ backend.Backend = new(Backend)
    52  }
    53  
    54  func TestBackendConfig(t *testing.T) {
    55  	connectionURI := testACC(t)
    56  	connStr := os.Getenv("DATABASE_URL")
    57  
    58  	user := connectionURI.User.Username()
    59  	password, _ := connectionURI.User.Password()
    60  	databaseName := connectionURI.Path[1:]
    61  
    62  	connectionURIObfuscated := connectionURI
    63  	connectionURIObfuscated.User = nil
    64  
    65  	testCases := []struct {
    66  		Name                     string
    67  		EnvVars                  map[string]string
    68  		Config                   map[string]interface{}
    69  		ExpectConfigurationError string
    70  		ExpectConnectionError    string
    71  	}{
    72  		{
    73  			Name: "valid-config",
    74  			Config: map[string]interface{}{
    75  				"conn_str":    connStr,
    76  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
    77  			},
    78  		},
    79  		{
    80  			Name: "missing-conn_str-defaults-to-localhost",
    81  			EnvVars: map[string]string{
    82  				"PGSSLMODE":  "disable",
    83  				"PGDATABASE": databaseName,
    84  				"PGUSER":     user,
    85  				"PGPASSWORD": password,
    86  			},
    87  			Config: map[string]interface{}{
    88  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
    89  			},
    90  		},
    91  		{
    92  			Name: "conn-str-env-var",
    93  			EnvVars: map[string]string{
    94  				"PG_CONN_STR": connStr,
    95  			},
    96  			Config: map[string]interface{}{
    97  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
    98  			},
    99  		},
   100  		{
   101  			Name: "setting-credentials-using-env-vars",
   102  			EnvVars: map[string]string{
   103  				"PGUSER":     "baduser",
   104  				"PGPASSWORD": "badpassword",
   105  			},
   106  			Config: map[string]interface{}{
   107  				"conn_str":    connectionURIObfuscated.String(),
   108  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
   109  			},
   110  			ExpectConnectionError: `authentication failed for user "baduser"`,
   111  		},
   112  		{
   113  			Name: "host-in-env-vars",
   114  			EnvVars: map[string]string{
   115  				"PGHOST": "hostthatdoesnotexist",
   116  			},
   117  			Config: map[string]interface{}{
   118  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
   119  			},
   120  			ExpectConnectionError: `no such host`,
   121  		},
   122  		{
   123  			Name: "boolean-env-vars",
   124  			EnvVars: map[string]string{
   125  				"PGSSLMODE":               "disable",
   126  				"PG_SKIP_SCHEMA_CREATION": "f",
   127  				"PG_SKIP_TABLE_CREATION":  "f",
   128  				"PG_SKIP_INDEX_CREATION":  "f",
   129  				"PGDATABASE":              databaseName,
   130  			},
   131  			Config: map[string]interface{}{
   132  				"conn_str":    connStr,
   133  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
   134  			},
   135  		},
   136  		{
   137  			Name: "wrong-boolean-env-vars",
   138  			EnvVars: map[string]string{
   139  				"PGSSLMODE":               "disable",
   140  				"PG_SKIP_SCHEMA_CREATION": "foo",
   141  				"PGDATABASE":              databaseName,
   142  			},
   143  			Config: map[string]interface{}{
   144  				"schema_name": fmt.Sprintf("terraform_%s", t.Name()),
   145  			},
   146  			ExpectConfigurationError: `error getting default for "skip_schema_creation"`,
   147  		},
   148  	}
   149  
   150  	for _, tc := range testCases {
   151  		t.Run(tc.Name, func(t *testing.T) {
   152  			for k, v := range tc.EnvVars {
   153  				t.Setenv(k, v)
   154  			}
   155  
   156  			config := backend.TestWrapConfig(tc.Config)
   157  			schemaName := pq.QuoteIdentifier(tc.Config["schema_name"].(string))
   158  
   159  			dbCleaner, err := sql.Open("postgres", connStr)
   160  			if err != nil {
   161  				t.Fatal(err)
   162  			}
   163  			defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
   164  
   165  			var diags tfdiags.Diagnostics
   166  			b := New(encryption.StateEncryptionDisabled()).(*Backend)
   167  			schema := b.ConfigSchema()
   168  			spec := schema.DecoderSpec()
   169  			obj, decDiags := hcldec.Decode(config, spec, nil)
   170  			diags = diags.Append(decDiags)
   171  
   172  			newObj, valDiags := b.PrepareConfig(obj)
   173  			diags = diags.Append(valDiags.InConfigBody(config, ""))
   174  
   175  			if tc.ExpectConfigurationError != "" {
   176  				if !diags.HasErrors() {
   177  					t.Fatal("error expected but got none")
   178  				}
   179  				if !strings.Contains(diags.ErrWithWarnings().Error(), tc.ExpectConfigurationError) {
   180  					t.Fatalf("failed to find %q in %s", tc.ExpectConfigurationError, diags.ErrWithWarnings())
   181  				}
   182  				return
   183  			} else if diags.HasErrors() {
   184  				t.Fatal(diags.ErrWithWarnings())
   185  			}
   186  
   187  			obj = newObj
   188  
   189  			confDiags := b.Configure(obj)
   190  			if tc.ExpectConnectionError != "" {
   191  				err := confDiags.InConfigBody(config, "").ErrWithWarnings()
   192  				if err == nil {
   193  					t.Fatal("error expected but got none")
   194  				}
   195  				if !strings.Contains(err.Error(), tc.ExpectConnectionError) {
   196  					t.Fatalf("failed to find %q in %s", tc.ExpectConnectionError, err)
   197  				}
   198  				return
   199  			} else if len(confDiags) != 0 {
   200  				confDiags = confDiags.InConfigBody(config, "")
   201  				t.Fatal(confDiags.ErrWithWarnings())
   202  			}
   203  
   204  			if b == nil {
   205  				t.Fatal("Backend could not be configured")
   206  			}
   207  
   208  			_, err = b.db.Query(fmt.Sprintf("SELECT name, data FROM %s.%s LIMIT 1", schemaName, statesTableName))
   209  			if err != nil {
   210  				t.Fatal(err)
   211  			}
   212  
   213  			_, err = b.StateMgr(backend.DefaultStateName)
   214  			if err != nil {
   215  				t.Fatal(err)
   216  			}
   217  
   218  			s, err := b.StateMgr(backend.DefaultStateName)
   219  			if err != nil {
   220  				t.Fatal(err)
   221  			}
   222  			c := s.(*remote.State).Client.(*RemoteClient)
   223  			if c.Name != backend.DefaultStateName {
   224  				t.Fatal("RemoteClient name is not configured")
   225  			}
   226  
   227  			backend.TestBackendStates(t, b)
   228  		})
   229  	}
   230  
   231  }
   232  
   233  func TestBackendConfigSkipOptions(t *testing.T) {
   234  	testACC(t)
   235  	connStr := getDatabaseUrl()
   236  
   237  	testCases := []struct {
   238  		Name               string
   239  		SkipSchemaCreation bool
   240  		SkipTableCreation  bool
   241  		SkipIndexCreation  bool
   242  		TestIndexIsPresent bool
   243  		Setup              func(t *testing.T, db *sql.DB, schemaName string)
   244  	}{
   245  		{
   246  			Name:               "skip_schema_creation",
   247  			SkipSchemaCreation: true,
   248  			TestIndexIsPresent: true,
   249  			Setup: func(t *testing.T, db *sql.DB, schemaName string) {
   250  				// create the schema as a prerequisites
   251  				_, err := db.Query(fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaName))
   252  				if err != nil {
   253  					t.Fatal(err)
   254  				}
   255  			},
   256  		},
   257  		{
   258  			Name:               "skip_table_creation",
   259  			SkipTableCreation:  true,
   260  			TestIndexIsPresent: true,
   261  			Setup: func(t *testing.T, db *sql.DB, schemaName string) {
   262  				// since the table needs to be already created the schema must be too
   263  				_, err := db.Query(fmt.Sprintf(`CREATE SCHEMA %s`, schemaName))
   264  				if err != nil {
   265  					t.Fatal(err)
   266  				}
   267  				_, err = db.Query(fmt.Sprintf(`CREATE TABLE %s.%s (
   268  					id SERIAL PRIMARY KEY,
   269  					name TEXT,
   270  					data TEXT
   271  					)`, schemaName, statesTableName))
   272  				if err != nil {
   273  					t.Fatal(err)
   274  				}
   275  			},
   276  		},
   277  		{
   278  			Name:               "skip_index_creation",
   279  			SkipIndexCreation:  true,
   280  			TestIndexIsPresent: true,
   281  			Setup: func(t *testing.T, db *sql.DB, schemaName string) {
   282  				// Everything need to exists for the index to be created
   283  				_, err := db.Query(fmt.Sprintf(`CREATE SCHEMA %s`, schemaName))
   284  				if err != nil {
   285  					t.Fatal(err)
   286  				}
   287  				_, err = db.Query(fmt.Sprintf(`CREATE TABLE %s.%s (
   288  					id SERIAL PRIMARY KEY,
   289  					name TEXT,
   290  					data TEXT
   291  					)`, schemaName, statesTableName))
   292  				if err != nil {
   293  					t.Fatal(err)
   294  				}
   295  				_, err = db.Exec(fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s.%s (name)`, statesIndexName, schemaName, statesTableName))
   296  				if err != nil {
   297  					t.Fatal(err)
   298  				}
   299  			},
   300  		},
   301  		{
   302  			Name:              "missing_index",
   303  			SkipIndexCreation: true,
   304  		},
   305  	}
   306  
   307  	for _, tc := range testCases {
   308  		t.Run(tc.Name, func(t *testing.T) {
   309  			schemaName := tc.Name
   310  
   311  			config := backend.TestWrapConfig(map[string]interface{}{
   312  				"conn_str":             connStr,
   313  				"schema_name":          schemaName,
   314  				"skip_schema_creation": tc.SkipSchemaCreation,
   315  				"skip_table_creation":  tc.SkipTableCreation,
   316  				"skip_index_creation":  tc.SkipIndexCreation,
   317  			})
   318  			schemaName = pq.QuoteIdentifier(schemaName)
   319  			db, err := sql.Open("postgres", connStr)
   320  			if err != nil {
   321  				t.Fatal(err)
   322  			}
   323  
   324  			if tc.Setup != nil {
   325  				tc.Setup(t, db, schemaName)
   326  			}
   327  			defer db.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
   328  
   329  			b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), config).(*Backend)
   330  
   331  			if b == nil {
   332  				t.Fatal("Backend could not be configured")
   333  			}
   334  
   335  			// Make sure everything has been created
   336  
   337  			// This tests that both the schema and the table have been created
   338  			_, err = b.db.Query(fmt.Sprintf("SELECT name, data FROM %s.%s LIMIT 1", schemaName, statesTableName))
   339  			if err != nil {
   340  				t.Fatal(err)
   341  			}
   342  			if tc.TestIndexIsPresent {
   343  				// Make sure that the index exists
   344  				query := `select count(*) from pg_indexes where schemaname=$1 and tablename=$2 and indexname=$3;`
   345  				var count int
   346  				if err := b.db.QueryRow(query, tc.Name, statesTableName, statesIndexName).Scan(&count); err != nil {
   347  					t.Fatal(err)
   348  				}
   349  				if count != 1 {
   350  					t.Fatalf("The index has not been created (%d)", count)
   351  				}
   352  			}
   353  
   354  			_, err = b.StateMgr(backend.DefaultStateName)
   355  			if err != nil {
   356  				t.Fatal(err)
   357  			}
   358  
   359  			s, err := b.StateMgr(backend.DefaultStateName)
   360  			if err != nil {
   361  				t.Fatal(err)
   362  			}
   363  			c := s.(*remote.State).Client.(*RemoteClient)
   364  			if c.Name != backend.DefaultStateName {
   365  				t.Fatal("RemoteClient name is not configured")
   366  			}
   367  
   368  			// Make sure that all workspace must have a unique name
   369  			_, err = db.Exec(fmt.Sprintf(`INSERT INTO %s.%s VALUES (100, 'unique_name_test', '')`, schemaName, statesTableName))
   370  			if err != nil {
   371  				t.Fatal(err)
   372  			}
   373  			_, err = db.Exec(fmt.Sprintf(`INSERT INTO %s.%s VALUES (101, 'unique_name_test', '')`, schemaName, statesTableName))
   374  			if err == nil {
   375  				t.Fatal("Creating two workspaces with the same name did not raise an error")
   376  			}
   377  		})
   378  	}
   379  
   380  }
   381  
   382  func TestBackendStates(t *testing.T) {
   383  	testACC(t)
   384  	connStr := getDatabaseUrl()
   385  
   386  	testCases := []string{
   387  		fmt.Sprintf("terraform_%s", t.Name()),
   388  		fmt.Sprintf("test with spaces: %s", t.Name()),
   389  	}
   390  	for _, schemaName := range testCases {
   391  		t.Run(schemaName, func(t *testing.T) {
   392  			dbCleaner, err := sql.Open("postgres", connStr)
   393  			if err != nil {
   394  				t.Fatal(err)
   395  			}
   396  			defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", pq.QuoteIdentifier(schemaName)))
   397  
   398  			config := backend.TestWrapConfig(map[string]interface{}{
   399  				"conn_str":    connStr,
   400  				"schema_name": schemaName,
   401  			})
   402  			b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), config).(*Backend)
   403  
   404  			if b == nil {
   405  				t.Fatal("Backend could not be configured")
   406  			}
   407  
   408  			backend.TestBackendStates(t, b)
   409  		})
   410  	}
   411  }
   412  
   413  func TestBackendStateLocks(t *testing.T) {
   414  	testACC(t)
   415  	connStr := getDatabaseUrl()
   416  	schemaName := fmt.Sprintf("terraform_%s", t.Name())
   417  	dbCleaner, err := sql.Open("postgres", connStr)
   418  	if err != nil {
   419  		t.Fatal(err)
   420  	}
   421  	defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
   422  
   423  	config := backend.TestWrapConfig(map[string]interface{}{
   424  		"conn_str":    connStr,
   425  		"schema_name": schemaName,
   426  	})
   427  	b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), config).(*Backend)
   428  
   429  	if b == nil {
   430  		t.Fatal("Backend could not be configured")
   431  	}
   432  
   433  	bb := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), config).(*Backend)
   434  
   435  	if bb == nil {
   436  		t.Fatal("Backend could not be configured")
   437  	}
   438  
   439  	backend.TestBackendStateLocks(t, b, bb)
   440  }
   441  
   442  func TestBackendConcurrentLock(t *testing.T) {
   443  	testACC(t)
   444  	connStr := getDatabaseUrl()
   445  	dbCleaner, err := sql.Open("postgres", connStr)
   446  	if err != nil {
   447  		t.Fatal(err)
   448  	}
   449  
   450  	getStateMgr := func(schemaName string) (statemgr.Full, *statemgr.LockInfo) {
   451  		defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
   452  		config := backend.TestWrapConfig(map[string]interface{}{
   453  			"conn_str":    connStr,
   454  			"schema_name": schemaName,
   455  		})
   456  		b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), config).(*Backend)
   457  
   458  		if b == nil {
   459  			t.Fatal("Backend could not be configured")
   460  		}
   461  		stateMgr, err := b.StateMgr(backend.DefaultStateName)
   462  		if err != nil {
   463  			t.Fatalf("Failed to get the state manager: %v", err)
   464  		}
   465  
   466  		info := statemgr.NewLockInfo()
   467  		info.Operation = "test"
   468  		info.Who = schemaName
   469  
   470  		return stateMgr, info
   471  	}
   472  
   473  	s1, i1 := getStateMgr(fmt.Sprintf("terraform_%s_1", t.Name()))
   474  	s2, i2 := getStateMgr(fmt.Sprintf("terraform_%s_2", t.Name()))
   475  
   476  	// First we need to create the workspace as the lock for creating them is
   477  	// global
   478  	lockID1, err := s1.Lock(i1)
   479  	if err != nil {
   480  		t.Fatalf("failed to lock first state: %v", err)
   481  	}
   482  
   483  	if err = s1.PersistState(nil); err != nil {
   484  		t.Fatalf("failed to persist state: %v", err)
   485  	}
   486  
   487  	if err := s1.Unlock(lockID1); err != nil {
   488  		t.Fatalf("failed to unlock first state: %v", err)
   489  	}
   490  
   491  	lockID2, err := s2.Lock(i2)
   492  	if err != nil {
   493  		t.Fatalf("failed to lock second state: %v", err)
   494  	}
   495  
   496  	if err = s2.PersistState(nil); err != nil {
   497  		t.Fatalf("failed to persist state: %v", err)
   498  	}
   499  
   500  	if err := s2.Unlock(lockID2); err != nil {
   501  		t.Fatalf("failed to unlock first state: %v", err)
   502  	}
   503  
   504  	// Now we can test concurrent lock
   505  	lockID1, err = s1.Lock(i1)
   506  	if err != nil {
   507  		t.Fatalf("failed to lock first state: %v", err)
   508  	}
   509  
   510  	lockID2, err = s2.Lock(i2)
   511  	if err != nil {
   512  		t.Fatalf("failed to lock second state: %v", err)
   513  	}
   514  
   515  	if err := s1.Unlock(lockID1); err != nil {
   516  		t.Fatalf("failed to unlock first state: %v", err)
   517  	}
   518  
   519  	if err := s2.Unlock(lockID2); err != nil {
   520  		t.Fatalf("failed to unlock first state: %v", err)
   521  	}
   522  }
   523  
   524  func getDatabaseUrl() string {
   525  	return os.Getenv("DATABASE_URL")
   526  }