github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/proxy/client_test.go (about) 1 // Copyright 2023 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package proxy 16 17 import ( 18 "context" 19 "crypto/tls" 20 "crypto/x509" 21 "crypto/x509/pkix" 22 "encoding/asn1" 23 "errors" 24 "fmt" 25 "net" 26 "slices" 27 "testing" 28 "time" 29 30 "github.com/google/go-cmp/cmp" 31 "github.com/google/go-cmp/cmp/cmpopts" 32 "github.com/gravitational/trace" 33 "github.com/gravitational/trace/trail" 34 "github.com/stretchr/testify/assert" 35 "github.com/stretchr/testify/require" 36 "golang.org/x/crypto/ssh" 37 "golang.org/x/crypto/ssh/agent" 38 "google.golang.org/grpc" 39 "google.golang.org/grpc/credentials" 40 "google.golang.org/grpc/credentials/insecure" 41 "google.golang.org/grpc/test/bufconn" 42 "google.golang.org/protobuf/testing/protocmp" 43 44 "github.com/gravitational/teleport/api/client" 45 "github.com/gravitational/teleport/api/client/proto" 46 transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1" 47 "github.com/gravitational/teleport/api/utils/grpc/stream" 48 ) 49 50 type fakeGetClusterDetails func(context.Context, *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) 51 52 type fakeProxySSHServer func(transportv1pb.TransportService_ProxySSHServer) error 53 54 type fakeProxyClusterServer func(transportv1pb.TransportService_ProxyClusterServer) error 55 56 // fakeTransportService is a [transportv1pb.TransportServiceServer] implementation 57 // that allows tests to manipulate the server side of various RPCs. 58 type fakeTransportService struct { 59 transportv1pb.UnimplementedTransportServiceServer 60 61 details fakeGetClusterDetails 62 ssh fakeProxySSHServer 63 cluster fakeProxyClusterServer 64 } 65 66 func (s fakeTransportService) GetClusterDetails(ctx context.Context, req *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 67 if s.details == nil { 68 return s.UnimplementedTransportServiceServer.GetClusterDetails(ctx, req) 69 } 70 return s.details(ctx, req) 71 } 72 73 func (s fakeTransportService) ProxySSH(stream transportv1pb.TransportService_ProxySSHServer) error { 74 if s.ssh == nil { 75 return s.UnimplementedTransportServiceServer.ProxySSH(stream) 76 } 77 return s.ssh(stream) 78 } 79 80 func (s fakeTransportService) ProxyCluster(stream transportv1pb.TransportService_ProxyClusterServer) error { 81 if s.cluster == nil { 82 return s.UnimplementedTransportServiceServer.ProxyCluster(stream) 83 } 84 return s.cluster(stream) 85 } 86 87 // newGRPCServer creates a [grpc.Server] and registers the 88 // provided [transportv1pb.TransportServiceServer]. 89 func newGRPCServer(t *testing.T, srv transportv1pb.TransportServiceServer) *fakeGRPCServer { 90 // gRPC testPack. 91 lis := bufconn.Listen(100) 92 t.Cleanup(func() { require.NoError(t, lis.Close()) }) 93 94 s := grpc.NewServer() 95 t.Cleanup(s.Stop) 96 97 // Register service. 98 if srv != nil { 99 transportv1pb.RegisterTransportServiceServer(s, srv) 100 } 101 102 // Start. 103 go func() { 104 if err := s.Serve(lis); err != nil && !errors.Is(err, grpc.ErrServerStopped) { 105 panic(fmt.Sprintf("Serve returned err = %v", err)) 106 } 107 }() 108 109 return &fakeGRPCServer{Listener: lis} 110 } 111 112 type fakeGRPCServer struct { 113 *bufconn.Listener 114 } 115 116 type fakeAuthServer struct { 117 *proto.UnimplementedAuthServiceServer 118 listener net.Listener 119 srv *grpc.Server 120 } 121 122 func newFakeAuthServer(t *testing.T, conn net.Conn) *fakeAuthServer { 123 f := &fakeAuthServer{ 124 listener: newOneShotListener(conn), 125 UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, 126 srv: grpc.NewServer(), 127 } 128 129 t.Cleanup(f.Stop) 130 proto.RegisterAuthServiceServer(f.srv, f) 131 return f 132 } 133 134 func (f *fakeAuthServer) Ping(context.Context, *proto.PingRequest) (*proto.PingResponse, error) { 135 return &proto.PingResponse{ 136 ClusterName: "test", 137 ServerVersion: "1.0.0", 138 IsBoring: true, 139 }, nil 140 } 141 142 func (f *fakeAuthServer) Serve() error { 143 return f.srv.Serve(f.listener) 144 } 145 146 func (f *fakeAuthServer) Stop() { 147 _ = f.listener.Close() 148 f.srv.Stop() 149 } 150 151 type oneShotListener struct { 152 conn net.Conn 153 closedCh chan struct{} 154 listenedCh chan struct{} 155 } 156 157 func newOneShotListener(conn net.Conn) oneShotListener { 158 return oneShotListener{ 159 conn: conn, 160 closedCh: make(chan struct{}), 161 listenedCh: make(chan struct{}), 162 } 163 } 164 165 func (l oneShotListener) Accept() (net.Conn, error) { 166 select { 167 case <-l.listenedCh: 168 <-l.closedCh 169 return nil, net.ErrClosed 170 default: 171 close(l.listenedCh) 172 return l.conn, nil 173 } 174 } 175 176 func (l oneShotListener) Close() error { 177 select { 178 case <-l.closedCh: 179 default: 180 close(l.closedCh) 181 } 182 183 return nil 184 } 185 186 func (l oneShotListener) Addr() net.Addr { 187 return addr("127.0.0.1") 188 } 189 190 // addr is a [net.Addr] implementation for static tcp addresses. 191 type addr string 192 193 func (a addr) Network() string { 194 return "tcp" 195 } 196 197 func (a addr) String() string { 198 return string(a) 199 } 200 201 type fakeProxy struct { 202 *fakeGRPCServer 203 } 204 205 func newFakeProxy(t *testing.T, transportService transportv1pb.TransportServiceServer) *fakeProxy { 206 grpcSrv := newGRPCServer(t, transportService) 207 208 return &fakeProxy{ 209 fakeGRPCServer: grpcSrv, 210 } 211 } 212 213 func (f *fakeProxy) clientConfig(t *testing.T) ClientConfig { 214 return ClientConfig{ 215 ProxyAddress: "127.0.0.1", 216 SSHConfig: &ssh.ClientConfig{}, 217 DialOpts: []grpc.DialOption{grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { 218 return f.fakeGRPCServer.DialContext(ctx) 219 })}, 220 } 221 } 222 223 func TestNewClient(t *testing.T) { 224 t.Parallel() 225 226 ctx := context.Background() 227 tests := []struct { 228 name string 229 srv transportv1pb.TransportServiceServer 230 assertion func(t *testing.T, clt *Client, err error) 231 }{ 232 { 233 name: "does not implement transport", 234 assertion: func(t *testing.T, clt *Client, err error) { 235 require.NoError(t, err) 236 require.NotNil(t, clt) 237 238 details, err := clt.transport.ClusterDetails(context.Background()) 239 require.Error(t, err) 240 require.Nil(t, details) 241 }, 242 }, 243 { 244 name: "compliant grpc server", 245 srv: fakeTransportService{ 246 details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 247 return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil 248 }, 249 }, 250 assertion: func(t *testing.T, clt *Client, err error) { 251 require.NoError(t, err) 252 require.NotNil(t, clt) 253 }, 254 }, 255 } 256 257 for _, test := range tests { 258 t.Run(test.name, func(t *testing.T) { 259 proxy := newFakeProxy(t, test.srv) 260 cfg := proxy.clientConfig(t) 261 262 clt, err := NewClient(ctx, cfg) 263 if clt != nil { 264 t.Cleanup(func() { require.NoError(t, clt.Close()) }) 265 } 266 test.assertion(t, clt, err) 267 }) 268 } 269 } 270 271 func TestClient_ClusterDetails(t *testing.T) { 272 t.Parallel() 273 ctx := context.Background() 274 275 tests := []struct { 276 name string 277 srv transportv1pb.TransportServiceServer 278 assertion func(t *testing.T, details ClusterDetails, err error) 279 }{ 280 { 281 name: "cluster details", 282 srv: fakeTransportService{ 283 details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 284 return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: false}}, nil 285 }, 286 }, 287 assertion: func(t *testing.T, details ClusterDetails, err error) { 288 require.NoError(t, err) 289 require.False(t, details.FIPS) 290 }, 291 }, 292 { 293 name: "cluster details fails", 294 srv: fakeTransportService{ 295 details: func() func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 296 return func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 297 return nil, trace.ConnectionProblem(nil, "connection closed") 298 } 299 }(), 300 }, 301 assertion: func(t *testing.T, details ClusterDetails, err error) { 302 require.Error(t, err) 303 }, 304 }, 305 } 306 307 for _, test := range tests { 308 t.Run(test.name, func(t *testing.T) { 309 proxy := newFakeProxy(t, test.srv) 310 cfg := proxy.clientConfig(t) 311 312 cfg.DialOpts = append(cfg.DialOpts, grpc.WithDisableRetry()) 313 314 clt, err := NewClient(ctx, cfg) 315 require.NoError(t, err) 316 t.Cleanup(func() { require.NoError(t, clt.Close()) }) 317 318 details, err := clt.ClusterDetails(ctx) 319 test.assertion(t, details, err) 320 }) 321 } 322 } 323 324 func TestClient_DialHost(t *testing.T) { 325 t.Parallel() 326 ctx := context.Background() 327 328 tests := []struct { 329 name string 330 srv transportv1pb.TransportServiceServer 331 keyring agent.ExtendedAgent 332 assertion func(t *testing.T, conn net.Conn, details ClusterDetails, err error) 333 }{ 334 { 335 name: "grpc connection fails", 336 srv: fakeTransportService{ 337 details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 338 return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil 339 }, 340 ssh: func(server transportv1pb.TransportService_ProxySSHServer) error { 341 _, err := server.Recv() 342 if err != nil { 343 return trail.ToGRPC(trace.Wrap(err)) 344 } 345 346 return trail.ToGRPC(trace.ConnectionProblem(nil, "connection closed")) 347 }, 348 }, 349 assertion: func(t *testing.T, conn net.Conn, details ClusterDetails, err error) { 350 require.ErrorIs(t, err, trace.ConnectionProblem(nil, "connection closed")) 351 require.Nil(t, conn) 352 require.False(t, details.FIPS) 353 }, 354 }, 355 { 356 name: "grpc connection established", 357 srv: fakeTransportService{ 358 details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 359 return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil 360 }, 361 ssh: func(server transportv1pb.TransportService_ProxySSHServer) error { 362 _, err := server.Recv() 363 if err != nil { 364 return trail.ToGRPC(trace.Wrap(err)) 365 } 366 367 if err := server.Send(&transportv1pb.ProxySSHResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}); err != nil { 368 return trail.ToGRPC(err) 369 } 370 371 req, err := server.Recv() 372 if err != nil { 373 return trail.ToGRPC(trace.Wrap(err)) 374 } 375 376 switch f := req.Frame.(type) { 377 case *transportv1pb.ProxySSHRequest_Ssh: 378 if err := server.Send(&transportv1pb.ProxySSHResponse{ 379 Details: nil, 380 Frame: &transportv1pb.ProxySSHResponse_Ssh{Ssh: &transportv1pb.Frame{Payload: f.Ssh.Payload}}, 381 }); err != nil { 382 return trail.ToGRPC(trace.Wrap(err)) 383 } 384 default: 385 return trace.BadParameter("unexpected frame type received") 386 } 387 388 return nil 389 }, 390 }, 391 assertion: func(t *testing.T, conn net.Conn, details ClusterDetails, err error) { 392 require.NoError(t, err) 393 require.NotNil(t, conn) 394 require.True(t, details.FIPS) 395 396 // test that the server echos data back over the connection 397 msg := []byte("hello123") 398 n, err := conn.Write(msg) 399 require.NoError(t, err) 400 require.Len(t, msg, n) 401 402 out := make([]byte, len(msg)) 403 n, err = conn.Read(out) 404 require.NoError(t, err) 405 require.Len(t, msg, n) 406 require.Equal(t, msg, out) 407 408 require.NoError(t, conn.Close()) 409 }, 410 }, 411 } 412 413 for _, test := range tests { 414 t.Run(test.name, func(t *testing.T) { 415 proxy := newFakeProxy(t, test.srv) 416 cfg := proxy.clientConfig(t) 417 418 clt, err := NewClient(ctx, cfg) 419 require.NoError(t, err) 420 t.Cleanup(func() { require.NoError(t, clt.Close()) }) 421 422 conn, details, err := clt.DialHost(ctx, "test", "cluster", test.keyring) 423 test.assertion(t, conn, details, err) 424 }) 425 } 426 } 427 428 func TestClient_DialCluster(t *testing.T) { 429 t.Parallel() 430 ctx := context.Background() 431 432 tests := []struct { 433 name string 434 authCfg func(config *client.Config) 435 srv transportv1pb.TransportServiceServer 436 keyring agent.ExtendedAgent 437 assertion func(t *testing.T, clt *client.Client, err error) 438 }{ 439 { 440 name: "grpc connection fails", 441 authCfg: func(config *client.Config) { 442 config.DialTimeout = 500 * time.Millisecond // speed up dial failure 443 }, 444 srv: fakeTransportService{ 445 details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 446 return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil 447 }, 448 cluster: func(server transportv1pb.TransportService_ProxyClusterServer) error { 449 _, err := server.Recv() 450 if err != nil { 451 return trace.Wrap(err) 452 } 453 454 return trace.ConnectionProblem(nil, "connection closed") 455 }, 456 }, 457 assertion: func(t *testing.T, clt *client.Client, err error) { 458 require.Error(t, err) 459 require.Nil(t, clt) 460 }, 461 }, 462 { 463 name: "grpc connection established", 464 authCfg: func(config *client.Config) {}, 465 srv: fakeTransportService{ 466 details: func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 467 return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil 468 }, 469 cluster: func(server transportv1pb.TransportService_ProxyClusterServer) error { 470 _, err := server.Recv() 471 if err != nil { 472 return trace.Wrap(err) 473 } 474 475 rw, err := stream.NewReadWriter(clusterStream{stream: server}) 476 if err != nil { 477 return trace.Wrap(err) 478 } 479 480 auth := newFakeAuthServer(t, stream.NewConn(rw, nil, nil)) 481 err = auth.Serve() 482 return trace.Wrap(err) 483 }, 484 }, 485 assertion: func(t *testing.T, clt *client.Client, err error) { 486 require.NoError(t, err) 487 require.NotNil(t, clt) 488 489 expected := &proto.PingResponse{ 490 ClusterName: "test", 491 ServerVersion: "1.0.0", 492 IsBoring: true, 493 } 494 495 resp, err := clt.Ping(ctx) 496 require.NoError(t, err) 497 require.Empty(t, cmp.Diff(expected, resp, protocmp.Transform())) 498 }, 499 }, 500 } 501 502 for _, test := range tests { 503 t.Run(test.name, func(t *testing.T) { 504 proxy := newFakeProxy(t, test.srv) 505 cfg := proxy.clientConfig(t) 506 507 clt, err := NewClient(ctx, cfg) 508 require.NoError(t, err) 509 t.Cleanup(func() { require.NoError(t, clt.Close()) }) 510 511 authCfg, err := clt.ClientConfig(ctx, "cluster") 512 require.NoError(t, err) 513 514 authCfg.DialOpts = []grpc.DialOption{ 515 grpc.WithTransportCredentials(insecure.NewCredentials()), 516 grpc.WithReturnConnectionError(), 517 grpc.WithDisableRetry(), 518 grpc.FailOnNonTempDialError(true), 519 } 520 authCfg.Credentials = []client.Credentials{insecureCredentials{}} 521 authCfg.DialTimeout = 3 * time.Second 522 523 test.authCfg(&authCfg) 524 525 authClt, err := client.New(ctx, authCfg) 526 if authClt != nil { 527 t.Cleanup(func() { 528 require.NoError(t, authClt.Close()) 529 }) 530 } 531 test.assertion(t, authClt, err) 532 }) 533 } 534 } 535 536 // clusterStream implements the [streamutils.Source] interface 537 // for a [transportv1pb.TransportService_ProxyClusterServer]. 538 type clusterStream struct { 539 stream transportv1pb.TransportService_ProxyClusterServer 540 } 541 542 func (c clusterStream) Recv() ([]byte, error) { 543 req, err := c.stream.Recv() 544 if err != nil { 545 return nil, trace.Wrap(err) 546 } 547 548 if req.Frame == nil { 549 return nil, trace.BadParameter("received invalid frame") 550 } 551 552 return req.Frame.Payload, nil 553 } 554 555 func (c clusterStream) Send(frame []byte) error { 556 return trace.Wrap(c.stream.Send(&transportv1pb.ProxyClusterResponse{Frame: &transportv1pb.Frame{Payload: frame}})) 557 } 558 559 func TestClient_SSHConfig(t *testing.T) { 560 t.Parallel() 561 562 proxy := newFakeProxy(t, fakeTransportService{}) 563 cfg := proxy.clientConfig(t) 564 565 clt, err := NewClient(context.Background(), cfg) 566 require.NoError(t, err) 567 t.Cleanup(func() { require.NoError(t, clt.Close()) }) 568 569 const user = "test-user" 570 sshConfig := clt.SSHConfig(user) 571 require.NotNil(t, sshConfig) 572 require.Equal(t, user, sshConfig.User) 573 require.Empty(t, cmp.Diff(cfg.SSHConfig, sshConfig, cmpopts.IgnoreFields(ssh.ClientConfig{}, "User", "Auth", "HostKeyCallback"))) 574 } 575 576 type fakeTransportCredentials struct { 577 credentials.TransportCredentials 578 info credentials.AuthInfo 579 err error 580 } 581 582 type fakeAuthInfo struct{} 583 584 func (f fakeAuthInfo) AuthType() string { 585 return "test" 586 } 587 588 func (t fakeTransportCredentials) ClientHandshake(ctx context.Context, addr string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { 589 return conn, t.info, t.err 590 } 591 592 func TestClusterCredentials(t *testing.T) { 593 t.Parallel() 594 595 cases := []struct { 596 name string 597 expectedClusterName string 598 credentials fakeTransportCredentials 599 errAssertion require.ErrorAssertionFunc 600 }{ 601 { 602 name: "handshake error", 603 credentials: fakeTransportCredentials{err: context.Canceled}, 604 errAssertion: require.Error, 605 }, 606 { 607 name: "no tls auth info", 608 credentials: fakeTransportCredentials{info: fakeAuthInfo{}}, 609 errAssertion: require.NoError, 610 }, 611 { 612 name: "no server cert", 613 credentials: fakeTransportCredentials{info: credentials.TLSInfo{}}, 614 errAssertion: require.NoError, 615 }, 616 { 617 name: "no cluster oid set", 618 credentials: fakeTransportCredentials{info: credentials.TLSInfo{ 619 State: tls.ConnectionState{ 620 PeerCertificates: []*x509.Certificate{ 621 { 622 Subject: pkix.Name{ 623 Names: []pkix.AttributeTypeAndValue{ 624 { 625 Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 1}, 626 }, 627 { 628 Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 1}, 629 }, 630 { 631 Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 2}, 632 }, 633 { 634 Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 2}, 635 }, 636 }, 637 }, 638 }, 639 }, 640 }, 641 }}, 642 errAssertion: require.NoError, 643 }, { 644 name: "cluster name presented", 645 expectedClusterName: "test-cluster", 646 credentials: fakeTransportCredentials{info: credentials.TLSInfo{ 647 State: tls.ConnectionState{ 648 PeerCertificates: []*x509.Certificate{ 649 { 650 Subject: pkix.Name{ 651 Names: []pkix.AttributeTypeAndValue{ 652 { 653 Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 1}, 654 }, 655 { 656 Type: asn1.ObjectIdentifier{1, 3, 9999, 0, 2}, 657 }, 658 { 659 Type: asn1.ObjectIdentifier{1, 3, 9999, 2, 2}, 660 }, 661 { 662 Type: teleportClusterASN1ExtensionOID, 663 Value: "test-cluster", 664 }, 665 }, 666 }, 667 }, 668 }, 669 }, 670 }}, 671 errAssertion: require.NoError, 672 }, 673 } 674 675 for _, test := range cases { 676 t.Run(test.name, func(t *testing.T) { 677 c := &clusterName{} 678 creds := clusterCredentials{TransportCredentials: test.credentials, clusterName: c} 679 _, _, err := creds.ClientHandshake(context.Background(), "127.0.0.1", nil) 680 test.errAssertion(t, err) 681 require.Equal(t, test.expectedClusterName, c.get()) 682 }) 683 } 684 } 685 686 func TestNewDialerForGRPCClient(t *testing.T) { 687 t.Run("Check that PROXYHeaderGetter if present sends PROXY header as first bytes on the connection", func(t *testing.T) { 688 listener, err := net.Listen("tcp", "127.0.0.1:0") 689 require.NoError(t, err) 690 t.Cleanup(func() { 691 require.NoError(t, listener.Close()) 692 }) 693 694 prefix := []byte("FAKEPROXY") 695 proxyHeaderGetter := func() ([]byte, error) { 696 return prefix, nil 697 } 698 699 ctx := context.Background() 700 cfg := &ClientConfig{ 701 PROXYHeaderGetter: proxyHeaderGetter, 702 } 703 dialer := newDialerForGRPCClient(ctx, cfg) 704 705 resultChan := make(chan bool) 706 // Start listening, emulating receiving end of connection 707 go func() { 708 conn, err := listener.Accept() 709 if err != nil { 710 assert.Fail(t, err.Error()) 711 return 712 } 713 714 buf := make([]byte, len(prefix)) 715 _, err = conn.Read(buf) 716 assert.NoError(t, err) 717 t.Cleanup(func() { 718 require.NoError(t, conn.Close()) 719 }) 720 721 // On the received connection first bytes should be our PROXY prefix 722 resultChan <- slices.Equal(buf, prefix) 723 }() 724 725 conn, err := dialer(ctx, listener.Addr().String()) 726 require.NoError(t, err) 727 t.Cleanup(func() { 728 require.NoError(t, conn.Close()) 729 }) 730 731 select { 732 case res := <-resultChan: 733 require.True(t, res, "Didn't receive required prefix as first bytes on the connection") 734 case <-time.After(time.Second): 735 require.Fail(t, "Timed out waiting for connection") 736 } 737 }) 738 }