github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/backend/remote-state/http/server_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package http
     5  
     6  //go:generate go run github.com/golang/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE
     7  
     8  import (
     9  	"context"
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"encoding/json"
    13  	"io"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"os"
    17  	"os/signal"
    18  	"path/filepath"
    19  	"reflect"
    20  	"strings"
    21  	"sync"
    22  	"syscall"
    23  	"testing"
    24  
    25  	"github.com/golang/mock/gomock"
    26  	"github.com/terramate-io/tf/addrs"
    27  	"github.com/terramate-io/tf/backend"
    28  	"github.com/terramate-io/tf/configs"
    29  	"github.com/terramate-io/tf/states"
    30  	"github.com/zclconf/go-cty/cty"
    31  )
    32  
    33  const sampleState = `
    34  {
    35      "version": 4,
    36      "serial": 0,
    37      "lineage": "666f9301-7e65-4b19-ae23-71184bb19b03",
    38      "remote": {
    39          "type": "http",
    40          "config": {
    41              "path": "local-state.tfstate"
    42          }
    43      }
    44  }
    45  `
    46  
    47  type (
    48  	HttpServerCallback interface {
    49  		StateGET(req *http.Request)
    50  		StatePOST(req *http.Request)
    51  		StateDELETE(req *http.Request)
    52  		StateLOCK(req *http.Request)
    53  		StateUNLOCK(req *http.Request)
    54  	}
    55  	httpServer struct {
    56  		r     *http.ServeMux
    57  		data  map[string]string
    58  		locks map[string]string
    59  		lock  sync.RWMutex
    60  
    61  		httpServerCallback HttpServerCallback
    62  	}
    63  	httpServerOpt func(*httpServer)
    64  )
    65  
    66  func withHttpServerCallback(callback HttpServerCallback) httpServerOpt {
    67  	return func(s *httpServer) {
    68  		s.httpServerCallback = callback
    69  	}
    70  }
    71  
    72  func newHttpServer(opts ...httpServerOpt) *httpServer {
    73  	r := http.NewServeMux()
    74  	s := &httpServer{
    75  		r:     r,
    76  		data:  make(map[string]string),
    77  		locks: make(map[string]string),
    78  	}
    79  	for _, opt := range opts {
    80  		opt(s)
    81  	}
    82  	s.data["sample"] = sampleState
    83  	r.HandleFunc("/state/", s.handleState)
    84  	return s
    85  }
    86  
    87  func (h *httpServer) getResource(req *http.Request) string {
    88  	switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) {
    89  	case 3:
    90  		return pathParts[2]
    91  	default:
    92  		return ""
    93  	}
    94  }
    95  
    96  func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) {
    97  	switch req.Method {
    98  	case "GET":
    99  		h.handleStateGET(writer, req)
   100  	case "POST":
   101  		h.handleStatePOST(writer, req)
   102  	case "DELETE":
   103  		h.handleStateDELETE(writer, req)
   104  	case "LOCK":
   105  		h.handleStateLOCK(writer, req)
   106  	case "UNLOCK":
   107  		h.handleStateUNLOCK(writer, req)
   108  	}
   109  }
   110  
   111  func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) {
   112  	if h.httpServerCallback != nil {
   113  		defer h.httpServerCallback.StateGET(req)
   114  	}
   115  	resource := h.getResource(req)
   116  
   117  	h.lock.RLock()
   118  	defer h.lock.RUnlock()
   119  
   120  	if state, ok := h.data[resource]; ok {
   121  		_, _ = io.WriteString(writer, state)
   122  	} else {
   123  		writer.WriteHeader(http.StatusNotFound)
   124  	}
   125  }
   126  
   127  func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) {
   128  	if h.httpServerCallback != nil {
   129  		defer h.httpServerCallback.StatePOST(req)
   130  	}
   131  	defer req.Body.Close()
   132  	resource := h.getResource(req)
   133  
   134  	data, err := io.ReadAll(req.Body)
   135  	if err != nil {
   136  		writer.WriteHeader(http.StatusBadRequest)
   137  		return
   138  	}
   139  
   140  	h.lock.Lock()
   141  	defer h.lock.Unlock()
   142  
   143  	h.data[resource] = string(data)
   144  	writer.WriteHeader(http.StatusOK)
   145  }
   146  
   147  func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) {
   148  	if h.httpServerCallback != nil {
   149  		defer h.httpServerCallback.StateDELETE(req)
   150  	}
   151  	resource := h.getResource(req)
   152  
   153  	h.lock.Lock()
   154  	defer h.lock.Unlock()
   155  
   156  	delete(h.data, resource)
   157  	writer.WriteHeader(http.StatusOK)
   158  }
   159  
   160  func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) {
   161  	if h.httpServerCallback != nil {
   162  		defer h.httpServerCallback.StateLOCK(req)
   163  	}
   164  	defer req.Body.Close()
   165  	resource := h.getResource(req)
   166  
   167  	data, err := io.ReadAll(req.Body)
   168  	if err != nil {
   169  		writer.WriteHeader(http.StatusBadRequest)
   170  		return
   171  	}
   172  
   173  	h.lock.Lock()
   174  	defer h.lock.Unlock()
   175  
   176  	if existingLock, ok := h.locks[resource]; ok {
   177  		writer.WriteHeader(http.StatusLocked)
   178  		_, _ = io.WriteString(writer, existingLock)
   179  	} else {
   180  		h.locks[resource] = string(data)
   181  		_, _ = io.WriteString(writer, existingLock)
   182  	}
   183  }
   184  
   185  func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) {
   186  	if h.httpServerCallback != nil {
   187  		defer h.httpServerCallback.StateUNLOCK(req)
   188  	}
   189  	defer req.Body.Close()
   190  	resource := h.getResource(req)
   191  
   192  	data, err := io.ReadAll(req.Body)
   193  	if err != nil {
   194  		writer.WriteHeader(http.StatusBadRequest)
   195  		return
   196  	}
   197  	var lockInfo map[string]interface{}
   198  	if err = json.Unmarshal(data, &lockInfo); err != nil {
   199  		writer.WriteHeader(http.StatusInternalServerError)
   200  		return
   201  	}
   202  
   203  	h.lock.Lock()
   204  	defer h.lock.Unlock()
   205  
   206  	if existingLock, ok := h.locks[resource]; ok {
   207  		var existingLockInfo map[string]interface{}
   208  		if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil {
   209  			writer.WriteHeader(http.StatusInternalServerError)
   210  			return
   211  		}
   212  		lockID := lockInfo["ID"].(string)
   213  		existingID := existingLockInfo["ID"].(string)
   214  		if lockID != existingID {
   215  			writer.WriteHeader(http.StatusConflict)
   216  			_, _ = io.WriteString(writer, existingLock)
   217  		} else {
   218  			delete(h.locks, resource)
   219  			_, _ = io.WriteString(writer, existingLock)
   220  		}
   221  	} else {
   222  		writer.WriteHeader(http.StatusConflict)
   223  	}
   224  }
   225  
   226  func (h *httpServer) handler() http.Handler {
   227  	return h.r
   228  }
   229  
   230  func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) {
   231  	clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem")
   232  	if err != nil {
   233  		return nil, err
   234  	}
   235  	clientCAs := x509.NewCertPool()
   236  	clientCAs.AppendCertsFromPEM(clientCAData)
   237  
   238  	cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key")
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	h := newHttpServer(opts...)
   244  	s := httptest.NewUnstartedServer(h.handler())
   245  	s.TLS = &tls.Config{
   246  		ClientAuth:   tls.RequireAndVerifyClientCert,
   247  		ClientCAs:    clientCAs,
   248  		Certificates: []tls.Certificate{cert},
   249  	}
   250  
   251  	s.StartTLS()
   252  	return s, nil
   253  }
   254  
   255  func TestMTLSServer_NoCertFails(t *testing.T) {
   256  	// Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert
   257  	ctrl := gomock.NewController(t)
   258  	defer ctrl.Finish()
   259  	mockCallback := NewMockHttpServerCallback(ctrl)
   260  
   261  	// Fire up a test server
   262  	ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
   263  	if err != nil {
   264  		t.Fatalf("unexpected error creating test server: %v", err)
   265  	}
   266  	defer ts.Close()
   267  
   268  	// Configure the backend to the pre-populated sample state
   269  	url := ts.URL + "/state/sample"
   270  	conf := map[string]cty.Value{
   271  		"address":                cty.StringVal(url),
   272  		"skip_cert_verification": cty.BoolVal(true),
   273  	}
   274  	b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend)
   275  	if nil == b {
   276  		t.Fatal("nil backend")
   277  	}
   278  
   279  	// Now get a state manager and check that it fails to refresh the state
   280  	sm, err := b.StateMgr(backend.DefaultStateName)
   281  	if err != nil {
   282  		t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
   283  	}
   284  	err = sm.RefreshState()
   285  	if nil == err {
   286  		t.Error("expected error when refreshing state without a client cert")
   287  	} else if !strings.Contains(err.Error(), "remote error: tls: bad certificate") {
   288  		t.Errorf("expected the error to report missing tls credentials: %v", err)
   289  	}
   290  }
   291  
   292  func TestMTLSServer_WithCertPasses(t *testing.T) {
   293  	// Ensure that the expected amount of calls is made to the server
   294  	ctrl := gomock.NewController(t)
   295  	defer ctrl.Finish()
   296  	mockCallback := NewMockHttpServerCallback(ctrl)
   297  
   298  	// Two or three (not testing the caching here) calls to GET
   299  	mockCallback.EXPECT().
   300  		StateGET(gomock.Any()).
   301  		MinTimes(2).
   302  		MaxTimes(3)
   303  	// One call to the POST to write the data
   304  	mockCallback.EXPECT().
   305  		StatePOST(gomock.Any())
   306  
   307  	// Fire up a test server
   308  	ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
   309  	if err != nil {
   310  		t.Fatalf("unexpected error creating test server: %v", err)
   311  	}
   312  	defer ts.Close()
   313  
   314  	// Configure the backend to the pre-populated sample state, and with all the test certs lined up
   315  	url := ts.URL + "/state/sample"
   316  	caData, err := os.ReadFile("testdata/certs/ca.cert.pem")
   317  	if err != nil {
   318  		t.Fatalf("error reading ca certs: %v", err)
   319  	}
   320  	clientCertData, err := os.ReadFile("testdata/certs/client.crt")
   321  	if err != nil {
   322  		t.Fatalf("error reading client cert: %v", err)
   323  	}
   324  	clientKeyData, err := os.ReadFile("testdata/certs/client.key")
   325  	if err != nil {
   326  		t.Fatalf("error reading client key: %v", err)
   327  	}
   328  	conf := map[string]cty.Value{
   329  		"address":                   cty.StringVal(url),
   330  		"lock_address":              cty.StringVal(url),
   331  		"unlock_address":            cty.StringVal(url),
   332  		"client_ca_certificate_pem": cty.StringVal(string(caData)),
   333  		"client_certificate_pem":    cty.StringVal(string(clientCertData)),
   334  		"client_private_key_pem":    cty.StringVal(string(clientKeyData)),
   335  	}
   336  	b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend)
   337  	if nil == b {
   338  		t.Fatal("nil backend")
   339  	}
   340  
   341  	// Now get a state manager, fetch the state, and ensure that the "foo" output is not set
   342  	sm, err := b.StateMgr(backend.DefaultStateName)
   343  	if err != nil {
   344  		t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
   345  	}
   346  	if err = sm.RefreshState(); err != nil {
   347  		t.Fatalf("unexpected error calling RefreshState: %v", err)
   348  	}
   349  	state := sm.State()
   350  	if nil == state {
   351  		t.Fatal("nil state")
   352  	}
   353  	stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   354  	if stateFoo != nil {
   355  		t.Errorf("expected nil foo from state; got %v", stateFoo)
   356  	}
   357  
   358  	// Create a new state that has "foo" set to "bar" and ensure that state is as expected
   359  	state = states.BuildState(func(ss *states.SyncState) {
   360  		ss.SetOutputValue(
   361  			addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance),
   362  			cty.StringVal("bar"),
   363  			false)
   364  	})
   365  	stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   366  	if nil == stateFoo {
   367  		t.Fatal("nil foo after building state with foo populated")
   368  	}
   369  	if foo := stateFoo.Value.AsString(); foo != "bar" {
   370  		t.Errorf("Expected built state foo value to be bar; got %s", foo)
   371  	}
   372  
   373  	// Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states
   374  	curState := sm.State()
   375  	curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   376  	if curStateFoo != nil {
   377  		t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo)
   378  	}
   379  	if reflect.DeepEqual(state, curState) {
   380  		t.Errorf("expected %v != %v; but they were equal", state, curState)
   381  	}
   382  
   383  	// Write the new state, persist, and refresh
   384  	if err = sm.WriteState(state); err != nil {
   385  		t.Errorf("error writing state: %v", err)
   386  	}
   387  	if err = sm.PersistState(nil); err != nil {
   388  		t.Errorf("error persisting state: %v", err)
   389  	}
   390  	if err = sm.RefreshState(); err != nil {
   391  		t.Errorf("error refreshing state: %v", err)
   392  	}
   393  
   394  	// Get the state again and verify that is now the same as state and has the "foo" value set to "bar"
   395  	curState = sm.State()
   396  	if !reflect.DeepEqual(state, curState) {
   397  		t.Errorf("expected %v == %v; but they were unequal", state, curState)
   398  	}
   399  	curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   400  	if nil == curStateFoo {
   401  		t.Fatal("nil foo")
   402  	}
   403  	if foo := curStateFoo.Value.AsString(); foo != "bar" {
   404  		t.Errorf("expected foo to be bar, but got: %s", foo)
   405  	}
   406  }
   407  
   408  // TestRunServer allows running the server for local debugging; it runs until ctl-c is received
   409  func TestRunServer(t *testing.T) {
   410  	if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok {
   411  		t.Skip("TEST_RUN_SERVER not set")
   412  	}
   413  	s, err := NewHttpTestServer()
   414  	if err != nil {
   415  		t.Fatalf("unexpected error creating test server: %v", err)
   416  	}
   417  	defer s.Close()
   418  
   419  	t.Log(s.URL)
   420  
   421  	ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
   422  	defer cancel()
   423  	// wait until signal
   424  	<-ctx.Done()
   425  }