github.com/erikwilson/go-powershell@v0.0.0-20200701182037-6845e6fcfa79/shell.go (about)

     1  // Copyright (c) 2017 Gorillalabs. All rights reserved.
     2  
     3  package powershell
     4  
     5  import (
     6  	"fmt"
     7  	"io"
     8  	"regexp"
     9  	"sync"
    10  
    11  	"github.com/rancher/go-powershell/backend"
    12  	"github.com/rancher/go-powershell/utils"
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  const newline = "\r\n"
    17  
    18  type Shell interface {
    19  	Execute(cmd string) (string, string, error)
    20  	Exit()
    21  }
    22  
    23  type shell struct {
    24  	handle backend.Waiter
    25  	stdin  io.Writer
    26  	stdout io.Reader
    27  	stderr io.Reader
    28  }
    29  
    30  func New(backend backend.Starter) (Shell, error) {
    31  	handle, stdin, stdout, stderr, err := backend.StartProcess("powershell.exe", "-NoExit", "-Command", "-")
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	return &shell{handle, stdin, stdout, stderr}, nil
    37  }
    38  
    39  func (s *shell) Execute(cmd string) (string, string, error) {
    40  	if s.handle == nil {
    41  		return "", "", errors.Wrap(errors.New(cmd), "Cannot execute commands on closed shells.")
    42  	}
    43  
    44  	outBoundary := createBoundary()
    45  	errBoundary := createBoundary()
    46  
    47  	// wrap the command in special markers so we know when to stop reading from the pipes
    48  	full := fmt.Sprintf("%s; echo '%s'; [Console]::Error.WriteLine('%s')%s", cmd, outBoundary, errBoundary, newline)
    49  
    50  	_, err := s.stdin.Write([]byte(full))
    51  	if err != nil {
    52  		return "", "", errors.Wrap(errors.Wrap(err, cmd), "Could not send PowerShell command")
    53  	}
    54  
    55  	// read stdout and stderr
    56  	sout := ""
    57  	serr := ""
    58  
    59  	waiter := &sync.WaitGroup{}
    60  	waiter.Add(2)
    61  
    62  	go streamReader(s.stdout, outBoundary, &sout, waiter)
    63  	go streamReader(s.stderr, errBoundary, &serr, waiter)
    64  
    65  	waiter.Wait()
    66  
    67  	if len(serr) > 0 {
    68  		return sout, serr, errors.Wrap(errors.New(cmd), serr)
    69  	}
    70  
    71  	return sout, serr, nil
    72  }
    73  
    74  func (s *shell) Exit() {
    75  	s.stdin.Write([]byte("exit" + newline))
    76  
    77  	// if it's possible to close stdin, do so (some backends, like the local one,
    78  	// do support it)
    79  	closer, ok := s.stdin.(io.Closer)
    80  	if ok {
    81  		closer.Close()
    82  	}
    83  
    84  	s.handle.Wait()
    85  
    86  	s.handle = nil
    87  	s.stdin = nil
    88  	s.stdout = nil
    89  	s.stderr = nil
    90  }
    91  
    92  func streamReader(stream io.Reader, boundary string, buffer *string, signal *sync.WaitGroup) error {
    93  	// read all output until we have found our boundary token
    94  	output := ""
    95  	bufsize := 64
    96  	marker := regexp.MustCompile("(?s)(.*)" + regexp.QuoteMeta(boundary))
    97  
    98  	for {
    99  		buf := make([]byte, bufsize)
   100  		read, err := stream.Read(buf)
   101  		if err != nil {
   102  			return err
   103  		}
   104  
   105  		output = output + string(buf[:read])
   106  
   107  		if marker.MatchString(output) {
   108  			break
   109  		}
   110  	}
   111  
   112  	*buffer = marker.FindStringSubmatch(output)[1]
   113  	signal.Done()
   114  
   115  	return nil
   116  }
   117  
   118  func createBoundary() string {
   119  	return "$gorilla" + utils.CreateRandomString(12) + "$"
   120  }