github.com/hashicorp/go-plugin@v1.6.0/mux_broker.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package plugin
     5  
     6  import (
     7  	"encoding/binary"
     8  	"fmt"
     9  	"log"
    10  	"net"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/hashicorp/yamux"
    16  )
    17  
    18  // MuxBroker is responsible for brokering multiplexed connections by unique ID.
    19  //
    20  // It is used by plugins to multiplex multiple RPC connections and data
    21  // streams on top of a single connection between the plugin process and the
    22  // host process.
    23  //
    24  // This allows a plugin to request a channel with a specific ID to connect to
    25  // or accept a connection from, and the broker handles the details of
    26  // holding these channels open while they're being negotiated.
    27  //
    28  // The Plugin interface has access to these for both Server and Client.
    29  // The broker can be used by either (optionally) to reserve and connect to
    30  // new multiplexed streams. This is useful for complex args and return values,
    31  // or anything else you might need a data stream for.
    32  type MuxBroker struct {
    33  	nextId  uint32
    34  	session *yamux.Session
    35  	streams map[uint32]*muxBrokerPending
    36  
    37  	sync.Mutex
    38  }
    39  
    40  type muxBrokerPending struct {
    41  	ch     chan net.Conn
    42  	doneCh chan struct{}
    43  }
    44  
    45  func newMuxBroker(s *yamux.Session) *MuxBroker {
    46  	return &MuxBroker{
    47  		session: s,
    48  		streams: make(map[uint32]*muxBrokerPending),
    49  	}
    50  }
    51  
    52  // Accept accepts a connection by ID.
    53  //
    54  // This should not be called multiple times with the same ID at one time.
    55  func (m *MuxBroker) Accept(id uint32) (net.Conn, error) {
    56  	var c net.Conn
    57  	p := m.getStream(id)
    58  	select {
    59  	case c = <-p.ch:
    60  		close(p.doneCh)
    61  	case <-time.After(5 * time.Second):
    62  		m.Lock()
    63  		defer m.Unlock()
    64  		delete(m.streams, id)
    65  
    66  		return nil, fmt.Errorf("timeout waiting for accept")
    67  	}
    68  
    69  	// Ack our connection
    70  	if err := binary.Write(c, binary.LittleEndian, id); err != nil {
    71  		c.Close()
    72  		return nil, err
    73  	}
    74  
    75  	return c, nil
    76  }
    77  
    78  // AcceptAndServe is used to accept a specific stream ID and immediately
    79  // serve an RPC server on that stream ID. This is used to easily serve
    80  // complex arguments.
    81  //
    82  // The served interface is always registered to the "Plugin" name.
    83  func (m *MuxBroker) AcceptAndServe(id uint32, v interface{}) {
    84  	conn, err := m.Accept(id)
    85  	if err != nil {
    86  		log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err)
    87  		return
    88  	}
    89  
    90  	serve(conn, "Plugin", v)
    91  }
    92  
    93  // Close closes the connection and all sub-connections.
    94  func (m *MuxBroker) Close() error {
    95  	return m.session.Close()
    96  }
    97  
    98  // Dial opens a connection by ID.
    99  func (m *MuxBroker) Dial(id uint32) (net.Conn, error) {
   100  	// Open the stream
   101  	stream, err := m.session.OpenStream()
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	// Write the stream ID onto the wire.
   107  	if err := binary.Write(stream, binary.LittleEndian, id); err != nil {
   108  		stream.Close()
   109  		return nil, err
   110  	}
   111  
   112  	// Read the ack that we connected. Then we're off!
   113  	var ack uint32
   114  	if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil {
   115  		stream.Close()
   116  		return nil, err
   117  	}
   118  	if ack != id {
   119  		stream.Close()
   120  		return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id)
   121  	}
   122  
   123  	return stream, nil
   124  }
   125  
   126  // NextId returns a unique ID to use next.
   127  //
   128  // It is possible for very long-running plugin hosts to wrap this value,
   129  // though it would require a very large amount of RPC calls. In practice
   130  // we've never seen it happen.
   131  func (m *MuxBroker) NextId() uint32 {
   132  	return atomic.AddUint32(&m.nextId, 1)
   133  }
   134  
   135  // Run starts the brokering and should be executed in a goroutine, since it
   136  // blocks forever, or until the session closes.
   137  //
   138  // Uses of MuxBroker never need to call this. It is called internally by
   139  // the plugin host/client.
   140  func (m *MuxBroker) Run() {
   141  	for {
   142  		stream, err := m.session.AcceptStream()
   143  		if err != nil {
   144  			// Once we receive an error, just exit
   145  			break
   146  		}
   147  
   148  		// Read the stream ID from the stream
   149  		var id uint32
   150  		if err := binary.Read(stream, binary.LittleEndian, &id); err != nil {
   151  			stream.Close()
   152  			continue
   153  		}
   154  
   155  		// Initialize the waiter
   156  		p := m.getStream(id)
   157  		select {
   158  		case p.ch <- stream:
   159  		default:
   160  		}
   161  
   162  		// Wait for a timeout
   163  		go m.timeoutWait(id, p)
   164  	}
   165  }
   166  
   167  func (m *MuxBroker) getStream(id uint32) *muxBrokerPending {
   168  	m.Lock()
   169  	defer m.Unlock()
   170  
   171  	p, ok := m.streams[id]
   172  	if ok {
   173  		return p
   174  	}
   175  
   176  	m.streams[id] = &muxBrokerPending{
   177  		ch:     make(chan net.Conn, 1),
   178  		doneCh: make(chan struct{}),
   179  	}
   180  	return m.streams[id]
   181  }
   182  
   183  func (m *MuxBroker) timeoutWait(id uint32, p *muxBrokerPending) {
   184  	// Wait for the stream to either be picked up and connected, or
   185  	// for a timeout.
   186  	timeout := false
   187  	select {
   188  	case <-p.doneCh:
   189  	case <-time.After(5 * time.Second):
   190  		timeout = true
   191  	}
   192  
   193  	m.Lock()
   194  	defer m.Unlock()
   195  
   196  	// Delete the stream so no one else can grab it
   197  	delete(m.streams, id)
   198  
   199  	// If we timed out, then check if we have a channel in the buffer,
   200  	// and if so, close it.
   201  	if timeout {
   202  		select {
   203  		case s := <-p.ch:
   204  			s.Close()
   205  		}
   206  	}
   207  }