github.com/opentofu/opentofu@v1.7.1/internal/backend/remote-state/http/server_test.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package http
     7  
     8  //go:generate go run github.com/golang/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE
     9  
    10  import (
    11  	"context"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"encoding/json"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"net"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"os"
    22  	"os/signal"
    23  	"path/filepath"
    24  	"reflect"
    25  	"strings"
    26  	"sync"
    27  	"syscall"
    28  	"testing"
    29  
    30  	"github.com/golang/mock/gomock"
    31  	"github.com/opentofu/opentofu/internal/addrs"
    32  	"github.com/opentofu/opentofu/internal/backend"
    33  	"github.com/opentofu/opentofu/internal/configs"
    34  	"github.com/opentofu/opentofu/internal/encryption"
    35  	"github.com/opentofu/opentofu/internal/states"
    36  	"github.com/zclconf/go-cty/cty"
    37  )
    38  
    39  const sampleState = `
    40  {
    41      "version": 4,
    42      "serial": 0,
    43      "lineage": "666f9301-7e65-4b19-ae23-71184bb19b03",
    44      "remote": {
    45          "type": "http",
    46          "config": {
    47              "path": "local-state.tfstate"
    48          }
    49      }
    50  }
    51  `
    52  
    53  type (
    54  	HttpServerCallback interface {
    55  		StateGET(req *http.Request)
    56  		StatePOST(req *http.Request)
    57  		StateDELETE(req *http.Request)
    58  		StateLOCK(req *http.Request)
    59  		StateUNLOCK(req *http.Request)
    60  	}
    61  	httpServer struct {
    62  		r     *http.ServeMux
    63  		data  map[string]string
    64  		locks map[string]string
    65  		lock  sync.RWMutex
    66  
    67  		httpServerCallback HttpServerCallback
    68  	}
    69  	httpServerOpt func(*httpServer)
    70  )
    71  
    72  func withHttpServerCallback(callback HttpServerCallback) httpServerOpt {
    73  	return func(s *httpServer) {
    74  		s.httpServerCallback = callback
    75  	}
    76  }
    77  
    78  func newHttpServer(opts ...httpServerOpt) *httpServer {
    79  	r := http.NewServeMux()
    80  	s := &httpServer{
    81  		r:     r,
    82  		data:  make(map[string]string),
    83  		locks: make(map[string]string),
    84  	}
    85  	for _, opt := range opts {
    86  		opt(s)
    87  	}
    88  	s.data["sample"] = sampleState
    89  	r.HandleFunc("/state/", s.handleState)
    90  	return s
    91  }
    92  
    93  func (h *httpServer) getResource(req *http.Request) string {
    94  	switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) {
    95  	case 3:
    96  		return pathParts[2]
    97  	default:
    98  		return ""
    99  	}
   100  }
   101  
   102  func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) {
   103  	switch req.Method {
   104  	case "GET":
   105  		h.handleStateGET(writer, req)
   106  	case "POST":
   107  		h.handleStatePOST(writer, req)
   108  	case "DELETE":
   109  		h.handleStateDELETE(writer, req)
   110  	case "LOCK":
   111  		h.handleStateLOCK(writer, req)
   112  	case "UNLOCK":
   113  		h.handleStateUNLOCK(writer, req)
   114  	}
   115  }
   116  
   117  func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) {
   118  	if h.httpServerCallback != nil {
   119  		defer h.httpServerCallback.StateGET(req)
   120  	}
   121  	resource := h.getResource(req)
   122  
   123  	h.lock.RLock()
   124  	defer h.lock.RUnlock()
   125  
   126  	if state, ok := h.data[resource]; ok {
   127  		_, _ = io.WriteString(writer, state)
   128  	} else {
   129  		writer.WriteHeader(http.StatusNotFound)
   130  	}
   131  }
   132  
   133  func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) {
   134  	if h.httpServerCallback != nil {
   135  		defer h.httpServerCallback.StatePOST(req)
   136  	}
   137  	defer req.Body.Close()
   138  	resource := h.getResource(req)
   139  
   140  	data, err := io.ReadAll(req.Body)
   141  	if err != nil {
   142  		writer.WriteHeader(http.StatusBadRequest)
   143  		return
   144  	}
   145  
   146  	h.lock.Lock()
   147  	defer h.lock.Unlock()
   148  
   149  	h.data[resource] = string(data)
   150  	writer.WriteHeader(http.StatusOK)
   151  }
   152  
   153  func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) {
   154  	if h.httpServerCallback != nil {
   155  		defer h.httpServerCallback.StateDELETE(req)
   156  	}
   157  	resource := h.getResource(req)
   158  
   159  	h.lock.Lock()
   160  	defer h.lock.Unlock()
   161  
   162  	delete(h.data, resource)
   163  	writer.WriteHeader(http.StatusOK)
   164  }
   165  
   166  func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) {
   167  	if h.httpServerCallback != nil {
   168  		defer h.httpServerCallback.StateLOCK(req)
   169  	}
   170  	defer req.Body.Close()
   171  	resource := h.getResource(req)
   172  
   173  	data, err := io.ReadAll(req.Body)
   174  	if err != nil {
   175  		writer.WriteHeader(http.StatusBadRequest)
   176  		return
   177  	}
   178  
   179  	h.lock.Lock()
   180  	defer h.lock.Unlock()
   181  
   182  	if existingLock, ok := h.locks[resource]; ok {
   183  		writer.WriteHeader(http.StatusLocked)
   184  		_, _ = io.WriteString(writer, existingLock)
   185  	} else {
   186  		h.locks[resource] = string(data)
   187  		_, _ = io.WriteString(writer, existingLock)
   188  	}
   189  }
   190  
   191  func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) {
   192  	if h.httpServerCallback != nil {
   193  		defer h.httpServerCallback.StateUNLOCK(req)
   194  	}
   195  	defer req.Body.Close()
   196  	resource := h.getResource(req)
   197  
   198  	data, err := io.ReadAll(req.Body)
   199  	if err != nil {
   200  		writer.WriteHeader(http.StatusBadRequest)
   201  		return
   202  	}
   203  	var lockInfo map[string]interface{}
   204  	if err = json.Unmarshal(data, &lockInfo); err != nil {
   205  		writer.WriteHeader(http.StatusInternalServerError)
   206  		return
   207  	}
   208  
   209  	h.lock.Lock()
   210  	defer h.lock.Unlock()
   211  
   212  	if existingLock, ok := h.locks[resource]; ok {
   213  		var existingLockInfo map[string]interface{}
   214  		if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil {
   215  			writer.WriteHeader(http.StatusInternalServerError)
   216  			return
   217  		}
   218  		lockID := lockInfo["ID"].(string)
   219  		existingID := existingLockInfo["ID"].(string)
   220  		if lockID != existingID {
   221  			writer.WriteHeader(http.StatusConflict)
   222  			_, _ = io.WriteString(writer, existingLock)
   223  		} else {
   224  			delete(h.locks, resource)
   225  			_, _ = io.WriteString(writer, existingLock)
   226  		}
   227  	} else {
   228  		writer.WriteHeader(http.StatusConflict)
   229  	}
   230  }
   231  
   232  func (h *httpServer) handler() http.Handler {
   233  	return h.r
   234  }
   235  
   236  func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) {
   237  	clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem")
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  	clientCAs := x509.NewCertPool()
   242  	clientCAs.AppendCertsFromPEM(clientCAData)
   243  
   244  	cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key")
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	h := newHttpServer(opts...)
   250  	s := httptest.NewUnstartedServer(h.handler())
   251  	s.TLS = &tls.Config{
   252  		ClientAuth:   tls.RequireAndVerifyClientCert,
   253  		ClientCAs:    clientCAs,
   254  		Certificates: []tls.Certificate{cert},
   255  	}
   256  
   257  	s.StartTLS()
   258  	return s, nil
   259  }
   260  
   261  func TestMTLSServer_NoCertFails(t *testing.T) {
   262  	// Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert
   263  	ctrl := gomock.NewController(t)
   264  	defer ctrl.Finish()
   265  	mockCallback := NewMockHttpServerCallback(ctrl)
   266  
   267  	// Fire up a test server
   268  	ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
   269  	if err != nil {
   270  		t.Fatalf("unexpected error creating test server: %v", err)
   271  	}
   272  	defer ts.Close()
   273  
   274  	// Configure the backend to the pre-populated sample state
   275  	url := ts.URL + "/state/sample"
   276  	conf := map[string]cty.Value{
   277  		"address":                cty.StringVal(url),
   278  		"skip_cert_verification": cty.BoolVal(true),
   279  	}
   280  	b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend)
   281  	if nil == b {
   282  		t.Fatal("nil backend")
   283  	}
   284  
   285  	// Now get a state manager and check that it fails to refresh the state
   286  	sm, err := b.StateMgr(backend.DefaultStateName)
   287  	if err != nil {
   288  		t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
   289  	}
   290  
   291  	opErr := new(net.OpError)
   292  	err = sm.RefreshState()
   293  	if err == nil {
   294  		t.Fatal("expected error when refreshing state without a client cert")
   295  	}
   296  	if errors.As(err, &opErr) {
   297  		errType := fmt.Sprintf("%T", opErr.Err)
   298  		expected := "tls.alert"
   299  		if errType != expected {
   300  			t.Fatalf("expected net.OpError.Err type: %q got: %q error:%s", expected, errType, err)
   301  		}
   302  	} else {
   303  		t.Fatalf("unexpected error: %v", err)
   304  	}
   305  }
   306  
   307  func TestMTLSServer_WithCertPasses(t *testing.T) {
   308  	// Ensure that the expected amount of calls is made to the server
   309  	ctrl := gomock.NewController(t)
   310  	defer ctrl.Finish()
   311  	mockCallback := NewMockHttpServerCallback(ctrl)
   312  
   313  	// Two or three (not testing the caching here) calls to GET
   314  	mockCallback.EXPECT().
   315  		StateGET(gomock.Any()).
   316  		MinTimes(2).
   317  		MaxTimes(3)
   318  	// One call to the POST to write the data
   319  	mockCallback.EXPECT().
   320  		StatePOST(gomock.Any())
   321  
   322  	// Fire up a test server
   323  	ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
   324  	if err != nil {
   325  		t.Fatalf("unexpected error creating test server: %v", err)
   326  	}
   327  	defer ts.Close()
   328  
   329  	// Configure the backend to the pre-populated sample state, and with all the test certs lined up
   330  	url := ts.URL + "/state/sample"
   331  	caData, err := os.ReadFile("testdata/certs/ca.cert.pem")
   332  	if err != nil {
   333  		t.Fatalf("error reading ca certs: %v", err)
   334  	}
   335  	clientCertData, err := os.ReadFile("testdata/certs/client.crt")
   336  	if err != nil {
   337  		t.Fatalf("error reading client cert: %v", err)
   338  	}
   339  	clientKeyData, err := os.ReadFile("testdata/certs/client.key")
   340  	if err != nil {
   341  		t.Fatalf("error reading client key: %v", err)
   342  	}
   343  	conf := map[string]cty.Value{
   344  		"address":                   cty.StringVal(url),
   345  		"lock_address":              cty.StringVal(url),
   346  		"unlock_address":            cty.StringVal(url),
   347  		"client_ca_certificate_pem": cty.StringVal(string(caData)),
   348  		"client_certificate_pem":    cty.StringVal(string(clientCertData)),
   349  		"client_private_key_pem":    cty.StringVal(string(clientKeyData)),
   350  	}
   351  	b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend)
   352  	if nil == b {
   353  		t.Fatal("nil backend")
   354  	}
   355  
   356  	// Now get a state manager, fetch the state, and ensure that the "foo" output is not set
   357  	sm, err := b.StateMgr(backend.DefaultStateName)
   358  	if err != nil {
   359  		t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
   360  	}
   361  	if err = sm.RefreshState(); err != nil {
   362  		t.Fatalf("unexpected error calling RefreshState: %v", err)
   363  	}
   364  	state := sm.State()
   365  	if nil == state {
   366  		t.Fatal("nil state")
   367  	}
   368  	stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   369  	if stateFoo != nil {
   370  		t.Errorf("expected nil foo from state; got %v", stateFoo)
   371  	}
   372  
   373  	// Create a new state that has "foo" set to "bar" and ensure that state is as expected
   374  	state = states.BuildState(func(ss *states.SyncState) {
   375  		ss.SetOutputValue(
   376  			addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance),
   377  			cty.StringVal("bar"),
   378  			false)
   379  	})
   380  	stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   381  	if nil == stateFoo {
   382  		t.Fatal("nil foo after building state with foo populated")
   383  	}
   384  	if foo := stateFoo.Value.AsString(); foo != "bar" {
   385  		t.Errorf("Expected built state foo value to be bar; got %s", foo)
   386  	}
   387  
   388  	// Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states
   389  	curState := sm.State()
   390  	curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   391  	if curStateFoo != nil {
   392  		t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo)
   393  	}
   394  	if reflect.DeepEqual(state, curState) {
   395  		t.Errorf("expected %v != %v; but they were equal", state, curState)
   396  	}
   397  
   398  	// Write the new state, persist, and refresh
   399  	if err = sm.WriteState(state); err != nil {
   400  		t.Errorf("error writing state: %v", err)
   401  	}
   402  	if err = sm.PersistState(nil); err != nil {
   403  		t.Errorf("error persisting state: %v", err)
   404  	}
   405  	if err = sm.RefreshState(); err != nil {
   406  		t.Errorf("error refreshing state: %v", err)
   407  	}
   408  
   409  	// Get the state again and verify that is now the same as state and has the "foo" value set to "bar"
   410  	curState = sm.State()
   411  	if !reflect.DeepEqual(state, curState) {
   412  		t.Errorf("expected %v == %v; but they were unequal", state, curState)
   413  	}
   414  	curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
   415  	if nil == curStateFoo {
   416  		t.Fatal("nil foo")
   417  	}
   418  	if foo := curStateFoo.Value.AsString(); foo != "bar" {
   419  		t.Errorf("expected foo to be bar, but got: %s", foo)
   420  	}
   421  }
   422  
   423  // TestRunServer allows running the server for local debugging; it runs until ctl-c is received
   424  func TestRunServer(t *testing.T) {
   425  	if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok {
   426  		t.Skip("TEST_RUN_SERVER not set")
   427  	}
   428  	s, err := NewHttpTestServer()
   429  	if err != nil {
   430  		t.Fatalf("unexpected error creating test server: %v", err)
   431  	}
   432  	defer s.Close()
   433  
   434  	t.Log(s.URL)
   435  
   436  	ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
   437  	defer cancel()
   438  	// wait until signal
   439  	<-ctx.Done()
   440  }