github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/proxy/transport/transportv1/client.go (about)

     1  // Copyright 2023 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transportv1
    16  
    17  import (
    18  	"context"
    19  	"net"
    20  	"sync"
    21  
    22  	"github.com/gravitational/trace"
    23  	"golang.org/x/crypto/ssh/agent"
    24  	"google.golang.org/grpc/peer"
    25  
    26  	transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1"
    27  	streamutils "github.com/gravitational/teleport/api/utils/grpc/stream"
    28  )
    29  
    30  // Client is a wrapper around a [transportv1.TransportServiceClient] that
    31  // hides the implementation details of establishing connections
    32  // over gRPC streams.
    33  type Client struct {
    34  	clt transportv1pb.TransportServiceClient
    35  }
    36  
    37  // NewClient constructs a Client that operates on the provided
    38  // [transportv1pb.TransportServiceClient]. An error is returned if the client
    39  // provided is nil.
    40  func NewClient(client transportv1pb.TransportServiceClient) (*Client, error) {
    41  	if client == nil {
    42  		return nil, trace.BadParameter("parameter client required")
    43  	}
    44  
    45  	return &Client{clt: client}, nil
    46  }
    47  
    48  // ClusterDetails retrieves the cluster details as observed by the Teleport Proxy
    49  // that the Client is connected to.
    50  func (c *Client) ClusterDetails(ctx context.Context) (*transportv1pb.ClusterDetails, error) {
    51  	resp, err := c.clt.GetClusterDetails(ctx, &transportv1pb.GetClusterDetailsRequest{})
    52  	if err != nil {
    53  		return nil, trace.Wrap(err)
    54  	}
    55  
    56  	return resp.Details, nil
    57  }
    58  
    59  // DialCluster establishes a connection to the provided cluster. The provided
    60  // src address will be used as the LocalAddr of the returned [net.Conn].
    61  func (c *Client) DialCluster(ctx context.Context, cluster string, src net.Addr) (net.Conn, error) {
    62  	// we do this rather than using context.Background to inherit any OTEL data
    63  	// from the dial context
    64  	connCtx, cancel := context.WithCancel(context.WithoutCancel(ctx))
    65  	stop := context.AfterFunc(ctx, cancel)
    66  	defer stop()
    67  
    68  	stream, err := c.clt.ProxyCluster(connCtx)
    69  	if err != nil {
    70  		cancel()
    71  		return nil, trace.Wrap(err, "unable to establish proxy stream")
    72  	}
    73  
    74  	if err := stream.Send(&transportv1pb.ProxyClusterRequest{Cluster: cluster}); err != nil {
    75  		cancel()
    76  		return nil, trace.Wrap(err, "failed to send cluster request")
    77  	}
    78  
    79  	if !stop() {
    80  		cancel()
    81  		return nil, trace.Wrap(connCtx.Err(), "unable to establish proxy stream")
    82  	}
    83  
    84  	streamRW, err := streamutils.NewReadWriter(clusterStream{stream: stream, cancel: cancel})
    85  	if err != nil {
    86  		cancel()
    87  		return nil, trace.Wrap(err, "unable to create stream reader")
    88  	}
    89  
    90  	p, ok := peer.FromContext(stream.Context())
    91  	if !ok {
    92  		streamRW.Close()
    93  		return nil, trace.BadParameter("unable to retrieve peer information")
    94  	}
    95  
    96  	return streamutils.NewConn(streamRW, src, p.Addr), nil
    97  }
    98  
    99  // clusterStream implements the [streamutils.Source] interface
   100  // for a [transportv1pb.TransportService_ProxyClusterClient].
   101  type clusterStream struct {
   102  	stream transportv1pb.TransportService_ProxyClusterClient
   103  	cancel context.CancelFunc
   104  }
   105  
   106  func (c clusterStream) Recv() ([]byte, error) {
   107  	req, err := c.stream.Recv()
   108  	if err != nil {
   109  		return nil, trace.Wrap(err)
   110  	}
   111  
   112  	if req.Frame == nil {
   113  		return nil, trace.BadParameter("received invalid frame")
   114  	}
   115  
   116  	return req.Frame.Payload, nil
   117  }
   118  
   119  func (c clusterStream) Send(frame []byte) error {
   120  	return trace.Wrap(c.stream.Send(&transportv1pb.ProxyClusterRequest{Frame: &transportv1pb.Frame{Payload: frame}}))
   121  }
   122  
   123  func (c clusterStream) Close() error {
   124  	if c.cancel != nil {
   125  		c.cancel()
   126  	}
   127  	return nil
   128  }
   129  
   130  // DialHost establishes a connection to the instance in the provided cluster that matches
   131  // the hostport. If a keyring is provided then it will be forwarded to the remote instance.
   132  // The src address will be used as the LocalAddr of the returned [net.Conn].
   133  func (c *Client) DialHost(ctx context.Context, hostport, cluster string, src net.Addr, keyring agent.ExtendedAgent) (net.Conn, *transportv1pb.ClusterDetails, error) {
   134  	ctx, cancel := context.WithCancel(ctx)
   135  	stream, err := c.clt.ProxySSH(ctx)
   136  	if err != nil {
   137  		cancel()
   138  		return nil, nil, trace.Wrap(err, "unable to establish proxy stream")
   139  	}
   140  
   141  	if err := stream.Send(&transportv1pb.ProxySSHRequest{DialTarget: &transportv1pb.TargetHost{
   142  		HostPort: hostport,
   143  		Cluster:  cluster,
   144  	}}); err != nil {
   145  		cancel()
   146  		return nil, nil, trace.Wrap(err, "failed to send dial target request")
   147  	}
   148  
   149  	resp, err := stream.Recv()
   150  	if err != nil {
   151  		cancel()
   152  		return nil, nil, trace.Wrap(err, "failed to receive cluster details response")
   153  	}
   154  
   155  	// create streams for ssh and agent protocol
   156  	sshStream, agentStream := newSSHStreams(stream, cancel)
   157  
   158  	// create a reader writer for agent protocol
   159  	agentRW, err := streamutils.NewReadWriter(agentStream)
   160  	if err != nil {
   161  		return nil, nil, trace.Wrap(err)
   162  	}
   163  
   164  	// create a reader writer for SSH protocol
   165  	sshRW, err := streamutils.NewReadWriter(sshStream)
   166  	if err != nil {
   167  		return nil, nil, trace.Wrap(err)
   168  	}
   169  
   170  	sshConn := streamutils.NewConn(sshRW, src, addr(hostport))
   171  
   172  	// multiplex the frames to the correct handlers
   173  	var serveOnce sync.Once
   174  	go func() {
   175  		defer func() {
   176  			// closing the agentRW will terminate the agent.ServeAgent goroutine
   177  			agentRW.Close()
   178  			// closing the connection will close sshRW and end the connection for
   179  			// the user
   180  			sshConn.Close()
   181  		}()
   182  
   183  		for {
   184  			req, err := stream.Recv()
   185  			if err != nil {
   186  				sshStream.errorC <- trace.Wrap(err)
   187  				agentStream.errorC <- trace.Wrap(err)
   188  				return
   189  			}
   190  
   191  			switch frame := req.Frame.(type) {
   192  			case *transportv1pb.ProxySSHResponse_Ssh:
   193  				sshStream.incomingC <- frame.Ssh.Payload
   194  			case *transportv1pb.ProxySSHResponse_Agent:
   195  				if keyring == nil {
   196  					continue
   197  				}
   198  
   199  				// start serving the agent only if the upstream
   200  				// service attempts to interact with it
   201  				serveOnce.Do(func() {
   202  					go agent.ServeAgent(keyring, agentRW)
   203  				})
   204  
   205  				agentStream.incomingC <- frame.Agent.Payload
   206  			default:
   207  				continue
   208  			}
   209  		}
   210  	}()
   211  
   212  	return sshConn, resp.Details, nil
   213  }
   214  
   215  type addr string
   216  
   217  func (a addr) Network() string {
   218  	return "tcp"
   219  }
   220  
   221  func (a addr) String() string {
   222  	return string(a)
   223  }
   224  
   225  // sshStream implements the [streamutils.Source] interface
   226  // for a [transportv1pb.TransportService_ProxySSHClient]. Instead of
   227  // reading directly from the stream reads are from an incoming
   228  // channel that is fed by the multiplexer.
   229  type sshStream struct {
   230  	incomingC chan []byte
   231  	errorC    chan error
   232  	requestFn func(payload []byte) *transportv1pb.ProxySSHRequest
   233  	closedC   chan struct{}
   234  	wLock     *sync.Mutex
   235  	stream    transportv1pb.TransportService_ProxySSHClient
   236  	cancel    context.CancelFunc
   237  }
   238  
   239  func newSSHStreams(stream transportv1pb.TransportService_ProxySSHClient, cancel context.CancelFunc) (ssh *sshStream, agent *sshStream) {
   240  	wLock := &sync.Mutex{}
   241  	closedC := make(chan struct{})
   242  
   243  	ssh = &sshStream{
   244  		incomingC: make(chan []byte, 10),
   245  		errorC:    make(chan error, 1),
   246  		stream:    stream,
   247  		requestFn: func(payload []byte) *transportv1pb.ProxySSHRequest {
   248  			return &transportv1pb.ProxySSHRequest{Frame: &transportv1pb.ProxySSHRequest_Ssh{Ssh: &transportv1pb.Frame{Payload: payload}}}
   249  		},
   250  		wLock:   wLock,
   251  		closedC: closedC,
   252  		cancel:  cancel,
   253  	}
   254  
   255  	agent = &sshStream{
   256  		incomingC: make(chan []byte, 10),
   257  		errorC:    make(chan error, 1),
   258  		stream:    stream,
   259  		requestFn: func(payload []byte) *transportv1pb.ProxySSHRequest {
   260  			return &transportv1pb.ProxySSHRequest{Frame: &transportv1pb.ProxySSHRequest_Agent{Agent: &transportv1pb.Frame{Payload: payload}}}
   261  		},
   262  		wLock:   wLock,
   263  		closedC: closedC,
   264  		cancel:  cancel,
   265  	}
   266  
   267  	return ssh, agent
   268  }
   269  
   270  func (s *sshStream) Recv() ([]byte, error) {
   271  	select {
   272  	case err := <-s.errorC:
   273  		return nil, trace.Wrap(err)
   274  	case frame := <-s.incomingC:
   275  		return frame, nil
   276  	}
   277  }
   278  
   279  func (s *sshStream) Send(frame []byte) error {
   280  	// grab lock to prevent any other sends from occurring
   281  	s.wLock.Lock()
   282  	defer s.wLock.Unlock()
   283  
   284  	// only Send if the stream hasn't already been closed
   285  	select {
   286  	case <-s.closedC:
   287  		return nil
   288  	default:
   289  		return trace.Wrap(s.stream.Send(s.requestFn(frame)))
   290  	}
   291  }
   292  
   293  func (s *sshStream) Close() error {
   294  	s.cancel()
   295  	// grab lock to prevent any sends from occurring
   296  	s.wLock.Lock()
   297  	defer s.wLock.Unlock()
   298  
   299  	// only CloseSend if the stream hasn't already been closed
   300  	select {
   301  	case <-s.closedC:
   302  		return nil
   303  	default:
   304  		close(s.closedC)
   305  		return trace.Wrap(s.stream.CloseSend())
   306  	}
   307  }