github.com/hashicorp/go-plugin@v1.6.0/grpc_broker.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package plugin
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"errors"
    10  	"fmt"
    11  	"log"
    12  	"net"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  
    17  	"github.com/hashicorp/go-plugin/internal/grpcmux"
    18  	"github.com/hashicorp/go-plugin/internal/plugin"
    19  	"github.com/hashicorp/go-plugin/runner"
    20  
    21  	"github.com/oklog/run"
    22  	"google.golang.org/grpc"
    23  	"google.golang.org/grpc/credentials"
    24  )
    25  
    26  // streamer interface is used in the broker to send/receive connection
    27  // information.
    28  type streamer interface {
    29  	Send(*plugin.ConnInfo) error
    30  	Recv() (*plugin.ConnInfo, error)
    31  	Close()
    32  }
    33  
    34  // sendErr is used to pass errors back during a send.
    35  type sendErr struct {
    36  	i  *plugin.ConnInfo
    37  	ch chan error
    38  }
    39  
    40  // gRPCBrokerServer is used by the plugin to start a stream and to send
    41  // connection information to/from the plugin. Implements GRPCBrokerServer and
    42  // streamer interfaces.
    43  type gRPCBrokerServer struct {
    44  	plugin.UnimplementedGRPCBrokerServer
    45  
    46  	// send is used to send connection info to the gRPC stream.
    47  	send chan *sendErr
    48  
    49  	// recv is used to receive connection info from the gRPC stream.
    50  	recv chan *plugin.ConnInfo
    51  
    52  	// quit closes down the stream.
    53  	quit chan struct{}
    54  
    55  	// o is used to ensure we close the quit channel only once.
    56  	o sync.Once
    57  }
    58  
    59  func newGRPCBrokerServer() *gRPCBrokerServer {
    60  	return &gRPCBrokerServer{
    61  		send: make(chan *sendErr),
    62  		recv: make(chan *plugin.ConnInfo),
    63  		quit: make(chan struct{}),
    64  	}
    65  }
    66  
    67  // StartStream implements the GRPCBrokerServer interface and will block until
    68  // the quit channel is closed or the context reports Done. The stream will pass
    69  // connection information to/from the client.
    70  func (s *gRPCBrokerServer) StartStream(stream plugin.GRPCBroker_StartStreamServer) error {
    71  	doneCh := stream.Context().Done()
    72  	defer s.Close()
    73  
    74  	// Proccess send stream
    75  	go func() {
    76  		for {
    77  			select {
    78  			case <-doneCh:
    79  				return
    80  			case <-s.quit:
    81  				return
    82  			case se := <-s.send:
    83  				err := stream.Send(se.i)
    84  				se.ch <- err
    85  			}
    86  		}
    87  	}()
    88  
    89  	// Process receive stream
    90  	for {
    91  		i, err := stream.Recv()
    92  		if err != nil {
    93  			return err
    94  		}
    95  		select {
    96  		case <-doneCh:
    97  			return nil
    98  		case <-s.quit:
    99  			return nil
   100  		case s.recv <- i:
   101  		}
   102  	}
   103  
   104  	return nil
   105  }
   106  
   107  // Send is used by the GRPCBroker to pass connection information into the stream
   108  // to the client.
   109  func (s *gRPCBrokerServer) Send(i *plugin.ConnInfo) error {
   110  	ch := make(chan error)
   111  	defer close(ch)
   112  
   113  	select {
   114  	case <-s.quit:
   115  		return errors.New("broker closed")
   116  	case s.send <- &sendErr{
   117  		i:  i,
   118  		ch: ch,
   119  	}:
   120  	}
   121  
   122  	return <-ch
   123  }
   124  
   125  // Recv is used by the GRPCBroker to pass connection information that has been
   126  // sent from the client from the stream to the broker.
   127  func (s *gRPCBrokerServer) Recv() (*plugin.ConnInfo, error) {
   128  	select {
   129  	case <-s.quit:
   130  		return nil, errors.New("broker closed")
   131  	case i := <-s.recv:
   132  		return i, nil
   133  	}
   134  }
   135  
   136  // Close closes the quit channel, shutting down the stream.
   137  func (s *gRPCBrokerServer) Close() {
   138  	s.o.Do(func() {
   139  		close(s.quit)
   140  	})
   141  }
   142  
   143  // gRPCBrokerClientImpl is used by the client to start a stream and to send
   144  // connection information to/from the client. Implements GRPCBrokerClient and
   145  // streamer interfaces.
   146  type gRPCBrokerClientImpl struct {
   147  	// client is the underlying GRPC client used to make calls to the server.
   148  	client plugin.GRPCBrokerClient
   149  
   150  	// send is used to send connection info to the gRPC stream.
   151  	send chan *sendErr
   152  
   153  	// recv is used to receive connection info from the gRPC stream.
   154  	recv chan *plugin.ConnInfo
   155  
   156  	// quit closes down the stream.
   157  	quit chan struct{}
   158  
   159  	// o is used to ensure we close the quit channel only once.
   160  	o sync.Once
   161  }
   162  
   163  func newGRPCBrokerClient(conn *grpc.ClientConn) *gRPCBrokerClientImpl {
   164  	return &gRPCBrokerClientImpl{
   165  		client: plugin.NewGRPCBrokerClient(conn),
   166  		send:   make(chan *sendErr),
   167  		recv:   make(chan *plugin.ConnInfo),
   168  		quit:   make(chan struct{}),
   169  	}
   170  }
   171  
   172  // StartStream implements the GRPCBrokerClient interface and will block until
   173  // the quit channel is closed or the context reports Done. The stream will pass
   174  // connection information to/from the plugin.
   175  func (s *gRPCBrokerClientImpl) StartStream() error {
   176  	ctx, cancelFunc := context.WithCancel(context.Background())
   177  	defer cancelFunc()
   178  	defer s.Close()
   179  
   180  	stream, err := s.client.StartStream(ctx)
   181  	if err != nil {
   182  		return err
   183  	}
   184  	doneCh := stream.Context().Done()
   185  
   186  	go func() {
   187  		for {
   188  			select {
   189  			case <-doneCh:
   190  				return
   191  			case <-s.quit:
   192  				return
   193  			case se := <-s.send:
   194  				err := stream.Send(se.i)
   195  				se.ch <- err
   196  			}
   197  		}
   198  	}()
   199  
   200  	for {
   201  		i, err := stream.Recv()
   202  		if err != nil {
   203  			return err
   204  		}
   205  		select {
   206  		case <-doneCh:
   207  			return nil
   208  		case <-s.quit:
   209  			return nil
   210  		case s.recv <- i:
   211  		}
   212  	}
   213  
   214  	return nil
   215  }
   216  
   217  // Send is used by the GRPCBroker to pass connection information into the stream
   218  // to the plugin.
   219  func (s *gRPCBrokerClientImpl) Send(i *plugin.ConnInfo) error {
   220  	ch := make(chan error)
   221  	defer close(ch)
   222  
   223  	select {
   224  	case <-s.quit:
   225  		return errors.New("broker closed")
   226  	case s.send <- &sendErr{
   227  		i:  i,
   228  		ch: ch,
   229  	}:
   230  	}
   231  
   232  	return <-ch
   233  }
   234  
   235  // Recv is used by the GRPCBroker to pass connection information that has been
   236  // sent from the plugin to the broker.
   237  func (s *gRPCBrokerClientImpl) Recv() (*plugin.ConnInfo, error) {
   238  	select {
   239  	case <-s.quit:
   240  		return nil, errors.New("broker closed")
   241  	case i := <-s.recv:
   242  		return i, nil
   243  	}
   244  }
   245  
   246  // Close closes the quit channel, shutting down the stream.
   247  func (s *gRPCBrokerClientImpl) Close() {
   248  	s.o.Do(func() {
   249  		close(s.quit)
   250  	})
   251  }
   252  
   253  // GRPCBroker is responsible for brokering connections by unique ID.
   254  //
   255  // It is used by plugins to create multiple gRPC connections and data
   256  // streams between the plugin process and the host process.
   257  //
   258  // This allows a plugin to request a channel with a specific ID to connect to
   259  // or accept a connection from, and the broker handles the details of
   260  // holding these channels open while they're being negotiated.
   261  //
   262  // The Plugin interface has access to these for both Server and Client.
   263  // The broker can be used by either (optionally) to reserve and connect to
   264  // new streams. This is useful for complex args and return values,
   265  // or anything else you might need a data stream for.
   266  type GRPCBroker struct {
   267  	nextId   uint32
   268  	streamer streamer
   269  	tls      *tls.Config
   270  	doneCh   chan struct{}
   271  	o        sync.Once
   272  
   273  	clientStreams map[uint32]*gRPCBrokerPending
   274  	serverStreams map[uint32]*gRPCBrokerPending
   275  
   276  	unixSocketCfg  UnixSocketConfig
   277  	addrTranslator runner.AddrTranslator
   278  
   279  	dialMutex sync.Mutex
   280  
   281  	muxer grpcmux.GRPCMuxer
   282  
   283  	sync.Mutex
   284  }
   285  
   286  type gRPCBrokerPending struct {
   287  	ch     chan *plugin.ConnInfo
   288  	doneCh chan struct{}
   289  	once   sync.Once
   290  }
   291  
   292  func newGRPCBroker(s streamer, tls *tls.Config, unixSocketCfg UnixSocketConfig, addrTranslator runner.AddrTranslator, muxer grpcmux.GRPCMuxer) *GRPCBroker {
   293  	return &GRPCBroker{
   294  		streamer: s,
   295  		tls:      tls,
   296  		doneCh:   make(chan struct{}),
   297  
   298  		clientStreams: make(map[uint32]*gRPCBrokerPending),
   299  		serverStreams: make(map[uint32]*gRPCBrokerPending),
   300  		muxer:         muxer,
   301  
   302  		unixSocketCfg:  unixSocketCfg,
   303  		addrTranslator: addrTranslator,
   304  	}
   305  }
   306  
   307  // Accept accepts a connection by ID.
   308  //
   309  // This should not be called multiple times with the same ID at one time.
   310  func (b *GRPCBroker) Accept(id uint32) (net.Listener, error) {
   311  	if b.muxer.Enabled() {
   312  		p := b.getServerStream(id)
   313  		go func() {
   314  			err := b.listenForKnocks(id)
   315  			if err != nil {
   316  				log.Printf("[ERR]: error listening for knocks, id: %d, error: %s", id, err)
   317  			}
   318  		}()
   319  
   320  		ln, err := b.muxer.Listener(id, p.doneCh)
   321  		if err != nil {
   322  			return nil, err
   323  		}
   324  
   325  		ln = &rmListener{
   326  			Listener: ln,
   327  			close: func() error {
   328  				// We could have multiple listeners on the same ID, so use sync.Once
   329  				// for closing doneCh to ensure we don't get a panic.
   330  				p.once.Do(func() {
   331  					close(p.doneCh)
   332  				})
   333  
   334  				b.Lock()
   335  				defer b.Unlock()
   336  
   337  				// No longer need to listen for knocks once the listener is closed.
   338  				delete(b.serverStreams, id)
   339  
   340  				return nil
   341  			},
   342  		}
   343  
   344  		return ln, nil
   345  	}
   346  
   347  	listener, err := serverListener(b.unixSocketCfg)
   348  	if err != nil {
   349  		return nil, err
   350  	}
   351  
   352  	advertiseNet := listener.Addr().Network()
   353  	advertiseAddr := listener.Addr().String()
   354  	if b.addrTranslator != nil {
   355  		advertiseNet, advertiseAddr, err = b.addrTranslator.HostToPlugin(advertiseNet, advertiseAddr)
   356  		if err != nil {
   357  			return nil, err
   358  		}
   359  	}
   360  	err = b.streamer.Send(&plugin.ConnInfo{
   361  		ServiceId: id,
   362  		Network:   advertiseNet,
   363  		Address:   advertiseAddr,
   364  	})
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	return listener, nil
   370  }
   371  
   372  // AcceptAndServe is used to accept a specific stream ID and immediately
   373  // serve a gRPC server on that stream ID. This is used to easily serve
   374  // complex arguments. Each AcceptAndServe call opens a new listener socket and
   375  // sends the connection info down the stream to the dialer. Since a new
   376  // connection is opened every call, these calls should be used sparingly.
   377  // Multiple gRPC server implementations can be registered to a single
   378  // AcceptAndServe call.
   379  func (b *GRPCBroker) AcceptAndServe(id uint32, newGRPCServer func([]grpc.ServerOption) *grpc.Server) {
   380  	ln, err := b.Accept(id)
   381  	if err != nil {
   382  		log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err)
   383  		return
   384  	}
   385  	defer ln.Close()
   386  
   387  	var opts []grpc.ServerOption
   388  	if b.tls != nil {
   389  		opts = []grpc.ServerOption{grpc.Creds(credentials.NewTLS(b.tls))}
   390  	}
   391  
   392  	server := newGRPCServer(opts)
   393  
   394  	// Here we use a run group to close this goroutine if the server is shutdown
   395  	// or the broker is shutdown.
   396  	var g run.Group
   397  	{
   398  		// Serve on the listener, if shutting down call GracefulStop.
   399  		g.Add(func() error {
   400  			return server.Serve(ln)
   401  		}, func(err error) {
   402  			server.GracefulStop()
   403  		})
   404  	}
   405  	{
   406  		// block on the closeCh or the doneCh. If we are shutting down close the
   407  		// closeCh.
   408  		closeCh := make(chan struct{})
   409  		g.Add(func() error {
   410  			select {
   411  			case <-b.doneCh:
   412  			case <-closeCh:
   413  			}
   414  			return nil
   415  		}, func(err error) {
   416  			close(closeCh)
   417  		})
   418  	}
   419  
   420  	// Block until we are done
   421  	g.Run()
   422  }
   423  
   424  // Close closes the stream and all servers.
   425  func (b *GRPCBroker) Close() error {
   426  	b.streamer.Close()
   427  	b.o.Do(func() {
   428  		close(b.doneCh)
   429  	})
   430  	return nil
   431  }
   432  
   433  func (b *GRPCBroker) listenForKnocks(id uint32) error {
   434  	p := b.getServerStream(id)
   435  	for {
   436  		select {
   437  		case msg := <-p.ch:
   438  			// Shouldn't be possible.
   439  			if msg.ServiceId != id {
   440  				return fmt.Errorf("knock received with wrong service ID; expected %d but got %d", id, msg.ServiceId)
   441  			}
   442  
   443  			// Also shouldn't be possible.
   444  			if msg.Knock == nil || !msg.Knock.Knock || msg.Knock.Ack {
   445  				return fmt.Errorf("knock received for service ID %d with incorrect values; knock=%+v", id, msg.Knock)
   446  			}
   447  
   448  			// Successful knock, open the door for the given ID.
   449  			var ackError string
   450  			err := b.muxer.AcceptKnock(id)
   451  			if err != nil {
   452  				ackError = err.Error()
   453  			}
   454  
   455  			// Send back an acknowledgement to allow the client to start dialling.
   456  			err = b.streamer.Send(&plugin.ConnInfo{
   457  				ServiceId: id,
   458  				Knock: &plugin.ConnInfo_Knock{
   459  					Knock: true,
   460  					Ack:   true,
   461  					Error: ackError,
   462  				},
   463  			})
   464  			if err != nil {
   465  				return fmt.Errorf("error sending back knock acknowledgement: %w", err)
   466  			}
   467  		case <-p.doneCh:
   468  			return nil
   469  		}
   470  	}
   471  }
   472  
   473  func (b *GRPCBroker) knock(id uint32) error {
   474  	// Send a knock.
   475  	err := b.streamer.Send(&plugin.ConnInfo{
   476  		ServiceId: id,
   477  		Knock: &plugin.ConnInfo_Knock{
   478  			Knock: true,
   479  		},
   480  	})
   481  	if err != nil {
   482  		return err
   483  	}
   484  
   485  	// Wait for the ack.
   486  	p := b.getClientStream(id)
   487  	select {
   488  	case msg := <-p.ch:
   489  		if msg.ServiceId != id {
   490  			return fmt.Errorf("handshake failed for multiplexing on id %d; got response for %d", id, msg.ServiceId)
   491  		}
   492  		if msg.Knock == nil || !msg.Knock.Knock || !msg.Knock.Ack {
   493  			return fmt.Errorf("handshake failed for multiplexing on id %d; expected knock and ack, but got %+v", id, msg.Knock)
   494  		}
   495  		if msg.Knock.Error != "" {
   496  			return fmt.Errorf("failed to knock for id %d: %s", id, msg.Knock.Error)
   497  		}
   498  	case <-time.After(5 * time.Second):
   499  		return fmt.Errorf("timeout waiting for multiplexing knock handshake on id %d", id)
   500  	}
   501  
   502  	return nil
   503  }
   504  
   505  func (b *GRPCBroker) muxDial(id uint32) func(string, time.Duration) (net.Conn, error) {
   506  	return func(string, time.Duration) (net.Conn, error) {
   507  		b.dialMutex.Lock()
   508  		defer b.dialMutex.Unlock()
   509  
   510  		// Tell the other side the listener ID it should give the next stream to.
   511  		err := b.knock(id)
   512  		if err != nil {
   513  			return nil, fmt.Errorf("failed to knock before dialling client: %w", err)
   514  		}
   515  
   516  		conn, err := b.muxer.Dial()
   517  		if err != nil {
   518  			return nil, err
   519  		}
   520  
   521  		return conn, nil
   522  	}
   523  }
   524  
   525  // Dial opens a connection by ID.
   526  func (b *GRPCBroker) Dial(id uint32) (conn *grpc.ClientConn, err error) {
   527  	if b.muxer.Enabled() {
   528  		return dialGRPCConn(b.tls, b.muxDial(id))
   529  	}
   530  
   531  	var c *plugin.ConnInfo
   532  
   533  	// Open the stream
   534  	p := b.getClientStream(id)
   535  	select {
   536  	case c = <-p.ch:
   537  		close(p.doneCh)
   538  	case <-time.After(5 * time.Second):
   539  		return nil, fmt.Errorf("timeout waiting for connection info")
   540  	}
   541  
   542  	network, address := c.Network, c.Address
   543  	if b.addrTranslator != nil {
   544  		network, address, err = b.addrTranslator.PluginToHost(network, address)
   545  		if err != nil {
   546  			return nil, err
   547  		}
   548  	}
   549  
   550  	var addr net.Addr
   551  	switch network {
   552  	case "tcp":
   553  		addr, err = net.ResolveTCPAddr("tcp", address)
   554  	case "unix":
   555  		addr, err = net.ResolveUnixAddr("unix", address)
   556  	default:
   557  		err = fmt.Errorf("Unknown address type: %s", c.Address)
   558  	}
   559  	if err != nil {
   560  		return nil, err
   561  	}
   562  
   563  	return dialGRPCConn(b.tls, netAddrDialer(addr))
   564  }
   565  
   566  // NextId returns a unique ID to use next.
   567  //
   568  // It is possible for very long-running plugin hosts to wrap this value,
   569  // though it would require a very large amount of calls. In practice
   570  // we've never seen it happen.
   571  func (m *GRPCBroker) NextId() uint32 {
   572  	return atomic.AddUint32(&m.nextId, 1)
   573  }
   574  
   575  // Run starts the brokering and should be executed in a goroutine, since it
   576  // blocks forever, or until the session closes.
   577  //
   578  // Uses of GRPCBroker never need to call this. It is called internally by
   579  // the plugin host/client.
   580  func (m *GRPCBroker) Run() {
   581  	for {
   582  		msg, err := m.streamer.Recv()
   583  		if err != nil {
   584  			// Once we receive an error, just exit
   585  			break
   586  		}
   587  
   588  		// Initialize the waiter
   589  		var p *gRPCBrokerPending
   590  		if msg.Knock != nil && msg.Knock.Knock && !msg.Knock.Ack {
   591  			p = m.getServerStream(msg.ServiceId)
   592  			// The server side doesn't close the channel immediately as it needs
   593  			// to continuously listen for knocks.
   594  		} else {
   595  			p = m.getClientStream(msg.ServiceId)
   596  			go m.timeoutWait(msg.ServiceId, p)
   597  		}
   598  		select {
   599  		case p.ch <- msg:
   600  		default:
   601  		}
   602  	}
   603  }
   604  
   605  // getClientStream is a buffer to receive new connection info and knock acks
   606  // by stream ID.
   607  func (m *GRPCBroker) getClientStream(id uint32) *gRPCBrokerPending {
   608  	m.Lock()
   609  	defer m.Unlock()
   610  
   611  	p, ok := m.clientStreams[id]
   612  	if ok {
   613  		return p
   614  	}
   615  
   616  	m.clientStreams[id] = &gRPCBrokerPending{
   617  		ch:     make(chan *plugin.ConnInfo, 1),
   618  		doneCh: make(chan struct{}),
   619  	}
   620  	return m.clientStreams[id]
   621  }
   622  
   623  // getServerStream is a buffer to receive knocks to a multiplexed stream ID
   624  // that its side is listening on. Not used unless multiplexing is enabled.
   625  func (m *GRPCBroker) getServerStream(id uint32) *gRPCBrokerPending {
   626  	m.Lock()
   627  	defer m.Unlock()
   628  
   629  	p, ok := m.serverStreams[id]
   630  	if ok {
   631  		return p
   632  	}
   633  
   634  	m.serverStreams[id] = &gRPCBrokerPending{
   635  		ch:     make(chan *plugin.ConnInfo, 1),
   636  		doneCh: make(chan struct{}),
   637  	}
   638  	return m.serverStreams[id]
   639  }
   640  
   641  func (m *GRPCBroker) timeoutWait(id uint32, p *gRPCBrokerPending) {
   642  	// Wait for the stream to either be picked up and connected, or
   643  	// for a timeout.
   644  	select {
   645  	case <-p.doneCh:
   646  	case <-time.After(5 * time.Second):
   647  	}
   648  
   649  	m.Lock()
   650  	defer m.Unlock()
   651  
   652  	// Delete the stream so no one else can grab it
   653  	delete(m.clientStreams, id)
   654  }