github.com/hashicorp/go-plugin@v1.6.0/grpc_stdio.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package plugin
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"context"
    10  	"io"
    11  
    12  	empty "github.com/golang/protobuf/ptypes/empty"
    13  	hclog "github.com/hashicorp/go-hclog"
    14  	"github.com/hashicorp/go-plugin/internal/plugin"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/codes"
    17  	"google.golang.org/grpc/status"
    18  )
    19  
    20  // grpcStdioBuffer is the buffer size we try to fill when sending a chunk of
    21  // stdio data. This is currently 1 KB for no reason other than that seems like
    22  // enough (stdio data isn't that common) and is fairly low.
    23  const grpcStdioBuffer = 1 * 1024
    24  
    25  // grpcStdioServer implements the Stdio service and streams stdiout/stderr.
    26  type grpcStdioServer struct {
    27  	stdoutCh <-chan []byte
    28  	stderrCh <-chan []byte
    29  }
    30  
    31  // newGRPCStdioServer creates a new grpcStdioServer and starts the stream
    32  // copying for the given out and err readers.
    33  //
    34  // This must only be called ONCE per srcOut, srcErr.
    35  func newGRPCStdioServer(log hclog.Logger, srcOut, srcErr io.Reader) *grpcStdioServer {
    36  	stdoutCh := make(chan []byte)
    37  	stderrCh := make(chan []byte)
    38  
    39  	// Begin copying the streams
    40  	go copyChan(log, stdoutCh, srcOut)
    41  	go copyChan(log, stderrCh, srcErr)
    42  
    43  	// Construct our server
    44  	return &grpcStdioServer{
    45  		stdoutCh: stdoutCh,
    46  		stderrCh: stderrCh,
    47  	}
    48  }
    49  
    50  // StreamStdio streams our stdout/err as the response.
    51  func (s *grpcStdioServer) StreamStdio(
    52  	_ *empty.Empty,
    53  	srv plugin.GRPCStdio_StreamStdioServer,
    54  ) error {
    55  	// Share the same data value between runs. Sending this over the wire
    56  	// marshals it so we can reuse this.
    57  	var data plugin.StdioData
    58  
    59  	for {
    60  		// Read our data
    61  		select {
    62  		case data.Data = <-s.stdoutCh:
    63  			data.Channel = plugin.StdioData_STDOUT
    64  
    65  		case data.Data = <-s.stderrCh:
    66  			data.Channel = plugin.StdioData_STDERR
    67  
    68  		case <-srv.Context().Done():
    69  			return nil
    70  		}
    71  
    72  		// Not sure if this is possible, but if we somehow got here and
    73  		// we didn't populate any data at all, then just continue.
    74  		if len(data.Data) == 0 {
    75  			continue
    76  		}
    77  
    78  		// Send our data to the client.
    79  		if err := srv.Send(&data); err != nil {
    80  			return err
    81  		}
    82  	}
    83  }
    84  
    85  // grpcStdioClient wraps the stdio service as a client to copy
    86  // the stdio data to output writers.
    87  type grpcStdioClient struct {
    88  	log         hclog.Logger
    89  	stdioClient plugin.GRPCStdio_StreamStdioClient
    90  }
    91  
    92  // newGRPCStdioClient creates a grpcStdioClient. This will perform the
    93  // initial connection to the stdio service. If the stdio service is unavailable
    94  // then this will be a no-op. This allows this to work without error for
    95  // plugins that don't support this.
    96  func newGRPCStdioClient(
    97  	ctx context.Context,
    98  	log hclog.Logger,
    99  	conn *grpc.ClientConn,
   100  ) (*grpcStdioClient, error) {
   101  	client := plugin.NewGRPCStdioClient(conn)
   102  
   103  	// Connect immediately to the endpoint
   104  	stdioClient, err := client.StreamStdio(ctx, &empty.Empty{})
   105  
   106  	// If we get an Unavailable or Unimplemented error, this means that the plugin isn't
   107  	// updated and linking to the latest version of go-plugin that supports
   108  	// this. We fall back to the previous behavior of just not syncing anything.
   109  	if status.Code(err) == codes.Unavailable || status.Code(err) == codes.Unimplemented {
   110  		log.Warn("stdio service not available, stdout/stderr syncing unavailable")
   111  		stdioClient = nil
   112  		err = nil
   113  	}
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	return &grpcStdioClient{
   119  		log:         log,
   120  		stdioClient: stdioClient,
   121  	}, nil
   122  }
   123  
   124  // Run starts the loop that receives stdio data and writes it to the given
   125  // writers. This blocks and should be run in a goroutine.
   126  func (c *grpcStdioClient) Run(stdout, stderr io.Writer) {
   127  	// This will be nil if stdio is not supported by the plugin
   128  	if c.stdioClient == nil {
   129  		c.log.Warn("stdio service unavailable, run will do nothing")
   130  		return
   131  	}
   132  
   133  	for {
   134  		c.log.Trace("waiting for stdio data")
   135  		data, err := c.stdioClient.Recv()
   136  		if err != nil {
   137  			if err == io.EOF ||
   138  				status.Code(err) == codes.Unavailable ||
   139  				status.Code(err) == codes.Canceled ||
   140  				status.Code(err) == codes.Unimplemented ||
   141  				err == context.Canceled {
   142  				c.log.Debug("received EOF, stopping recv loop", "err", err)
   143  				return
   144  			}
   145  
   146  			c.log.Error("error receiving data", "err", err)
   147  			return
   148  		}
   149  
   150  		// Determine our output writer based on channel
   151  		var w io.Writer
   152  		switch data.Channel {
   153  		case plugin.StdioData_STDOUT:
   154  			w = stdout
   155  
   156  		case plugin.StdioData_STDERR:
   157  			w = stderr
   158  
   159  		default:
   160  			c.log.Warn("unknown channel, dropping", "channel", data.Channel)
   161  			continue
   162  		}
   163  
   164  		// Write! In the event of an error we just continue.
   165  		if c.log.IsTrace() {
   166  			c.log.Trace("received data", "channel", data.Channel.String(), "len", len(data.Data))
   167  		}
   168  		if _, err := io.Copy(w, bytes.NewReader(data.Data)); err != nil {
   169  			c.log.Error("failed to copy all bytes", "err", err)
   170  		}
   171  	}
   172  }
   173  
   174  // copyChan copies an io.Reader into a channel.
   175  func copyChan(log hclog.Logger, dst chan<- []byte, src io.Reader) {
   176  	bufsrc := bufio.NewReader(src)
   177  
   178  	for {
   179  		// Make our data buffer. We allocate a new one per loop iteration
   180  		// so that we can send it over the channel.
   181  		var data [1024]byte
   182  
   183  		// Read the data, this will block until data is available
   184  		n, err := bufsrc.Read(data[:])
   185  
   186  		// We have to check if we have data BEFORE err != nil. The bufio
   187  		// docs guarantee n == 0 on EOF but its better to be safe here.
   188  		if n > 0 {
   189  			// We have data! Send it on the channel. This will block if there
   190  			// is no reader on the other side. We expect that go-plugin will
   191  			// connect immediately to the stdio server to drain this so we want
   192  			// this block to happen for backpressure.
   193  			dst <- data[:n]
   194  		}
   195  
   196  		// If we hit EOF we're done copying
   197  		if err == io.EOF {
   198  			log.Debug("stdio EOF, exiting copy loop")
   199  			return
   200  		}
   201  
   202  		// Any other error we just exit the loop. We don't expect there to
   203  		// be errors since our use case for this is reading/writing from
   204  		// a in-process pipe (os.Pipe).
   205  		if err != nil {
   206  			log.Warn("error copying stdio data, stopping copy", "err", err)
   207  			return
   208  		}
   209  	}
   210  }