github.com/opentofu/opentofu@v1.7.1/internal/backend/local/hook_state_test.go (about) 1 // Copyright (c) The OpenTofu Authors 2 // SPDX-License-Identifier: MPL-2.0 3 // Copyright (c) 2023 HashiCorp, Inc. 4 // SPDX-License-Identifier: MPL-2.0 5 6 package local 7 8 import ( 9 "fmt" 10 "testing" 11 "time" 12 13 "github.com/google/go-cmp/cmp" 14 "github.com/opentofu/opentofu/internal/states" 15 "github.com/opentofu/opentofu/internal/states/statemgr" 16 "github.com/opentofu/opentofu/internal/tofu" 17 ) 18 19 func TestStateHook_impl(t *testing.T) { 20 var _ tofu.Hook = new(StateHook) 21 } 22 23 func TestStateHook(t *testing.T) { 24 is := statemgr.NewTransientInMemory(nil) 25 var hook tofu.Hook = &StateHook{StateMgr: is} 26 27 s := statemgr.TestFullInitialState() 28 action, err := hook.PostStateUpdate(s) 29 if err != nil { 30 t.Fatalf("err: %s", err) 31 } 32 if action != tofu.HookActionContinue { 33 t.Fatalf("bad: %v", action) 34 } 35 if !is.State().Equal(s) { 36 t.Fatalf("bad state: %#v", is.State()) 37 } 38 } 39 40 func TestStateHookStopping(t *testing.T) { 41 is := &testPersistentState{} 42 hook := &StateHook{ 43 StateMgr: is, 44 Schemas: &tofu.Schemas{}, 45 PersistInterval: 4 * time.Hour, 46 intermediatePersist: IntermediateStatePersistInfo{ 47 LastPersist: time.Now(), 48 }, 49 } 50 51 s := statemgr.TestFullInitialState() 52 action, err := hook.PostStateUpdate(s) 53 if err != nil { 54 t.Fatalf("unexpected error from PostStateUpdate: %s", err) 55 } 56 if got, want := action, tofu.HookActionContinue; got != want { 57 t.Fatalf("wrong hookaction %#v; want %#v", got, want) 58 } 59 if is.Written == nil || !is.Written.Equal(s) { 60 t.Fatalf("mismatching state written") 61 } 62 if is.Persisted != nil { 63 t.Fatalf("persisted too soon") 64 } 65 66 // We'll now force lastPersist to be long enough ago that persisting 67 // should be due on the next call. 68 hook.intermediatePersist.LastPersist = time.Now().Add(-5 * time.Hour) 69 hook.PostStateUpdate(s) 70 if is.Written == nil || !is.Written.Equal(s) { 71 t.Fatalf("mismatching state written") 72 } 73 if is.Persisted == nil || !is.Persisted.Equal(s) { 74 t.Fatalf("mismatching state persisted") 75 } 76 hook.PostStateUpdate(s) 77 if is.Written == nil || !is.Written.Equal(s) { 78 t.Fatalf("mismatching state written") 79 } 80 if is.Persisted == nil || !is.Persisted.Equal(s) { 81 t.Fatalf("mismatching state persisted") 82 } 83 84 gotLog := is.CallLog 85 wantLog := []string{ 86 // Initial call before we reset lastPersist 87 "WriteState", 88 89 // Write and then persist after we reset lastPersist 90 "WriteState", 91 "PersistState", 92 93 // Final call when persisting wasn't due yet. 94 "WriteState", 95 } 96 if diff := cmp.Diff(wantLog, gotLog); diff != "" { 97 t.Fatalf("wrong call log so far\n%s", diff) 98 } 99 100 // We'll reset the log now before we try seeing what happens after 101 // we use "Stopped". 102 is.CallLog = is.CallLog[:0] 103 is.Persisted = nil 104 105 hook.Stopping() 106 if is.Persisted == nil || !is.Persisted.Equal(s) { 107 t.Fatalf("mismatching state persisted") 108 } 109 110 is.Persisted = nil 111 hook.PostStateUpdate(s) 112 if is.Persisted == nil || !is.Persisted.Equal(s) { 113 t.Fatalf("mismatching state persisted") 114 } 115 is.Persisted = nil 116 hook.PostStateUpdate(s) 117 if is.Persisted == nil || !is.Persisted.Equal(s) { 118 t.Fatalf("mismatching state persisted") 119 } 120 121 gotLog = is.CallLog 122 wantLog = []string{ 123 // "Stopping" immediately persisted 124 "PersistState", 125 126 // PostStateUpdate then writes and persists on every call, 127 // on the assumption that we're now bailing out after 128 // being cancelled and trying to save as much state as we can. 129 "WriteState", 130 "PersistState", 131 "WriteState", 132 "PersistState", 133 } 134 if diff := cmp.Diff(wantLog, gotLog); diff != "" { 135 t.Fatalf("wrong call log once in stopping mode\n%s", diff) 136 } 137 } 138 139 func TestStateHookCustomPersistRule(t *testing.T) { 140 is := &testPersistentStateThatRefusesToPersist{} 141 hook := &StateHook{ 142 StateMgr: is, 143 Schemas: &tofu.Schemas{}, 144 PersistInterval: 4 * time.Hour, 145 intermediatePersist: IntermediateStatePersistInfo{ 146 LastPersist: time.Now(), 147 }, 148 } 149 150 s := statemgr.TestFullInitialState() 151 action, err := hook.PostStateUpdate(s) 152 if err != nil { 153 t.Fatalf("unexpected error from PostStateUpdate: %s", err) 154 } 155 if got, want := action, tofu.HookActionContinue; got != want { 156 t.Fatalf("wrong hookaction %#v; want %#v", got, want) 157 } 158 if is.Written == nil || !is.Written.Equal(s) { 159 t.Fatalf("mismatching state written") 160 } 161 if is.Persisted != nil { 162 t.Fatalf("persisted too soon") 163 } 164 165 // We'll now force lastPersist to be long enough ago that persisting 166 // should be due on the next call. 167 hook.intermediatePersist.LastPersist = time.Now().Add(-5 * time.Hour) 168 hook.PostStateUpdate(s) 169 if is.Written == nil || !is.Written.Equal(s) { 170 t.Fatalf("mismatching state written") 171 } 172 if is.Persisted != nil { 173 t.Fatalf("has a persisted state, but shouldn't") 174 } 175 hook.PostStateUpdate(s) 176 if is.Written == nil || !is.Written.Equal(s) { 177 t.Fatalf("mismatching state written") 178 } 179 if is.Persisted != nil { 180 t.Fatalf("has a persisted state, but shouldn't") 181 } 182 183 gotLog := is.CallLog 184 wantLog := []string{ 185 // Initial call before we reset lastPersist 186 "WriteState", 187 "ShouldPersistIntermediateState", 188 // Previous call should return false, preventing a "PersistState" call 189 190 // Write and then decline to persist 191 "WriteState", 192 "ShouldPersistIntermediateState", 193 // Previous call should return false, preventing a "PersistState" call 194 195 // Final call before we start "stopping". 196 "WriteState", 197 "ShouldPersistIntermediateState", 198 // Previous call should return false, preventing a "PersistState" call 199 } 200 if diff := cmp.Diff(wantLog, gotLog); diff != "" { 201 t.Fatalf("wrong call log so far\n%s", diff) 202 } 203 204 // We'll reset the log now before we try seeing what happens after 205 // we use "Stopped". 206 is.CallLog = is.CallLog[:0] 207 is.Persisted = nil 208 209 hook.Stopping() 210 if is.Persisted == nil || !is.Persisted.Equal(s) { 211 t.Fatalf("mismatching state persisted") 212 } 213 214 is.Persisted = nil 215 hook.PostStateUpdate(s) 216 if is.Persisted == nil || !is.Persisted.Equal(s) { 217 t.Fatalf("mismatching state persisted") 218 } 219 is.Persisted = nil 220 hook.PostStateUpdate(s) 221 if is.Persisted == nil || !is.Persisted.Equal(s) { 222 t.Fatalf("mismatching state persisted") 223 } 224 225 gotLog = is.CallLog 226 wantLog = []string{ 227 "ShouldPersistIntermediateState", 228 // Previous call should return true, allowing the following "PersistState" call 229 "PersistState", 230 "WriteState", 231 "ShouldPersistIntermediateState", 232 // Previous call should return true, allowing the following "PersistState" call 233 "PersistState", 234 "WriteState", 235 "ShouldPersistIntermediateState", 236 // Previous call should return true, allowing the following "PersistState" call 237 "PersistState", 238 } 239 if diff := cmp.Diff(wantLog, gotLog); diff != "" { 240 t.Fatalf("wrong call log once in stopping mode\n%s", diff) 241 } 242 } 243 244 type testPersistentState struct { 245 CallLog []string 246 247 Written *states.State 248 Persisted *states.State 249 } 250 251 var _ statemgr.Writer = (*testPersistentState)(nil) 252 var _ statemgr.Persister = (*testPersistentState)(nil) 253 254 func (sm *testPersistentState) WriteState(state *states.State) error { 255 sm.CallLog = append(sm.CallLog, "WriteState") 256 sm.Written = state 257 return nil 258 } 259 260 func (sm *testPersistentState) PersistState(schemas *tofu.Schemas) error { 261 if schemas == nil { 262 return fmt.Errorf("no schemas") 263 } 264 sm.CallLog = append(sm.CallLog, "PersistState") 265 sm.Persisted = sm.Written 266 return nil 267 } 268 269 type testPersistentStateThatRefusesToPersist struct { 270 CallLog []string 271 272 Written *states.State 273 Persisted *states.State 274 } 275 276 var _ statemgr.Writer = (*testPersistentStateThatRefusesToPersist)(nil) 277 var _ statemgr.Persister = (*testPersistentStateThatRefusesToPersist)(nil) 278 var _ IntermediateStateConditionalPersister = (*testPersistentStateThatRefusesToPersist)(nil) 279 280 func (sm *testPersistentStateThatRefusesToPersist) WriteState(state *states.State) error { 281 sm.CallLog = append(sm.CallLog, "WriteState") 282 sm.Written = state 283 return nil 284 } 285 286 func (sm *testPersistentStateThatRefusesToPersist) PersistState(schemas *tofu.Schemas) error { 287 if schemas == nil { 288 return fmt.Errorf("no schemas") 289 } 290 sm.CallLog = append(sm.CallLog, "PersistState") 291 sm.Persisted = sm.Written 292 return nil 293 } 294 295 // ShouldPersistIntermediateState implements IntermediateStateConditionalPersister 296 func (sm *testPersistentStateThatRefusesToPersist) ShouldPersistIntermediateState(info *IntermediateStatePersistInfo) bool { 297 sm.CallLog = append(sm.CallLog, "ShouldPersistIntermediateState") 298 return info.ForcePersist 299 }