github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/sshutils/conn_test.go (about) 1 /* 2 Copyright 2022 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sshutils 18 19 import ( 20 "crypto/rand" 21 "crypto/rsa" 22 "crypto/x509" 23 "encoding/pem" 24 "net" 25 "testing" 26 "time" 27 28 "github.com/stretchr/testify/require" 29 "golang.org/x/crypto/ssh" 30 31 "github.com/gravitational/teleport/api/constants" 32 ) 33 34 type server struct { 35 listener net.Listener 36 config *ssh.ServerConfig 37 handler func(*ssh.ServerConn) 38 39 cSigner ssh.Signer 40 hSigner ssh.Signer 41 } 42 43 func (s *server) Run(errC chan error) { 44 for { 45 conn, err := s.listener.Accept() 46 if err != nil { 47 errC <- err 48 return 49 } 50 51 go func() { 52 sconn, _, _, err := ssh.NewServerConn(conn, s.config) 53 if err != nil { 54 errC <- err 55 return 56 } 57 s.handler(sconn) 58 }() 59 } 60 } 61 62 func (s *server) Stop() error { 63 return s.listener.Close() 64 } 65 66 func generateSigner(t *testing.T) ssh.Signer { 67 private, err := rsa.GenerateKey(rand.Reader, 2048) 68 require.NoError(t, err) 69 70 block := &pem.Block{ 71 Type: "RSA PRIVATE KEY", 72 Bytes: x509.MarshalPKCS1PrivateKey(private), 73 } 74 75 privatePEM := pem.EncodeToMemory(block) 76 signer, err := ssh.ParsePrivateKey(privatePEM) 77 require.NoError(t, err) 78 79 return signer 80 } 81 82 func (s *server) GetClient(t *testing.T) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) { 83 conn, err := net.Dial("tcp", s.listener.Addr().String()) 84 require.NoError(t, err) 85 86 sconn, nc, r, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ 87 Auth: []ssh.AuthMethod{ssh.PublicKeys(s.cSigner)}, 88 HostKeyCallback: ssh.FixedHostKey(s.hSigner.PublicKey()), 89 }) 90 require.NoError(t, err) 91 92 return sconn, nc, r 93 } 94 95 func newServer(t *testing.T, handler func(*ssh.ServerConn)) *server { 96 listener, err := net.Listen("tcp", "localhost:0") 97 require.NoError(t, err) 98 99 cSigner := generateSigner(t) 100 hSigner := generateSigner(t) 101 102 config := &ssh.ServerConfig{ 103 NoClientAuth: true, 104 } 105 config.AddHostKey(hSigner) 106 107 return &server{ 108 listener: listener, 109 config: config, 110 handler: handler, 111 cSigner: cSigner, 112 hSigner: hSigner, 113 } 114 } 115 116 // TestTransportError ensures ConnectProxyTransport does not block forever 117 // when an error occurs while opening the transport channel. 118 func TestTransportError(t *testing.T) { 119 handlerErrC := make(chan error, 1) 120 serverErrC := make(chan error, 1) 121 122 server := newServer(t, func(sconn *ssh.ServerConn) { 123 _, _, err := ConnectProxyTransport(sconn, &DialReq{ 124 Address: "test", ServerID: "test", 125 }, false) 126 handlerErrC <- err 127 }) 128 129 go server.Run(serverErrC) 130 t.Cleanup(func() { require.NoError(t, server.Stop()) }) 131 132 sconn1, nc, _ := server.GetClient(t) 133 t.Cleanup(func() { require.Error(t, sconn1.Close()) }) 134 135 channel := <-nc 136 require.Equal(t, constants.ChanTransport, channel.ChannelType()) 137 138 sconn1.Close() 139 err := timeoutErrC(t, handlerErrC, time.Second*5) 140 require.Error(t, err) 141 142 sconn2, nc, _ := server.GetClient(t) 143 t.Cleanup(func() { require.NoError(t, sconn2.Close()) }) 144 145 channel = <-nc 146 require.Equal(t, constants.ChanTransport, channel.ChannelType()) 147 148 err = channel.Reject(ssh.ConnectionFailed, "test reject") 149 require.NoError(t, err) 150 151 err = timeoutErrC(t, handlerErrC, time.Second*5) 152 require.Error(t, err) 153 154 select { 155 case err = <-serverErrC: 156 require.FailNow(t, err.Error()) 157 default: 158 } 159 } 160 161 func timeoutErrC(t *testing.T, errC <-chan error, d time.Duration) error { 162 timeout := time.NewTimer(d) 163 select { 164 case err := <-errC: 165 return err 166 case <-timeout.C: 167 require.FailNow(t, "failed to receive on err channel in time") 168 } 169 170 return nil 171 }