github.com/myhau/pulumi/pkg/v3@v3.70.2-0.20221116134521-f2775972e587/testing/integration/util.go (about)

     1  // Copyright 2016-2018, Pulumi Corporation.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package integration
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"io/ioutil"
    21  	"net/http"
    22  	"os"
    23  	"os/exec"
    24  	"path"
    25  	"path/filepath"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/stretchr/testify/assert"
    31  
    32  	"github.com/pulumi/pulumi/sdk/v3/go/common/resource"
    33  	"github.com/pulumi/pulumi/sdk/v3/go/common/util/contract"
    34  )
    35  
    36  // DecodeMapString takes a string of the form key1=value1:key2=value2 and returns a go map.
    37  func DecodeMapString(val string) (map[string]string, error) {
    38  	newMap := make(map[string]string)
    39  
    40  	if val != "" {
    41  		for _, overrideClause := range strings.Split(val, ":") {
    42  			data := strings.Split(overrideClause, "=")
    43  			if len(data) != 2 {
    44  				return nil, fmt.Errorf(
    45  					"could not decode %s as an override, should be of the form <package>=<version>",
    46  					overrideClause)
    47  			}
    48  			packageName := data[0]
    49  			packageVersion := data[1]
    50  			newMap[packageName] = packageVersion
    51  		}
    52  	}
    53  
    54  	return newMap, nil
    55  }
    56  
    57  // ReplaceInFile does a find and replace for a given string within a file.
    58  func ReplaceInFile(old, new, path string) error {
    59  	rawContents, err := ioutil.ReadFile(path)
    60  	if err != nil {
    61  		return err
    62  	}
    63  	newContents := strings.Replace(string(rawContents), old, new, -1)
    64  	return ioutil.WriteFile(path, []byte(newContents), os.ModePerm)
    65  }
    66  
    67  // getCmdBin returns the binary named bin in location loc or, if it hasn't yet been initialized, will lazily
    68  // populate it by either using the default def or, if empty, looking on the current $PATH.
    69  func getCmdBin(loc *string, bin, def string) (string, error) {
    70  	if *loc == "" {
    71  		*loc = def
    72  		if *loc == "" {
    73  			var err error
    74  			*loc, err = exec.LookPath(bin)
    75  			if err != nil {
    76  				return "", fmt.Errorf("Expected to find `%s` binary on $PATH: %w", bin, err)
    77  			}
    78  		}
    79  	}
    80  	return *loc, nil
    81  }
    82  
    83  func uniqueSuffix() string {
    84  	// .<timestamp>.<five random hex characters>
    85  	timestamp := time.Now().Format("20060102-150405")
    86  	suffix, err := resource.NewUniqueHex("."+timestamp+".", 5, -1)
    87  	contract.AssertNoError(err)
    88  	return suffix
    89  }
    90  
    91  const (
    92  	commandOutputFolderName = "command-output"
    93  )
    94  
    95  func writeCommandOutput(commandName, runDir string, output []byte) (string, error) {
    96  	logFileDir := filepath.Join(runDir, commandOutputFolderName)
    97  	if err := os.MkdirAll(logFileDir, 0700); err != nil {
    98  		return "", fmt.Errorf("Failed to create '%s': %w", logFileDir, err)
    99  	}
   100  
   101  	logFile := filepath.Join(logFileDir, commandName+uniqueSuffix()+".log")
   102  
   103  	if err := ioutil.WriteFile(logFile, output, 0600); err != nil {
   104  		return "", fmt.Errorf("Failed to write '%s': %w", logFile, err)
   105  	}
   106  
   107  	return logFile, nil
   108  }
   109  
   110  // CopyFile copies a single file from src to dst
   111  // From https://blog.depado.eu/post/copy-files-and-directories-in-go
   112  func CopyFile(src, dst string) error {
   113  	var err error
   114  	var srcfd *os.File
   115  	var dstfd *os.File
   116  	var srcinfo os.FileInfo
   117  	var n int64
   118  
   119  	if srcfd, err = os.Open(src); err != nil {
   120  		return err
   121  	}
   122  	defer srcfd.Close()
   123  
   124  	if dstfd, err = os.Create(dst); err != nil {
   125  		return err
   126  	}
   127  	defer dstfd.Close()
   128  
   129  	if n, err = io.Copy(dstfd, srcfd); err != nil {
   130  		return err
   131  	}
   132  	if srcinfo, err = os.Stat(src); err != nil {
   133  		return err
   134  	}
   135  	if n != srcinfo.Size() {
   136  		return fmt.Errorf("failed to copy all bytes from %v to %v", src, dst)
   137  	}
   138  	return os.Chmod(dst, srcinfo.Mode())
   139  }
   140  
   141  // CopyDir copies a whole directory recursively
   142  // From https://blog.depado.eu/post/copy-files-and-directories-in-go
   143  func CopyDir(src, dst string) error {
   144  	var err error
   145  	var fds []os.DirEntry
   146  	var srcinfo os.FileInfo
   147  
   148  	if srcinfo, err = os.Stat(src); err != nil {
   149  		return err
   150  	}
   151  
   152  	if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil {
   153  		return err
   154  	}
   155  
   156  	if fds, err = os.ReadDir(src); err != nil {
   157  		return err
   158  	}
   159  	for _, fd := range fds {
   160  		srcfp := path.Join(src, fd.Name())
   161  		dstfp := path.Join(dst, fd.Name())
   162  
   163  		if fd.IsDir() {
   164  			if err = CopyDir(srcfp, dstfp); err != nil {
   165  				fmt.Println(err)
   166  			}
   167  		} else {
   168  			if err = CopyFile(srcfp, dstfp); err != nil {
   169  				fmt.Println(err)
   170  			}
   171  		}
   172  	}
   173  	return nil
   174  }
   175  
   176  // AssertHTTPResultWithRetry attempts to assert that an HTTP endpoint exists
   177  // and evaluate its response.
   178  func AssertHTTPResultWithRetry(
   179  	t *testing.T,
   180  	output interface{},
   181  	headers map[string]string,
   182  	maxWait time.Duration,
   183  	check func(string) bool,
   184  ) bool {
   185  	hostname, ok := output.(string)
   186  	if !assert.True(t, ok, fmt.Sprintf("expected `%s` output", output)) {
   187  		return false
   188  	}
   189  	if !(strings.HasPrefix(hostname, "http://") || strings.HasPrefix(hostname, "https://")) {
   190  		hostname = fmt.Sprintf("http://%s", hostname)
   191  	}
   192  	var err error
   193  	var resp *http.Response
   194  	startTime := time.Now()
   195  	count, sleep := 0, 0
   196  	for true {
   197  		now := time.Now()
   198  		req, err := http.NewRequest("GET", hostname, nil)
   199  		if !assert.NoError(t, err, "error reading request: %v", err) {
   200  			return false
   201  		}
   202  
   203  		for k, v := range headers {
   204  			// Host header cannot be set via req.Header.Set(), and must be set
   205  			// directly.
   206  			if strings.ToLower(k) == "host" {
   207  				req.Host = v
   208  				continue
   209  			}
   210  			req.Header.Set(k, v)
   211  		}
   212  
   213  		client := &http.Client{Timeout: time.Second * 10}
   214  		resp, err = client.Do(req)
   215  
   216  		if err == nil && resp.StatusCode == 200 {
   217  			break
   218  		}
   219  		if now.Sub(startTime) >= maxWait {
   220  			t.Logf("Timeout after %v. Unable to http.get %v successfully.", maxWait, hostname)
   221  			break
   222  		}
   223  		count++
   224  		// delay 10s, 20s, then 30s and stay at 30s
   225  		if sleep > 30 {
   226  			sleep = 30
   227  		} else {
   228  			sleep += 10
   229  		}
   230  		time.Sleep(time.Duration(sleep) * time.Second)
   231  		t.Logf("Http Error: %v\n", err)
   232  		t.Logf("  Retry: %v, elapsed wait: %v, max wait %v\n", count, now.Sub(startTime), maxWait)
   233  	}
   234  	if !assert.NoError(t, err) {
   235  		return false
   236  	}
   237  	// Read the body
   238  	defer resp.Body.Close()
   239  	body, err := ioutil.ReadAll(resp.Body)
   240  	if !assert.NoError(t, err) {
   241  		return false
   242  	}
   243  	// Verify it matches expectations
   244  	return check(string(body))
   245  }