github.com/turbot/go-exec-communicator@v0.0.0-20230412124734-9374347749b6/communicator.go (about)

     1  package communicator
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/turbot/go-exec-communicator/remote"
    13  	"github.com/turbot/go-exec-communicator/shared"
    14  	"github.com/turbot/go-exec-communicator/ssh"
    15  	"github.com/turbot/go-exec-communicator/winrm"
    16  )
    17  
    18  // Communicator is an interface that must be implemented by all communicators
    19  // used for any of the provisioners
    20  type Communicator interface {
    21  	// Connect is used to set up the connection
    22  	Connect(*shared.Outputter) error
    23  
    24  	// Disconnect is used to terminate the connection
    25  	Disconnect() error
    26  
    27  	// Timeout returns the configured connection timeout
    28  	Timeout() time.Duration
    29  
    30  	// ScriptPath returns the configured script path
    31  	ScriptPath() string
    32  
    33  	// Start executes a remote command in a new session
    34  	Start(*remote.Cmd) error
    35  
    36  	// Upload is used to upload a single file
    37  	Upload(string, io.Reader) error
    38  
    39  	// UploadScript is used to upload a file as an executable script
    40  	UploadScript(string, io.Reader) error
    41  
    42  	// UploadDir is used to upload a directory
    43  	UploadDir(string, string) error
    44  }
    45  
    46  // New returns a configured Communicator or an error if the connection type is not supported
    47  func New(ci shared.ConnectionInfo) (Communicator, error) {
    48  
    49  	/*
    50  		v, err := shared.ConnectionBlockSupersetSchema.CoerceValue(v)
    51  		if err != nil {
    52  			return nil, err
    53  		}
    54  
    55  		typeVal := ci.Typev.GetAttr("type")
    56  		connType := ""
    57  		if !typeVal.IsNull() {
    58  			connType = typeVal.AsString()
    59  		}
    60  	*/
    61  
    62  	switch ci.Type {
    63  	case "ssh", "": // The default connection type is ssh, so if connType is empty use ssh
    64  		return ssh.New(ci)
    65  	case "winrm":
    66  		return winrm.New(ci)
    67  	default:
    68  		return nil, fmt.Errorf("connection type '%s' not supported", ci.Type)
    69  	}
    70  }
    71  
    72  // maxBackoffDelay is the maximum delay between retry attempts
    73  var maxBackoffDelay = 20 * time.Second
    74  var initialBackoffDelay = time.Second
    75  
    76  // in practice we want to abort the retry asap, but for tests we need to
    77  // synchronize the return.
    78  var retryTestWg *sync.WaitGroup
    79  
    80  // Fatal is an interface that error values can return to halt Retry
    81  type Fatal interface {
    82  	FatalError() error
    83  }
    84  
    85  // Retry retries the function f until it returns a nil error, a Fatal error, or
    86  // the context expires.
    87  func Retry(ctx context.Context, f func() error) error {
    88  	// container for atomic error value
    89  	type errWrap struct {
    90  		E error
    91  	}
    92  
    93  	// Try the function in a goroutine
    94  	var errVal atomic.Value
    95  	doneCh := make(chan struct{})
    96  	go func() {
    97  		if retryTestWg != nil {
    98  			defer retryTestWg.Done()
    99  		}
   100  
   101  		defer close(doneCh)
   102  
   103  		delay := time.Duration(0)
   104  		for {
   105  			// If our context ended, we want to exit right away.
   106  			select {
   107  			case <-ctx.Done():
   108  				return
   109  			case <-time.After(delay):
   110  			}
   111  
   112  			// Try the function call
   113  			err := f()
   114  
   115  			// return if we have no error, or a FatalError
   116  			done := false
   117  			switch e := err.(type) {
   118  			case nil:
   119  				done = true
   120  			case Fatal:
   121  				err = e.FatalError()
   122  				done = true
   123  			}
   124  
   125  			errVal.Store(errWrap{err})
   126  
   127  			if done {
   128  				return
   129  			}
   130  
   131  			log.Printf("[WARN] retryable error: %v", err)
   132  
   133  			delay *= 2
   134  
   135  			if delay == 0 {
   136  				delay = initialBackoffDelay
   137  			}
   138  
   139  			if delay > maxBackoffDelay {
   140  				delay = maxBackoffDelay
   141  			}
   142  
   143  			log.Printf("[INFO] sleeping for %s", delay)
   144  		}
   145  	}()
   146  
   147  	// Wait for completion
   148  	select {
   149  	case <-ctx.Done():
   150  	case <-doneCh:
   151  	}
   152  
   153  	var lastErr error
   154  	// Check if we got an error executing
   155  	if ev, ok := errVal.Load().(errWrap); ok {
   156  		lastErr = ev.E
   157  	}
   158  
   159  	// Check if we have a context error to check if we're interrupted or timeout
   160  	switch ctx.Err() {
   161  	case context.Canceled:
   162  		return fmt.Errorf("interrupted - last error: %v", lastErr)
   163  	case context.DeadlineExceeded:
   164  		return fmt.Errorf("timeout - last error: %v", lastErr)
   165  	}
   166  
   167  	if lastErr != nil {
   168  		return lastErr
   169  	}
   170  	return nil
   171  }