github.com/hashicorp/go-plugin@v1.6.0/internal/grpcmux/grpc_server_muxer.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package grpcmux
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/hashicorp/go-hclog"
    14  	"github.com/hashicorp/yamux"
    15  )
    16  
    17  var _ GRPCMuxer = (*GRPCServerMuxer)(nil)
    18  var _ net.Listener = (*GRPCServerMuxer)(nil)
    19  
    20  // GRPCServerMuxer implements the server (plugin) side of the gRPC broker's
    21  // GRPCMuxer interface for multiplexing multiple gRPC broker connections over
    22  // a single net.Conn.
    23  //
    24  // The server side needs a listener to serve the gRPC broker's control services,
    25  // which includes the service we will receive knocks on. That means we always
    26  // accept the first connection onto a "default" main listener, and if we accept
    27  // any further connections without receiving a knock first, they are also given
    28  // to the default listener.
    29  //
    30  // When creating additional multiplexed listeners for specific stream IDs, we
    31  // can't control the order in which gRPC servers will call Accept() on each
    32  // listener, but we do need to control which gRPC server accepts which connection.
    33  // As such, each multiplexed listener blocks waiting on a channel. It will be
    34  // unblocked when a knock is received for the matching stream ID.
    35  type GRPCServerMuxer struct {
    36  	addr   net.Addr
    37  	logger hclog.Logger
    38  
    39  	sessionErrCh chan error
    40  	sess         *yamux.Session
    41  
    42  	knockCh chan uint32
    43  
    44  	acceptMutex    sync.Mutex
    45  	acceptChannels map[uint32]chan acceptResult
    46  }
    47  
    48  func NewGRPCServerMuxer(logger hclog.Logger, ln net.Listener) *GRPCServerMuxer {
    49  	m := &GRPCServerMuxer{
    50  		addr:   ln.Addr(),
    51  		logger: logger,
    52  
    53  		sessionErrCh: make(chan error),
    54  
    55  		knockCh:        make(chan uint32, 1),
    56  		acceptChannels: make(map[uint32]chan acceptResult),
    57  	}
    58  
    59  	go m.acceptSession(ln)
    60  
    61  	return m
    62  }
    63  
    64  // acceptSessionAndMuxAccept is responsible for establishing the yamux session,
    65  // and then kicking off the acceptLoop function.
    66  func (m *GRPCServerMuxer) acceptSession(ln net.Listener) {
    67  	defer close(m.sessionErrCh)
    68  
    69  	m.logger.Debug("accepting initial connection", "addr", m.addr)
    70  	conn, err := ln.Accept()
    71  	if err != nil {
    72  		m.sessionErrCh <- err
    73  		return
    74  	}
    75  
    76  	m.logger.Debug("initial server connection accepted", "addr", m.addr)
    77  	cfg := yamux.DefaultConfig()
    78  	cfg.Logger = m.logger.Named("yamux").StandardLogger(&hclog.StandardLoggerOptions{
    79  		InferLevels: true,
    80  	})
    81  	cfg.LogOutput = nil
    82  	m.sess, err = yamux.Server(conn, cfg)
    83  	if err != nil {
    84  		m.sessionErrCh <- err
    85  		return
    86  	}
    87  }
    88  
    89  func (m *GRPCServerMuxer) session() (*yamux.Session, error) {
    90  	select {
    91  	case err := <-m.sessionErrCh:
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  	case <-time.After(5 * time.Second):
    96  		return nil, errors.New("timed out waiting for connection to be established")
    97  	}
    98  
    99  	// Should never happen.
   100  	if m.sess == nil {
   101  		return nil, errors.New("no connection established and no error received")
   102  	}
   103  
   104  	return m.sess, nil
   105  }
   106  
   107  // Accept accepts all incoming connections and routes them to the correct
   108  // stream ID based on the most recent knock received.
   109  func (m *GRPCServerMuxer) Accept() (net.Conn, error) {
   110  	session, err := m.session()
   111  	if err != nil {
   112  		return nil, fmt.Errorf("error establishing yamux session: %w", err)
   113  	}
   114  
   115  	for {
   116  		conn, acceptErr := session.Accept()
   117  
   118  		select {
   119  		case id := <-m.knockCh:
   120  			m.acceptMutex.Lock()
   121  			acceptCh, ok := m.acceptChannels[id]
   122  			m.acceptMutex.Unlock()
   123  
   124  			if !ok {
   125  				if conn != nil {
   126  					_ = conn.Close()
   127  				}
   128  				return nil, fmt.Errorf("received knock on ID %d that doesn't have a listener", id)
   129  			}
   130  			m.logger.Debug("sending conn to brokered listener", "id", id)
   131  			acceptCh <- acceptResult{
   132  				conn: conn,
   133  				err:  acceptErr,
   134  			}
   135  		default:
   136  			m.logger.Debug("sending conn to default listener")
   137  			return conn, acceptErr
   138  		}
   139  	}
   140  }
   141  
   142  func (m *GRPCServerMuxer) Addr() net.Addr {
   143  	return m.addr
   144  }
   145  
   146  func (m *GRPCServerMuxer) Close() error {
   147  	session, err := m.session()
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	return session.Close()
   153  }
   154  
   155  func (m *GRPCServerMuxer) Enabled() bool {
   156  	return m != nil
   157  }
   158  
   159  func (m *GRPCServerMuxer) Listener(id uint32, doneCh <-chan struct{}) (net.Listener, error) {
   160  	sess, err := m.session()
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	ln := newBlockedServerListener(sess.Addr(), doneCh)
   166  	m.acceptMutex.Lock()
   167  	m.acceptChannels[id] = ln.acceptCh
   168  	m.acceptMutex.Unlock()
   169  
   170  	return ln, nil
   171  }
   172  
   173  func (m *GRPCServerMuxer) Dial() (net.Conn, error) {
   174  	sess, err := m.session()
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	stream, err := sess.OpenStream()
   180  	if err != nil {
   181  		return nil, fmt.Errorf("error dialling new server stream: %w", err)
   182  	}
   183  
   184  	return stream, nil
   185  }
   186  
   187  func (m *GRPCServerMuxer) AcceptKnock(id uint32) error {
   188  	m.knockCh <- id
   189  	return nil
   190  }