github.com/hashicorp/go-plugin@v1.6.0/internal/grpcmux/grpc_server_muxer.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package grpcmux 5 6 import ( 7 "errors" 8 "fmt" 9 "net" 10 "sync" 11 "time" 12 13 "github.com/hashicorp/go-hclog" 14 "github.com/hashicorp/yamux" 15 ) 16 17 var _ GRPCMuxer = (*GRPCServerMuxer)(nil) 18 var _ net.Listener = (*GRPCServerMuxer)(nil) 19 20 // GRPCServerMuxer implements the server (plugin) side of the gRPC broker's 21 // GRPCMuxer interface for multiplexing multiple gRPC broker connections over 22 // a single net.Conn. 23 // 24 // The server side needs a listener to serve the gRPC broker's control services, 25 // which includes the service we will receive knocks on. That means we always 26 // accept the first connection onto a "default" main listener, and if we accept 27 // any further connections without receiving a knock first, they are also given 28 // to the default listener. 29 // 30 // When creating additional multiplexed listeners for specific stream IDs, we 31 // can't control the order in which gRPC servers will call Accept() on each 32 // listener, but we do need to control which gRPC server accepts which connection. 33 // As such, each multiplexed listener blocks waiting on a channel. It will be 34 // unblocked when a knock is received for the matching stream ID. 35 type GRPCServerMuxer struct { 36 addr net.Addr 37 logger hclog.Logger 38 39 sessionErrCh chan error 40 sess *yamux.Session 41 42 knockCh chan uint32 43 44 acceptMutex sync.Mutex 45 acceptChannels map[uint32]chan acceptResult 46 } 47 48 func NewGRPCServerMuxer(logger hclog.Logger, ln net.Listener) *GRPCServerMuxer { 49 m := &GRPCServerMuxer{ 50 addr: ln.Addr(), 51 logger: logger, 52 53 sessionErrCh: make(chan error), 54 55 knockCh: make(chan uint32, 1), 56 acceptChannels: make(map[uint32]chan acceptResult), 57 } 58 59 go m.acceptSession(ln) 60 61 return m 62 } 63 64 // acceptSessionAndMuxAccept is responsible for establishing the yamux session, 65 // and then kicking off the acceptLoop function. 66 func (m *GRPCServerMuxer) acceptSession(ln net.Listener) { 67 defer close(m.sessionErrCh) 68 69 m.logger.Debug("accepting initial connection", "addr", m.addr) 70 conn, err := ln.Accept() 71 if err != nil { 72 m.sessionErrCh <- err 73 return 74 } 75 76 m.logger.Debug("initial server connection accepted", "addr", m.addr) 77 cfg := yamux.DefaultConfig() 78 cfg.Logger = m.logger.Named("yamux").StandardLogger(&hclog.StandardLoggerOptions{ 79 InferLevels: true, 80 }) 81 cfg.LogOutput = nil 82 m.sess, err = yamux.Server(conn, cfg) 83 if err != nil { 84 m.sessionErrCh <- err 85 return 86 } 87 } 88 89 func (m *GRPCServerMuxer) session() (*yamux.Session, error) { 90 select { 91 case err := <-m.sessionErrCh: 92 if err != nil { 93 return nil, err 94 } 95 case <-time.After(5 * time.Second): 96 return nil, errors.New("timed out waiting for connection to be established") 97 } 98 99 // Should never happen. 100 if m.sess == nil { 101 return nil, errors.New("no connection established and no error received") 102 } 103 104 return m.sess, nil 105 } 106 107 // Accept accepts all incoming connections and routes them to the correct 108 // stream ID based on the most recent knock received. 109 func (m *GRPCServerMuxer) Accept() (net.Conn, error) { 110 session, err := m.session() 111 if err != nil { 112 return nil, fmt.Errorf("error establishing yamux session: %w", err) 113 } 114 115 for { 116 conn, acceptErr := session.Accept() 117 118 select { 119 case id := <-m.knockCh: 120 m.acceptMutex.Lock() 121 acceptCh, ok := m.acceptChannels[id] 122 m.acceptMutex.Unlock() 123 124 if !ok { 125 if conn != nil { 126 _ = conn.Close() 127 } 128 return nil, fmt.Errorf("received knock on ID %d that doesn't have a listener", id) 129 } 130 m.logger.Debug("sending conn to brokered listener", "id", id) 131 acceptCh <- acceptResult{ 132 conn: conn, 133 err: acceptErr, 134 } 135 default: 136 m.logger.Debug("sending conn to default listener") 137 return conn, acceptErr 138 } 139 } 140 } 141 142 func (m *GRPCServerMuxer) Addr() net.Addr { 143 return m.addr 144 } 145 146 func (m *GRPCServerMuxer) Close() error { 147 session, err := m.session() 148 if err != nil { 149 return err 150 } 151 152 return session.Close() 153 } 154 155 func (m *GRPCServerMuxer) Enabled() bool { 156 return m != nil 157 } 158 159 func (m *GRPCServerMuxer) Listener(id uint32, doneCh <-chan struct{}) (net.Listener, error) { 160 sess, err := m.session() 161 if err != nil { 162 return nil, err 163 } 164 165 ln := newBlockedServerListener(sess.Addr(), doneCh) 166 m.acceptMutex.Lock() 167 m.acceptChannels[id] = ln.acceptCh 168 m.acceptMutex.Unlock() 169 170 return ln, nil 171 } 172 173 func (m *GRPCServerMuxer) Dial() (net.Conn, error) { 174 sess, err := m.session() 175 if err != nil { 176 return nil, err 177 } 178 179 stream, err := sess.OpenStream() 180 if err != nil { 181 return nil, fmt.Errorf("error dialling new server stream: %w", err) 182 } 183 184 return stream, nil 185 } 186 187 func (m *GRPCServerMuxer) AcceptKnock(id uint32) error { 188 m.knockCh <- id 189 return nil 190 }