github.com/opentofu/opentofu@v1.7.1/internal/backend/remote-state/http/server_test.go (about) 1 // Copyright (c) The OpenTofu Authors 2 // SPDX-License-Identifier: MPL-2.0 3 // Copyright (c) 2023 HashiCorp, Inc. 4 // SPDX-License-Identifier: MPL-2.0 5 6 package http 7 8 //go:generate go run github.com/golang/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE 9 10 import ( 11 "context" 12 "crypto/tls" 13 "crypto/x509" 14 "encoding/json" 15 "errors" 16 "fmt" 17 "io" 18 "net" 19 "net/http" 20 "net/http/httptest" 21 "os" 22 "os/signal" 23 "path/filepath" 24 "reflect" 25 "strings" 26 "sync" 27 "syscall" 28 "testing" 29 30 "github.com/golang/mock/gomock" 31 "github.com/opentofu/opentofu/internal/addrs" 32 "github.com/opentofu/opentofu/internal/backend" 33 "github.com/opentofu/opentofu/internal/configs" 34 "github.com/opentofu/opentofu/internal/encryption" 35 "github.com/opentofu/opentofu/internal/states" 36 "github.com/zclconf/go-cty/cty" 37 ) 38 39 const sampleState = ` 40 { 41 "version": 4, 42 "serial": 0, 43 "lineage": "666f9301-7e65-4b19-ae23-71184bb19b03", 44 "remote": { 45 "type": "http", 46 "config": { 47 "path": "local-state.tfstate" 48 } 49 } 50 } 51 ` 52 53 type ( 54 HttpServerCallback interface { 55 StateGET(req *http.Request) 56 StatePOST(req *http.Request) 57 StateDELETE(req *http.Request) 58 StateLOCK(req *http.Request) 59 StateUNLOCK(req *http.Request) 60 } 61 httpServer struct { 62 r *http.ServeMux 63 data map[string]string 64 locks map[string]string 65 lock sync.RWMutex 66 67 httpServerCallback HttpServerCallback 68 } 69 httpServerOpt func(*httpServer) 70 ) 71 72 func withHttpServerCallback(callback HttpServerCallback) httpServerOpt { 73 return func(s *httpServer) { 74 s.httpServerCallback = callback 75 } 76 } 77 78 func newHttpServer(opts ...httpServerOpt) *httpServer { 79 r := http.NewServeMux() 80 s := &httpServer{ 81 r: r, 82 data: make(map[string]string), 83 locks: make(map[string]string), 84 } 85 for _, opt := range opts { 86 opt(s) 87 } 88 s.data["sample"] = sampleState 89 r.HandleFunc("/state/", s.handleState) 90 return s 91 } 92 93 func (h *httpServer) getResource(req *http.Request) string { 94 switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) { 95 case 3: 96 return pathParts[2] 97 default: 98 return "" 99 } 100 } 101 102 func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) { 103 switch req.Method { 104 case "GET": 105 h.handleStateGET(writer, req) 106 case "POST": 107 h.handleStatePOST(writer, req) 108 case "DELETE": 109 h.handleStateDELETE(writer, req) 110 case "LOCK": 111 h.handleStateLOCK(writer, req) 112 case "UNLOCK": 113 h.handleStateUNLOCK(writer, req) 114 } 115 } 116 117 func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) { 118 if h.httpServerCallback != nil { 119 defer h.httpServerCallback.StateGET(req) 120 } 121 resource := h.getResource(req) 122 123 h.lock.RLock() 124 defer h.lock.RUnlock() 125 126 if state, ok := h.data[resource]; ok { 127 _, _ = io.WriteString(writer, state) 128 } else { 129 writer.WriteHeader(http.StatusNotFound) 130 } 131 } 132 133 func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) { 134 if h.httpServerCallback != nil { 135 defer h.httpServerCallback.StatePOST(req) 136 } 137 defer req.Body.Close() 138 resource := h.getResource(req) 139 140 data, err := io.ReadAll(req.Body) 141 if err != nil { 142 writer.WriteHeader(http.StatusBadRequest) 143 return 144 } 145 146 h.lock.Lock() 147 defer h.lock.Unlock() 148 149 h.data[resource] = string(data) 150 writer.WriteHeader(http.StatusOK) 151 } 152 153 func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) { 154 if h.httpServerCallback != nil { 155 defer h.httpServerCallback.StateDELETE(req) 156 } 157 resource := h.getResource(req) 158 159 h.lock.Lock() 160 defer h.lock.Unlock() 161 162 delete(h.data, resource) 163 writer.WriteHeader(http.StatusOK) 164 } 165 166 func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) { 167 if h.httpServerCallback != nil { 168 defer h.httpServerCallback.StateLOCK(req) 169 } 170 defer req.Body.Close() 171 resource := h.getResource(req) 172 173 data, err := io.ReadAll(req.Body) 174 if err != nil { 175 writer.WriteHeader(http.StatusBadRequest) 176 return 177 } 178 179 h.lock.Lock() 180 defer h.lock.Unlock() 181 182 if existingLock, ok := h.locks[resource]; ok { 183 writer.WriteHeader(http.StatusLocked) 184 _, _ = io.WriteString(writer, existingLock) 185 } else { 186 h.locks[resource] = string(data) 187 _, _ = io.WriteString(writer, existingLock) 188 } 189 } 190 191 func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) { 192 if h.httpServerCallback != nil { 193 defer h.httpServerCallback.StateUNLOCK(req) 194 } 195 defer req.Body.Close() 196 resource := h.getResource(req) 197 198 data, err := io.ReadAll(req.Body) 199 if err != nil { 200 writer.WriteHeader(http.StatusBadRequest) 201 return 202 } 203 var lockInfo map[string]interface{} 204 if err = json.Unmarshal(data, &lockInfo); err != nil { 205 writer.WriteHeader(http.StatusInternalServerError) 206 return 207 } 208 209 h.lock.Lock() 210 defer h.lock.Unlock() 211 212 if existingLock, ok := h.locks[resource]; ok { 213 var existingLockInfo map[string]interface{} 214 if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil { 215 writer.WriteHeader(http.StatusInternalServerError) 216 return 217 } 218 lockID := lockInfo["ID"].(string) 219 existingID := existingLockInfo["ID"].(string) 220 if lockID != existingID { 221 writer.WriteHeader(http.StatusConflict) 222 _, _ = io.WriteString(writer, existingLock) 223 } else { 224 delete(h.locks, resource) 225 _, _ = io.WriteString(writer, existingLock) 226 } 227 } else { 228 writer.WriteHeader(http.StatusConflict) 229 } 230 } 231 232 func (h *httpServer) handler() http.Handler { 233 return h.r 234 } 235 236 func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) { 237 clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem") 238 if err != nil { 239 return nil, err 240 } 241 clientCAs := x509.NewCertPool() 242 clientCAs.AppendCertsFromPEM(clientCAData) 243 244 cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key") 245 if err != nil { 246 return nil, err 247 } 248 249 h := newHttpServer(opts...) 250 s := httptest.NewUnstartedServer(h.handler()) 251 s.TLS = &tls.Config{ 252 ClientAuth: tls.RequireAndVerifyClientCert, 253 ClientCAs: clientCAs, 254 Certificates: []tls.Certificate{cert}, 255 } 256 257 s.StartTLS() 258 return s, nil 259 } 260 261 func TestMTLSServer_NoCertFails(t *testing.T) { 262 // Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert 263 ctrl := gomock.NewController(t) 264 defer ctrl.Finish() 265 mockCallback := NewMockHttpServerCallback(ctrl) 266 267 // Fire up a test server 268 ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback)) 269 if err != nil { 270 t.Fatalf("unexpected error creating test server: %v", err) 271 } 272 defer ts.Close() 273 274 // Configure the backend to the pre-populated sample state 275 url := ts.URL + "/state/sample" 276 conf := map[string]cty.Value{ 277 "address": cty.StringVal(url), 278 "skip_cert_verification": cty.BoolVal(true), 279 } 280 b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend) 281 if nil == b { 282 t.Fatal("nil backend") 283 } 284 285 // Now get a state manager and check that it fails to refresh the state 286 sm, err := b.StateMgr(backend.DefaultStateName) 287 if err != nil { 288 t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) 289 } 290 291 opErr := new(net.OpError) 292 err = sm.RefreshState() 293 if err == nil { 294 t.Fatal("expected error when refreshing state without a client cert") 295 } 296 if errors.As(err, &opErr) { 297 errType := fmt.Sprintf("%T", opErr.Err) 298 expected := "tls.alert" 299 if errType != expected { 300 t.Fatalf("expected net.OpError.Err type: %q got: %q error:%s", expected, errType, err) 301 } 302 } else { 303 t.Fatalf("unexpected error: %v", err) 304 } 305 } 306 307 func TestMTLSServer_WithCertPasses(t *testing.T) { 308 // Ensure that the expected amount of calls is made to the server 309 ctrl := gomock.NewController(t) 310 defer ctrl.Finish() 311 mockCallback := NewMockHttpServerCallback(ctrl) 312 313 // Two or three (not testing the caching here) calls to GET 314 mockCallback.EXPECT(). 315 StateGET(gomock.Any()). 316 MinTimes(2). 317 MaxTimes(3) 318 // One call to the POST to write the data 319 mockCallback.EXPECT(). 320 StatePOST(gomock.Any()) 321 322 // Fire up a test server 323 ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback)) 324 if err != nil { 325 t.Fatalf("unexpected error creating test server: %v", err) 326 } 327 defer ts.Close() 328 329 // Configure the backend to the pre-populated sample state, and with all the test certs lined up 330 url := ts.URL + "/state/sample" 331 caData, err := os.ReadFile("testdata/certs/ca.cert.pem") 332 if err != nil { 333 t.Fatalf("error reading ca certs: %v", err) 334 } 335 clientCertData, err := os.ReadFile("testdata/certs/client.crt") 336 if err != nil { 337 t.Fatalf("error reading client cert: %v", err) 338 } 339 clientKeyData, err := os.ReadFile("testdata/certs/client.key") 340 if err != nil { 341 t.Fatalf("error reading client key: %v", err) 342 } 343 conf := map[string]cty.Value{ 344 "address": cty.StringVal(url), 345 "lock_address": cty.StringVal(url), 346 "unlock_address": cty.StringVal(url), 347 "client_ca_certificate_pem": cty.StringVal(string(caData)), 348 "client_certificate_pem": cty.StringVal(string(clientCertData)), 349 "client_private_key_pem": cty.StringVal(string(clientKeyData)), 350 } 351 b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend) 352 if nil == b { 353 t.Fatal("nil backend") 354 } 355 356 // Now get a state manager, fetch the state, and ensure that the "foo" output is not set 357 sm, err := b.StateMgr(backend.DefaultStateName) 358 if err != nil { 359 t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) 360 } 361 if err = sm.RefreshState(); err != nil { 362 t.Fatalf("unexpected error calling RefreshState: %v", err) 363 } 364 state := sm.State() 365 if nil == state { 366 t.Fatal("nil state") 367 } 368 stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 369 if stateFoo != nil { 370 t.Errorf("expected nil foo from state; got %v", stateFoo) 371 } 372 373 // Create a new state that has "foo" set to "bar" and ensure that state is as expected 374 state = states.BuildState(func(ss *states.SyncState) { 375 ss.SetOutputValue( 376 addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance), 377 cty.StringVal("bar"), 378 false) 379 }) 380 stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 381 if nil == stateFoo { 382 t.Fatal("nil foo after building state with foo populated") 383 } 384 if foo := stateFoo.Value.AsString(); foo != "bar" { 385 t.Errorf("Expected built state foo value to be bar; got %s", foo) 386 } 387 388 // Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states 389 curState := sm.State() 390 curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 391 if curStateFoo != nil { 392 t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo) 393 } 394 if reflect.DeepEqual(state, curState) { 395 t.Errorf("expected %v != %v; but they were equal", state, curState) 396 } 397 398 // Write the new state, persist, and refresh 399 if err = sm.WriteState(state); err != nil { 400 t.Errorf("error writing state: %v", err) 401 } 402 if err = sm.PersistState(nil); err != nil { 403 t.Errorf("error persisting state: %v", err) 404 } 405 if err = sm.RefreshState(); err != nil { 406 t.Errorf("error refreshing state: %v", err) 407 } 408 409 // Get the state again and verify that is now the same as state and has the "foo" value set to "bar" 410 curState = sm.State() 411 if !reflect.DeepEqual(state, curState) { 412 t.Errorf("expected %v == %v; but they were unequal", state, curState) 413 } 414 curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 415 if nil == curStateFoo { 416 t.Fatal("nil foo") 417 } 418 if foo := curStateFoo.Value.AsString(); foo != "bar" { 419 t.Errorf("expected foo to be bar, but got: %s", foo) 420 } 421 } 422 423 // TestRunServer allows running the server for local debugging; it runs until ctl-c is received 424 func TestRunServer(t *testing.T) { 425 if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok { 426 t.Skip("TEST_RUN_SERVER not set") 427 } 428 s, err := NewHttpTestServer() 429 if err != nil { 430 t.Fatalf("unexpected error creating test server: %v", err) 431 } 432 defer s.Close() 433 434 t.Log(s.URL) 435 436 ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 437 defer cancel() 438 // wait until signal 439 <-ctx.Done() 440 }