github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/proxy/transport/transportv1/client_test.go (about) 1 // Copyright 2023 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package transportv1 16 17 import ( 18 "bytes" 19 "context" 20 "crypto/rand" 21 "crypto/rsa" 22 "errors" 23 "fmt" 24 "io" 25 "math" 26 "net" 27 "sync/atomic" 28 "testing" 29 "time" 30 31 "github.com/gravitational/trace" 32 "github.com/gravitational/trace/trail" 33 "github.com/stretchr/testify/require" 34 "golang.org/x/crypto/ssh/agent" 35 "google.golang.org/grpc" 36 "google.golang.org/grpc/credentials/insecure" 37 "google.golang.org/grpc/test/bufconn" 38 39 transportv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1" 40 "github.com/gravitational/teleport/api/utils/grpc/interceptors" 41 streamutils "github.com/gravitational/teleport/api/utils/grpc/stream" 42 ) 43 44 type fakeGetClusterDetailsServer func(context.Context, *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) 45 46 type fakeProxySSHServer func(transportv1pb.TransportService_ProxySSHServer) error 47 48 type fakeProxyClusterServer func(transportv1pb.TransportService_ProxyClusterServer) error 49 50 // fakeServer is a [transportv1pb.TransportServiceServer] implementation 51 // that allows tests to manipulate the server side of various RPCs. 52 type fakeServer struct { 53 transportv1pb.UnimplementedTransportServiceServer 54 55 details fakeGetClusterDetailsServer 56 ssh fakeProxySSHServer 57 cluster fakeProxyClusterServer 58 } 59 60 func (s fakeServer) GetClusterDetails(ctx context.Context, req *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 61 return s.details(ctx, req) 62 } 63 64 func (s fakeServer) ProxySSH(stream transportv1pb.TransportService_ProxySSHServer) error { 65 return s.ssh(stream) 66 } 67 68 func (s fakeServer) ProxyCluster(stream transportv1pb.TransportService_ProxyClusterServer) error { 69 return s.cluster(stream) 70 } 71 72 // TestClient_ClusterDetails validates that a Client can retrieve 73 // [transportv1pb.ClusterDetails] from a [transportv1pb.TransportServiceServer]. 74 func TestClient_ClusterDetails(t *testing.T) { 75 t.Parallel() 76 77 pack := newServer(t, fakeServer{ 78 details: func() fakeGetClusterDetailsServer { 79 var i atomic.Bool 80 return func(ctx context.Context, request *transportv1pb.GetClusterDetailsRequest) (*transportv1pb.GetClusterDetailsResponse, error) { 81 if i.CompareAndSwap(false, true) { 82 return &transportv1pb.GetClusterDetailsResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}, nil 83 } 84 85 return nil, trail.ToGRPC(trace.NotImplemented("not implemented")) 86 } 87 }(), 88 }) 89 90 tests := []struct { 91 name string 92 assertion func(t *testing.T, response *transportv1pb.ClusterDetails, err error) 93 }{ 94 { 95 name: "details retrieved successfully", 96 assertion: func(t *testing.T, response *transportv1pb.ClusterDetails, err error) { 97 require.NoError(t, err) 98 require.NotNil(t, response) 99 require.True(t, response.FipsEnabled) 100 }, 101 }, 102 { 103 name: "error getting details", 104 assertion: func(t *testing.T, response *transportv1pb.ClusterDetails, err error) { 105 require.ErrorIs(t, err, trace.NotImplemented("not implemented")) 106 require.Nil(t, response) 107 }, 108 }, 109 } 110 111 for _, test := range tests { 112 t.Run(test.name, func(t *testing.T) { 113 resp, err := pack.Client.ClusterDetails(context.Background()) 114 test.assertion(t, resp, err) 115 }) 116 } 117 } 118 119 // TestClient_DialCluster validates that a Client can establish a 120 // connection to a cluster and that said connection is proxied over 121 // the gRPC stream. 122 func TestClient_DialCluster(t *testing.T) { 123 t.Parallel() 124 125 pack := newServer(t, fakeServer{ 126 cluster: func(server transportv1pb.TransportService_ProxyClusterServer) error { 127 req, err := server.Recv() 128 if err != nil { 129 return trail.ToGRPC(err) 130 } 131 132 switch req.Cluster { 133 case "": 134 return trail.ToGRPC(trace.BadParameter("first message must contain a cluster")) 135 case "not-implemented": 136 return trail.ToGRPC(trace.NotImplemented("not implemented")) 137 case "echo": 138 // get the payload written 139 req, err = server.Recv() 140 if err != nil { 141 return trail.ToGRPC(err) 142 } 143 144 // echo the data back 145 if err := server.Send(&transportv1pb.ProxyClusterResponse{Frame: &transportv1pb.Frame{Payload: req.Frame.Payload}}); err != nil { 146 return trail.ToGRPC(err) 147 } 148 149 return nil 150 default: 151 return trace.NotFound("unknown cluster: %q", req.Cluster) 152 } 153 }, 154 }) 155 156 tests := []struct { 157 name string 158 cluster string 159 assertion func(t *testing.T, conn net.Conn, err error) 160 }{ 161 { 162 name: "stream terminated", 163 cluster: "not-implemented", 164 assertion: func(t *testing.T, conn net.Conn, err error) { 165 require.NoError(t, err) 166 require.NotNil(t, conn) 167 168 n, err := conn.Read(make([]byte, 10)) 169 require.True(t, trace.IsConnectionProblem(err)) 170 require.Zero(t, n) 171 }, 172 }, 173 { 174 name: "invalid cluster name", 175 cluster: "unknown", 176 assertion: func(t *testing.T, conn net.Conn, err error) { 177 require.NoError(t, err) 178 require.NotNil(t, conn) 179 180 n, err := conn.Read(make([]byte, 10)) 181 require.True(t, trace.IsConnectionProblem(err)) 182 require.Zero(t, n) 183 }, 184 }, 185 { 186 name: "connection successfully established", 187 cluster: "echo", 188 assertion: func(t *testing.T, conn net.Conn, err error) { 189 require.NoError(t, err) 190 require.NotNil(t, conn) 191 192 msg := []byte("hello") 193 n, err := conn.Write(msg) 194 require.NoError(t, err) 195 require.Len(t, msg, n) 196 197 out := make([]byte, n) 198 n, err = conn.Read(out) 199 require.NoError(t, err) 200 require.Len(t, msg, n) 201 require.Equal(t, msg, out) 202 203 require.NoError(t, conn.Close()) 204 }, 205 }, 206 } 207 208 for _, test := range tests { 209 t.Run(test.name, func(t *testing.T) { 210 conn, err := pack.Client.DialCluster(context.Background(), test.cluster, nil) 211 test.assertion(t, conn, err) 212 }) 213 } 214 } 215 216 // TestClient_DialHost validates that a Client can establish a 217 // connection to a host and that both SSH and SSH Agent protocol is 218 // proxied over the gRPC stream. 219 func TestClient_DialHost(t *testing.T) { 220 t.Parallel() 221 222 keyring := newKeyring(t) 223 224 pack := newServer(t, fakeServer{ 225 ssh: func(server transportv1pb.TransportService_ProxySSHServer) error { 226 req, err := server.Recv() 227 if err != nil { 228 return trail.ToGRPC(err) 229 } 230 231 switch { 232 case req == nil: 233 return trail.ToGRPC(trace.BadParameter("first message must contain a dial target")) 234 case req.DialTarget.Cluster == "": 235 return trail.ToGRPC(trace.BadParameter("first message must contain a cluster")) 236 case req.DialTarget.HostPort == "": 237 return trail.ToGRPC(trace.BadParameter("invalid dial target")) 238 case req.DialTarget.Cluster == "not-implemented": 239 return trail.ToGRPC(trace.NotImplemented("not implemented")) 240 case req.DialTarget.Cluster == "payload-too-large": 241 // send the initial cluster details 242 if err := server.Send(&transportv1pb.ProxySSHResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}); err != nil && !errors.Is(err, io.EOF) { 243 return trail.ToGRPC(trace.Wrap(err)) 244 } 245 246 // wait for the first ssh frame 247 req, err = server.Recv() 248 if err != nil { 249 return trail.ToGRPC(trace.Wrap(err)) 250 } 251 252 // write too much data to terminate the stream 253 switch req.Frame.(type) { 254 case *transportv1pb.ProxySSHRequest_Ssh: 255 if err := server.Send(&transportv1pb.ProxySSHResponse{ 256 Details: nil, 257 Frame: &transportv1pb.ProxySSHResponse_Ssh{Ssh: &transportv1pb.Frame{Payload: bytes.Repeat([]byte{0}, 1001)}}, 258 }); err != nil && !errors.Is(err, io.EOF) { 259 return trail.ToGRPC(trace.Wrap(err)) 260 } 261 case *transportv1pb.ProxySSHRequest_Agent: 262 return trail.ToGRPC(trace.BadParameter("test expects first frame to be ssh. got an agent frame")) 263 } 264 265 return nil 266 case req.DialTarget.Cluster == "echo": 267 // send the initial cluster details 268 if err := server.Send(&transportv1pb.ProxySSHResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}); err != nil && !errors.Is(err, io.EOF) { 269 return trail.ToGRPC(trace.Wrap(err)) 270 } 271 272 // wait for the first ssh frame 273 req, err = server.Recv() 274 if err != nil { 275 return trail.ToGRPC(trace.Wrap(err)) 276 } 277 278 // write too much data to terminate the stream 279 switch f := req.Frame.(type) { 280 case *transportv1pb.ProxySSHRequest_Ssh: 281 if err := server.Send(&transportv1pb.ProxySSHResponse{ 282 Details: nil, 283 Frame: &transportv1pb.ProxySSHResponse_Ssh{Ssh: &transportv1pb.Frame{Payload: f.Ssh.Payload}}, 284 }); err != nil && !errors.Is(err, io.EOF) { 285 return trail.ToGRPC(trace.Wrap(err)) 286 } 287 case *transportv1pb.ProxySSHRequest_Agent: 288 return trail.ToGRPC(trace.BadParameter("test expects first frame to be ssh. got an agent frame")) 289 } 290 return nil 291 case req.DialTarget.Cluster == "forward": 292 // send the initial cluster details 293 if err := server.Send(&transportv1pb.ProxySSHResponse{Details: &transportv1pb.ClusterDetails{FipsEnabled: true}}); err != nil && !errors.Is(err, io.EOF) { 294 return trail.ToGRPC(trace.Wrap(err)) 295 } 296 297 // wait for the first ssh frame 298 req, err = server.Recv() 299 if err != nil { 300 return trail.ToGRPC(trace.Wrap(err)) 301 } 302 303 // echo the data back on an ssh frame 304 switch f := req.Frame.(type) { 305 case *transportv1pb.ProxySSHRequest_Ssh: 306 if err := server.Send(&transportv1pb.ProxySSHResponse{ 307 Details: nil, 308 Frame: &transportv1pb.ProxySSHResponse_Ssh{Ssh: &transportv1pb.Frame{Payload: f.Ssh.Payload}}, 309 }); err != nil && !errors.Is(err, io.EOF) { 310 return trail.ToGRPC(trace.Wrap(err)) 311 } 312 case *transportv1pb.ProxySSHRequest_Agent: 313 return trail.ToGRPC(trace.BadParameter("test expects first frame to be ssh. got an agent frame")) 314 } 315 316 // create an agent stream and writer to communicate agent protocol on 317 agentStream := newServerStream(server, func(payload []byte) *transportv1pb.ProxySSHResponse { 318 return &transportv1pb.ProxySSHResponse{Frame: &transportv1pb.ProxySSHResponse_Agent{Agent: &transportv1pb.Frame{Payload: payload}}} 319 }) 320 agentStreamRW, err := streamutils.NewReadWriter(agentStream) 321 if err != nil { 322 return trail.ToGRPC(trace.Wrap(err, "failed constructing ssh agent streamer")) 323 } 324 325 // read in agent frames 326 go func() { 327 for { 328 req, err := server.Recv() 329 if err != nil { 330 if errors.Is(err, io.EOF) { 331 return 332 } 333 334 return 335 } 336 337 switch frame := req.Frame.(type) { 338 case *transportv1pb.ProxySSHRequest_Agent: 339 agentStream.incomingC <- frame.Agent.Payload 340 default: 341 continue 342 } 343 } 344 }() 345 346 // create an agent that will communicate over the agent frames 347 // and list the keys from the client 348 clt := agent.NewClient(agentStreamRW) 349 keys, err := clt.List() 350 if err != nil { 351 return trail.ToGRPC(trace.Wrap(err)) 352 } 353 354 if len(keys) != 1 { 355 return trail.ToGRPC(fmt.Errorf("expected to receive 1 key. got %v", len(keys))) 356 } 357 358 // send the key blob back via an ssh frame to alert the 359 // test that we finished listing keys 360 if err := server.Send(&transportv1pb.ProxySSHResponse{ 361 Details: nil, 362 Frame: &transportv1pb.ProxySSHResponse_Ssh{Ssh: &transportv1pb.Frame{Payload: keys[0].Blob}}, 363 }); err != nil && !errors.Is(err, io.EOF) { 364 return trail.ToGRPC(trace.Wrap(err)) 365 } 366 return nil 367 default: 368 return trail.ToGRPC(trace.BadParameter("invalid cluster")) 369 } 370 }, 371 }) 372 373 tests := []struct { 374 name string 375 cluster string 376 target string 377 keyring agent.ExtendedAgent 378 assertion func(t *testing.T, conn net.Conn, details *transportv1pb.ClusterDetails, err error) 379 }{ 380 { 381 name: "stream terminated", 382 cluster: "not-implemented", 383 target: "127.0.0.1:8080", 384 assertion: func(t *testing.T, conn net.Conn, details *transportv1pb.ClusterDetails, err error) { 385 require.ErrorIs(t, err, trace.NotImplemented("not implemented")) 386 require.Nil(t, conn) 387 require.Nil(t, details) 388 }, 389 }, 390 { 391 name: "invalid dial target", 392 cluster: "valid", 393 assertion: func(t *testing.T, conn net.Conn, details *transportv1pb.ClusterDetails, err error) { 394 require.ErrorIs(t, err, trace.BadParameter("invalid dial target")) 395 require.Nil(t, conn) 396 require.Nil(t, details) 397 }, 398 }, 399 { 400 name: "connection terminated when receive returns an error", 401 cluster: "payload-too-large", 402 target: "127.0.0.1:8080", 403 assertion: func(t *testing.T, conn net.Conn, details *transportv1pb.ClusterDetails, err error) { 404 require.NoError(t, err) 405 require.NotNil(t, conn) 406 407 msg := []byte("hello") 408 n, err := conn.Write(msg) 409 require.NoError(t, err) 410 require.Len(t, msg, n) 411 412 out := make([]byte, 10) 413 n, err = conn.Read(out) 414 require.True(t, trace.IsConnectionProblem(err)) 415 require.Zero(t, n) 416 417 require.NoError(t, conn.Close()) 418 }, 419 }, 420 { 421 name: "connection successfully established without agent forwarding", 422 cluster: "echo", 423 target: "127.0.0.1:8080", 424 assertion: func(t *testing.T, conn net.Conn, details *transportv1pb.ClusterDetails, err error) { 425 require.NoError(t, err) 426 require.NotNil(t, conn) 427 428 msg := []byte("hello") 429 n, err := conn.Write(msg) 430 require.NoError(t, err) 431 require.Len(t, msg, n) 432 433 out := make([]byte, n) 434 n, err = conn.Read(out) 435 require.NoError(t, err) 436 require.Len(t, msg, n) 437 require.Equal(t, msg, out) 438 439 n, err = conn.Read(out) 440 require.ErrorIs(t, err, io.EOF) 441 require.Zero(t, n) 442 443 require.NoError(t, conn.Close()) 444 }, 445 }, 446 { 447 name: "connection successfully established with agent forwarding", 448 cluster: "forward", 449 target: "127.0.0.1:8080", 450 keyring: keyring, 451 assertion: func(t *testing.T, conn net.Conn, details *transportv1pb.ClusterDetails, err error) { 452 require.NoError(t, err) 453 require.NotNil(t, conn) 454 require.True(t, details.FipsEnabled) 455 456 // write data via ssh frames 457 msg := []byte("hello") 458 n, err := conn.Write(msg) 459 require.NoError(t, err) 460 require.Len(t, msg, n) 461 462 // read data via ssh frames 463 out := make([]byte, n) 464 n, err = conn.Read(out) 465 require.NoError(t, err) 466 require.Len(t, msg, n) 467 require.Equal(t, msg, out) 468 469 // get the keys from our local keyring 470 keys, err := keyring.List() 471 require.NoError(t, err) 472 require.Len(t, keys, 1) 473 474 // the server performs a remote list of keys 475 // via ssh frames. to prevent the test from terminating 476 // before it can complete it will write the blob of the 477 // listed key back on the ssh frame. verify that the key 478 // it received matches the one from out local keyring. 479 out = make([]byte, len(keys[0].Blob)) 480 n, err = conn.Read(out) 481 require.NoError(t, err) 482 require.Len(t, keys[0].Blob, n) 483 require.Equal(t, keys[0].Blob, out) 484 485 // close the stream 486 require.NoError(t, conn.Close()) 487 }, 488 }, 489 } 490 491 for _, test := range tests { 492 t.Run(test.name, func(t *testing.T) { 493 conn, details, err := pack.Client.DialHost(context.Background(), test.target, test.cluster, nil, test.keyring) 494 test.assertion(t, conn, details, err) 495 }) 496 } 497 } 498 499 // testPack used to test a [Client]. 500 type testPack struct { 501 Client *Client 502 Server transportv1pb.TransportServiceServer 503 } 504 505 // newServer creates a [grpc.Server] and registers the 506 // provided [transportv1pb.TransportServiceServer] with it opens 507 // an authenticated Client. 508 func newServer(t *testing.T, srv transportv1pb.TransportServiceServer) testPack { 509 // gRPC testPack. 510 const bufSize = 100 // arbitrary 511 lis := bufconn.Listen(bufSize) 512 t.Cleanup(func() { 513 require.NoError(t, lis.Close()) 514 }) 515 516 s := grpc.NewServer() 517 t.Cleanup(func() { 518 s.GracefulStop() 519 s.Stop() 520 }) 521 522 // Register service. 523 transportv1pb.RegisterTransportServiceServer(s, srv) 524 525 // Start. 526 go func() { 527 if err := s.Serve(lis); err != nil && !errors.Is(err, grpc.ErrServerStopped) { 528 panic(fmt.Sprintf("Serve returned err = %v", err)) 529 } 530 }() 531 532 // gRPC client. 533 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 534 defer cancel() 535 cc, err := grpc.DialContext(ctx, "unused", 536 grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { 537 return lis.DialContext(ctx) 538 }), 539 grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(1000)), 540 grpc.WithTransportCredentials(insecure.NewCredentials()), 541 grpc.WithUnaryInterceptor(interceptors.GRPCClientUnaryErrorInterceptor), 542 grpc.WithStreamInterceptor(interceptors.GRPCClientStreamErrorInterceptor), 543 ) 544 require.NoError(t, err) 545 t.Cleanup(func() { 546 require.NoError(t, cc.Close()) 547 }) 548 549 return testPack{ 550 Client: &Client{clt: transportv1pb.NewTransportServiceClient(cc)}, 551 Server: srv, 552 } 553 } 554 555 // newKeyring returns an [agent.ExtendedAgent] that has 556 // one key populated in it. 557 func newKeyring(t *testing.T) agent.ExtendedAgent { 558 private, err := rsa.GenerateKey(rand.Reader, 2048) 559 require.NoError(t, err) 560 561 keyring := agent.NewKeyring() 562 563 require.NoError(t, keyring.Add(agent.AddedKey{ 564 PrivateKey: private, 565 Comment: "test", 566 LifetimeSecs: math.MaxUint32, 567 })) 568 569 extendedKeyring, ok := keyring.(agent.ExtendedAgent) 570 require.True(t, ok) 571 572 return extendedKeyring 573 } 574 575 // serverStream implements the [streamutils.Source] interface 576 // for a [transportv1pb.TransportService_ProxySSHServer]. Instead of 577 // reading directly from the stream reads are from an incoming 578 // channel that is fed by the multiplexer. 579 type serverStream struct { 580 incomingC chan []byte 581 stream transportv1pb.TransportService_ProxySSHServer 582 responseFn func(payload []byte) *transportv1pb.ProxySSHResponse 583 } 584 585 func newServerStream(stream transportv1pb.TransportService_ProxySSHServer, responseFn func(payload []byte) *transportv1pb.ProxySSHResponse) *serverStream { 586 return &serverStream{ 587 incomingC: make(chan []byte, 10), 588 stream: stream, 589 responseFn: responseFn, 590 } 591 } 592 593 func (s *serverStream) Recv() ([]byte, error) { 594 select { 595 case <-s.stream.Context().Done(): 596 return nil, io.EOF 597 case frame := <-s.incomingC: 598 return frame, nil 599 } 600 } 601 602 func (s *serverStream) Send(frame []byte) error { 603 return trace.Wrap(s.stream.Send(s.responseFn(frame))) 604 }