github.com/graywolf-at-work-2/terraform-vendor@v1.4.5/internal/backend/remote-state/http/server_test.go (about)

     1  package http
     2  
     3  //go:generate go run github.com/golang/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"encoding/json"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"os"
    14  	"os/signal"
    15  	"path/filepath"
    16  	"reflect"
    17  	"strings"
    18  	"sync"
    19  	"syscall"
    20  	"testing"
    21  
    22  	"github.com/golang/mock/gomock"
    23  	"github.com/hashicorp/terraform/internal/addrs"
    24  	"github.com/hashicorp/terraform/internal/backend"
    25  	"github.com/hashicorp/terraform/internal/configs"
    26  	"github.com/hashicorp/terraform/internal/states"
    27  	"github.com/zclconf/go-cty/cty"
    28  )
    29  
    30  const sampleState = `
    31  {
    32      "version": 4,
    33      "serial": 0,
    34      "lineage": "666f9301-7e65-4b19-ae23-71184bb19b03",
    35      "remote": {
    36          "type": "http",
    37          "config": {
    38              "path": "local-state.tfstate"
    39          }
    40      }
    41  }
    42  `
    43  
    44  type (
    45  	HttpServerCallback interface {
    46  		StateGET(req *http.Request)
    47  		StatePOST(req *http.Request)
    48  		StateDELETE(req *http.Request)
    49  		StateLOCK(req *http.Request)
    50  		StateUNLOCK(req *http.Request)
    51  	}
    52  	httpServer struct {
    53  		r     *http.ServeMux
    54  		data  map[string]string
    55  		locks map[string]string
    56  		lock  sync.RWMutex
    57  
    58  		httpServerCallback HttpServerCallback
    59  	}
    60  	httpServerOpt func(*httpServer)
    61  )
    62  
    63  func withHttpServerCallback(callback HttpServerCallback) httpServerOpt {
    64  	return func(s *httpServer) {
    65  		s.httpServerCallback = callback
    66  	}
    67  }
    68  
    69  func newHttpServer(opts ...httpServerOpt) *httpServer {
    70  	r := http.NewServeMux()
    71  	s := &httpServer{
    72  		r:     r,
    73  		data:  make(map[string]string),
    74  		locks: make(map[string]string),
    75  	}
    76  	for _, opt := range opts {
    77  		opt(s)
    78  	}
    79  	s.data["sample"] = sampleState
    80  	r.HandleFunc("/state/", s.handleState)
    81  	return s
    82  }
    83  
    84  func (h *httpServer) getResource(req *http.Request) string {
    85  	switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) {
    86  	case 3:
    87  		return pathParts[2]
    88  	default:
    89  		return ""
    90  	}
    91  }
    92  
    93  func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) {
    94  	switch req.Method {
    95  	case "GET":
    96  		h.handleStateGET(writer, req)
    97  	case "POST":
    98  		h.handleStatePOST(writer, req)
    99  	case "DELETE":
   100  		h.handleStateDELETE(writer, req)
   101  	case "LOCK":
   102  		h.handleStateLOCK(writer, req)
   103  	case "UNLOCK":
   104  		h.handleStateUNLOCK(writer, req)
   105  	}
   106  }
   107  
   108  func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) {
   109  	if h.httpServerCallback != nil {
   110  		defer h.httpServerCallback.StateGET(req)
   111  	}
   112  	resource := h.getResource(req)
   113  
   114  	h.lock.RLock()
   115  	defer h.lock.RUnlock()
   116  
   117  	if state, ok := h.data[resource]; ok {
   118  		_, _ = io.WriteString(writer, state)
   119  	} else {
   120  		writer.WriteHeader(http.StatusNotFound)
   121  	}
   122  }
   123  
   124  func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) {
   125  	if h.httpServerCallback != nil {
   126  		defer h.httpServerCallback.StatePOST(req)
   127  	}
   128  	defer req.Body.Close()
   129  	resource := h.getResource(req)
   130  
   131  	data, err := io.ReadAll(req.Body)
   132  	if err != nil {
   133  		writer.WriteHeader(http.StatusBadRequest)
   134  		return
   135  	}
   136  
   137  	h.lock.Lock()
   138  	defer h.lock.Unlock()
   139  
   140  	h.data[resource] = string(data)
   141  	writer.WriteHeader(http.StatusOK)
   142  }
   143  
   144  func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) {
   145  	if h.httpServerCallback != nil {
   146  		defer h.httpServerCallback.StateDELETE(req)
   147  	}
   148  	resource := h.getResource(req)
   149  
   150  	h.lock.Lock()
   151  	defer h.lock.Unlock()
   152  
   153  	delete(h.data, resource)
   154  	writer.WriteHeader(http.StatusOK)
   155  }
   156  
   157  func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) {
   158  	if h.httpServerCallback != nil {
   159  		defer h.httpServerCallback.StateLOCK(req)
   160  	}
   161  	defer req.Body.Close()
   162  	resource := h.getResource(req)
   163  
   164  	data, err := io.ReadAll(req.Body)
   165  	if err != nil {
   166  		writer.WriteHeader(http.StatusBadRequest)
   167  		return
   168  	}
   169  
   170  	h.lock.Lock()
   171  	defer h.lock.Unlock()
   172  
   173  	if existingLock, ok := h.locks[resource]; ok {
   174  		writer.WriteHeader(http.StatusLocked)
   175  		_, _ = io.WriteString(writer, existingLock)
   176  	} else {
   177  		h.locks[resource] = string(data)
   178  		_, _ = io.WriteString(writer, existingLock)
   179  	}
   180  }
   181  
   182  func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) {
   183  	if h.httpServerCallback != nil {
   184  		defer h.httpServerCallback.StateUNLOCK(req)
   185  	}
   186  	defer req.Body.Close()
   187  	resource := h.getResource(req)
   188  
   189  	data, err := io.ReadAll(req.Body)
   190  	if err != nil {
   191  		writer.WriteHeader(http.StatusBadRequest)
   192  		return
   193  	}
   194  	var lockInfo map[string]interface{}
   195  	if err = json.Unmarshal(data, &lockInfo); err != nil {
   196  		writer.WriteHeader(http.StatusInternalServerError)
   197  		return
   198  	}
   199  
   200  	h.lock.Lock()
   201  	defer h.lock.Unlock()
   202  
   203  	if existingLock, ok := h.locks[resource]; ok {
   204  		var existingLockInfo map[string]interface{}
   205  		if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil {
   206  			writer.WriteHeader(http.StatusInternalServerError)
   207  			return
   208  		}
   209  		lockID := lockInfo["ID"].(string)
   210  		existingID := existingLockInfo["ID"].(string)
   211  		if lockID != existingID {
   212  			writer.WriteHeader(http.StatusConflict)
   213  			_, _ = io.WriteString(writer, existingLock)
   214  		} else {
   215  			delete(h.locks, resource)
   216  			_, _ = io.WriteString(writer, existingLock)
   217  		}
   218  	} else {
   219  		writer.WriteHeader(http.StatusConflict)
   220  	}
   221  }
   222  
   223  func (h *httpServer) handler() http.Handler {
   224  	return h.r
   225  }
   226  
   227  func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) {
   228  	clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem")
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  	clientCAs := x509.NewCertPool()
   233  	clientCAs.AppendCertsFromPEM(clientCAData)
   234  
   235  	cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key")
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  
   240  	h := newHttpServer(opts...)
   241  	s := httptest.NewUnstartedServer(h.handler())
   242  	s.TLS = &tls.Config{
   243  		ClientAuth:   tls.RequireAndVerifyClientCert,
   244  		ClientCAs:    clientCAs,
   245  		Certificates: []tls.Certificate{cert},
   246  	}
   247  
   248  	s.StartTLS()
   249  	return s, nil
   250  }
   251  
   252  func TestMTLSServer_NoCertFails(t *testing.T) {
   253  	// Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert
   254  	ctrl := gomock.NewController(t)
   255  	defer ctrl.Finish()
   256  	mockCallback := NewMockHttpServerCallback(ctrl)
   257  
   258  	// Fire up a test server
   259  	ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
   260  	if err != nil {
   261  		t.Fatalf("unexpected error creating test server: %v", err)
   262  	}
   263  	defer ts.Close()
   264  
   265  	// Configure the backend to the pre-populated sample state
   266  	url := ts.URL + "/state/sample"
   267  	conf := map[string]cty.Value{
   268  		"address":                cty.StringVal(url),
   269  		"skip_cert_verification": cty.BoolVal(true),
   270  	}
   271  	b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend)
   272  	if nil == b {
   273  		t.Fatal("nil backend")
   274  	}
   275  
   276  	// Now get a state manager and check that it fails to refresh the state
   277  	sm, err := b.StateMgr(backend.DefaultStateName)
   278  	if err != nil {
   279  		t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
   280  	}
   281  	err = sm.RefreshState()
   282  	if nil == err {
   283  		t.Error("expected error when refreshing state without a client cert")
   284  	} else if !strings.Contains(err.Error(), "remote error: tls: bad certificate") {
   285  		t.Errorf("expected the error to report missing tls credentials: %v", err)
   286  	}
   287  }
   288  
   289  func TestMTLSServer_WithCertPasses(t *testing.T) {
   290  	// Ensure that the expected amount of calls is made to the server
   291  	ctrl := gomock.NewController(t)
   292  	defer ctrl.Finish()
   293  	mockCallback := NewMockHttpServerCallback(ctrl)
   294  
   295  	// Two or three (not testing the caching here) calls to GET
   296  	mockCallback.EXPECT().
   297  		StateGET(gomock.Any()).
   298  		MinTimes(2).
   299  		MaxTimes(3)
   300  	// One call to the POST to write the data
   301  	mockCallback.EXPECT().
   302  		StatePOST(gomock.Any())
   303  
   304  	// Fire up a test server
   305  	ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
   306  	if err != nil {
   307  		t.Fatalf("unexpected error creating test server: %v", err)
   308  	}
   309  	defer ts.Close()
   310  
   311  	// Configure the backend to the pre-populated sample state, and with all the test certs lined up
   312  	url := ts.URL + "/state/sample"
   313  	caData, err := os.ReadFile("testdata/certs/ca.cert.pem")
   314  	if err != nil {
   315  		t.Fatalf("error reading ca certs: %v", err)
   316  	}
   317  	clientCertData, err := os.ReadFile("testdata/certs/client.crt")
   318  	if err != nil {
   319  		t.Fatalf("error reading client cert: %v", err)
   320  	}
   321  	clientKeyData, err := os.ReadFile("testdata/certs/client.key")
   322  	if err != nil {
   323  		t.Fatalf("error reading client key: %v", err)
   324  	}
   325  	conf := map[string]cty.Value{
   326  		"address":                   cty.StringVal(url),
   327  		"lock_address":              cty.StringVal(url),
   328  		"unlock_address":            cty.StringVal(url),
   329  		"client_ca_certificate_pem": cty.StringVal(string(caData)),
   330  		"client_certificate_pem":    cty.StringVal(string(clientCertData)),
   331  		"client_private_key_pem":    cty.StringVal(string(clientKeyData)),
   332  	}
   333  	b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend)
   334  	if nil == b {
   335  		t.Fatal("nil backend")
   336  	}
   337  
   338  	// Now get a state manager, fetch the state, and ensure that the "foo" output is not set
   339  	sm, err := b.StateMgr(backend.DefaultStateName)
   340  	if err != nil {
   341  		t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
   342  	}
   343  	if err = sm.RefreshState(); err != nil {
   344  		t.Fatalf("unexpected error calling RefreshState: %v", err)
   345  	}
   346  	state := sm.State()
   347  	if nil == state {
   348  		t.Fatal("nil state")
   349  	}
   350  	stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   351  	if stateFoo != nil {
   352  		t.Errorf("expected nil foo from state; got %v", stateFoo)
   353  	}
   354  
   355  	// Create a new state that has "foo" set to "bar" and ensure that state is as expected
   356  	state = states.BuildState(func(ss *states.SyncState) {
   357  		ss.SetOutputValue(
   358  			addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance),
   359  			cty.StringVal("bar"),
   360  			false)
   361  	})
   362  	stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   363  	if nil == stateFoo {
   364  		t.Fatal("nil foo after building state with foo populated")
   365  	}
   366  	if foo := stateFoo.Value.AsString(); foo != "bar" {
   367  		t.Errorf("Expected built state foo value to be bar; got %s", foo)
   368  	}
   369  
   370  	// Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states
   371  	curState := sm.State()
   372  	curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   373  	if curStateFoo != nil {
   374  		t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo)
   375  	}
   376  	if reflect.DeepEqual(state, curState) {
   377  		t.Errorf("expected %v != %v; but they were equal", state, curState)
   378  	}
   379  
   380  	// Write the new state, persist, and refresh
   381  	if err = sm.WriteState(state); err != nil {
   382  		t.Errorf("error writing state: %v", err)
   383  	}
   384  	if err = sm.PersistState(nil); err != nil {
   385  		t.Errorf("error persisting state: %v", err)
   386  	}
   387  	if err = sm.RefreshState(); err != nil {
   388  		t.Errorf("error refreshing state: %v", err)
   389  	}
   390  
   391  	// Get the state again and verify that is now the same as state and has the "foo" value set to "bar"
   392  	curState = sm.State()
   393  	if !reflect.DeepEqual(state, curState) {
   394  		t.Errorf("expected %v == %v; but they were unequal", state, curState)
   395  	}
   396  	curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   397  	if nil == curStateFoo {
   398  		t.Fatal("nil foo")
   399  	}
   400  	if foo := curStateFoo.Value.AsString(); foo != "bar" {
   401  		t.Errorf("expected foo to be bar, but got: %s", foo)
   402  	}
   403  }
   404  
   405  // TestRunServer allows running the server for local debugging; it runs until ctl-c is received
   406  func TestRunServer(t *testing.T) {
   407  	if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok {
   408  		t.Skip("TEST_RUN_SERVER not set")
   409  	}
   410  	s, err := NewHttpTestServer()
   411  	if err != nil {
   412  		t.Fatalf("unexpected error creating test server: %v", err)
   413  	}
   414  	defer s.Close()
   415  
   416  	t.Log(s.URL)
   417  
   418  	ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
   419  	defer cancel()
   420  	// wait until signal
   421  	<-ctx.Done()
   422  }