github.com/HashDataInc/packer@v1.3.2/communicator/winrm/communicator.go (about)

     1  package winrm
     2  
     3  import (
     4  	"encoding/base64"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"log"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"sync"
    13  
    14  	"github.com/hashicorp/packer/packer"
    15  	"github.com/masterzen/winrm"
    16  	"github.com/packer-community/winrmcp/winrmcp"
    17  )
    18  
    19  // Communicator represents the WinRM communicator
    20  type Communicator struct {
    21  	config   *Config
    22  	client   *winrm.Client
    23  	endpoint *winrm.Endpoint
    24  }
    25  
    26  // New creates a new communicator implementation over WinRM.
    27  func New(config *Config) (*Communicator, error) {
    28  	endpoint := &winrm.Endpoint{
    29  		Host:     config.Host,
    30  		Port:     config.Port,
    31  		HTTPS:    config.Https,
    32  		Insecure: config.Insecure,
    33  
    34  		/*
    35  			TODO
    36  			HTTPS:    connInfo.HTTPS,
    37  			Insecure: connInfo.Insecure,
    38  			CACert:   connInfo.CACert,
    39  		*/
    40  	}
    41  
    42  	// Create the client
    43  	params := *winrm.DefaultParameters
    44  
    45  	if config.TransportDecorator != nil {
    46  		params.TransportDecorator = config.TransportDecorator
    47  	}
    48  
    49  	params.Timeout = formatDuration(config.Timeout)
    50  	client, err := winrm.NewClientWithParameters(
    51  		endpoint, config.Username, config.Password, &params)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	// Create the shell to verify the connection
    57  	log.Printf("[DEBUG] connecting to remote shell using WinRM")
    58  	shell, err := client.CreateShell()
    59  	if err != nil {
    60  		log.Printf("[ERROR] connection error: %s", err)
    61  		return nil, err
    62  	}
    63  
    64  	if err := shell.Close(); err != nil {
    65  		log.Printf("[ERROR] error closing connection: %s", err)
    66  		return nil, err
    67  	}
    68  
    69  	return &Communicator{
    70  		config:   config,
    71  		client:   client,
    72  		endpoint: endpoint,
    73  	}, nil
    74  }
    75  
    76  // Start implementation of communicator.Communicator interface
    77  func (c *Communicator) Start(rc *packer.RemoteCmd) error {
    78  	shell, err := c.client.CreateShell()
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	log.Printf("[INFO] starting remote command: %s", rc.Command)
    84  	cmd, err := shell.Execute(rc.Command)
    85  	if err != nil {
    86  		return err
    87  	}
    88  
    89  	go runCommand(shell, cmd, rc)
    90  	return nil
    91  }
    92  
    93  func runCommand(shell *winrm.Shell, cmd *winrm.Command, rc *packer.RemoteCmd) {
    94  	defer shell.Close()
    95  	var wg sync.WaitGroup
    96  
    97  	copyFunc := func(w io.Writer, r io.Reader) {
    98  		defer wg.Done()
    99  		io.Copy(w, r)
   100  	}
   101  
   102  	if rc.Stdout != nil && cmd.Stdout != nil {
   103  		wg.Add(1)
   104  		go copyFunc(rc.Stdout, cmd.Stdout)
   105  	} else {
   106  		log.Printf("[WARN] Failed to read stdout for command '%s'", rc.Command)
   107  	}
   108  
   109  	if rc.Stderr != nil && cmd.Stderr != nil {
   110  		wg.Add(1)
   111  		go copyFunc(rc.Stderr, cmd.Stderr)
   112  	} else {
   113  		log.Printf("[WARN] Failed to read stderr for command '%s'", rc.Command)
   114  	}
   115  
   116  	cmd.Wait()
   117  	wg.Wait()
   118  
   119  	code := cmd.ExitCode()
   120  	log.Printf("[INFO] command '%s' exited with code: %d", rc.Command, code)
   121  	rc.SetExited(code)
   122  }
   123  
   124  // Upload implementation of communicator.Communicator interface
   125  func (c *Communicator) Upload(path string, input io.Reader, fi *os.FileInfo) error {
   126  	wcp, err := c.newCopyClient()
   127  	if err != nil {
   128  		return fmt.Errorf("Was unable to create winrm client: %s", err)
   129  	}
   130  	if strings.HasSuffix(path, `\`) {
   131  		// path is a directory
   132  		path += filepath.Base((*fi).Name())
   133  	}
   134  	log.Printf("Uploading file to '%s'", path)
   135  	return wcp.Write(path, input)
   136  }
   137  
   138  // UploadDir implementation of communicator.Communicator interface
   139  func (c *Communicator) UploadDir(dst string, src string, exclude []string) error {
   140  	if !strings.HasSuffix(src, "/") {
   141  		dst = fmt.Sprintf("%s\\%s", dst, filepath.Base(src))
   142  	}
   143  	log.Printf("Uploading dir '%s' to '%s'", src, dst)
   144  	wcp, err := c.newCopyClient()
   145  	if err != nil {
   146  		return err
   147  	}
   148  	return wcp.Copy(src, dst)
   149  }
   150  
   151  func (c *Communicator) Download(src string, dst io.Writer) error {
   152  	client, err := c.newWinRMClient()
   153  	if err != nil {
   154  		return err
   155  	}
   156  
   157  	encodeScript := `$file=[System.IO.File]::ReadAllBytes("%s"); Write-Output $([System.Convert]::ToBase64String($file))`
   158  
   159  	base64DecodePipe := &Base64Pipe{w: dst}
   160  
   161  	cmd := winrm.Powershell(fmt.Sprintf(encodeScript, src))
   162  	_, err = client.Run(cmd, base64DecodePipe, ioutil.Discard)
   163  
   164  	return err
   165  }
   166  
   167  func (c *Communicator) DownloadDir(src string, dst string, exclude []string) error {
   168  	return fmt.Errorf("WinRM doesn't support download dir.")
   169  }
   170  
   171  func (c *Communicator) getClientConfig() *winrmcp.Config {
   172  	return &winrmcp.Config{
   173  		Auth: winrmcp.Auth{
   174  			User:     c.config.Username,
   175  			Password: c.config.Password,
   176  		},
   177  		Https:                 c.config.Https,
   178  		Insecure:              c.config.Insecure,
   179  		OperationTimeout:      c.config.Timeout,
   180  		MaxOperationsPerShell: 15, // lowest common denominator
   181  		TransportDecorator:    c.config.TransportDecorator,
   182  	}
   183  }
   184  
   185  func (c *Communicator) newCopyClient() (*winrmcp.Winrmcp, error) {
   186  	addr := fmt.Sprintf("%s:%d", c.endpoint.Host, c.endpoint.Port)
   187  	clientConfig := c.getClientConfig()
   188  	return winrmcp.New(addr, clientConfig)
   189  }
   190  
   191  func (c *Communicator) newWinRMClient() (*winrm.Client, error) {
   192  	conf := c.getClientConfig()
   193  
   194  	// Shamelessly borrowed from the winrmcp client to ensure
   195  	// that the client is configured using the same defaulting behaviors that
   196  	// winrmcp uses even we we aren't using winrmcp. This ensures similar
   197  	// behavior between upload, download, and copy functions. We can't use the
   198  	// one generated by winrmcp because it isn't exported.
   199  	var endpoint *winrm.Endpoint
   200  	endpoint = &winrm.Endpoint{
   201  		Host:          c.endpoint.Host,
   202  		Port:          c.endpoint.Port,
   203  		HTTPS:         conf.Https,
   204  		Insecure:      conf.Insecure,
   205  		TLSServerName: conf.TLSServerName,
   206  		CACert:        conf.CACertBytes,
   207  		Timeout:       conf.ConnectTimeout,
   208  	}
   209  	params := winrm.NewParameters(
   210  		winrm.DefaultParameters.Timeout,
   211  		winrm.DefaultParameters.Locale,
   212  		winrm.DefaultParameters.EnvelopeSize,
   213  	)
   214  
   215  	params.TransportDecorator = conf.TransportDecorator
   216  	params.Timeout = "PT3M"
   217  
   218  	client, err := winrm.NewClientWithParameters(
   219  		endpoint, conf.Auth.User, conf.Auth.Password, params)
   220  	return client, err
   221  }
   222  
   223  type Base64Pipe struct {
   224  	w io.Writer // underlying writer (file, buffer)
   225  }
   226  
   227  func (d *Base64Pipe) ReadFrom(r io.Reader) (int64, error) {
   228  	b, err := ioutil.ReadAll(r)
   229  	if err != nil {
   230  		return 0, err
   231  	}
   232  
   233  	var i int
   234  	i, err = d.Write(b)
   235  
   236  	if err != nil {
   237  		return 0, err
   238  	}
   239  
   240  	return int64(i), err
   241  }
   242  
   243  func (d *Base64Pipe) Write(p []byte) (int, error) {
   244  	dst := make([]byte, base64.StdEncoding.DecodedLen(len(p)))
   245  
   246  	decodedBytes, err := base64.StdEncoding.Decode(dst, p)
   247  	if err != nil {
   248  		return 0, err
   249  	}
   250  
   251  	return d.w.Write(dst[0:decodedBytes])
   252  }