github.com/sylr/terraform@v0.11.12-beta1/helper/resource/testing_test.go (about)

     1  package resource
     2  
     3  import (
     4  	"errors"
     5  	"flag"
     6  	"fmt"
     7  	"os"
     8  	"reflect"
     9  	"regexp"
    10  	"sort"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  	"testing"
    15  
    16  	"github.com/hashicorp/go-multierror"
    17  	"github.com/hashicorp/terraform/terraform"
    18  )
    19  
    20  func init() {
    21  	testTesting = true
    22  
    23  	// TODO: Remove when we remove the guard on id checks
    24  	if err := os.Setenv("TF_ACC_IDONLY", "1"); err != nil {
    25  		panic(err)
    26  	}
    27  
    28  	if err := os.Setenv(TestEnvVar, "1"); err != nil {
    29  		panic(err)
    30  	}
    31  }
    32  
    33  // wrap the mock provider to implement TestProvider
    34  type resetProvider struct {
    35  	*terraform.MockResourceProvider
    36  	mu              sync.Mutex
    37  	TestResetCalled bool
    38  	TestResetError  error
    39  }
    40  
    41  func (p *resetProvider) TestReset() error {
    42  	p.mu.Lock()
    43  	defer p.mu.Unlock()
    44  	p.TestResetCalled = true
    45  	return p.TestResetError
    46  }
    47  
    48  func TestParallelTest(t *testing.T) {
    49  	mt := new(mockT)
    50  	ParallelTest(mt, TestCase{})
    51  
    52  	if !mt.ParallelCalled {
    53  		t.Fatal("Parallel() not called")
    54  	}
    55  }
    56  
    57  func TestTest(t *testing.T) {
    58  	mp := &resetProvider{
    59  		MockResourceProvider: testProvider(),
    60  	}
    61  
    62  	mp.DiffReturn = nil
    63  
    64  	mp.ApplyFn = func(
    65  		info *terraform.InstanceInfo,
    66  		state *terraform.InstanceState,
    67  		diff *terraform.InstanceDiff) (*terraform.InstanceState, error) {
    68  		if !diff.Destroy {
    69  			return &terraform.InstanceState{
    70  				ID: "foo",
    71  			}, nil
    72  		}
    73  
    74  		return nil, nil
    75  	}
    76  
    77  	var refreshCount int32
    78  	mp.RefreshFn = func(*terraform.InstanceInfo, *terraform.InstanceState) (*terraform.InstanceState, error) {
    79  		atomic.AddInt32(&refreshCount, 1)
    80  		return &terraform.InstanceState{ID: "foo"}, nil
    81  	}
    82  
    83  	checkDestroy := false
    84  	checkStep := false
    85  
    86  	checkDestroyFn := func(*terraform.State) error {
    87  		checkDestroy = true
    88  		return nil
    89  	}
    90  
    91  	checkStepFn := func(s *terraform.State) error {
    92  		checkStep = true
    93  
    94  		rs, ok := s.RootModule().Resources["test_instance.foo"]
    95  		if !ok {
    96  			t.Error("test_instance.foo is not present")
    97  			return nil
    98  		}
    99  		is := rs.Primary
   100  		if is.ID != "foo" {
   101  			t.Errorf("bad check ID: %s", is.ID)
   102  		}
   103  
   104  		return nil
   105  	}
   106  
   107  	mt := new(mockT)
   108  	Test(mt, TestCase{
   109  		Providers: map[string]terraform.ResourceProvider{
   110  			"test": mp,
   111  		},
   112  		CheckDestroy: checkDestroyFn,
   113  		Steps: []TestStep{
   114  			TestStep{
   115  				Config: testConfigStr,
   116  				Check:  checkStepFn,
   117  			},
   118  		},
   119  	})
   120  
   121  	if mt.failed() {
   122  		t.Fatalf("test failed: %s", mt.failMessage())
   123  	}
   124  	if mt.ParallelCalled {
   125  		t.Fatal("Parallel() called")
   126  	}
   127  	if !checkStep {
   128  		t.Fatal("didn't call check for step")
   129  	}
   130  	if !checkDestroy {
   131  		t.Fatal("didn't call check for destroy")
   132  	}
   133  	if !mp.TestResetCalled {
   134  		t.Fatal("didn't call TestReset")
   135  	}
   136  }
   137  
   138  func TestTest_plan_only(t *testing.T) {
   139  	mp := testProvider()
   140  	mp.ApplyReturn = &terraform.InstanceState{
   141  		ID: "foo",
   142  	}
   143  
   144  	checkDestroy := false
   145  
   146  	checkDestroyFn := func(*terraform.State) error {
   147  		checkDestroy = true
   148  		return nil
   149  	}
   150  
   151  	mt := new(mockT)
   152  	Test(mt, TestCase{
   153  		Providers: map[string]terraform.ResourceProvider{
   154  			"test": mp,
   155  		},
   156  		CheckDestroy: checkDestroyFn,
   157  		Steps: []TestStep{
   158  			TestStep{
   159  				Config:             testConfigStr,
   160  				PlanOnly:           true,
   161  				ExpectNonEmptyPlan: false,
   162  			},
   163  		},
   164  	})
   165  
   166  	if !mt.failed() {
   167  		t.Fatal("test should've failed")
   168  	}
   169  
   170  	expected := `Step 0 error: After applying this step, the plan was not empty:
   171  
   172  DIFF:
   173  
   174  CREATE: test_instance.foo
   175    foo: "" => "bar"
   176  
   177  STATE:
   178  
   179  <no state>`
   180  
   181  	if mt.failMessage() != expected {
   182  		t.Fatalf("Expected message: %s\n\ngot:\n\n%s", expected, mt.failMessage())
   183  	}
   184  
   185  	if !checkDestroy {
   186  		t.Fatal("didn't call check for destroy")
   187  	}
   188  }
   189  
   190  func TestTest_idRefresh(t *testing.T) {
   191  	// Refresh count should be 3:
   192  	//   1.) initial Ref/Plan/Apply
   193  	//   2.) post Ref/Plan/Apply for plan-check
   194  	//   3.) id refresh check
   195  	var expectedRefresh int32 = 3
   196  
   197  	mp := testProvider()
   198  	mp.DiffReturn = nil
   199  
   200  	mp.ApplyFn = func(
   201  		info *terraform.InstanceInfo,
   202  		state *terraform.InstanceState,
   203  		diff *terraform.InstanceDiff) (*terraform.InstanceState, error) {
   204  		if !diff.Destroy {
   205  			return &terraform.InstanceState{
   206  				ID: "foo",
   207  			}, nil
   208  		}
   209  
   210  		return nil, nil
   211  	}
   212  
   213  	var refreshCount int32
   214  	mp.RefreshFn = func(*terraform.InstanceInfo, *terraform.InstanceState) (*terraform.InstanceState, error) {
   215  		atomic.AddInt32(&refreshCount, 1)
   216  		return &terraform.InstanceState{ID: "foo"}, nil
   217  	}
   218  
   219  	mt := new(mockT)
   220  	Test(mt, TestCase{
   221  		IDRefreshName: "test_instance.foo",
   222  		Providers: map[string]terraform.ResourceProvider{
   223  			"test": mp,
   224  		},
   225  		Steps: []TestStep{
   226  			TestStep{
   227  				Config: testConfigStr,
   228  			},
   229  		},
   230  	})
   231  
   232  	if mt.failed() {
   233  		t.Fatalf("test failed: %s", mt.failMessage())
   234  	}
   235  
   236  	// See declaration of expectedRefresh for why that number
   237  	if refreshCount != expectedRefresh {
   238  		t.Fatalf("bad refresh count: %d", refreshCount)
   239  	}
   240  }
   241  
   242  func TestTest_idRefreshCustomName(t *testing.T) {
   243  	// Refresh count should be 3:
   244  	//   1.) initial Ref/Plan/Apply
   245  	//   2.) post Ref/Plan/Apply for plan-check
   246  	//   3.) id refresh check
   247  	var expectedRefresh int32 = 3
   248  
   249  	mp := testProvider()
   250  	mp.DiffReturn = nil
   251  
   252  	mp.ApplyFn = func(
   253  		info *terraform.InstanceInfo,
   254  		state *terraform.InstanceState,
   255  		diff *terraform.InstanceDiff) (*terraform.InstanceState, error) {
   256  		if !diff.Destroy {
   257  			return &terraform.InstanceState{
   258  				ID: "foo",
   259  			}, nil
   260  		}
   261  
   262  		return nil, nil
   263  	}
   264  
   265  	var refreshCount int32
   266  	mp.RefreshFn = func(*terraform.InstanceInfo, *terraform.InstanceState) (*terraform.InstanceState, error) {
   267  		atomic.AddInt32(&refreshCount, 1)
   268  		return &terraform.InstanceState{ID: "foo"}, nil
   269  	}
   270  
   271  	mt := new(mockT)
   272  	Test(mt, TestCase{
   273  		IDRefreshName: "test_instance.foo",
   274  		Providers: map[string]terraform.ResourceProvider{
   275  			"test": mp,
   276  		},
   277  		Steps: []TestStep{
   278  			TestStep{
   279  				Config: testConfigStr,
   280  			},
   281  		},
   282  	})
   283  
   284  	if mt.failed() {
   285  		t.Fatalf("test failed: %s", mt.failMessage())
   286  	}
   287  
   288  	// See declaration of expectedRefresh for why that number
   289  	if refreshCount != expectedRefresh {
   290  		t.Fatalf("bad refresh count: %d", refreshCount)
   291  	}
   292  }
   293  
   294  func TestTest_idRefreshFail(t *testing.T) {
   295  	// Refresh count should be 3:
   296  	//   1.) initial Ref/Plan/Apply
   297  	//   2.) post Ref/Plan/Apply for plan-check
   298  	//   3.) id refresh check
   299  	var expectedRefresh int32 = 3
   300  
   301  	mp := testProvider()
   302  	mp.DiffReturn = nil
   303  
   304  	mp.ApplyFn = func(
   305  		info *terraform.InstanceInfo,
   306  		state *terraform.InstanceState,
   307  		diff *terraform.InstanceDiff) (*terraform.InstanceState, error) {
   308  		if !diff.Destroy {
   309  			return &terraform.InstanceState{
   310  				ID: "foo",
   311  			}, nil
   312  		}
   313  
   314  		return nil, nil
   315  	}
   316  
   317  	var refreshCount int32
   318  	mp.RefreshFn = func(*terraform.InstanceInfo, *terraform.InstanceState) (*terraform.InstanceState, error) {
   319  		atomic.AddInt32(&refreshCount, 1)
   320  		if atomic.LoadInt32(&refreshCount) == expectedRefresh-1 {
   321  			return &terraform.InstanceState{
   322  				ID:         "foo",
   323  				Attributes: map[string]string{"foo": "bar"},
   324  			}, nil
   325  		} else if atomic.LoadInt32(&refreshCount) < expectedRefresh {
   326  			return &terraform.InstanceState{ID: "foo"}, nil
   327  		} else {
   328  			return nil, nil
   329  		}
   330  	}
   331  
   332  	mt := new(mockT)
   333  	Test(mt, TestCase{
   334  		IDRefreshName: "test_instance.foo",
   335  		Providers: map[string]terraform.ResourceProvider{
   336  			"test": mp,
   337  		},
   338  		Steps: []TestStep{
   339  			TestStep{
   340  				Config: testConfigStr,
   341  			},
   342  		},
   343  	})
   344  
   345  	if !mt.failed() {
   346  		t.Fatal("test didn't fail")
   347  	}
   348  	t.Logf("failure reason: %s", mt.failMessage())
   349  
   350  	// See declaration of expectedRefresh for why that number
   351  	if refreshCount != expectedRefresh {
   352  		t.Fatalf("bad refresh count: %d", refreshCount)
   353  	}
   354  }
   355  
   356  func TestTest_empty(t *testing.T) {
   357  	destroyCalled := false
   358  	checkDestroyFn := func(*terraform.State) error {
   359  		destroyCalled = true
   360  		return nil
   361  	}
   362  
   363  	mt := new(mockT)
   364  	Test(mt, TestCase{
   365  		CheckDestroy: checkDestroyFn,
   366  	})
   367  
   368  	if mt.failed() {
   369  		t.Fatal("test failed")
   370  	}
   371  	if destroyCalled {
   372  		t.Fatal("should not call check destroy if there is no steps")
   373  	}
   374  }
   375  
   376  func TestTest_noEnv(t *testing.T) {
   377  	// Unset the variable
   378  	if err := os.Setenv(TestEnvVar, ""); err != nil {
   379  		t.Fatalf("err: %s", err)
   380  	}
   381  	defer os.Setenv(TestEnvVar, "1")
   382  
   383  	mt := new(mockT)
   384  	Test(mt, TestCase{})
   385  
   386  	if !mt.SkipCalled {
   387  		t.Fatal("skip not called")
   388  	}
   389  }
   390  
   391  func TestTest_preCheck(t *testing.T) {
   392  	called := false
   393  
   394  	mt := new(mockT)
   395  	Test(mt, TestCase{
   396  		PreCheck: func() { called = true },
   397  	})
   398  
   399  	if !called {
   400  		t.Fatal("precheck should be called")
   401  	}
   402  }
   403  
   404  func TestTest_skipFunc(t *testing.T) {
   405  	preCheckCalled := false
   406  	skipped := false
   407  
   408  	mp := testProvider()
   409  	mp.ApplyReturn = &terraform.InstanceState{
   410  		ID: "foo",
   411  	}
   412  
   413  	checkStepFn := func(*terraform.State) error {
   414  		return fmt.Errorf("error")
   415  	}
   416  
   417  	mt := new(mockT)
   418  	Test(mt, TestCase{
   419  		Providers: map[string]terraform.ResourceProvider{
   420  			"test": mp,
   421  		},
   422  		PreCheck: func() { preCheckCalled = true },
   423  		Steps: []TestStep{
   424  			{
   425  				Config:   testConfigStr,
   426  				Check:    checkStepFn,
   427  				SkipFunc: func() (bool, error) { skipped = true; return true, nil },
   428  			},
   429  		},
   430  	})
   431  
   432  	if mt.failed() {
   433  		t.Fatal("Expected check to be skipped")
   434  	}
   435  
   436  	if !preCheckCalled {
   437  		t.Fatal("precheck should be called")
   438  	}
   439  	if !skipped {
   440  		t.Fatal("SkipFunc should be called")
   441  	}
   442  }
   443  
   444  func TestTest_stepError(t *testing.T) {
   445  	mp := testProvider()
   446  	mp.ApplyReturn = &terraform.InstanceState{
   447  		ID: "foo",
   448  	}
   449  
   450  	checkDestroy := false
   451  
   452  	checkDestroyFn := func(*terraform.State) error {
   453  		checkDestroy = true
   454  		return nil
   455  	}
   456  
   457  	checkStepFn := func(*terraform.State) error {
   458  		return fmt.Errorf("error")
   459  	}
   460  
   461  	mt := new(mockT)
   462  	Test(mt, TestCase{
   463  		Providers: map[string]terraform.ResourceProvider{
   464  			"test": mp,
   465  		},
   466  		CheckDestroy: checkDestroyFn,
   467  		Steps: []TestStep{
   468  			TestStep{
   469  				Config: testConfigStr,
   470  				Check:  checkStepFn,
   471  			},
   472  		},
   473  	})
   474  
   475  	if !mt.failed() {
   476  		t.Fatal("test should've failed")
   477  	}
   478  	expected := "Step 0 error: Check failed: error"
   479  	if mt.failMessage() != expected {
   480  		t.Fatalf("Expected message: %s\n\ngot:\n\n%s", expected, mt.failMessage())
   481  	}
   482  
   483  	if !checkDestroy {
   484  		t.Fatal("didn't call check for destroy")
   485  	}
   486  }
   487  
   488  func TestTest_factoryError(t *testing.T) {
   489  	resourceFactoryError := fmt.Errorf("resource factory error")
   490  
   491  	factory := func() (terraform.ResourceProvider, error) {
   492  		return nil, resourceFactoryError
   493  	}
   494  
   495  	mt := new(mockT)
   496  	Test(mt, TestCase{
   497  		ProviderFactories: map[string]terraform.ResourceProviderFactory{
   498  			"test": factory,
   499  		},
   500  		Steps: []TestStep{
   501  			TestStep{
   502  				ExpectError: regexp.MustCompile("resource factory error"),
   503  			},
   504  		},
   505  	})
   506  
   507  	if !mt.failed() {
   508  		t.Fatal("test should've failed")
   509  	}
   510  }
   511  
   512  func TestTest_resetError(t *testing.T) {
   513  	mp := &resetProvider{
   514  		MockResourceProvider: testProvider(),
   515  		TestResetError:       fmt.Errorf("provider reset error"),
   516  	}
   517  
   518  	mt := new(mockT)
   519  	Test(mt, TestCase{
   520  		Providers: map[string]terraform.ResourceProvider{
   521  			"test": mp,
   522  		},
   523  		Steps: []TestStep{
   524  			TestStep{
   525  				ExpectError: regexp.MustCompile("provider reset error"),
   526  			},
   527  		},
   528  	})
   529  
   530  	if !mt.failed() {
   531  		t.Fatal("test should've failed")
   532  	}
   533  }
   534  
   535  func TestTest_expectError(t *testing.T) {
   536  	cases := []struct {
   537  		name     string
   538  		planErr  bool
   539  		applyErr bool
   540  		badErr   bool
   541  	}{
   542  		{
   543  			name:     "successful apply",
   544  			planErr:  false,
   545  			applyErr: false,
   546  		},
   547  		{
   548  			name:     "bad plan",
   549  			planErr:  true,
   550  			applyErr: false,
   551  		},
   552  		{
   553  			name:     "bad apply",
   554  			planErr:  false,
   555  			applyErr: true,
   556  		},
   557  		{
   558  			name:     "bad plan, bad err",
   559  			planErr:  true,
   560  			applyErr: false,
   561  			badErr:   true,
   562  		},
   563  		{
   564  			name:     "bad apply, bad err",
   565  			planErr:  false,
   566  			applyErr: true,
   567  			badErr:   true,
   568  		},
   569  	}
   570  
   571  	for _, tc := range cases {
   572  		t.Run(tc.name, func(t *testing.T) {
   573  			mp := testProvider()
   574  			expectedText := "test provider error"
   575  			var errText string
   576  			if tc.badErr {
   577  				errText = "wrong provider error"
   578  			} else {
   579  				errText = expectedText
   580  			}
   581  			noErrText := "no error received, but expected a match to"
   582  			if tc.planErr {
   583  				mp.DiffReturnError = errors.New(errText)
   584  			}
   585  			if tc.applyErr {
   586  				mp.ApplyReturnError = errors.New(errText)
   587  			}
   588  			mt := new(mockT)
   589  			Test(mt, TestCase{
   590  				Providers: map[string]terraform.ResourceProvider{
   591  					"test": mp,
   592  				},
   593  				Steps: []TestStep{
   594  					TestStep{
   595  						Config:             testConfigStr,
   596  						ExpectError:        regexp.MustCompile(expectedText),
   597  						Check:              func(*terraform.State) error { return nil },
   598  						ExpectNonEmptyPlan: true,
   599  					},
   600  				},
   601  			},
   602  			)
   603  			if mt.FatalCalled {
   604  				t.Fatalf("fatal: %+v", mt.FatalArgs)
   605  			}
   606  			switch {
   607  			case len(mt.ErrorArgs) < 1 && !tc.planErr && !tc.applyErr:
   608  				t.Fatalf("expected error, got none")
   609  			case !tc.planErr && !tc.applyErr:
   610  				for _, e := range mt.ErrorArgs {
   611  					if regexp.MustCompile(noErrText).MatchString(fmt.Sprintf("%v", e)) {
   612  						return
   613  					}
   614  				}
   615  				t.Fatalf("expected error to match %s, got %+v", noErrText, mt.ErrorArgs)
   616  			case tc.badErr:
   617  				for _, e := range mt.ErrorArgs {
   618  					if regexp.MustCompile(expectedText).MatchString(fmt.Sprintf("%v", e)) {
   619  						return
   620  					}
   621  				}
   622  				t.Fatalf("expected error to match %s, got %+v", expectedText, mt.ErrorArgs)
   623  			}
   624  		})
   625  	}
   626  }
   627  
   628  func TestComposeAggregateTestCheckFunc(t *testing.T) {
   629  	check1 := func(s *terraform.State) error {
   630  		return errors.New("Error 1")
   631  	}
   632  
   633  	check2 := func(s *terraform.State) error {
   634  		return errors.New("Error 2")
   635  	}
   636  
   637  	f := ComposeAggregateTestCheckFunc(check1, check2)
   638  	err := f(nil)
   639  	if err == nil {
   640  		t.Fatalf("Expected errors")
   641  	}
   642  
   643  	multi := err.(*multierror.Error)
   644  	if !strings.Contains(multi.Errors[0].Error(), "Error 1") {
   645  		t.Fatalf("Expected Error 1, Got %s", multi.Errors[0])
   646  	}
   647  	if !strings.Contains(multi.Errors[1].Error(), "Error 2") {
   648  		t.Fatalf("Expected Error 2, Got %s", multi.Errors[1])
   649  	}
   650  }
   651  
   652  func TestComposeTestCheckFunc(t *testing.T) {
   653  	cases := []struct {
   654  		F      []TestCheckFunc
   655  		Result string
   656  	}{
   657  		{
   658  			F: []TestCheckFunc{
   659  				func(*terraform.State) error { return nil },
   660  			},
   661  			Result: "",
   662  		},
   663  
   664  		{
   665  			F: []TestCheckFunc{
   666  				func(*terraform.State) error {
   667  					return fmt.Errorf("error")
   668  				},
   669  				func(*terraform.State) error { return nil },
   670  			},
   671  			Result: "Check 1/2 error: error",
   672  		},
   673  
   674  		{
   675  			F: []TestCheckFunc{
   676  				func(*terraform.State) error { return nil },
   677  				func(*terraform.State) error {
   678  					return fmt.Errorf("error")
   679  				},
   680  			},
   681  			Result: "Check 2/2 error: error",
   682  		},
   683  
   684  		{
   685  			F: []TestCheckFunc{
   686  				func(*terraform.State) error { return nil },
   687  				func(*terraform.State) error { return nil },
   688  			},
   689  			Result: "",
   690  		},
   691  	}
   692  
   693  	for i, tc := range cases {
   694  		f := ComposeTestCheckFunc(tc.F...)
   695  		err := f(nil)
   696  		if err == nil {
   697  			err = fmt.Errorf("")
   698  		}
   699  		if tc.Result != err.Error() {
   700  			t.Fatalf("Case %d bad: %s", i, err)
   701  		}
   702  	}
   703  }
   704  
   705  // mockT implements TestT for testing
   706  type mockT struct {
   707  	ErrorCalled    bool
   708  	ErrorArgs      []interface{}
   709  	FatalCalled    bool
   710  	FatalArgs      []interface{}
   711  	ParallelCalled bool
   712  	SkipCalled     bool
   713  	SkipArgs       []interface{}
   714  
   715  	f bool
   716  }
   717  
   718  func (t *mockT) Error(args ...interface{}) {
   719  	t.ErrorCalled = true
   720  	t.ErrorArgs = args
   721  	t.f = true
   722  }
   723  
   724  func (t *mockT) Fatal(args ...interface{}) {
   725  	t.FatalCalled = true
   726  	t.FatalArgs = args
   727  	t.f = true
   728  }
   729  
   730  func (t *mockT) Parallel() {
   731  	t.ParallelCalled = true
   732  }
   733  
   734  func (t *mockT) Skip(args ...interface{}) {
   735  	t.SkipCalled = true
   736  	t.SkipArgs = args
   737  	t.f = true
   738  }
   739  
   740  func (t *mockT) Name() string {
   741  	return "MockedName"
   742  }
   743  
   744  func (t *mockT) failed() bool {
   745  	return t.f
   746  }
   747  
   748  func (t *mockT) failMessage() string {
   749  	if t.FatalCalled {
   750  		return t.FatalArgs[0].(string)
   751  	} else if t.ErrorCalled {
   752  		return t.ErrorArgs[0].(string)
   753  	} else if t.SkipCalled {
   754  		return t.SkipArgs[0].(string)
   755  	}
   756  
   757  	return "unknown"
   758  }
   759  
   760  func testProvider() *terraform.MockResourceProvider {
   761  	mp := new(terraform.MockResourceProvider)
   762  	mp.DiffReturn = &terraform.InstanceDiff{
   763  		Attributes: map[string]*terraform.ResourceAttrDiff{
   764  			"foo": &terraform.ResourceAttrDiff{
   765  				New: "bar",
   766  			},
   767  		},
   768  	}
   769  	mp.ResourcesReturn = []terraform.ResourceType{
   770  		terraform.ResourceType{Name: "test_instance"},
   771  	}
   772  
   773  	return mp
   774  }
   775  
   776  func TestTest_Main(t *testing.T) {
   777  	flag.Parse()
   778  	if *flagSweep == "" {
   779  		// Tests for the TestMain method used for Sweepers will panic without the -sweep
   780  		// flag specified. Mock the value for now
   781  		*flagSweep = "us-east-1"
   782  	}
   783  
   784  	cases := []struct {
   785  		Name            string
   786  		Sweepers        map[string]*Sweeper
   787  		ExpectedRunList []string
   788  		SweepRun        string
   789  	}{
   790  		{
   791  			Name: "normal",
   792  			Sweepers: map[string]*Sweeper{
   793  				"aws_dummy": &Sweeper{
   794  					Name: "aws_dummy",
   795  					F:    mockSweeperFunc,
   796  				},
   797  			},
   798  			ExpectedRunList: []string{"aws_dummy"},
   799  		},
   800  		{
   801  			Name: "with dep",
   802  			Sweepers: map[string]*Sweeper{
   803  				"aws_dummy": &Sweeper{
   804  					Name: "aws_dummy",
   805  					F:    mockSweeperFunc,
   806  				},
   807  				"aws_top": &Sweeper{
   808  					Name:         "aws_top",
   809  					Dependencies: []string{"aws_sub"},
   810  					F:            mockSweeperFunc,
   811  				},
   812  				"aws_sub": &Sweeper{
   813  					Name: "aws_sub",
   814  					F:    mockSweeperFunc,
   815  				},
   816  			},
   817  			ExpectedRunList: []string{"aws_dummy", "aws_sub", "aws_top"},
   818  		},
   819  		{
   820  			Name: "with filter",
   821  			Sweepers: map[string]*Sweeper{
   822  				"aws_dummy": &Sweeper{
   823  					Name: "aws_dummy",
   824  					F:    mockSweeperFunc,
   825  				},
   826  				"aws_top": &Sweeper{
   827  					Name:         "aws_top",
   828  					Dependencies: []string{"aws_sub"},
   829  					F:            mockSweeperFunc,
   830  				},
   831  				"aws_sub": &Sweeper{
   832  					Name: "aws_sub",
   833  					F:    mockSweeperFunc,
   834  				},
   835  			},
   836  			ExpectedRunList: []string{"aws_dummy"},
   837  			SweepRun:        "aws_dummy",
   838  		},
   839  		{
   840  			Name: "with two filters",
   841  			Sweepers: map[string]*Sweeper{
   842  				"aws_dummy": &Sweeper{
   843  					Name: "aws_dummy",
   844  					F:    mockSweeperFunc,
   845  				},
   846  				"aws_top": &Sweeper{
   847  					Name:         "aws_top",
   848  					Dependencies: []string{"aws_sub"},
   849  					F:            mockSweeperFunc,
   850  				},
   851  				"aws_sub": &Sweeper{
   852  					Name: "aws_sub",
   853  					F:    mockSweeperFunc,
   854  				},
   855  			},
   856  			ExpectedRunList: []string{"aws_dummy", "aws_sub"},
   857  			SweepRun:        "aws_dummy,aws_sub",
   858  		},
   859  		{
   860  			Name: "with dep and filter",
   861  			Sweepers: map[string]*Sweeper{
   862  				"aws_dummy": &Sweeper{
   863  					Name: "aws_dummy",
   864  					F:    mockSweeperFunc,
   865  				},
   866  				"aws_top": &Sweeper{
   867  					Name:         "aws_top",
   868  					Dependencies: []string{"aws_sub"},
   869  					F:            mockSweeperFunc,
   870  				},
   871  				"aws_sub": &Sweeper{
   872  					Name: "aws_sub",
   873  					F:    mockSweeperFunc,
   874  				},
   875  			},
   876  			ExpectedRunList: []string{"aws_top", "aws_sub"},
   877  			SweepRun:        "aws_top",
   878  		},
   879  		{
   880  			Name: "filter and none",
   881  			Sweepers: map[string]*Sweeper{
   882  				"aws_dummy": &Sweeper{
   883  					Name: "aws_dummy",
   884  					F:    mockSweeperFunc,
   885  				},
   886  				"aws_top": &Sweeper{
   887  					Name:         "aws_top",
   888  					Dependencies: []string{"aws_sub"},
   889  					F:            mockSweeperFunc,
   890  				},
   891  				"aws_sub": &Sweeper{
   892  					Name: "aws_sub",
   893  					F:    mockSweeperFunc,
   894  				},
   895  			},
   896  			SweepRun: "none",
   897  		},
   898  	}
   899  
   900  	for _, tc := range cases {
   901  		// reset sweepers
   902  		sweeperFuncs = map[string]*Sweeper{}
   903  
   904  		t.Run(tc.Name, func(t *testing.T) {
   905  			for n, s := range tc.Sweepers {
   906  				AddTestSweepers(n, s)
   907  			}
   908  			*flagSweepRun = tc.SweepRun
   909  
   910  			TestMain(&testing.M{})
   911  
   912  			// get list of tests ran from sweeperRunList keys
   913  			var keys []string
   914  			for k, _ := range sweeperRunList {
   915  				keys = append(keys, k)
   916  			}
   917  
   918  			sort.Strings(keys)
   919  			sort.Strings(tc.ExpectedRunList)
   920  			if !reflect.DeepEqual(keys, tc.ExpectedRunList) {
   921  				t.Fatalf("Expected keys mismatch, expected:\n%#v\ngot:\n%#v\n", tc.ExpectedRunList, keys)
   922  			}
   923  		})
   924  	}
   925  }
   926  
   927  func mockSweeperFunc(s string) error {
   928  	return nil
   929  }
   930  
   931  func TestTest_Taint(t *testing.T) {
   932  	mp := testProvider()
   933  	mp.DiffFn = func(
   934  		_ *terraform.InstanceInfo,
   935  		state *terraform.InstanceState,
   936  		_ *terraform.ResourceConfig,
   937  	) (*terraform.InstanceDiff, error) {
   938  		return &terraform.InstanceDiff{
   939  			DestroyTainted: state.Tainted,
   940  		}, nil
   941  	}
   942  
   943  	mp.ApplyFn = func(
   944  		info *terraform.InstanceInfo,
   945  		state *terraform.InstanceState,
   946  		diff *terraform.InstanceDiff,
   947  	) (*terraform.InstanceState, error) {
   948  		var id string
   949  		switch {
   950  		case diff.Destroy && !diff.DestroyTainted:
   951  			return nil, nil
   952  		case diff.DestroyTainted:
   953  			id = "tainted"
   954  		default:
   955  			id = "not_tainted"
   956  		}
   957  
   958  		return &terraform.InstanceState{
   959  			ID: id,
   960  		}, nil
   961  	}
   962  
   963  	mp.RefreshFn = func(
   964  		_ *terraform.InstanceInfo,
   965  		state *terraform.InstanceState,
   966  	) (*terraform.InstanceState, error) {
   967  		return state, nil
   968  	}
   969  
   970  	mt := new(mockT)
   971  	Test(mt, TestCase{
   972  		Providers: map[string]terraform.ResourceProvider{
   973  			"test": mp,
   974  		},
   975  		Steps: []TestStep{
   976  			TestStep{
   977  				Config: testConfigStr,
   978  				Check: func(s *terraform.State) error {
   979  					rs := s.RootModule().Resources["test_instance.foo"]
   980  					if rs.Primary.ID != "not_tainted" {
   981  						return fmt.Errorf("expected not_tainted, got %s", rs.Primary.ID)
   982  					}
   983  					return nil
   984  				},
   985  			},
   986  			TestStep{
   987  				Taint:  []string{"test_instance.foo"},
   988  				Config: testConfigStr,
   989  				Check: func(s *terraform.State) error {
   990  					rs := s.RootModule().Resources["test_instance.foo"]
   991  					if rs.Primary.ID != "tainted" {
   992  						return fmt.Errorf("expected tainted, got %s", rs.Primary.ID)
   993  					}
   994  					return nil
   995  				},
   996  			},
   997  			TestStep{
   998  				Taint:       []string{"test_instance.fooo"},
   999  				Config:      testConfigStr,
  1000  				ExpectError: regexp.MustCompile("resource \"test_instance.fooo\" not found in state"),
  1001  			},
  1002  		},
  1003  	})
  1004  
  1005  	if mt.failed() {
  1006  		t.Fatalf("test failure: %s", mt.failMessage())
  1007  	}
  1008  }
  1009  
  1010  const testConfigStr = `
  1011  resource "test_instance" "foo" {}
  1012  `
  1013  
  1014  const testConfigStrProvider = `
  1015  provider "test" {}
  1016  `