github.com/graywolf-at-work-2/terraform-vendor@v1.4.5/internal/backend/remote-state/http/server_test.go (about) 1 package http 2 3 //go:generate go run github.com/golang/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "encoding/json" 10 "io" 11 "net/http" 12 "net/http/httptest" 13 "os" 14 "os/signal" 15 "path/filepath" 16 "reflect" 17 "strings" 18 "sync" 19 "syscall" 20 "testing" 21 22 "github.com/golang/mock/gomock" 23 "github.com/hashicorp/terraform/internal/addrs" 24 "github.com/hashicorp/terraform/internal/backend" 25 "github.com/hashicorp/terraform/internal/configs" 26 "github.com/hashicorp/terraform/internal/states" 27 "github.com/zclconf/go-cty/cty" 28 ) 29 30 const sampleState = ` 31 { 32 "version": 4, 33 "serial": 0, 34 "lineage": "666f9301-7e65-4b19-ae23-71184bb19b03", 35 "remote": { 36 "type": "http", 37 "config": { 38 "path": "local-state.tfstate" 39 } 40 } 41 } 42 ` 43 44 type ( 45 HttpServerCallback interface { 46 StateGET(req *http.Request) 47 StatePOST(req *http.Request) 48 StateDELETE(req *http.Request) 49 StateLOCK(req *http.Request) 50 StateUNLOCK(req *http.Request) 51 } 52 httpServer struct { 53 r *http.ServeMux 54 data map[string]string 55 locks map[string]string 56 lock sync.RWMutex 57 58 httpServerCallback HttpServerCallback 59 } 60 httpServerOpt func(*httpServer) 61 ) 62 63 func withHttpServerCallback(callback HttpServerCallback) httpServerOpt { 64 return func(s *httpServer) { 65 s.httpServerCallback = callback 66 } 67 } 68 69 func newHttpServer(opts ...httpServerOpt) *httpServer { 70 r := http.NewServeMux() 71 s := &httpServer{ 72 r: r, 73 data: make(map[string]string), 74 locks: make(map[string]string), 75 } 76 for _, opt := range opts { 77 opt(s) 78 } 79 s.data["sample"] = sampleState 80 r.HandleFunc("/state/", s.handleState) 81 return s 82 } 83 84 func (h *httpServer) getResource(req *http.Request) string { 85 switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) { 86 case 3: 87 return pathParts[2] 88 default: 89 return "" 90 } 91 } 92 93 func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) { 94 switch req.Method { 95 case "GET": 96 h.handleStateGET(writer, req) 97 case "POST": 98 h.handleStatePOST(writer, req) 99 case "DELETE": 100 h.handleStateDELETE(writer, req) 101 case "LOCK": 102 h.handleStateLOCK(writer, req) 103 case "UNLOCK": 104 h.handleStateUNLOCK(writer, req) 105 } 106 } 107 108 func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) { 109 if h.httpServerCallback != nil { 110 defer h.httpServerCallback.StateGET(req) 111 } 112 resource := h.getResource(req) 113 114 h.lock.RLock() 115 defer h.lock.RUnlock() 116 117 if state, ok := h.data[resource]; ok { 118 _, _ = io.WriteString(writer, state) 119 } else { 120 writer.WriteHeader(http.StatusNotFound) 121 } 122 } 123 124 func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) { 125 if h.httpServerCallback != nil { 126 defer h.httpServerCallback.StatePOST(req) 127 } 128 defer req.Body.Close() 129 resource := h.getResource(req) 130 131 data, err := io.ReadAll(req.Body) 132 if err != nil { 133 writer.WriteHeader(http.StatusBadRequest) 134 return 135 } 136 137 h.lock.Lock() 138 defer h.lock.Unlock() 139 140 h.data[resource] = string(data) 141 writer.WriteHeader(http.StatusOK) 142 } 143 144 func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) { 145 if h.httpServerCallback != nil { 146 defer h.httpServerCallback.StateDELETE(req) 147 } 148 resource := h.getResource(req) 149 150 h.lock.Lock() 151 defer h.lock.Unlock() 152 153 delete(h.data, resource) 154 writer.WriteHeader(http.StatusOK) 155 } 156 157 func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) { 158 if h.httpServerCallback != nil { 159 defer h.httpServerCallback.StateLOCK(req) 160 } 161 defer req.Body.Close() 162 resource := h.getResource(req) 163 164 data, err := io.ReadAll(req.Body) 165 if err != nil { 166 writer.WriteHeader(http.StatusBadRequest) 167 return 168 } 169 170 h.lock.Lock() 171 defer h.lock.Unlock() 172 173 if existingLock, ok := h.locks[resource]; ok { 174 writer.WriteHeader(http.StatusLocked) 175 _, _ = io.WriteString(writer, existingLock) 176 } else { 177 h.locks[resource] = string(data) 178 _, _ = io.WriteString(writer, existingLock) 179 } 180 } 181 182 func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) { 183 if h.httpServerCallback != nil { 184 defer h.httpServerCallback.StateUNLOCK(req) 185 } 186 defer req.Body.Close() 187 resource := h.getResource(req) 188 189 data, err := io.ReadAll(req.Body) 190 if err != nil { 191 writer.WriteHeader(http.StatusBadRequest) 192 return 193 } 194 var lockInfo map[string]interface{} 195 if err = json.Unmarshal(data, &lockInfo); err != nil { 196 writer.WriteHeader(http.StatusInternalServerError) 197 return 198 } 199 200 h.lock.Lock() 201 defer h.lock.Unlock() 202 203 if existingLock, ok := h.locks[resource]; ok { 204 var existingLockInfo map[string]interface{} 205 if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil { 206 writer.WriteHeader(http.StatusInternalServerError) 207 return 208 } 209 lockID := lockInfo["ID"].(string) 210 existingID := existingLockInfo["ID"].(string) 211 if lockID != existingID { 212 writer.WriteHeader(http.StatusConflict) 213 _, _ = io.WriteString(writer, existingLock) 214 } else { 215 delete(h.locks, resource) 216 _, _ = io.WriteString(writer, existingLock) 217 } 218 } else { 219 writer.WriteHeader(http.StatusConflict) 220 } 221 } 222 223 func (h *httpServer) handler() http.Handler { 224 return h.r 225 } 226 227 func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) { 228 clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem") 229 if err != nil { 230 return nil, err 231 } 232 clientCAs := x509.NewCertPool() 233 clientCAs.AppendCertsFromPEM(clientCAData) 234 235 cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key") 236 if err != nil { 237 return nil, err 238 } 239 240 h := newHttpServer(opts...) 241 s := httptest.NewUnstartedServer(h.handler()) 242 s.TLS = &tls.Config{ 243 ClientAuth: tls.RequireAndVerifyClientCert, 244 ClientCAs: clientCAs, 245 Certificates: []tls.Certificate{cert}, 246 } 247 248 s.StartTLS() 249 return s, nil 250 } 251 252 func TestMTLSServer_NoCertFails(t *testing.T) { 253 // Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert 254 ctrl := gomock.NewController(t) 255 defer ctrl.Finish() 256 mockCallback := NewMockHttpServerCallback(ctrl) 257 258 // Fire up a test server 259 ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback)) 260 if err != nil { 261 t.Fatalf("unexpected error creating test server: %v", err) 262 } 263 defer ts.Close() 264 265 // Configure the backend to the pre-populated sample state 266 url := ts.URL + "/state/sample" 267 conf := map[string]cty.Value{ 268 "address": cty.StringVal(url), 269 "skip_cert_verification": cty.BoolVal(true), 270 } 271 b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend) 272 if nil == b { 273 t.Fatal("nil backend") 274 } 275 276 // Now get a state manager and check that it fails to refresh the state 277 sm, err := b.StateMgr(backend.DefaultStateName) 278 if err != nil { 279 t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) 280 } 281 err = sm.RefreshState() 282 if nil == err { 283 t.Error("expected error when refreshing state without a client cert") 284 } else if !strings.Contains(err.Error(), "remote error: tls: bad certificate") { 285 t.Errorf("expected the error to report missing tls credentials: %v", err) 286 } 287 } 288 289 func TestMTLSServer_WithCertPasses(t *testing.T) { 290 // Ensure that the expected amount of calls is made to the server 291 ctrl := gomock.NewController(t) 292 defer ctrl.Finish() 293 mockCallback := NewMockHttpServerCallback(ctrl) 294 295 // Two or three (not testing the caching here) calls to GET 296 mockCallback.EXPECT(). 297 StateGET(gomock.Any()). 298 MinTimes(2). 299 MaxTimes(3) 300 // One call to the POST to write the data 301 mockCallback.EXPECT(). 302 StatePOST(gomock.Any()) 303 304 // Fire up a test server 305 ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback)) 306 if err != nil { 307 t.Fatalf("unexpected error creating test server: %v", err) 308 } 309 defer ts.Close() 310 311 // Configure the backend to the pre-populated sample state, and with all the test certs lined up 312 url := ts.URL + "/state/sample" 313 caData, err := os.ReadFile("testdata/certs/ca.cert.pem") 314 if err != nil { 315 t.Fatalf("error reading ca certs: %v", err) 316 } 317 clientCertData, err := os.ReadFile("testdata/certs/client.crt") 318 if err != nil { 319 t.Fatalf("error reading client cert: %v", err) 320 } 321 clientKeyData, err := os.ReadFile("testdata/certs/client.key") 322 if err != nil { 323 t.Fatalf("error reading client key: %v", err) 324 } 325 conf := map[string]cty.Value{ 326 "address": cty.StringVal(url), 327 "lock_address": cty.StringVal(url), 328 "unlock_address": cty.StringVal(url), 329 "client_ca_certificate_pem": cty.StringVal(string(caData)), 330 "client_certificate_pem": cty.StringVal(string(clientCertData)), 331 "client_private_key_pem": cty.StringVal(string(clientKeyData)), 332 } 333 b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend) 334 if nil == b { 335 t.Fatal("nil backend") 336 } 337 338 // Now get a state manager, fetch the state, and ensure that the "foo" output is not set 339 sm, err := b.StateMgr(backend.DefaultStateName) 340 if err != nil { 341 t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) 342 } 343 if err = sm.RefreshState(); err != nil { 344 t.Fatalf("unexpected error calling RefreshState: %v", err) 345 } 346 state := sm.State() 347 if nil == state { 348 t.Fatal("nil state") 349 } 350 stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 351 if stateFoo != nil { 352 t.Errorf("expected nil foo from state; got %v", stateFoo) 353 } 354 355 // Create a new state that has "foo" set to "bar" and ensure that state is as expected 356 state = states.BuildState(func(ss *states.SyncState) { 357 ss.SetOutputValue( 358 addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance), 359 cty.StringVal("bar"), 360 false) 361 }) 362 stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 363 if nil == stateFoo { 364 t.Fatal("nil foo after building state with foo populated") 365 } 366 if foo := stateFoo.Value.AsString(); foo != "bar" { 367 t.Errorf("Expected built state foo value to be bar; got %s", foo) 368 } 369 370 // Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states 371 curState := sm.State() 372 curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 373 if curStateFoo != nil { 374 t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo) 375 } 376 if reflect.DeepEqual(state, curState) { 377 t.Errorf("expected %v != %v; but they were equal", state, curState) 378 } 379 380 // Write the new state, persist, and refresh 381 if err = sm.WriteState(state); err != nil { 382 t.Errorf("error writing state: %v", err) 383 } 384 if err = sm.PersistState(nil); err != nil { 385 t.Errorf("error persisting state: %v", err) 386 } 387 if err = sm.RefreshState(); err != nil { 388 t.Errorf("error refreshing state: %v", err) 389 } 390 391 // Get the state again and verify that is now the same as state and has the "foo" value set to "bar" 392 curState = sm.State() 393 if !reflect.DeepEqual(state, curState) { 394 t.Errorf("expected %v == %v; but they were unequal", state, curState) 395 } 396 curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) 397 if nil == curStateFoo { 398 t.Fatal("nil foo") 399 } 400 if foo := curStateFoo.Value.AsString(); foo != "bar" { 401 t.Errorf("expected foo to be bar, but got: %s", foo) 402 } 403 } 404 405 // TestRunServer allows running the server for local debugging; it runs until ctl-c is received 406 func TestRunServer(t *testing.T) { 407 if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok { 408 t.Skip("TEST_RUN_SERVER not set") 409 } 410 s, err := NewHttpTestServer() 411 if err != nil { 412 t.Fatalf("unexpected error creating test server: %v", err) 413 } 414 defer s.Close() 415 416 t.Log(s.URL) 417 418 ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 419 defer cancel() 420 // wait until signal 421 <-ctx.Done() 422 }