github.com/direktiv/go-powershell@v0.0.0-20220309043221-9f9747a0a631/shell.go (about)

     1  // Copyright (c) 2017 Gorillalabs. All rights reserved.
     2  
     3  package powershell
     4  
     5  import (
     6  	"fmt"
     7  	"io"
     8  	"regexp"
     9  	"sync"
    10  
    11  	"github.com/acarl005/stripansi"
    12  	"github.com/direktiv/go-powershell/backend"
    13  	"github.com/direktiv/go-powershell/utils"
    14  	"github.com/pkg/errors"
    15  )
    16  
    17  const newline = "\r\n"
    18  
    19  type Shell interface {
    20  	Execute(cmd string) (string, string, error)
    21  	Exit()
    22  }
    23  
    24  type shell struct {
    25  	handle backend.Waiter
    26  	stdin  io.Writer
    27  	stdout io.Reader
    28  	stderr io.Reader
    29  
    30  	be backend.Starter
    31  }
    32  
    33  func New(be backend.Starter) (Shell, error) {
    34  	handle, stdin, stdout, stderr, err := be.StartProcess("/bin/pwsh", "-NoExit", "-Command", "-")
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	return &shell{handle, stdin, stdout, stderr, be}, nil
    40  }
    41  
    42  func (s *shell) Execute(cmd string) (string, string, error) {
    43  
    44  	if s.handle == nil {
    45  		return "", "", errors.Wrap(errors.New(cmd), "Cannot execute commands on closed shells.")
    46  	}
    47  
    48  	outBoundary := createBoundary()
    49  	errBoundary := createBoundary()
    50  
    51  	// wrap the command in special markers so we know when to stop reading from the pipes
    52  	full := fmt.Sprintf("%s; echo '%s'; [Console]::Error.WriteLine('%s')%s", cmd, outBoundary, errBoundary, newline)
    53  
    54  	_, err := s.stdin.Write([]byte(full))
    55  	if err != nil {
    56  		return "", "", errors.Wrap(errors.Wrap(err, cmd), "Could not send PowerShell command")
    57  	}
    58  
    59  	// read stdout and stderr
    60  	sout := ""
    61  	serr := ""
    62  
    63  	waiter := &sync.WaitGroup{}
    64  	waiter.Add(2)
    65  
    66  	var wr io.Writer
    67  	bl, ok := s.be.(*backend.Local)
    68  
    69  	if ok {
    70  		wr = bl.Writer
    71  	}
    72  
    73  	go streamReader(s.stdout, outBoundary, &sout, waiter, wr)
    74  	go streamReader(s.stderr, errBoundary, &serr, waiter, wr)
    75  
    76  	waiter.Wait()
    77  
    78  	if len(serr) > 0 {
    79  		return sout, serr, errors.Wrap(errors.New(cmd), serr)
    80  	}
    81  
    82  	return sout, serr, nil
    83  }
    84  
    85  func (s *shell) Exit() {
    86  	s.stdin.Write([]byte("exit" + newline))
    87  
    88  	// if it's possible to close stdin, do so (some backends, like the local one,
    89  	// do support it)
    90  	closer, ok := s.stdin.(io.Closer)
    91  	if ok {
    92  		closer.Close()
    93  	}
    94  
    95  	s.handle.Wait()
    96  
    97  	s.handle = nil
    98  	s.stdin = nil
    99  	s.stdout = nil
   100  	s.stderr = nil
   101  }
   102  
   103  func streamReader(stream io.Reader, boundary string, buffer *string, signal *sync.WaitGroup, wr io.Writer) error {
   104  	// read all output until we have found our boundary token
   105  	output := ""
   106  	bufsize := 64
   107  	marker := regexp.MustCompile("(?s)(.*)" + regexp.QuoteMeta(boundary))
   108  	marker2 := regexp.MustCompile("\\$gorilla[a-z0-9]*")
   109  
   110  	for {
   111  		buf := make([]byte, bufsize)
   112  		read, err := stream.Read(buf)
   113  		if err != nil {
   114  			return err
   115  		}
   116  
   117  		output = output + string(buf[:read])
   118  		if marker.MatchString(output) {
   119  			break
   120  		}
   121  
   122  		if wr != nil {
   123  			sb := marker2.ReplaceAll(buf[:read], []byte(""))
   124  			wr.Write([]byte(stripansi.Strip(string(sb))))
   125  		}
   126  	}
   127  
   128  	*buffer = marker.FindStringSubmatch(output)[1]
   129  
   130  	signal.Done()
   131  
   132  	return nil
   133  }
   134  
   135  func createBoundary() string {
   136  	return "$gorilla" + utils.CreateRandomString(12) + "$"
   137  }