github.com/opentofu/opentofu@v1.7.1/internal/backend/local/hook_state_test.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package local
     7  
     8  import (
     9  	"fmt"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/google/go-cmp/cmp"
    14  	"github.com/opentofu/opentofu/internal/states"
    15  	"github.com/opentofu/opentofu/internal/states/statemgr"
    16  	"github.com/opentofu/opentofu/internal/tofu"
    17  )
    18  
    19  func TestStateHook_impl(t *testing.T) {
    20  	var _ tofu.Hook = new(StateHook)
    21  }
    22  
    23  func TestStateHook(t *testing.T) {
    24  	is := statemgr.NewTransientInMemory(nil)
    25  	var hook tofu.Hook = &StateHook{StateMgr: is}
    26  
    27  	s := statemgr.TestFullInitialState()
    28  	action, err := hook.PostStateUpdate(s)
    29  	if err != nil {
    30  		t.Fatalf("err: %s", err)
    31  	}
    32  	if action != tofu.HookActionContinue {
    33  		t.Fatalf("bad: %v", action)
    34  	}
    35  	if !is.State().Equal(s) {
    36  		t.Fatalf("bad state: %#v", is.State())
    37  	}
    38  }
    39  
    40  func TestStateHookStopping(t *testing.T) {
    41  	is := &testPersistentState{}
    42  	hook := &StateHook{
    43  		StateMgr:        is,
    44  		Schemas:         &tofu.Schemas{},
    45  		PersistInterval: 4 * time.Hour,
    46  		intermediatePersist: IntermediateStatePersistInfo{
    47  			LastPersist: time.Now(),
    48  		},
    49  	}
    50  
    51  	s := statemgr.TestFullInitialState()
    52  	action, err := hook.PostStateUpdate(s)
    53  	if err != nil {
    54  		t.Fatalf("unexpected error from PostStateUpdate: %s", err)
    55  	}
    56  	if got, want := action, tofu.HookActionContinue; got != want {
    57  		t.Fatalf("wrong hookaction %#v; want %#v", got, want)
    58  	}
    59  	if is.Written == nil || !is.Written.Equal(s) {
    60  		t.Fatalf("mismatching state written")
    61  	}
    62  	if is.Persisted != nil {
    63  		t.Fatalf("persisted too soon")
    64  	}
    65  
    66  	// We'll now force lastPersist to be long enough ago that persisting
    67  	// should be due on the next call.
    68  	hook.intermediatePersist.LastPersist = time.Now().Add(-5 * time.Hour)
    69  	hook.PostStateUpdate(s)
    70  	if is.Written == nil || !is.Written.Equal(s) {
    71  		t.Fatalf("mismatching state written")
    72  	}
    73  	if is.Persisted == nil || !is.Persisted.Equal(s) {
    74  		t.Fatalf("mismatching state persisted")
    75  	}
    76  	hook.PostStateUpdate(s)
    77  	if is.Written == nil || !is.Written.Equal(s) {
    78  		t.Fatalf("mismatching state written")
    79  	}
    80  	if is.Persisted == nil || !is.Persisted.Equal(s) {
    81  		t.Fatalf("mismatching state persisted")
    82  	}
    83  
    84  	gotLog := is.CallLog
    85  	wantLog := []string{
    86  		// Initial call before we reset lastPersist
    87  		"WriteState",
    88  
    89  		// Write and then persist after we reset lastPersist
    90  		"WriteState",
    91  		"PersistState",
    92  
    93  		// Final call when persisting wasn't due yet.
    94  		"WriteState",
    95  	}
    96  	if diff := cmp.Diff(wantLog, gotLog); diff != "" {
    97  		t.Fatalf("wrong call log so far\n%s", diff)
    98  	}
    99  
   100  	// We'll reset the log now before we try seeing what happens after
   101  	// we use "Stopped".
   102  	is.CallLog = is.CallLog[:0]
   103  	is.Persisted = nil
   104  
   105  	hook.Stopping()
   106  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   107  		t.Fatalf("mismatching state persisted")
   108  	}
   109  
   110  	is.Persisted = nil
   111  	hook.PostStateUpdate(s)
   112  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   113  		t.Fatalf("mismatching state persisted")
   114  	}
   115  	is.Persisted = nil
   116  	hook.PostStateUpdate(s)
   117  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   118  		t.Fatalf("mismatching state persisted")
   119  	}
   120  
   121  	gotLog = is.CallLog
   122  	wantLog = []string{
   123  		// "Stopping" immediately persisted
   124  		"PersistState",
   125  
   126  		// PostStateUpdate then writes and persists on every call,
   127  		// on the assumption that we're now bailing out after
   128  		// being cancelled and trying to save as much state as we can.
   129  		"WriteState",
   130  		"PersistState",
   131  		"WriteState",
   132  		"PersistState",
   133  	}
   134  	if diff := cmp.Diff(wantLog, gotLog); diff != "" {
   135  		t.Fatalf("wrong call log once in stopping mode\n%s", diff)
   136  	}
   137  }
   138  
   139  func TestStateHookCustomPersistRule(t *testing.T) {
   140  	is := &testPersistentStateThatRefusesToPersist{}
   141  	hook := &StateHook{
   142  		StateMgr:        is,
   143  		Schemas:         &tofu.Schemas{},
   144  		PersistInterval: 4 * time.Hour,
   145  		intermediatePersist: IntermediateStatePersistInfo{
   146  			LastPersist: time.Now(),
   147  		},
   148  	}
   149  
   150  	s := statemgr.TestFullInitialState()
   151  	action, err := hook.PostStateUpdate(s)
   152  	if err != nil {
   153  		t.Fatalf("unexpected error from PostStateUpdate: %s", err)
   154  	}
   155  	if got, want := action, tofu.HookActionContinue; got != want {
   156  		t.Fatalf("wrong hookaction %#v; want %#v", got, want)
   157  	}
   158  	if is.Written == nil || !is.Written.Equal(s) {
   159  		t.Fatalf("mismatching state written")
   160  	}
   161  	if is.Persisted != nil {
   162  		t.Fatalf("persisted too soon")
   163  	}
   164  
   165  	// We'll now force lastPersist to be long enough ago that persisting
   166  	// should be due on the next call.
   167  	hook.intermediatePersist.LastPersist = time.Now().Add(-5 * time.Hour)
   168  	hook.PostStateUpdate(s)
   169  	if is.Written == nil || !is.Written.Equal(s) {
   170  		t.Fatalf("mismatching state written")
   171  	}
   172  	if is.Persisted != nil {
   173  		t.Fatalf("has a persisted state, but shouldn't")
   174  	}
   175  	hook.PostStateUpdate(s)
   176  	if is.Written == nil || !is.Written.Equal(s) {
   177  		t.Fatalf("mismatching state written")
   178  	}
   179  	if is.Persisted != nil {
   180  		t.Fatalf("has a persisted state, but shouldn't")
   181  	}
   182  
   183  	gotLog := is.CallLog
   184  	wantLog := []string{
   185  		// Initial call before we reset lastPersist
   186  		"WriteState",
   187  		"ShouldPersistIntermediateState",
   188  		// Previous call should return false, preventing a "PersistState" call
   189  
   190  		// Write and then decline to persist
   191  		"WriteState",
   192  		"ShouldPersistIntermediateState",
   193  		// Previous call should return false, preventing a "PersistState" call
   194  
   195  		// Final call before we start "stopping".
   196  		"WriteState",
   197  		"ShouldPersistIntermediateState",
   198  		// Previous call should return false, preventing a "PersistState" call
   199  	}
   200  	if diff := cmp.Diff(wantLog, gotLog); diff != "" {
   201  		t.Fatalf("wrong call log so far\n%s", diff)
   202  	}
   203  
   204  	// We'll reset the log now before we try seeing what happens after
   205  	// we use "Stopped".
   206  	is.CallLog = is.CallLog[:0]
   207  	is.Persisted = nil
   208  
   209  	hook.Stopping()
   210  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   211  		t.Fatalf("mismatching state persisted")
   212  	}
   213  
   214  	is.Persisted = nil
   215  	hook.PostStateUpdate(s)
   216  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   217  		t.Fatalf("mismatching state persisted")
   218  	}
   219  	is.Persisted = nil
   220  	hook.PostStateUpdate(s)
   221  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   222  		t.Fatalf("mismatching state persisted")
   223  	}
   224  
   225  	gotLog = is.CallLog
   226  	wantLog = []string{
   227  		"ShouldPersistIntermediateState",
   228  		// Previous call should return true, allowing the following "PersistState" call
   229  		"PersistState",
   230  		"WriteState",
   231  		"ShouldPersistIntermediateState",
   232  		// Previous call should return true, allowing the following "PersistState" call
   233  		"PersistState",
   234  		"WriteState",
   235  		"ShouldPersistIntermediateState",
   236  		// Previous call should return true, allowing the following "PersistState" call
   237  		"PersistState",
   238  	}
   239  	if diff := cmp.Diff(wantLog, gotLog); diff != "" {
   240  		t.Fatalf("wrong call log once in stopping mode\n%s", diff)
   241  	}
   242  }
   243  
   244  type testPersistentState struct {
   245  	CallLog []string
   246  
   247  	Written   *states.State
   248  	Persisted *states.State
   249  }
   250  
   251  var _ statemgr.Writer = (*testPersistentState)(nil)
   252  var _ statemgr.Persister = (*testPersistentState)(nil)
   253  
   254  func (sm *testPersistentState) WriteState(state *states.State) error {
   255  	sm.CallLog = append(sm.CallLog, "WriteState")
   256  	sm.Written = state
   257  	return nil
   258  }
   259  
   260  func (sm *testPersistentState) PersistState(schemas *tofu.Schemas) error {
   261  	if schemas == nil {
   262  		return fmt.Errorf("no schemas")
   263  	}
   264  	sm.CallLog = append(sm.CallLog, "PersistState")
   265  	sm.Persisted = sm.Written
   266  	return nil
   267  }
   268  
   269  type testPersistentStateThatRefusesToPersist struct {
   270  	CallLog []string
   271  
   272  	Written   *states.State
   273  	Persisted *states.State
   274  }
   275  
   276  var _ statemgr.Writer = (*testPersistentStateThatRefusesToPersist)(nil)
   277  var _ statemgr.Persister = (*testPersistentStateThatRefusesToPersist)(nil)
   278  var _ IntermediateStateConditionalPersister = (*testPersistentStateThatRefusesToPersist)(nil)
   279  
   280  func (sm *testPersistentStateThatRefusesToPersist) WriteState(state *states.State) error {
   281  	sm.CallLog = append(sm.CallLog, "WriteState")
   282  	sm.Written = state
   283  	return nil
   284  }
   285  
   286  func (sm *testPersistentStateThatRefusesToPersist) PersistState(schemas *tofu.Schemas) error {
   287  	if schemas == nil {
   288  		return fmt.Errorf("no schemas")
   289  	}
   290  	sm.CallLog = append(sm.CallLog, "PersistState")
   291  	sm.Persisted = sm.Written
   292  	return nil
   293  }
   294  
   295  // ShouldPersistIntermediateState implements IntermediateStateConditionalPersister
   296  func (sm *testPersistentStateThatRefusesToPersist) ShouldPersistIntermediateState(info *IntermediateStatePersistInfo) bool {
   297  	sm.CallLog = append(sm.CallLog, "ShouldPersistIntermediateState")
   298  	return info.ForcePersist
   299  }