github.com/rothwerx/packer@v0.9.0/packer/rpc/mux_broker.go (about)

     1  package rpc
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/hashicorp/yamux"
    13  )
    14  
    15  // muxBroker is responsible for brokering multiplexed connections by unique ID.
    16  //
    17  // This allows a plugin to request a channel with a specific ID to connect to
    18  // or accept a connection from, and the broker handles the details of
    19  // holding these channels open while they're being negotiated.
    20  type muxBroker struct {
    21  	nextId  uint32
    22  	session *yamux.Session
    23  	streams map[uint32]*muxBrokerPending
    24  
    25  	sync.Mutex
    26  }
    27  
    28  type muxBrokerPending struct {
    29  	ch     chan net.Conn
    30  	doneCh chan struct{}
    31  }
    32  
    33  func newMuxBroker(s *yamux.Session) *muxBroker {
    34  	return &muxBroker{
    35  		session: s,
    36  		streams: make(map[uint32]*muxBrokerPending),
    37  	}
    38  }
    39  
    40  func newMuxBrokerClient(rwc io.ReadWriteCloser) (*muxBroker, error) {
    41  	s, err := yamux.Client(rwc, nil)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	return newMuxBroker(s), nil
    47  }
    48  
    49  func newMuxBrokerServer(rwc io.ReadWriteCloser) (*muxBroker, error) {
    50  	s, err := yamux.Server(rwc, nil)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	return newMuxBroker(s), nil
    56  }
    57  
    58  // Accept accepts a connection by ID.
    59  //
    60  // This should not be called multiple times with the same ID at one time.
    61  func (m *muxBroker) Accept(id uint32) (net.Conn, error) {
    62  	var c net.Conn
    63  	p := m.getStream(id)
    64  	select {
    65  	case c = <-p.ch:
    66  		close(p.doneCh)
    67  	case <-time.After(5 * time.Second):
    68  		m.Lock()
    69  		defer m.Unlock()
    70  		delete(m.streams, id)
    71  
    72  		return nil, fmt.Errorf("timeout waiting for accept")
    73  	}
    74  
    75  	// Ack our connection
    76  	if err := binary.Write(c, binary.LittleEndian, id); err != nil {
    77  		c.Close()
    78  		return nil, err
    79  	}
    80  
    81  	return c, nil
    82  }
    83  
    84  // Close closes the connection and all sub-connections.
    85  func (m *muxBroker) Close() error {
    86  	return m.session.Close()
    87  }
    88  
    89  // Dial opens a connection by ID.
    90  func (m *muxBroker) Dial(id uint32) (net.Conn, error) {
    91  	// Open the stream
    92  	stream, err := m.session.OpenStream()
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	// Write the stream ID onto the wire.
    98  	if err := binary.Write(stream, binary.LittleEndian, id); err != nil {
    99  		stream.Close()
   100  		return nil, err
   101  	}
   102  
   103  	// Read the ack that we connected. Then we're off!
   104  	var ack uint32
   105  	if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil {
   106  		stream.Close()
   107  		return nil, err
   108  	}
   109  	if ack != id {
   110  		stream.Close()
   111  		return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id)
   112  	}
   113  
   114  	return stream, nil
   115  }
   116  
   117  // NextId returns a unique ID to use next.
   118  func (m *muxBroker) NextId() uint32 {
   119  	return atomic.AddUint32(&m.nextId, 1)
   120  }
   121  
   122  // Run starts the brokering and should be executed in a goroutine, since it
   123  // blocks forever, or until the session closes.
   124  func (m *muxBroker) Run() {
   125  	for {
   126  		stream, err := m.session.AcceptStream()
   127  		if err != nil {
   128  			// Once we receive an error, just exit
   129  			break
   130  		}
   131  
   132  		// Read the stream ID from the stream
   133  		var id uint32
   134  		if err := binary.Read(stream, binary.LittleEndian, &id); err != nil {
   135  			stream.Close()
   136  			continue
   137  		}
   138  
   139  		// Initialize the waiter
   140  		p := m.getStream(id)
   141  		select {
   142  		case p.ch <- stream:
   143  		default:
   144  		}
   145  
   146  		// Wait for a timeout
   147  		go m.timeoutWait(id, p)
   148  	}
   149  }
   150  
   151  func (m *muxBroker) getStream(id uint32) *muxBrokerPending {
   152  	m.Lock()
   153  	defer m.Unlock()
   154  
   155  	p, ok := m.streams[id]
   156  	if ok {
   157  		return p
   158  	}
   159  
   160  	m.streams[id] = &muxBrokerPending{
   161  		ch:     make(chan net.Conn, 1),
   162  		doneCh: make(chan struct{}),
   163  	}
   164  	return m.streams[id]
   165  }
   166  
   167  func (m *muxBroker) timeoutWait(id uint32, p *muxBrokerPending) {
   168  	// Wait for the stream to either be picked up and connected, or
   169  	// for a timeout.
   170  	timeout := false
   171  	select {
   172  	case <-p.doneCh:
   173  	case <-time.After(5 * time.Second):
   174  		timeout = true
   175  	}
   176  
   177  	m.Lock()
   178  	defer m.Unlock()
   179  
   180  	// Delete the stream so no one else can grab it
   181  	delete(m.streams, id)
   182  
   183  	// If we timed out, then check if we have a channel in the buffer,
   184  	// and if so, close it.
   185  	if timeout {
   186  		select {
   187  		case s := <-p.ch:
   188  			s.Close()
   189  		}
   190  	}
   191  }