github.com/opentofu/opentofu@v1.7.1/internal/communicator/communicator.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package communicator
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  
    17  	"github.com/opentofu/opentofu/internal/communicator/remote"
    18  	"github.com/opentofu/opentofu/internal/communicator/shared"
    19  	"github.com/opentofu/opentofu/internal/communicator/ssh"
    20  	"github.com/opentofu/opentofu/internal/communicator/winrm"
    21  	"github.com/opentofu/opentofu/internal/provisioners"
    22  	"github.com/zclconf/go-cty/cty"
    23  )
    24  
    25  // Communicator is an interface that must be implemented by all communicators
    26  // used for any of the provisioners
    27  type Communicator interface {
    28  	// Connect is used to set up the connection
    29  	Connect(provisioners.UIOutput) error
    30  
    31  	// Disconnect is used to terminate the connection
    32  	Disconnect() error
    33  
    34  	// Timeout returns the configured connection timeout
    35  	Timeout() time.Duration
    36  
    37  	// ScriptPath returns the configured script path
    38  	ScriptPath() string
    39  
    40  	// Start executes a remote command in a new session
    41  	Start(*remote.Cmd) error
    42  
    43  	// Upload is used to upload a single file
    44  	Upload(string, io.Reader) error
    45  
    46  	// UploadScript is used to upload a file as an executable script
    47  	UploadScript(string, io.Reader) error
    48  
    49  	// UploadDir is used to upload a directory
    50  	UploadDir(string, string) error
    51  }
    52  
    53  // New returns a configured Communicator or an error if the connection type is not supported
    54  func New(v cty.Value) (Communicator, error) {
    55  	v, err := shared.ConnectionBlockSupersetSchema.CoerceValue(v)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	typeVal := v.GetAttr("type")
    61  	connType := ""
    62  	if !typeVal.IsNull() {
    63  		connType = typeVal.AsString()
    64  	}
    65  
    66  	switch connType {
    67  	case "ssh", "": // The default connection type is ssh, so if connType is empty use ssh
    68  		return ssh.New(v)
    69  	case "winrm":
    70  		return winrm.New(v)
    71  	default:
    72  		return nil, fmt.Errorf("connection type '%s' not supported", connType)
    73  	}
    74  }
    75  
    76  // maxBackoffDelay is the maximum delay between retry attempts
    77  var maxBackoffDelay = 20 * time.Second
    78  var initialBackoffDelay = time.Second
    79  
    80  // in practice we want to abort the retry asap, but for tests we need to
    81  // synchronize the return.
    82  var retryTestWg *sync.WaitGroup
    83  
    84  // Fatal is an interface that error values can return to halt Retry
    85  type Fatal interface {
    86  	FatalError() error
    87  }
    88  
    89  // Retry retries the function f until it returns a nil error, a Fatal error, or
    90  // the context expires.
    91  func Retry(ctx context.Context, f func() error) error {
    92  	// container for atomic error value
    93  	type errWrap struct {
    94  		E error
    95  	}
    96  
    97  	// Try the function in a goroutine
    98  	var errVal atomic.Value
    99  	doneCh := make(chan struct{})
   100  	go func() {
   101  		if retryTestWg != nil {
   102  			defer retryTestWg.Done()
   103  		}
   104  
   105  		defer close(doneCh)
   106  
   107  		delay := time.Duration(0)
   108  		for {
   109  			// If our context ended, we want to exit right away.
   110  			select {
   111  			case <-ctx.Done():
   112  				return
   113  			case <-time.After(delay):
   114  			}
   115  
   116  			// Try the function call
   117  			err := f()
   118  
   119  			// return if we have no error, or a FatalError
   120  			done := false
   121  			switch e := err.(type) {
   122  			case nil:
   123  				done = true
   124  			case Fatal:
   125  				err = e.FatalError()
   126  				done = true
   127  			}
   128  
   129  			errVal.Store(errWrap{err})
   130  
   131  			if done {
   132  				return
   133  			}
   134  
   135  			log.Printf("[WARN] retryable error: %v", err)
   136  
   137  			delay *= 2
   138  
   139  			if delay == 0 {
   140  				delay = initialBackoffDelay
   141  			}
   142  
   143  			if delay > maxBackoffDelay {
   144  				delay = maxBackoffDelay
   145  			}
   146  
   147  			log.Printf("[INFO] sleeping for %s", delay)
   148  		}
   149  	}()
   150  
   151  	// Wait for completion
   152  	select {
   153  	case <-ctx.Done():
   154  	case <-doneCh:
   155  	}
   156  
   157  	var lastErr error
   158  	// Check if we got an error executing
   159  	if ev, ok := errVal.Load().(errWrap); ok {
   160  		lastErr = ev.E
   161  	}
   162  
   163  	// Check if we have a context error to check if we're interrupted or timeout
   164  	switch ctx.Err() {
   165  	case context.Canceled:
   166  		return fmt.Errorf("interrupted - last error: %w", lastErr)
   167  	case context.DeadlineExceeded:
   168  		return fmt.Errorf("timeout - last error: %w", lastErr)
   169  	}
   170  
   171  	if lastErr != nil {
   172  		return lastErr
   173  	}
   174  	return nil
   175  }