github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/proxy/transport/transportv1/client.go (about) 1 // Copyright 2023 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package transportv1 16 17 import ( 18 "context" 19 "net" 20 "sync" 21 22 "github.com/gravitational/trace" 23 "golang.org/x/crypto/ssh/agent" 24 "google.golang.org/grpc/peer" 25 26 transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1" 27 streamutils "github.com/gravitational/teleport/api/utils/grpc/stream" 28 ) 29 30 // Client is a wrapper around a [transportv1.TransportServiceClient] that 31 // hides the implementation details of establishing connections 32 // over gRPC streams. 33 type Client struct { 34 clt transportv1pb.TransportServiceClient 35 } 36 37 // NewClient constructs a Client that operates on the provided 38 // [transportv1pb.TransportServiceClient]. An error is returned if the client 39 // provided is nil. 40 func NewClient(client transportv1pb.TransportServiceClient) (*Client, error) { 41 if client == nil { 42 return nil, trace.BadParameter("parameter client required") 43 } 44 45 return &Client{clt: client}, nil 46 } 47 48 // ClusterDetails retrieves the cluster details as observed by the Teleport Proxy 49 // that the Client is connected to. 50 func (c *Client) ClusterDetails(ctx context.Context) (*transportv1pb.ClusterDetails, error) { 51 resp, err := c.clt.GetClusterDetails(ctx, &transportv1pb.GetClusterDetailsRequest{}) 52 if err != nil { 53 return nil, trace.Wrap(err) 54 } 55 56 return resp.Details, nil 57 } 58 59 // DialCluster establishes a connection to the provided cluster. The provided 60 // src address will be used as the LocalAddr of the returned [net.Conn]. 61 func (c *Client) DialCluster(ctx context.Context, cluster string, src net.Addr) (net.Conn, error) { 62 // we do this rather than using context.Background to inherit any OTEL data 63 // from the dial context 64 connCtx, cancel := context.WithCancel(context.WithoutCancel(ctx)) 65 stop := context.AfterFunc(ctx, cancel) 66 defer stop() 67 68 stream, err := c.clt.ProxyCluster(connCtx) 69 if err != nil { 70 cancel() 71 return nil, trace.Wrap(err, "unable to establish proxy stream") 72 } 73 74 if err := stream.Send(&transportv1pb.ProxyClusterRequest{Cluster: cluster}); err != nil { 75 cancel() 76 return nil, trace.Wrap(err, "failed to send cluster request") 77 } 78 79 if !stop() { 80 cancel() 81 return nil, trace.Wrap(connCtx.Err(), "unable to establish proxy stream") 82 } 83 84 streamRW, err := streamutils.NewReadWriter(clusterStream{stream: stream, cancel: cancel}) 85 if err != nil { 86 cancel() 87 return nil, trace.Wrap(err, "unable to create stream reader") 88 } 89 90 p, ok := peer.FromContext(stream.Context()) 91 if !ok { 92 streamRW.Close() 93 return nil, trace.BadParameter("unable to retrieve peer information") 94 } 95 96 return streamutils.NewConn(streamRW, src, p.Addr), nil 97 } 98 99 // clusterStream implements the [streamutils.Source] interface 100 // for a [transportv1pb.TransportService_ProxyClusterClient]. 101 type clusterStream struct { 102 stream transportv1pb.TransportService_ProxyClusterClient 103 cancel context.CancelFunc 104 } 105 106 func (c clusterStream) Recv() ([]byte, error) { 107 req, err := c.stream.Recv() 108 if err != nil { 109 return nil, trace.Wrap(err) 110 } 111 112 if req.Frame == nil { 113 return nil, trace.BadParameter("received invalid frame") 114 } 115 116 return req.Frame.Payload, nil 117 } 118 119 func (c clusterStream) Send(frame []byte) error { 120 return trace.Wrap(c.stream.Send(&transportv1pb.ProxyClusterRequest{Frame: &transportv1pb.Frame{Payload: frame}})) 121 } 122 123 func (c clusterStream) Close() error { 124 if c.cancel != nil { 125 c.cancel() 126 } 127 return nil 128 } 129 130 // DialHost establishes a connection to the instance in the provided cluster that matches 131 // the hostport. If a keyring is provided then it will be forwarded to the remote instance. 132 // The src address will be used as the LocalAddr of the returned [net.Conn]. 133 func (c *Client) DialHost(ctx context.Context, hostport, cluster string, src net.Addr, keyring agent.ExtendedAgent) (net.Conn, *transportv1pb.ClusterDetails, error) { 134 ctx, cancel := context.WithCancel(ctx) 135 stream, err := c.clt.ProxySSH(ctx) 136 if err != nil { 137 cancel() 138 return nil, nil, trace.Wrap(err, "unable to establish proxy stream") 139 } 140 141 if err := stream.Send(&transportv1pb.ProxySSHRequest{DialTarget: &transportv1pb.TargetHost{ 142 HostPort: hostport, 143 Cluster: cluster, 144 }}); err != nil { 145 cancel() 146 return nil, nil, trace.Wrap(err, "failed to send dial target request") 147 } 148 149 resp, err := stream.Recv() 150 if err != nil { 151 cancel() 152 return nil, nil, trace.Wrap(err, "failed to receive cluster details response") 153 } 154 155 // create streams for ssh and agent protocol 156 sshStream, agentStream := newSSHStreams(stream, cancel) 157 158 // create a reader writer for agent protocol 159 agentRW, err := streamutils.NewReadWriter(agentStream) 160 if err != nil { 161 return nil, nil, trace.Wrap(err) 162 } 163 164 // create a reader writer for SSH protocol 165 sshRW, err := streamutils.NewReadWriter(sshStream) 166 if err != nil { 167 return nil, nil, trace.Wrap(err) 168 } 169 170 sshConn := streamutils.NewConn(sshRW, src, addr(hostport)) 171 172 // multiplex the frames to the correct handlers 173 var serveOnce sync.Once 174 go func() { 175 defer func() { 176 // closing the agentRW will terminate the agent.ServeAgent goroutine 177 agentRW.Close() 178 // closing the connection will close sshRW and end the connection for 179 // the user 180 sshConn.Close() 181 }() 182 183 for { 184 req, err := stream.Recv() 185 if err != nil { 186 sshStream.errorC <- trace.Wrap(err) 187 agentStream.errorC <- trace.Wrap(err) 188 return 189 } 190 191 switch frame := req.Frame.(type) { 192 case *transportv1pb.ProxySSHResponse_Ssh: 193 sshStream.incomingC <- frame.Ssh.Payload 194 case *transportv1pb.ProxySSHResponse_Agent: 195 if keyring == nil { 196 continue 197 } 198 199 // start serving the agent only if the upstream 200 // service attempts to interact with it 201 serveOnce.Do(func() { 202 go agent.ServeAgent(keyring, agentRW) 203 }) 204 205 agentStream.incomingC <- frame.Agent.Payload 206 default: 207 continue 208 } 209 } 210 }() 211 212 return sshConn, resp.Details, nil 213 } 214 215 type addr string 216 217 func (a addr) Network() string { 218 return "tcp" 219 } 220 221 func (a addr) String() string { 222 return string(a) 223 } 224 225 // sshStream implements the [streamutils.Source] interface 226 // for a [transportv1pb.TransportService_ProxySSHClient]. Instead of 227 // reading directly from the stream reads are from an incoming 228 // channel that is fed by the multiplexer. 229 type sshStream struct { 230 incomingC chan []byte 231 errorC chan error 232 requestFn func(payload []byte) *transportv1pb.ProxySSHRequest 233 closedC chan struct{} 234 wLock *sync.Mutex 235 stream transportv1pb.TransportService_ProxySSHClient 236 cancel context.CancelFunc 237 } 238 239 func newSSHStreams(stream transportv1pb.TransportService_ProxySSHClient, cancel context.CancelFunc) (ssh *sshStream, agent *sshStream) { 240 wLock := &sync.Mutex{} 241 closedC := make(chan struct{}) 242 243 ssh = &sshStream{ 244 incomingC: make(chan []byte, 10), 245 errorC: make(chan error, 1), 246 stream: stream, 247 requestFn: func(payload []byte) *transportv1pb.ProxySSHRequest { 248 return &transportv1pb.ProxySSHRequest{Frame: &transportv1pb.ProxySSHRequest_Ssh{Ssh: &transportv1pb.Frame{Payload: payload}}} 249 }, 250 wLock: wLock, 251 closedC: closedC, 252 cancel: cancel, 253 } 254 255 agent = &sshStream{ 256 incomingC: make(chan []byte, 10), 257 errorC: make(chan error, 1), 258 stream: stream, 259 requestFn: func(payload []byte) *transportv1pb.ProxySSHRequest { 260 return &transportv1pb.ProxySSHRequest{Frame: &transportv1pb.ProxySSHRequest_Agent{Agent: &transportv1pb.Frame{Payload: payload}}} 261 }, 262 wLock: wLock, 263 closedC: closedC, 264 cancel: cancel, 265 } 266 267 return ssh, agent 268 } 269 270 func (s *sshStream) Recv() ([]byte, error) { 271 select { 272 case err := <-s.errorC: 273 return nil, trace.Wrap(err) 274 case frame := <-s.incomingC: 275 return frame, nil 276 } 277 } 278 279 func (s *sshStream) Send(frame []byte) error { 280 // grab lock to prevent any other sends from occurring 281 s.wLock.Lock() 282 defer s.wLock.Unlock() 283 284 // only Send if the stream hasn't already been closed 285 select { 286 case <-s.closedC: 287 return nil 288 default: 289 return trace.Wrap(s.stream.Send(s.requestFn(frame))) 290 } 291 } 292 293 func (s *sshStream) Close() error { 294 s.cancel() 295 // grab lock to prevent any sends from occurring 296 s.wLock.Lock() 297 defer s.wLock.Unlock() 298 299 // only CloseSend if the stream hasn't already been closed 300 select { 301 case <-s.closedC: 302 return nil 303 default: 304 close(s.closedC) 305 return trace.Wrap(s.stream.CloseSend()) 306 } 307 }