github.com/rstandt/terraform@v0.12.32-0.20230710220336-b1063613405c/backend/atlas/state_client_test.go (about)

     1  package atlas
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/md5"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"encoding/json"
    10  	"errors"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/url"
    14  	"os"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/zclconf/go-cty/cty"
    19  
    20  	"github.com/hashicorp/terraform/backend"
    21  	"github.com/hashicorp/terraform/configs"
    22  	"github.com/hashicorp/terraform/helper/acctest"
    23  	"github.com/hashicorp/terraform/state/remote"
    24  	"github.com/hashicorp/terraform/terraform"
    25  )
    26  
    27  func testStateClient(t *testing.T, c map[string]string) remote.Client {
    28  	vals := make(map[string]cty.Value)
    29  	for k, s := range c {
    30  		vals[k] = cty.StringVal(s)
    31  	}
    32  	synthBody := configs.SynthBody("<test>", vals)
    33  
    34  	b := backend.TestBackendConfig(t, New(), synthBody)
    35  	raw, err := b.StateMgr(backend.DefaultStateName)
    36  	if err != nil {
    37  		t.Fatalf("err: %s", err)
    38  	}
    39  
    40  	s := raw.(*remote.State)
    41  	return s.Client
    42  }
    43  
    44  func TestStateClient_impl(t *testing.T) {
    45  	var _ remote.Client = new(stateClient)
    46  }
    47  
    48  func TestStateClient(t *testing.T) {
    49  	acctest.RemoteTestPrecheck(t)
    50  
    51  	token := os.Getenv("ATLAS_TOKEN")
    52  	if token == "" {
    53  		t.Skipf("skipping, ATLAS_TOKEN must be set")
    54  	}
    55  
    56  	client := testStateClient(t, map[string]string{
    57  		"access_token": token,
    58  		"name":         "hashicorp/test-remote-state",
    59  	})
    60  
    61  	remote.TestClient(t, client)
    62  }
    63  
    64  func TestStateClient_noRetryOnBadCerts(t *testing.T) {
    65  	acctest.RemoteTestPrecheck(t)
    66  
    67  	client := testStateClient(t, map[string]string{
    68  		"access_token": "NOT_REQUIRED",
    69  		"name":         "hashicorp/test-remote-state",
    70  	})
    71  
    72  	ac := client.(*stateClient)
    73  	// trigger the StateClient to build the http client and assign HTTPClient
    74  	httpClient, err := ac.http()
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  
    79  	// remove the CA certs from the client
    80  	brokenCfg := &tls.Config{
    81  		RootCAs: new(x509.CertPool),
    82  	}
    83  	httpClient.HTTPClient.Transport.(*http.Transport).TLSClientConfig = brokenCfg
    84  
    85  	// Instrument CheckRetry to make sure we didn't retry
    86  	retries := 0
    87  	oldCheck := httpClient.CheckRetry
    88  	httpClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
    89  		if retries > 0 {
    90  			t.Fatal("retried after certificate error")
    91  		}
    92  		retries++
    93  		return oldCheck(ctx, resp, err)
    94  	}
    95  
    96  	_, err = client.Get()
    97  	if err != nil {
    98  		if err, ok := err.(*url.Error); ok {
    99  			if _, ok := err.Err.(x509.UnknownAuthorityError); ok {
   100  				return
   101  			}
   102  		}
   103  	}
   104  
   105  	t.Fatalf("expected x509.UnknownAuthorityError, got %v", err)
   106  }
   107  
   108  func TestStateClient_ReportedConflictEqualStates(t *testing.T) {
   109  	fakeAtlas := newFakeAtlas(t, testStateModuleOrderChange)
   110  	srv := fakeAtlas.Server()
   111  	defer srv.Close()
   112  
   113  	client := testStateClient(t, map[string]string{
   114  		"access_token": "sometoken",
   115  		"name":         "someuser/some-test-remote-state",
   116  		"address":      srv.URL,
   117  	})
   118  
   119  	state, err := terraform.ReadState(bytes.NewReader(testStateModuleOrderChange))
   120  	if err != nil {
   121  		t.Fatalf("err: %s", err)
   122  	}
   123  
   124  	var stateJson bytes.Buffer
   125  	if err := terraform.WriteState(state, &stateJson); err != nil {
   126  		t.Fatalf("err: %s", err)
   127  	}
   128  	if err := client.Put(stateJson.Bytes()); err != nil {
   129  		t.Fatalf("err: %s", err)
   130  	}
   131  }
   132  
   133  func TestStateClient_NoConflict(t *testing.T) {
   134  	fakeAtlas := newFakeAtlas(t, testStateSimple)
   135  	srv := fakeAtlas.Server()
   136  	defer srv.Close()
   137  
   138  	client := testStateClient(t, map[string]string{
   139  		"access_token": "sometoken",
   140  		"name":         "someuser/some-test-remote-state",
   141  		"address":      srv.URL,
   142  	})
   143  
   144  	state, err := terraform.ReadState(bytes.NewReader(testStateSimple))
   145  	if err != nil {
   146  		t.Fatalf("err: %s", err)
   147  	}
   148  
   149  	fakeAtlas.NoConflictAllowed(true)
   150  
   151  	var stateJson bytes.Buffer
   152  	if err := terraform.WriteState(state, &stateJson); err != nil {
   153  		t.Fatalf("err: %s", err)
   154  	}
   155  
   156  	if err := client.Put(stateJson.Bytes()); err != nil {
   157  		t.Fatalf("err: %s", err)
   158  	}
   159  }
   160  
   161  func TestStateClient_LegitimateConflict(t *testing.T) {
   162  	fakeAtlas := newFakeAtlas(t, testStateSimple)
   163  	srv := fakeAtlas.Server()
   164  	defer srv.Close()
   165  
   166  	client := testStateClient(t, map[string]string{
   167  		"access_token": "sometoken",
   168  		"name":         "someuser/some-test-remote-state",
   169  		"address":      srv.URL,
   170  	})
   171  
   172  	state, err := terraform.ReadState(bytes.NewReader(testStateSimple))
   173  	if err != nil {
   174  		t.Fatalf("err: %s", err)
   175  	}
   176  
   177  	var buf bytes.Buffer
   178  	terraform.WriteState(state, &buf)
   179  
   180  	// Changing the state but not the serial. Should generate a conflict.
   181  	state.RootModule().Outputs["drift"] = &terraform.OutputState{
   182  		Type:      "string",
   183  		Sensitive: false,
   184  		Value:     "happens",
   185  	}
   186  
   187  	var stateJson bytes.Buffer
   188  	if err := terraform.WriteState(state, &stateJson); err != nil {
   189  		t.Fatalf("err: %s", err)
   190  	}
   191  	if err := client.Put(stateJson.Bytes()); err == nil {
   192  		t.Fatal("Expected error from state conflict, got none.")
   193  	}
   194  }
   195  
   196  func TestStateClient_UnresolvableConflict(t *testing.T) {
   197  	fakeAtlas := newFakeAtlas(t, testStateSimple)
   198  
   199  	// Something unexpected causes Atlas to conflict in a way that we can't fix.
   200  	fakeAtlas.AlwaysConflict(true)
   201  
   202  	srv := fakeAtlas.Server()
   203  	defer srv.Close()
   204  
   205  	client := testStateClient(t, map[string]string{
   206  		"access_token": "sometoken",
   207  		"name":         "someuser/some-test-remote-state",
   208  		"address":      srv.URL,
   209  	})
   210  
   211  	state, err := terraform.ReadState(bytes.NewReader(testStateSimple))
   212  	if err != nil {
   213  		t.Fatalf("err: %s", err)
   214  	}
   215  
   216  	var stateJson bytes.Buffer
   217  	if err := terraform.WriteState(state, &stateJson); err != nil {
   218  		t.Fatalf("err: %s", err)
   219  	}
   220  	errCh := make(chan error)
   221  	go func() {
   222  		defer close(errCh)
   223  		if err := client.Put(stateJson.Bytes()); err == nil {
   224  			errCh <- errors.New("expected error from state conflict, got none.")
   225  			return
   226  		}
   227  	}()
   228  
   229  	select {
   230  	case err := <-errCh:
   231  		if err != nil {
   232  			t.Fatalf("error from anonymous test goroutine: %s", err)
   233  		}
   234  	case <-time.After(500 * time.Millisecond):
   235  		t.Fatalf("Timed out after 500ms, probably because retrying infinitely.")
   236  	}
   237  }
   238  
   239  // Stub Atlas HTTP API for a given state JSON string; does checksum-based
   240  // conflict detection equivalent to Atlas's.
   241  type fakeAtlas struct {
   242  	state []byte
   243  	t     *testing.T
   244  
   245  	// Used to test that we only do the special conflict handling retry once.
   246  	alwaysConflict bool
   247  
   248  	// Used to fail the test immediately if a conflict happens.
   249  	noConflictAllowed bool
   250  }
   251  
   252  func newFakeAtlas(t *testing.T, state []byte) *fakeAtlas {
   253  	return &fakeAtlas{
   254  		state: state,
   255  		t:     t,
   256  	}
   257  }
   258  
   259  func (f *fakeAtlas) Server() *httptest.Server {
   260  	return httptest.NewServer(http.HandlerFunc(f.handler))
   261  }
   262  
   263  func (f *fakeAtlas) CurrentState() *terraform.State {
   264  	// we read the state manually here, because terraform may alter state
   265  	// during read
   266  	currentState := &terraform.State{}
   267  	err := json.Unmarshal(f.state, currentState)
   268  	if err != nil {
   269  		f.t.Fatalf("err: %s", err)
   270  	}
   271  	return currentState
   272  }
   273  
   274  func (f *fakeAtlas) CurrentSerial() int64 {
   275  	return f.CurrentState().Serial
   276  }
   277  
   278  func (f *fakeAtlas) CurrentSum() [md5.Size]byte {
   279  	return md5.Sum(f.state)
   280  }
   281  
   282  func (f *fakeAtlas) AlwaysConflict(b bool) {
   283  	f.alwaysConflict = b
   284  }
   285  
   286  func (f *fakeAtlas) NoConflictAllowed(b bool) {
   287  	f.noConflictAllowed = b
   288  }
   289  
   290  func (f *fakeAtlas) handler(resp http.ResponseWriter, req *http.Request) {
   291  	// access tokens should only be sent as a header
   292  	if req.FormValue("access_token") != "" {
   293  		http.Error(resp, "access_token in request params", http.StatusBadRequest)
   294  		return
   295  	}
   296  
   297  	if req.Header.Get(atlasTokenHeader) == "" {
   298  		http.Error(resp, "missing access token", http.StatusBadRequest)
   299  		return
   300  	}
   301  
   302  	switch req.Method {
   303  	case "GET":
   304  		// Respond with the current stored state.
   305  		resp.Header().Set("Content-Type", "application/json")
   306  		resp.Write(f.state)
   307  	case "PUT":
   308  		var buf bytes.Buffer
   309  		buf.ReadFrom(req.Body)
   310  		sum := md5.Sum(buf.Bytes())
   311  
   312  		// we read the state manually here, because terraform may alter state
   313  		// during read
   314  		state := &terraform.State{}
   315  		err := json.Unmarshal(buf.Bytes(), state)
   316  		if err != nil {
   317  			f.t.Fatalf("err: %s", err)
   318  		}
   319  
   320  		conflict := f.CurrentSerial() == state.Serial && f.CurrentSum() != sum
   321  		conflict = conflict || f.alwaysConflict
   322  		if conflict {
   323  			if f.noConflictAllowed {
   324  				f.t.Fatal("Got conflict when NoConflictAllowed was set.")
   325  			}
   326  			http.Error(resp, "Conflict", 409)
   327  		} else {
   328  			f.state = buf.Bytes()
   329  			resp.WriteHeader(200)
   330  		}
   331  	}
   332  }
   333  
   334  // This is a tfstate file with the module order changed, which is a structural
   335  // but not a semantic difference. Terraform will sort these modules as it
   336  // loads the state.
   337  var testStateModuleOrderChange = []byte(
   338  	`{
   339      "version": 3,
   340      "serial": 1,
   341      "modules": [
   342          {
   343              "path": [
   344                  "root",
   345                  "child2",
   346                  "grandchild"
   347              ],
   348              "outputs": {
   349                  "foo": {
   350                      "sensitive": false,
   351                      "type": "string",
   352                      "value": "bar"
   353                  }
   354              },
   355              "resources": null
   356          },
   357          {
   358              "path": [
   359                  "root",
   360                  "child1",
   361                  "grandchild"
   362              ],
   363              "outputs": {
   364                  "foo": {
   365                      "sensitive": false,
   366                      "type": "string",
   367                      "value": "bar"
   368                  }
   369              },
   370              "resources": null
   371          }
   372      ]
   373  }
   374  `)
   375  
   376  var testStateSimple = []byte(
   377  	`{
   378      "version": 3,
   379      "serial": 2,
   380      "lineage": "c00ad9ac-9b35-42fe-846e-b06f0ef877e9",
   381      "modules": [
   382          {
   383              "path": [
   384                  "root"
   385              ],
   386              "outputs": {
   387                  "foo": {
   388                      "sensitive": false,
   389                      "type": "string",
   390                      "value": "bar"
   391                  }
   392              },
   393              "resources": {},
   394              "depends_on": []
   395          }
   396      ]
   397  }
   398  `)