github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/proxy/client_test.go (about)

     1  // Copyright 2023 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"crypto/x509"
    21  	"crypto/x509/pkix"
    22  	"encoding/asn1"
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"slices"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  	"github.com/google/go-cmp/cmp/cmpopts"
    32  	"github.com/gravitational/trace"
    33  	"github.com/gravitational/trace/trail"
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  	"golang.org/x/crypto/ssh"
    37  	"golang.org/x/crypto/ssh/agent"
    38  	"google.golang.org/grpc"
    39  	"google.golang.org/grpc/credentials"
    40  	"google.golang.org/grpc/credentials/insecure"
    41  	"google.golang.org/grpc/test/bufconn"
    42  	"google.golang.org/protobuf/testing/protocmp"
    43  
    44  	"github.com/gravitational/teleport/api/client"
    45  	"github.com/gravitational/teleport/api/client/proto"
    46  	transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1"
    47  	"github.com/gravitational/teleport/api/utils/grpc/stream"
    48  )
    49  
    50  type fakeGetClusterDetails func(context.Context, *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error)
    51  
    52  type fakeProxySSHServer func(transportv1pb.TransportService_ProxySSHServer) error
    53  
    54  type fakeProxyClusterServer func(transportv1pb.TransportService_ProxyClusterServer) error
    55  
    56  // fakeTransportService is a [transportv1pb.TransportServiceServer] implementation
    57  // that allows tests to manipulate the server side of various RPCs.
    58  type fakeTransportService struct {
    59  	transportv1pb.UnimplementedTransportServiceServer
    60  
    61  	details fakeGetClusterDetails
    62  	ssh     fakeProxySSHServer
    63  	cluster fakeProxyClusterServer
    64  }
    65  
    66  func (s fakeTransportService) GetClusterDetails(ctx context.Context, req *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
    67  	if s.details == nil {
    68  		return s.UnimplementedTransportServiceServer.GetClusterDetails(ctx, req)
    69  	}
    70  	return s.details(ctx, req)
    71  }
    72  
    73  func (s fakeTransportService) ProxySSH(stream transportv1pb.TransportService_ProxySSHServer) error {
    74  	if s.ssh == nil {
    75  		return s.UnimplementedTransportServiceServer.ProxySSH(stream)
    76  	}
    77  	return s.ssh(stream)
    78  }
    79  
    80  func (s fakeTransportService) ProxyCluster(stream transportv1pb.TransportService_ProxyClusterServer) error {
    81  	if s.cluster == nil {
    82  		return s.UnimplementedTransportServiceServer.ProxyCluster(stream)
    83  	}
    84  	return s.cluster(stream)
    85  }
    86  
    87  // newGRPCServer creates a [grpc.Server] and registers the
    88  // provided [transportv1pb.TransportServiceServer].
    89  func newGRPCServer(t *testing.T, srv transportv1pb.TransportServiceServer) *fakeGRPCServer {
    90  	// gRPC testPack.
    91  	lis := bufconn.Listen(100)
    92  	t.Cleanup(func() { require.NoError(t, lis.Close()) })
    93  
    94  	s := grpc.NewServer()
    95  	t.Cleanup(s.Stop)
    96  
    97  	// Register service.
    98  	if srv != nil {
    99  		transportv1pb.RegisterTransportServiceServer(s, srv)
   100  	}
   101  
   102  	// Start.
   103  	go func() {
   104  		if err := s.Serve(lis); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
   105  			panic(fmt.Sprintf("Serve returned err = %v", err))
   106  		}
   107  	}()
   108  
   109  	return &fakeGRPCServer{Listener: lis}
   110  }
   111  
   112  type fakeGRPCServer struct {
   113  	*bufconn.Listener
   114  }
   115  
   116  type fakeAuthServer struct {
   117  	*proto.UnimplementedAuthServiceServer
   118  	listener net.Listener
   119  	srv      *grpc.Server
   120  }
   121  
   122  func newFakeAuthServer(t *testing.T, conn net.Conn) *fakeAuthServer {
   123  	f := &fakeAuthServer{
   124  		listener:                       newOneShotListener(conn),
   125  		UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{},
   126  		srv:                            grpc.NewServer(),
   127  	}
   128  
   129  	t.Cleanup(f.Stop)
   130  	proto.RegisterAuthServiceServer(f.srv, f)
   131  	return f
   132  }
   133  
   134  func (f *fakeAuthServer) Ping(context.Context, *proto.PingRequest) (*proto.PingResponse, error) {
   135  	return &proto.PingResponse{
   136  		ClusterName:   "test",
   137  		ServerVersion: "1.0.0",
   138  		IsBoring:      true,
   139  	}, nil
   140  }
   141  
   142  func (f *fakeAuthServer) Serve() error {
   143  	return f.srv.Serve(f.listener)
   144  }
   145  
   146  func (f *fakeAuthServer) Stop() {
   147  	_ = f.listener.Close()
   148  	f.srv.Stop()
   149  }
   150  
   151  type oneShotListener struct {
   152  	conn       net.Conn
   153  	closedCh   chan struct{}
   154  	listenedCh chan struct{}
   155  }
   156  
   157  func newOneShotListener(conn net.Conn) oneShotListener {
   158  	return oneShotListener{
   159  		conn:       conn,
   160  		closedCh:   make(chan struct{}),
   161  		listenedCh: make(chan struct{}),
   162  	}
   163  }
   164  
   165  func (l oneShotListener) Accept() (net.Conn, error) {
   166  	select {
   167  	case <-l.listenedCh:
   168  		<-l.closedCh
   169  		return nil, net.ErrClosed
   170  	default:
   171  		close(l.listenedCh)
   172  		return l.conn, nil
   173  	}
   174  }
   175  
   176  func (l oneShotListener) Close() error {
   177  	select {
   178  	case <-l.closedCh:
   179  	default:
   180  		close(l.closedCh)
   181  	}
   182  
   183  	return nil
   184  }
   185  
   186  func (l oneShotListener) Addr() net.Addr {
   187  	return addr("127.0.0.1")
   188  }
   189  
   190  // addr is a [net.Addr] implementation for static tcp addresses.
   191  type addr string
   192  
   193  func (a addr) Network() string {
   194  	return "tcp"
   195  }
   196  
   197  func (a addr) String() string {
   198  	return string(a)
   199  }
   200  
   201  type fakeProxy struct {
   202  	*fakeGRPCServer
   203  }
   204  
   205  func newFakeProxy(t *testing.T, transportService transportv1pb.TransportServiceServer) *fakeProxy {
   206  	grpcSrv := newGRPCServer(t, transportService)
   207  
   208  	return &fakeProxy{
   209  		fakeGRPCServer: grpcSrv,
   210  	}
   211  }
   212  
   213  func (f *fakeProxy) clientConfig(t *testing.T) ClientConfig {
   214  	return ClientConfig{
   215  		ProxyAddress: "127.0.0.1",
   216  		SSHConfig:    &ssh.ClientConfig{},
   217  		DialOpts: []grpc.DialOption{grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
   218  			return f.fakeGRPCServer.DialContext(ctx)
   219  		})},
   220  	}
   221  }
   222  
   223  func TestNewClient(t *testing.T) {
   224  	t.Parallel()
   225  
   226  	ctx := context.Background()
   227  	tests := []struct {
   228  		name      string
   229  		srv       transportv1pb.TransportServiceServer
   230  		assertion func(t *testing.T, clt *Client, err error)
   231  	}{
   232  		{
   233  			name: "does not implement transport",
   234  			assertion: func(t *testing.T, clt *Client, err error) {
   235  				require.NoError(t, err)
   236  				require.NotNil(t, clt)
   237  
   238  				details, err := clt.transport.ClusterDetails(context.Background())
   239  				require.Error(t, err)
   240  				require.Nil(t, details)
   241  			},
   242  		},
   243  		{
   244  			name: "compliant grpc server",
   245  			srv: fakeTransportService{
   246  				details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
   247  					return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil
   248  				},
   249  			},
   250  			assertion: func(t *testing.T, clt *Client, err error) {
   251  				require.NoError(t, err)
   252  				require.NotNil(t, clt)
   253  			},
   254  		},
   255  	}
   256  
   257  	for _, test := range tests {
   258  		t.Run(test.name, func(t *testing.T) {
   259  			proxy := newFakeProxy(t, test.srv)
   260  			cfg := proxy.clientConfig(t)
   261  
   262  			clt, err := NewClient(ctx, cfg)
   263  			if clt != nil {
   264  				t.Cleanup(func() { require.NoError(t, clt.Close()) })
   265  			}
   266  			test.assertion(t, clt, err)
   267  		})
   268  	}
   269  }
   270  
   271  func TestClient_ClusterDetails(t *testing.T) {
   272  	t.Parallel()
   273  	ctx := context.Background()
   274  
   275  	tests := []struct {
   276  		name      string
   277  		srv       transportv1pb.TransportServiceServer
   278  		assertion func(t *testing.T, details ClusterDetails, err error)
   279  	}{
   280  		{
   281  			name: "cluster details",
   282  			srv: fakeTransportService{
   283  				details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
   284  					return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: false}}, nil
   285  				},
   286  			},
   287  			assertion: func(t *testing.T, details ClusterDetails, err error) {
   288  				require.NoError(t, err)
   289  				require.False(t, details.FIPS)
   290  			},
   291  		},
   292  		{
   293  			name: "cluster details fails",
   294  			srv: fakeTransportService{
   295  				details: func() func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
   296  					return func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
   297  						return nil, trace.ConnectionProblem(nil, "connection closed")
   298  					}
   299  				}(),
   300  			},
   301  			assertion: func(t *testing.T, details ClusterDetails, err error) {
   302  				require.Error(t, err)
   303  			},
   304  		},
   305  	}
   306  
   307  	for _, test := range tests {
   308  		t.Run(test.name, func(t *testing.T) {
   309  			proxy := newFakeProxy(t, test.srv)
   310  			cfg := proxy.clientConfig(t)
   311  
   312  			cfg.DialOpts = append(cfg.DialOpts, grpc.WithDisableRetry())
   313  
   314  			clt, err := NewClient(ctx, cfg)
   315  			require.NoError(t, err)
   316  			t.Cleanup(func() { require.NoError(t, clt.Close()) })
   317  
   318  			details, err := clt.ClusterDetails(ctx)
   319  			test.assertion(t, details, err)
   320  		})
   321  	}
   322  }
   323  
   324  func TestClient_DialHost(t *testing.T) {
   325  	t.Parallel()
   326  	ctx := context.Background()
   327  
   328  	tests := []struct {
   329  		name      string
   330  		srv       transportv1pb.TransportServiceServer
   331  		keyring   agent.ExtendedAgent
   332  		assertion func(t *testing.T, conn net.Conn, details ClusterDetails, err error)
   333  	}{
   334  		{
   335  			name: "grpc connection fails",
   336  			srv: fakeTransportService{
   337  				details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
   338  					return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil
   339  				},
   340  				ssh: func(server transportv1pb.TransportService_ProxySSHServer) error {
   341  					_, err := server.Recv()
   342  					if err != nil {
   343  						return trail.ToGRPC(trace.Wrap(err))
   344  					}
   345  
   346  					return trail.ToGRPC(trace.ConnectionProblem(nil, "connection closed"))
   347  				},
   348  			},
   349  			assertion: func(t *testing.T, conn net.Conn, details ClusterDetails, err error) {
   350  				require.ErrorIs(t, err, trace.ConnectionProblem(nil, "connection closed"))
   351  				require.Nil(t, conn)
   352  				require.False(t, details.FIPS)
   353  			},
   354  		},
   355  		{
   356  			name: "grpc connection established",
   357  			srv: fakeTransportService{
   358  				details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
   359  					return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil
   360  				},
   361  				ssh: func(server transportv1pb.TransportService_ProxySSHServer) error {
   362  					_, err := server.Recv()
   363  					if err != nil {
   364  						return trail.ToGRPC(trace.Wrap(err))
   365  					}
   366  
   367  					if err := server.Send(&transportv1pb.ProxySSHResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}); err != nil {
   368  						return trail.ToGRPC(err)
   369  					}
   370  
   371  					req, err := server.Recv()
   372  					if err != nil {
   373  						return trail.ToGRPC(trace.Wrap(err))
   374  					}
   375  
   376  					switch f := req.Frame.(type) {
   377  					case *transportv1pb.ProxySSHRequest_Ssh:
   378  						if err := server.Send(&transportv1pb.ProxySSHResponse{
   379  							Details: nil,
   380  							Frame:   &transportv1pb.ProxySSHResponse_Ssh{Ssh: &transportv1pb.Frame{Payload: f.Ssh.Payload}},
   381  						}); err != nil {
   382  							return trail.ToGRPC(trace.Wrap(err))
   383  						}
   384  					default:
   385  						return trace.BadParameter("unexpected frame type received")
   386  					}
   387  
   388  					return nil
   389  				},
   390  			},
   391  			assertion: func(t *testing.T, conn net.Conn, details ClusterDetails, err error) {
   392  				require.NoError(t, err)
   393  				require.NotNil(t, conn)
   394  				require.True(t, details.FIPS)
   395  
   396  				// test that the server echos data back over the connection
   397  				msg := []byte("hello123")
   398  				n, err := conn.Write(msg)
   399  				require.NoError(t, err)
   400  				require.Len(t, msg, n)
   401  
   402  				out := make([]byte, len(msg))
   403  				n, err = conn.Read(out)
   404  				require.NoError(t, err)
   405  				require.Len(t, msg, n)
   406  				require.Equal(t, msg, out)
   407  
   408  				require.NoError(t, conn.Close())
   409  			},
   410  		},
   411  	}
   412  
   413  	for _, test := range tests {
   414  		t.Run(test.name, func(t *testing.T) {
   415  			proxy := newFakeProxy(t, test.srv)
   416  			cfg := proxy.clientConfig(t)
   417  
   418  			clt, err := NewClient(ctx, cfg)
   419  			require.NoError(t, err)
   420  			t.Cleanup(func() { require.NoError(t, clt.Close()) })
   421  
   422  			conn, details, err := clt.DialHost(ctx, "test", "cluster", test.keyring)
   423  			test.assertion(t, conn, details, err)
   424  		})
   425  	}
   426  }
   427  
   428  func TestClient_DialCluster(t *testing.T) {
   429  	t.Parallel()
   430  	ctx := context.Background()
   431  
   432  	tests := []struct {
   433  		name      string
   434  		authCfg   func(config *client.Config)
   435  		srv       transportv1pb.TransportServiceServer
   436  		keyring   agent.ExtendedAgent
   437  		assertion func(t *testing.T, clt *client.Client, err error)
   438  	}{
   439  		{
   440  			name: "grpc connection fails",
   441  			authCfg: func(config *client.Config) {
   442  				config.DialTimeout = 500 * time.Millisecond // speed up dial failure
   443  			},
   444  			srv: fakeTransportService{
   445  				details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
   446  					return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil
   447  				},
   448  				cluster: func(server transportv1pb.TransportService_ProxyClusterServer) error {
   449  					_, err := server.Recv()
   450  					if err != nil {
   451  						return trace.Wrap(err)
   452  					}
   453  
   454  					return trace.ConnectionProblem(nil, "connection closed")
   455  				},
   456  			},
   457  			assertion: func(t *testing.T, clt *client.Client, err error) {
   458  				require.Error(t, err)
   459  				require.Nil(t, clt)
   460  			},
   461  		},
   462  		{
   463  			name:    "grpc connection established",
   464  			authCfg: func(config *client.Config) {},
   465  			srv: fakeTransportService{
   466  				details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) {
   467  					return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil
   468  				},
   469  				cluster: func(server transportv1pb.TransportService_ProxyClusterServer) error {
   470  					_, err := server.Recv()
   471  					if err != nil {
   472  						return trace.Wrap(err)
   473  					}
   474  
   475  					rw, err := stream.NewReadWriter(clusterStream{stream: server})
   476  					if err != nil {
   477  						return trace.Wrap(err)
   478  					}
   479  
   480  					auth := newFakeAuthServer(t, stream.NewConn(rw, nil, nil))
   481  					err = auth.Serve()
   482  					return trace.Wrap(err)
   483  				},
   484  			},
   485  			assertion: func(t *testing.T, clt *client.Client, err error) {
   486  				require.NoError(t, err)
   487  				require.NotNil(t, clt)
   488  
   489  				expected := &proto.PingResponse{
   490  					ClusterName:   "test",
   491  					ServerVersion: "1.0.0",
   492  					IsBoring:      true,
   493  				}
   494  
   495  				resp, err := clt.Ping(ctx)
   496  				require.NoError(t, err)
   497  				require.Empty(t, cmp.Diff(expected, resp, protocmp.Transform()))
   498  			},
   499  		},
   500  	}
   501  
   502  	for _, test := range tests {
   503  		t.Run(test.name, func(t *testing.T) {
   504  			proxy := newFakeProxy(t, test.srv)
   505  			cfg := proxy.clientConfig(t)
   506  
   507  			clt, err := NewClient(ctx, cfg)
   508  			require.NoError(t, err)
   509  			t.Cleanup(func() { require.NoError(t, clt.Close()) })
   510  
   511  			authCfg, err := clt.ClientConfig(ctx, "cluster")
   512  			require.NoError(t, err)
   513  
   514  			authCfg.DialOpts = []grpc.DialOption{
   515  				grpc.WithTransportCredentials(insecure.NewCredentials()),
   516  				grpc.WithReturnConnectionError(),
   517  				grpc.WithDisableRetry(),
   518  				grpc.FailOnNonTempDialError(true),
   519  			}
   520  			authCfg.Credentials = []client.Credentials{insecureCredentials{}}
   521  			authCfg.DialTimeout = 3 * time.Second
   522  
   523  			test.authCfg(&authCfg)
   524  
   525  			authClt, err := client.New(ctx, authCfg)
   526  			if authClt != nil {
   527  				t.Cleanup(func() {
   528  					require.NoError(t, authClt.Close())
   529  				})
   530  			}
   531  			test.assertion(t, authClt, err)
   532  		})
   533  	}
   534  }
   535  
   536  // clusterStream implements the [streamutils.Source] interface
   537  // for a [transportv1pb.TransportService_ProxyClusterServer].
   538  type clusterStream struct {
   539  	stream transportv1pb.TransportService_ProxyClusterServer
   540  }
   541  
   542  func (c clusterStream) Recv() ([]byte, error) {
   543  	req, err := c.stream.Recv()
   544  	if err != nil {
   545  		return nil, trace.Wrap(err)
   546  	}
   547  
   548  	if req.Frame == nil {
   549  		return nil, trace.BadParameter("received invalid frame")
   550  	}
   551  
   552  	return req.Frame.Payload, nil
   553  }
   554  
   555  func (c clusterStream) Send(frame []byte) error {
   556  	return trace.Wrap(c.stream.Send(&transportv1pb.ProxyClusterResponse{Frame: &transportv1pb.Frame{Payload: frame}}))
   557  }
   558  
   559  func TestClient_SSHConfig(t *testing.T) {
   560  	t.Parallel()
   561  
   562  	proxy := newFakeProxy(t, fakeTransportService{})
   563  	cfg := proxy.clientConfig(t)
   564  
   565  	clt, err := NewClient(context.Background(), cfg)
   566  	require.NoError(t, err)
   567  	t.Cleanup(func() { require.NoError(t, clt.Close()) })
   568  
   569  	const user = "test-user"
   570  	sshConfig := clt.SSHConfig(user)
   571  	require.NotNil(t, sshConfig)
   572  	require.Equal(t, user, sshConfig.User)
   573  	require.Empty(t, cmp.Diff(cfg.SSHConfig, sshConfig, cmpopts.IgnoreFields(ssh.ClientConfig{}, "User", "Auth", "HostKeyCallback")))
   574  }
   575  
   576  type fakeTransportCredentials struct {
   577  	credentials.TransportCredentials
   578  	info credentials.AuthInfo
   579  	err  error
   580  }
   581  
   582  type fakeAuthInfo struct{}
   583  
   584  func (f fakeAuthInfo) AuthType() string {
   585  	return "test"
   586  }
   587  
   588  func (t fakeTransportCredentials) ClientHandshake(ctx context.Context, addr string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   589  	return conn, t.info, t.err
   590  }
   591  
   592  func TestClusterCredentials(t *testing.T) {
   593  	t.Parallel()
   594  
   595  	cases := []struct {
   596  		name                string
   597  		expectedClusterName string
   598  		credentials         fakeTransportCredentials
   599  		errAssertion        require.ErrorAssertionFunc
   600  	}{
   601  		{
   602  			name:         "handshake error",
   603  			credentials:  fakeTransportCredentials{err: context.Canceled},
   604  			errAssertion: require.Error,
   605  		},
   606  		{
   607  			name:         "no tls auth info",
   608  			credentials:  fakeTransportCredentials{info: fakeAuthInfo{}},
   609  			errAssertion: require.NoError,
   610  		},
   611  		{
   612  			name:         "no server cert",
   613  			credentials:  fakeTransportCredentials{info: credentials.TLSInfo{}},
   614  			errAssertion: require.NoError,
   615  		},
   616  		{
   617  			name: "no cluster oid set",
   618  			credentials: fakeTransportCredentials{info: credentials.TLSInfo{
   619  				State: tls.ConnectionState{
   620  					PeerCertificates: []*x509.Certificate{
   621  						{
   622  							Subject: pkix.Name{
   623  								Names: []pkix.AttributeTypeAndValue{
   624  									{
   625  										Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 1},
   626  									},
   627  									{
   628  										Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 1},
   629  									},
   630  									{
   631  										Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 2},
   632  									},
   633  									{
   634  										Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 2},
   635  									},
   636  								},
   637  							},
   638  						},
   639  					},
   640  				},
   641  			}},
   642  			errAssertion: require.NoError,
   643  		}, {
   644  			name:                "cluster name presented",
   645  			expectedClusterName: "test-cluster",
   646  			credentials: fakeTransportCredentials{info: credentials.TLSInfo{
   647  				State: tls.ConnectionState{
   648  					PeerCertificates: []*x509.Certificate{
   649  						{
   650  							Subject: pkix.Name{
   651  								Names: []pkix.AttributeTypeAndValue{
   652  									{
   653  										Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 1},
   654  									},
   655  									{
   656  										Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 2},
   657  									},
   658  									{
   659  										Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 2},
   660  									},
   661  									{
   662  										Type:  teleportClusterASN1ExtensionOID,
   663  										Value: "test-cluster",
   664  									},
   665  								},
   666  							},
   667  						},
   668  					},
   669  				},
   670  			}},
   671  			errAssertion: require.NoError,
   672  		},
   673  	}
   674  
   675  	for _, test := range cases {
   676  		t.Run(test.name, func(t *testing.T) {
   677  			c := &clusterName{}
   678  			creds := clusterCredentials{TransportCredentials: test.credentials, clusterName: c}
   679  			_, _, err := creds.ClientHandshake(context.Background(), "127.0.0.1", nil)
   680  			test.errAssertion(t, err)
   681  			require.Equal(t, test.expectedClusterName, c.get())
   682  		})
   683  	}
   684  }
   685  
   686  func TestNewDialerForGRPCClient(t *testing.T) {
   687  	t.Run("Check that PROXYHeaderGetter if present sends PROXY header as first bytes on the connection", func(t *testing.T) {
   688  		listener, err := net.Listen("tcp", "127.0.0.1:0")
   689  		require.NoError(t, err)
   690  		t.Cleanup(func() {
   691  			require.NoError(t, listener.Close())
   692  		})
   693  
   694  		prefix := []byte("FAKEPROXY")
   695  		proxyHeaderGetter := func() ([]byte, error) {
   696  			return prefix, nil
   697  		}
   698  
   699  		ctx := context.Background()
   700  		cfg := &ClientConfig{
   701  			PROXYHeaderGetter: proxyHeaderGetter,
   702  		}
   703  		dialer := newDialerForGRPCClient(ctx, cfg)
   704  
   705  		resultChan := make(chan bool)
   706  		// Start listening, emulating receiving end of connection
   707  		go func() {
   708  			conn, err := listener.Accept()
   709  			if err != nil {
   710  				assert.Fail(t, err.Error())
   711  				return
   712  			}
   713  
   714  			buf := make([]byte, len(prefix))
   715  			_, err = conn.Read(buf)
   716  			assert.NoError(t, err)
   717  			t.Cleanup(func() {
   718  				require.NoError(t, conn.Close())
   719  			})
   720  
   721  			// On the received connection first bytes should be our PROXY prefix
   722  			resultChan <- slices.Equal(buf, prefix)
   723  		}()
   724  
   725  		conn, err := dialer(ctx, listener.Addr().String())
   726  		require.NoError(t, err)
   727  		t.Cleanup(func() {
   728  			require.NoError(t, conn.Close())
   729  		})
   730  
   731  		select {
   732  		case res := <-resultChan:
   733  			require.True(t, res, "Didn't receive required prefix as first bytes on the connection")
   734  		case <-time.After(time.Second):
   735  			require.Fail(t, "Timed out waiting for connection")
   736  		}
   737  	})
   738  }