github.com/markdia/terraform@v0.5.1-0.20150508012022-f1ae920aa970/rpc/mux_broker.go (about)

     1  package rpc
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"net"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"github.com/hashicorp/yamux"
    12  )
    13  
    14  // muxBroker is responsible for brokering multiplexed connections by unique ID.
    15  //
    16  // This allows a plugin to request a channel with a specific ID to connect to
    17  // or accept a connection from, and the broker handles the details of
    18  // holding these channels open while they're being negotiated.
    19  type muxBroker struct {
    20  	nextId  uint32
    21  	session *yamux.Session
    22  	streams map[uint32]*muxBrokerPending
    23  
    24  	sync.Mutex
    25  }
    26  
    27  type muxBrokerPending struct {
    28  	ch     chan net.Conn
    29  	doneCh chan struct{}
    30  }
    31  
    32  func newMuxBroker(s *yamux.Session) *muxBroker {
    33  	return &muxBroker{
    34  		session: s,
    35  		streams: make(map[uint32]*muxBrokerPending),
    36  	}
    37  }
    38  
    39  // Accept accepts a connection by ID.
    40  //
    41  // This should not be called multiple times with the same ID at one time.
    42  func (m *muxBroker) Accept(id uint32) (net.Conn, error) {
    43  	var c net.Conn
    44  	p := m.getStream(id)
    45  	select {
    46  	case c = <-p.ch:
    47  		close(p.doneCh)
    48  	case <-time.After(5 * time.Second):
    49  		m.Lock()
    50  		defer m.Unlock()
    51  		delete(m.streams, id)
    52  
    53  		return nil, fmt.Errorf("timeout waiting for accept")
    54  	}
    55  
    56  	// Ack our connection
    57  	if err := binary.Write(c, binary.LittleEndian, id); err != nil {
    58  		c.Close()
    59  		return nil, err
    60  	}
    61  
    62  	return c, nil
    63  }
    64  
    65  // Close closes the connection and all sub-connections.
    66  func (m *muxBroker) Close() error {
    67  	return m.session.Close()
    68  }
    69  
    70  // Dial opens a connection by ID.
    71  func (m *muxBroker) Dial(id uint32) (net.Conn, error) {
    72  	// Open the stream
    73  	stream, err := m.session.OpenStream()
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	// Write the stream ID onto the wire.
    79  	if err := binary.Write(stream, binary.LittleEndian, id); err != nil {
    80  		stream.Close()
    81  		return nil, err
    82  	}
    83  
    84  	// Read the ack that we connected. Then we're off!
    85  	var ack uint32
    86  	if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil {
    87  		stream.Close()
    88  		return nil, err
    89  	}
    90  	if ack != id {
    91  		stream.Close()
    92  		return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id)
    93  	}
    94  
    95  	return stream, nil
    96  }
    97  
    98  // NextId returns a unique ID to use next.
    99  func (m *muxBroker) NextId() uint32 {
   100  	return atomic.AddUint32(&m.nextId, 1)
   101  }
   102  
   103  // Run starts the brokering and should be executed in a goroutine, since it
   104  // blocks forever, or until the session closes.
   105  func (m *muxBroker) Run() {
   106  	for {
   107  		stream, err := m.session.AcceptStream()
   108  		if err != nil {
   109  			// Once we receive an error, just exit
   110  			break
   111  		}
   112  
   113  		// Read the stream ID from the stream
   114  		var id uint32
   115  		if err := binary.Read(stream, binary.LittleEndian, &id); err != nil {
   116  			stream.Close()
   117  			continue
   118  		}
   119  
   120  		// Initialize the waiter
   121  		p := m.getStream(id)
   122  		select {
   123  		case p.ch <- stream:
   124  		default:
   125  		}
   126  
   127  		// Wait for a timeout
   128  		go m.timeoutWait(id, p)
   129  	}
   130  }
   131  
   132  func (m *muxBroker) getStream(id uint32) *muxBrokerPending {
   133  	m.Lock()
   134  	defer m.Unlock()
   135  
   136  	p, ok := m.streams[id]
   137  	if ok {
   138  		return p
   139  	}
   140  
   141  	m.streams[id] = &muxBrokerPending{
   142  		ch:     make(chan net.Conn, 1),
   143  		doneCh: make(chan struct{}),
   144  	}
   145  	return m.streams[id]
   146  }
   147  
   148  func (m *muxBroker) timeoutWait(id uint32, p *muxBrokerPending) {
   149  	// Wait for the stream to either be picked up and connected, or
   150  	// for a timeout.
   151  	timeout := false
   152  	select {
   153  	case <-p.doneCh:
   154  	case <-time.After(5 * time.Second):
   155  		timeout = true
   156  	}
   157  
   158  	m.Lock()
   159  	defer m.Unlock()
   160  
   161  	// Delete the stream so no one else can grab it
   162  	delete(m.streams, id)
   163  
   164  	// If we timed out, then check if we have a channel in the buffer,
   165  	// and if so, close it.
   166  	if timeout {
   167  		select {
   168  		case s := <-p.ch:
   169  			s.Close()
   170  		}
   171  	}
   172  }