github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/backend/remote-state/http/server_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package http 5 6 //go:generate go run github.com/golang/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE 7 8 import ( 9 "context" 10 "crypto/tls" 11 "crypto/x509" 12 "encoding/json" 13 "io" 14 "net/http" 15 "net/http/httptest" 16 "os" 17 "os/signal" 18 "path/filepath" 19 "reflect" 20 "strings" 21 "sync" 22 "syscall" 23 "testing" 24 25 "github.com/golang/mock/gomock" 26 "github.com/terramate-io/tf/addrs" 27 "github.com/terramate-io/tf/backend" 28 "github.com/terramate-io/tf/configs" 29 "github.com/terramate-io/tf/states" 30 "github.com/zclconf/go-cty/cty" 31 ) 32 33 const sampleState = ` 34 { 35 "version": 4, 36 "serial": 0, 37 "lineage": "666f9301-7e65-4b19-ae23-71184bb19b03", 38 "remote": { 39 "type": "http", 40 "config": { 41 "path": "local-state.tfstate" 42 } 43 } 44 } 45 ` 46 47 type ( 48 HttpServerCallback interface { 49 StateGET(req *http.Request) 50 StatePOST(req *http.Request) 51 StateDELETE(req *http.Request) 52 StateLOCK(req *http.Request) 53 StateUNLOCK(req *http.Request) 54 } 55 httpServer struct { 56 r *http.ServeMux 57 data map[string]string 58 locks map[string]string 59 lock sync.RWMutex 60 61 httpServerCallback HttpServerCallback 62 } 63 httpServerOpt func(*httpServer) 64 ) 65 66 func withHttpServerCallback(callback HttpServerCallback) httpServerOpt { 67 return func(s *httpServer) { 68 s.httpServerCallback = callback 69 } 70 } 71 72 func newHttpServer(opts ...httpServerOpt) *httpServer { 73 r := http.NewServeMux() 74 s := &httpServer{ 75 r: r, 76 data: make(map[string]string), 77 locks: make(map[string]string), 78 } 79 for _, opt := range opts { 80 opt(s) 81 } 82 s.data["sample"] = sampleState 83 r.HandleFunc("/state/", s.handleState) 84 return s 85 } 86 87 func (h *httpServer) getResource(req *http.Request) string { 88 switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) { 89 case 3: 90 return pathParts[2] 91 default: 92 return "" 93 } 94 } 95 96 func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) { 97 switch req.Method { 98 case "GET": 99 h.handleStateGET(writer, req) 100 case "POST": 101 h.handleStatePOST(writer, req) 102 case "DELETE": 103 h.handleStateDELETE(writer, req) 104 case "LOCK": 105 h.handleStateLOCK(writer, req) 106 case "UNLOCK": 107 h.handleStateUNLOCK(writer, req) 108 } 109 } 110 111 func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) { 112 if h.httpServerCallback != nil { 113 defer h.httpServerCallback.StateGET(req) 114 } 115 resource := h.getResource(req) 116 117 h.lock.RLock() 118 defer h.lock.RUnlock() 119 120 if state, ok := h.data[resource]; ok { 121 _, _ = io.WriteString(writer, state) 122 } else { 123 writer.WriteHeader(http.StatusNotFound) 124 } 125 } 126 127 func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) { 128 if h.httpServerCallback != nil { 129 defer h.httpServerCallback.StatePOST(req) 130 } 131 defer req.Body.Close() 132 resource := h.getResource(req) 133 134 data, err := io.ReadAll(req.Body) 135 if err != nil { 136 writer.WriteHeader(http.StatusBadRequest) 137 return 138 } 139 140 h.lock.Lock() 141 defer h.lock.Unlock() 142 143 h.data[resource] = string(data) 144 writer.WriteHeader(http.StatusOK) 145 } 146 147 func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) { 148 if h.httpServerCallback != nil { 149 defer h.httpServerCallback.StateDELETE(req) 150 } 151 resource := h.getResource(req) 152 153 h.lock.Lock() 154 defer h.lock.Unlock() 155 156 delete(h.data, resource) 157 writer.WriteHeader(http.StatusOK) 158 } 159 160 func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) { 161 if h.httpServerCallback != nil { 162 defer h.httpServerCallback.StateLOCK(req) 163 } 164 defer req.Body.Close() 165 resource := h.getResource(req) 166 167 data, err := io.ReadAll(req.Body) 168 if err != nil { 169 writer.WriteHeader(http.StatusBadRequest) 170 return 171 } 172 173 h.lock.Lock() 174 defer h.lock.Unlock() 175 176 if existingLock, ok := h.locks[resource]; ok { 177 writer.WriteHeader(http.StatusLocked) 178 _, _ = io.WriteString(writer, existingLock) 179 } else { 180 h.locks[resource] = string(data) 181 _, _ = io.WriteString(writer, existingLock) 182 } 183 } 184 185 func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) { 186 if h.httpServerCallback != nil { 187 defer h.httpServerCallback.StateUNLOCK(req) 188 } 189 defer req.Body.Close() 190 resource := h.getResource(req) 191 192 data, err := io.ReadAll(req.Body) 193 if err != nil { 194 writer.WriteHeader(http.StatusBadRequest) 195 return 196 } 197 var lockInfo map[string]interface{} 198 if err = json.Unmarshal(data, &lockInfo); err != nil { 199 writer.WriteHeader(http.StatusInternalServerError) 200 return 201 } 202 203 h.lock.Lock() 204 defer h.lock.Unlock() 205 206 if existingLock, ok := h.locks[resource]; ok { 207 var existingLockInfo map[string]interface{} 208 if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil { 209 writer.WriteHeader(http.StatusInternalServerError) 210 return 211 } 212 lockID := lockInfo["ID"].(string) 213 existingID := existingLockInfo["ID"].(string) 214 if lockID != existingID { 215 writer.WriteHeader(http.StatusConflict) 216 _, _ = io.WriteString(writer, existingLock) 217 } else { 218 delete(h.locks, resource) 219 _, _ = io.WriteString(writer, existingLock) 220 } 221 } else { 222 writer.WriteHeader(http.StatusConflict) 223 } 224 } 225 226 func (h *httpServer) handler() http.Handler { 227 return h.r 228 } 229 230 func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) { 231 clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem") 232 if err != nil { 233 return nil, err 234 } 235 clientCAs := x509.NewCertPool() 236 clientCAs.AppendCertsFromPEM(clientCAData) 237 238 cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key") 239 if err != nil { 240 return nil, err 241 } 242 243 h := newHttpServer(opts...) 244 s := httptest.NewUnstartedServer(h.handler()) 245 s.TLS = &tls.Config{ 246 ClientAuth: tls.RequireAndVerifyClientCert, 247 ClientCAs: clientCAs, 248 Certificates: []tls.Certificate{cert}, 249 } 250 251 s.StartTLS() 252 return s, nil 253 } 254 255 func TestMTLSServer_NoCertFails(t *testing.T) { 256 // Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert 257 ctrl := gomock.NewController(t) 258 defer ctrl.Finish() 259 mockCallback := NewMockHttpServerCallback(ctrl) 260 261 // Fire up a test server 262 ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback)) 263 if err != nil { 264 t.Fatalf("unexpected error creating test server: %v", err) 265 } 266 defer ts.Close() 267 268 // Configure the backend to the pre-populated sample state 269 url := ts.URL + "/state/sample" 270 conf := map[string]cty.Value{ 271 "address": cty.StringVal(url), 272 "skip_cert_verification": cty.BoolVal(true), 273 } 274 b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend) 275 if nil == b { 276 t.Fatal("nil backend") 277 } 278 279 // Now get a state manager and check that it fails to refresh the state 280 sm, err := b.StateMgr(backend.DefaultStateName) 281 if err != nil { 282 t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) 283 } 284 err = sm.RefreshState() 285 if nil == err { 286 t.Error("expected error when refreshing state without a client cert") 287 } else if !strings.Contains(err.Error(), "remote error: tls: bad certificate") { 288 t.Errorf("expected the error to report missing tls credentials: %v", err) 289 } 290 } 291 292 func TestMTLSServer_WithCertPasses(t *testing.T) { 293 // Ensure that the expected amount of calls is made to the server 294 ctrl := gomock.NewController(t) 295 defer ctrl.Finish() 296 mockCallback := NewMockHttpServerCallback(ctrl) 297 298 // Two or three (not testing the caching here) calls to GET 299 mockCallback.EXPECT(). 300 StateGET(gomock.Any()). 301 MinTimes(2). 302 MaxTimes(3) 303 // One call to the POST to write the data 304 mockCallback.EXPECT(). 305 StatePOST(gomock.Any()) 306 307 // Fire up a test server 308 ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback)) 309 if err != nil { 310 t.Fatalf("unexpected error creating test server: %v", err) 311 } 312 defer ts.Close() 313 314 // Configure the backend to the pre-populated sample state, and with all the test certs lined up 315 url := ts.URL + "/state/sample" 316 caData, err := os.ReadFile("testdata/certs/ca.cert.pem") 317 if err != nil { 318 t.Fatalf("error reading ca certs: %v", err) 319 } 320 clientCertData, err := os.ReadFile("testdata/certs/client.crt") 321 if err != nil { 322 t.Fatalf("error reading client cert: %v", err) 323 } 324 clientKeyData, err := os.ReadFile("testdata/certs/client.key") 325 if err != nil { 326 t.Fatalf("error reading client key: %v", err) 327 } 328 conf := map[string]cty.Value{ 329 "address": cty.StringVal(url), 330 "lock_address": cty.StringVal(url), 331 "unlock_address": cty.StringVal(url), 332 "client_ca_certificate_pem": cty.StringVal(string(caData)), 333 "client_certificate_pem": cty.StringVal(string(clientCertData)), 334 "client_private_key_pem": cty.StringVal(string(clientKeyData)), 335 } 336 b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend) 337 if nil == b { 338 t.Fatal("nil backend") 339 } 340 341 // Now get a state manager, fetch the state, and ensure that the "foo" output is not set 342 sm, err := b.StateMgr(backend.DefaultStateName) 343 if err != nil { 344 t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) 345 } 346 if err = sm.RefreshState(); err != nil { 347 t.Fatalf("unexpected error calling RefreshState: %v", err) 348 } 349 state := sm.State() 350 if nil == state { 351 t.Fatal("nil state") 352 } 353 stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 354 if stateFoo != nil { 355 t.Errorf("expected nil foo from state; got %v", stateFoo) 356 } 357 358 // Create a new state that has "foo" set to "bar" and ensure that state is as expected 359 state = states.BuildState(func(ss *states.SyncState) { 360 ss.SetOutputValue( 361 addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance), 362 cty.StringVal("bar"), 363 false) 364 }) 365 stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 366 if nil == stateFoo { 367 t.Fatal("nil foo after building state with foo populated") 368 } 369 if foo := stateFoo.Value.AsString(); foo != "bar" { 370 t.Errorf("Expected built state foo value to be bar; got %s", foo) 371 } 372 373 // Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states 374 curState := sm.State() 375 curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 376 if curStateFoo != nil { 377 t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo) 378 } 379 if reflect.DeepEqual(state, curState) { 380 t.Errorf("expected %v != %v; but they were equal", state, curState) 381 } 382 383 // Write the new state, persist, and refresh 384 if err = sm.WriteState(state); err != nil { 385 t.Errorf("error writing state: %v", err) 386 } 387 if err = sm.PersistState(nil); err != nil { 388 t.Errorf("error persisting state: %v", err) 389 } 390 if err = sm.RefreshState(); err != nil { 391 t.Errorf("error refreshing state: %v", err) 392 } 393 394 // Get the state again and verify that is now the same as state and has the "foo" value set to "bar" 395 curState = sm.State() 396 if !reflect.DeepEqual(state, curState) { 397 t.Errorf("expected %v == %v; but they were unequal", state, curState) 398 } 399 curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 400 if nil == curStateFoo { 401 t.Fatal("nil foo") 402 } 403 if foo := curStateFoo.Value.AsString(); foo != "bar" { 404 t.Errorf("expected foo to be bar, but got: %s", foo) 405 } 406 } 407 408 // TestRunServer allows running the server for local debugging; it runs until ctl-c is received 409 func TestRunServer(t *testing.T) { 410 if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok { 411 t.Skip("TEST_RUN_SERVER not set") 412 } 413 s, err := NewHttpTestServer() 414 if err != nil { 415 t.Fatalf("unexpected error creating test server: %v", err) 416 } 417 defer s.Close() 418 419 t.Log(s.URL) 420 421 ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 422 defer cancel() 423 // wait until signal 424 <-ctx.Done() 425 }