github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/backend/local/hook_state_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package local
     5  
     6  import (
     7  	"fmt"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/google/go-cmp/cmp"
    12  	"github.com/terramate-io/tf/states"
    13  	"github.com/terramate-io/tf/states/statemgr"
    14  	"github.com/terramate-io/tf/terraform"
    15  )
    16  
    17  func TestStateHook_impl(t *testing.T) {
    18  	var _ terraform.Hook = new(StateHook)
    19  }
    20  
    21  func TestStateHook(t *testing.T) {
    22  	is := statemgr.NewTransientInMemory(nil)
    23  	var hook terraform.Hook = &StateHook{StateMgr: is}
    24  
    25  	s := statemgr.TestFullInitialState()
    26  	action, err := hook.PostStateUpdate(s)
    27  	if err != nil {
    28  		t.Fatalf("err: %s", err)
    29  	}
    30  	if action != terraform.HookActionContinue {
    31  		t.Fatalf("bad: %v", action)
    32  	}
    33  	if !is.State().Equal(s) {
    34  		t.Fatalf("bad state: %#v", is.State())
    35  	}
    36  }
    37  
    38  func TestStateHookStopping(t *testing.T) {
    39  	is := &testPersistentState{}
    40  	hook := &StateHook{
    41  		StateMgr:        is,
    42  		Schemas:         &terraform.Schemas{},
    43  		PersistInterval: 4 * time.Hour,
    44  		intermediatePersist: IntermediateStatePersistInfo{
    45  			LastPersist: time.Now(),
    46  		},
    47  	}
    48  
    49  	s := statemgr.TestFullInitialState()
    50  	action, err := hook.PostStateUpdate(s)
    51  	if err != nil {
    52  		t.Fatalf("unexpected error from PostStateUpdate: %s", err)
    53  	}
    54  	if got, want := action, terraform.HookActionContinue; got != want {
    55  		t.Fatalf("wrong hookaction %#v; want %#v", got, want)
    56  	}
    57  	if is.Written == nil || !is.Written.Equal(s) {
    58  		t.Fatalf("mismatching state written")
    59  	}
    60  	if is.Persisted != nil {
    61  		t.Fatalf("persisted too soon")
    62  	}
    63  
    64  	// We'll now force lastPersist to be long enough ago that persisting
    65  	// should be due on the next call.
    66  	hook.intermediatePersist.LastPersist = time.Now().Add(-5 * time.Hour)
    67  	hook.PostStateUpdate(s)
    68  	if is.Written == nil || !is.Written.Equal(s) {
    69  		t.Fatalf("mismatching state written")
    70  	}
    71  	if is.Persisted == nil || !is.Persisted.Equal(s) {
    72  		t.Fatalf("mismatching state persisted")
    73  	}
    74  	hook.PostStateUpdate(s)
    75  	if is.Written == nil || !is.Written.Equal(s) {
    76  		t.Fatalf("mismatching state written")
    77  	}
    78  	if is.Persisted == nil || !is.Persisted.Equal(s) {
    79  		t.Fatalf("mismatching state persisted")
    80  	}
    81  
    82  	gotLog := is.CallLog
    83  	wantLog := []string{
    84  		// Initial call before we reset lastPersist
    85  		"WriteState",
    86  
    87  		// Write and then persist after we reset lastPersist
    88  		"WriteState",
    89  		"PersistState",
    90  
    91  		// Final call when persisting wasn't due yet.
    92  		"WriteState",
    93  	}
    94  	if diff := cmp.Diff(wantLog, gotLog); diff != "" {
    95  		t.Fatalf("wrong call log so far\n%s", diff)
    96  	}
    97  
    98  	// We'll reset the log now before we try seeing what happens after
    99  	// we use "Stopped".
   100  	is.CallLog = is.CallLog[:0]
   101  	is.Persisted = nil
   102  
   103  	hook.Stopping()
   104  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   105  		t.Fatalf("mismatching state persisted")
   106  	}
   107  
   108  	is.Persisted = nil
   109  	hook.PostStateUpdate(s)
   110  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   111  		t.Fatalf("mismatching state persisted")
   112  	}
   113  	is.Persisted = nil
   114  	hook.PostStateUpdate(s)
   115  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   116  		t.Fatalf("mismatching state persisted")
   117  	}
   118  
   119  	gotLog = is.CallLog
   120  	wantLog = []string{
   121  		// "Stopping" immediately persisted
   122  		"PersistState",
   123  
   124  		// PostStateUpdate then writes and persists on every call,
   125  		// on the assumption that we're now bailing out after
   126  		// being cancelled and trying to save as much state as we can.
   127  		"WriteState",
   128  		"PersistState",
   129  		"WriteState",
   130  		"PersistState",
   131  	}
   132  	if diff := cmp.Diff(wantLog, gotLog); diff != "" {
   133  		t.Fatalf("wrong call log once in stopping mode\n%s", diff)
   134  	}
   135  }
   136  
   137  func TestStateHookCustomPersistRule(t *testing.T) {
   138  	is := &testPersistentStateThatRefusesToPersist{}
   139  	hook := &StateHook{
   140  		StateMgr:        is,
   141  		Schemas:         &terraform.Schemas{},
   142  		PersistInterval: 4 * time.Hour,
   143  		intermediatePersist: IntermediateStatePersistInfo{
   144  			LastPersist: time.Now(),
   145  		},
   146  	}
   147  
   148  	s := statemgr.TestFullInitialState()
   149  	action, err := hook.PostStateUpdate(s)
   150  	if err != nil {
   151  		t.Fatalf("unexpected error from PostStateUpdate: %s", err)
   152  	}
   153  	if got, want := action, terraform.HookActionContinue; got != want {
   154  		t.Fatalf("wrong hookaction %#v; want %#v", got, want)
   155  	}
   156  	if is.Written == nil || !is.Written.Equal(s) {
   157  		t.Fatalf("mismatching state written")
   158  	}
   159  	if is.Persisted != nil {
   160  		t.Fatalf("persisted too soon")
   161  	}
   162  
   163  	// We'll now force lastPersist to be long enough ago that persisting
   164  	// should be due on the next call.
   165  	hook.intermediatePersist.LastPersist = time.Now().Add(-5 * time.Hour)
   166  	hook.PostStateUpdate(s)
   167  	if is.Written == nil || !is.Written.Equal(s) {
   168  		t.Fatalf("mismatching state written")
   169  	}
   170  	if is.Persisted != nil {
   171  		t.Fatalf("has a persisted state, but shouldn't")
   172  	}
   173  	hook.PostStateUpdate(s)
   174  	if is.Written == nil || !is.Written.Equal(s) {
   175  		t.Fatalf("mismatching state written")
   176  	}
   177  	if is.Persisted != nil {
   178  		t.Fatalf("has a persisted state, but shouldn't")
   179  	}
   180  
   181  	gotLog := is.CallLog
   182  	wantLog := []string{
   183  		// Initial call before we reset lastPersist
   184  		"WriteState",
   185  		"ShouldPersistIntermediateState",
   186  		// Previous call should return false, preventing a "PersistState" call
   187  
   188  		// Write and then decline to persist
   189  		"WriteState",
   190  		"ShouldPersistIntermediateState",
   191  		// Previous call should return false, preventing a "PersistState" call
   192  
   193  		// Final call before we start "stopping".
   194  		"WriteState",
   195  		"ShouldPersistIntermediateState",
   196  		// Previous call should return false, preventing a "PersistState" call
   197  	}
   198  	if diff := cmp.Diff(wantLog, gotLog); diff != "" {
   199  		t.Fatalf("wrong call log so far\n%s", diff)
   200  	}
   201  
   202  	// We'll reset the log now before we try seeing what happens after
   203  	// we use "Stopped".
   204  	is.CallLog = is.CallLog[:0]
   205  	is.Persisted = nil
   206  
   207  	hook.Stopping()
   208  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   209  		t.Fatalf("mismatching state persisted")
   210  	}
   211  
   212  	is.Persisted = nil
   213  	hook.PostStateUpdate(s)
   214  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   215  		t.Fatalf("mismatching state persisted")
   216  	}
   217  	is.Persisted = nil
   218  	hook.PostStateUpdate(s)
   219  	if is.Persisted == nil || !is.Persisted.Equal(s) {
   220  		t.Fatalf("mismatching state persisted")
   221  	}
   222  
   223  	gotLog = is.CallLog
   224  	wantLog = []string{
   225  		"ShouldPersistIntermediateState",
   226  		// Previous call should return true, allowing the following "PersistState" call
   227  		"PersistState",
   228  		"WriteState",
   229  		"ShouldPersistIntermediateState",
   230  		// Previous call should return true, allowing the following "PersistState" call
   231  		"PersistState",
   232  		"WriteState",
   233  		"ShouldPersistIntermediateState",
   234  		// Previous call should return true, allowing the following "PersistState" call
   235  		"PersistState",
   236  	}
   237  	if diff := cmp.Diff(wantLog, gotLog); diff != "" {
   238  		t.Fatalf("wrong call log once in stopping mode\n%s", diff)
   239  	}
   240  }
   241  
   242  type testPersistentState struct {
   243  	CallLog []string
   244  
   245  	Written   *states.State
   246  	Persisted *states.State
   247  }
   248  
   249  var _ statemgr.Writer = (*testPersistentState)(nil)
   250  var _ statemgr.Persister = (*testPersistentState)(nil)
   251  
   252  func (sm *testPersistentState) WriteState(state *states.State) error {
   253  	sm.CallLog = append(sm.CallLog, "WriteState")
   254  	sm.Written = state
   255  	return nil
   256  }
   257  
   258  func (sm *testPersistentState) PersistState(schemas *terraform.Schemas) error {
   259  	if schemas == nil {
   260  		return fmt.Errorf("no schemas")
   261  	}
   262  	sm.CallLog = append(sm.CallLog, "PersistState")
   263  	sm.Persisted = sm.Written
   264  	return nil
   265  }
   266  
   267  type testPersistentStateThatRefusesToPersist struct {
   268  	CallLog []string
   269  
   270  	Written   *states.State
   271  	Persisted *states.State
   272  }
   273  
   274  var _ statemgr.Writer = (*testPersistentStateThatRefusesToPersist)(nil)
   275  var _ statemgr.Persister = (*testPersistentStateThatRefusesToPersist)(nil)
   276  var _ IntermediateStateConditionalPersister = (*testPersistentStateThatRefusesToPersist)(nil)
   277  
   278  func (sm *testPersistentStateThatRefusesToPersist) WriteState(state *states.State) error {
   279  	sm.CallLog = append(sm.CallLog, "WriteState")
   280  	sm.Written = state
   281  	return nil
   282  }
   283  
   284  func (sm *testPersistentStateThatRefusesToPersist) PersistState(schemas *terraform.Schemas) error {
   285  	if schemas == nil {
   286  		return fmt.Errorf("no schemas")
   287  	}
   288  	sm.CallLog = append(sm.CallLog, "PersistState")
   289  	sm.Persisted = sm.Written
   290  	return nil
   291  }
   292  
   293  // ShouldPersistIntermediateState implements IntermediateStateConditionalPersister
   294  func (sm *testPersistentStateThatRefusesToPersist) ShouldPersistIntermediateState(info *IntermediateStatePersistInfo) bool {
   295  	sm.CallLog = append(sm.CallLog, "ShouldPersistIntermediateState")
   296  	return info.ForcePersist
   297  }