github.com/opentofu/opentofu@v1.7.1/internal/builtin/provisioners/remote-exec/resource_provisioner_test.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package remoteexec
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"testing"
    15  	"time"
    16  
    17  	"strings"
    18  
    19  	"github.com/mitchellh/cli"
    20  	"github.com/opentofu/opentofu/internal/communicator"
    21  	"github.com/opentofu/opentofu/internal/communicator/remote"
    22  	"github.com/opentofu/opentofu/internal/provisioners"
    23  	"github.com/zclconf/go-cty/cty"
    24  )
    25  
    26  func TestResourceProvider_Validate_good(t *testing.T) {
    27  	c := cty.ObjectVal(map[string]cty.Value{
    28  		"inline": cty.ListVal([]cty.Value{cty.StringVal("echo foo")}),
    29  	})
    30  
    31  	resp := New().ValidateProvisionerConfig(provisioners.ValidateProvisionerConfigRequest{
    32  		Config: c,
    33  	})
    34  	if len(resp.Diagnostics) > 0 {
    35  		t.Fatal(resp.Diagnostics.ErrWithWarnings())
    36  	}
    37  }
    38  
    39  func TestResourceProvider_Validate_bad(t *testing.T) {
    40  	c := cty.ObjectVal(map[string]cty.Value{
    41  		"invalid": cty.StringVal("nope"),
    42  	})
    43  
    44  	resp := New().ValidateProvisionerConfig(provisioners.ValidateProvisionerConfigRequest{
    45  		Config: c,
    46  	})
    47  	if !resp.Diagnostics.HasErrors() {
    48  		t.Fatalf("Should have errors")
    49  	}
    50  }
    51  
    52  var expectedScriptOut = `cd /tmp
    53  wget http://foobar
    54  exit 0
    55  `
    56  
    57  func TestResourceProvider_generateScript(t *testing.T) {
    58  	inline := cty.ListVal([]cty.Value{
    59  		cty.StringVal("cd /tmp"),
    60  		cty.StringVal("wget http://foobar"),
    61  		cty.StringVal("exit 0"),
    62  	})
    63  
    64  	out, err := generateScripts(inline)
    65  	if err != nil {
    66  		t.Fatalf("err: %v", err)
    67  	}
    68  
    69  	if len(out) != 1 {
    70  		t.Fatal("expected 1 out")
    71  	}
    72  
    73  	if out[0] != expectedScriptOut {
    74  		t.Fatalf("bad: %v", out)
    75  	}
    76  }
    77  
    78  func TestResourceProvider_generateScriptEmptyInline(t *testing.T) {
    79  	inline := cty.ListVal([]cty.Value{cty.StringVal("")})
    80  
    81  	_, err := generateScripts(inline)
    82  	if err == nil {
    83  		t.Fatal("expected error, got none")
    84  	}
    85  
    86  	if !strings.Contains(err.Error(), "empty string") {
    87  		t.Fatalf("expected empty string error, got: %s", err)
    88  	}
    89  }
    90  
    91  func TestResourceProvider_CollectScripts_inline(t *testing.T) {
    92  	conf := map[string]cty.Value{
    93  		"inline": cty.ListVal([]cty.Value{
    94  			cty.StringVal("cd /tmp"),
    95  			cty.StringVal("wget http://foobar"),
    96  			cty.StringVal("exit 0"),
    97  		}),
    98  	}
    99  
   100  	scripts, err := collectScripts(cty.ObjectVal(conf))
   101  	if err != nil {
   102  		t.Fatalf("err: %v", err)
   103  	}
   104  
   105  	if len(scripts) != 1 {
   106  		t.Fatalf("bad: %v", scripts)
   107  	}
   108  
   109  	var out bytes.Buffer
   110  	_, err = io.Copy(&out, scripts[0])
   111  	if err != nil {
   112  		t.Fatalf("err: %v", err)
   113  	}
   114  
   115  	if out.String() != expectedScriptOut {
   116  		t.Fatalf("bad: %v", out.String())
   117  	}
   118  }
   119  
   120  func TestResourceProvider_CollectScripts_script(t *testing.T) {
   121  	p := New()
   122  	schema := p.GetSchema().Provisioner
   123  
   124  	conf, err := schema.CoerceValue(cty.ObjectVal(map[string]cty.Value{
   125  		"scripts": cty.ListVal([]cty.Value{
   126  			cty.StringVal("testdata/script1.sh"),
   127  		}),
   128  	}))
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  
   133  	scripts, err := collectScripts(conf)
   134  	if err != nil {
   135  		t.Fatalf("err: %v", err)
   136  	}
   137  
   138  	if len(scripts) != 1 {
   139  		t.Fatalf("bad: %v", scripts)
   140  	}
   141  
   142  	var out bytes.Buffer
   143  	_, err = io.Copy(&out, scripts[0])
   144  	if err != nil {
   145  		t.Fatalf("err: %v", err)
   146  	}
   147  
   148  	if out.String() != expectedScriptOut {
   149  		t.Fatalf("bad: %v", out.String())
   150  	}
   151  }
   152  
   153  func TestResourceProvider_CollectScripts_scripts(t *testing.T) {
   154  	p := New()
   155  	schema := p.GetSchema().Provisioner
   156  
   157  	conf, err := schema.CoerceValue(cty.ObjectVal(map[string]cty.Value{
   158  		"scripts": cty.ListVal([]cty.Value{
   159  			cty.StringVal("testdata/script1.sh"),
   160  			cty.StringVal("testdata/script1.sh"),
   161  			cty.StringVal("testdata/script1.sh"),
   162  		}),
   163  	}))
   164  	if err != nil {
   165  		log.Fatal(err)
   166  	}
   167  
   168  	scripts, err := collectScripts(conf)
   169  	if err != nil {
   170  		t.Fatalf("err: %v", err)
   171  	}
   172  
   173  	if len(scripts) != 3 {
   174  		t.Fatalf("bad: %v", scripts)
   175  	}
   176  
   177  	for idx := range scripts {
   178  		var out bytes.Buffer
   179  		_, err = io.Copy(&out, scripts[idx])
   180  		if err != nil {
   181  			t.Fatalf("err: %v", err)
   182  		}
   183  
   184  		if out.String() != expectedScriptOut {
   185  			t.Fatalf("bad: %v", out.String())
   186  		}
   187  	}
   188  }
   189  
   190  func TestResourceProvider_CollectScripts_scriptsEmpty(t *testing.T) {
   191  	p := New()
   192  	schema := p.GetSchema().Provisioner
   193  
   194  	conf, err := schema.CoerceValue(cty.ObjectVal(map[string]cty.Value{
   195  		"scripts": cty.ListVal([]cty.Value{cty.StringVal("")}),
   196  	}))
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  
   201  	_, err = collectScripts(conf)
   202  	if err == nil {
   203  		t.Fatal("expected error")
   204  	}
   205  
   206  	if !strings.Contains(err.Error(), "empty string") {
   207  		t.Fatalf("Expected empty string error, got: %s", err)
   208  	}
   209  }
   210  
   211  func TestProvisionerTimeout(t *testing.T) {
   212  	o := cli.NewMockUi()
   213  	c := new(communicator.MockCommunicator)
   214  
   215  	disconnected := make(chan struct{})
   216  	c.DisconnectFunc = func() error {
   217  		close(disconnected)
   218  		return nil
   219  	}
   220  
   221  	completed := make(chan struct{})
   222  	c.CommandFunc = func(cmd *remote.Cmd) error {
   223  		defer close(completed)
   224  		cmd.Init()
   225  		time.Sleep(2 * time.Second)
   226  		cmd.SetExitStatus(0, nil)
   227  		return nil
   228  	}
   229  	c.ConnTimeout = time.Second
   230  	c.UploadScripts = map[string]string{"hello": "echo hello"}
   231  	c.RemoteScriptPath = "hello"
   232  
   233  	conf := map[string]cty.Value{
   234  		"inline": cty.ListVal([]cty.Value{cty.StringVal("echo hello")}),
   235  	}
   236  
   237  	scripts, err := collectScripts(cty.ObjectVal(conf))
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  
   242  	ctx := context.Background()
   243  
   244  	done := make(chan struct{})
   245  
   246  	var runErr error
   247  	go func() {
   248  		defer close(done)
   249  		runErr = runScripts(ctx, o, c, scripts)
   250  	}()
   251  
   252  	select {
   253  	case <-disconnected:
   254  		t.Fatal("communicator disconnected before command completed")
   255  	case <-completed:
   256  	}
   257  
   258  	<-done
   259  	if runErr != nil {
   260  		t.Fatal(err)
   261  	}
   262  }
   263  
   264  // Validate that Stop can Close can be called even when not provisioning.
   265  func TestResourceProvisioner_StopClose(t *testing.T) {
   266  	p := New()
   267  	p.Stop()
   268  	p.Close()
   269  }
   270  
   271  func TestResourceProvisioner_connectionRequired(t *testing.T) {
   272  	p := New()
   273  	resp := p.ProvisionResource(provisioners.ProvisionResourceRequest{})
   274  	if !resp.Diagnostics.HasErrors() {
   275  		t.Fatal("expected error")
   276  	}
   277  
   278  	got := resp.Diagnostics.Err().Error()
   279  	if !strings.Contains(got, "Missing connection") {
   280  		t.Fatalf("expected 'Missing connection' error: got %q", got)
   281  	}
   282  }
   283  
   284  func TestResourceProvisioner_nullsInOptionals(t *testing.T) {
   285  	output := cli.NewMockUi()
   286  	p := New()
   287  	schema := p.GetSchema().Provisioner
   288  
   289  	for i, cfg := range []cty.Value{
   290  		cty.ObjectVal(map[string]cty.Value{
   291  			"script": cty.StringVal("echo"),
   292  			"inline": cty.NullVal(cty.List(cty.String)),
   293  		}),
   294  		cty.ObjectVal(map[string]cty.Value{
   295  			"inline": cty.ListVal([]cty.Value{
   296  				cty.NullVal(cty.String),
   297  			}),
   298  		}),
   299  		cty.ObjectVal(map[string]cty.Value{
   300  			"script": cty.NullVal(cty.String),
   301  		}),
   302  		cty.ObjectVal(map[string]cty.Value{
   303  			"scripts": cty.NullVal(cty.List(cty.String)),
   304  		}),
   305  		cty.ObjectVal(map[string]cty.Value{
   306  			"scripts": cty.ListVal([]cty.Value{
   307  				cty.NullVal(cty.String),
   308  			}),
   309  		}),
   310  	} {
   311  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   312  
   313  			cfg, err := schema.CoerceValue(cfg)
   314  			if err != nil {
   315  				t.Fatal(err)
   316  			}
   317  
   318  			// verifying there are no panics
   319  			p.ProvisionResource(provisioners.ProvisionResourceRequest{
   320  				Config:   cfg,
   321  				UIOutput: output,
   322  			})
   323  		})
   324  	}
   325  }