github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/sshutils/conn_test.go (about)

     1  /*
     2  Copyright 2022 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sshutils
    18  
    19  import (
    20  	"crypto/rand"
    21  	"crypto/rsa"
    22  	"crypto/x509"
    23  	"encoding/pem"
    24  	"net"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/stretchr/testify/require"
    29  	"golang.org/x/crypto/ssh"
    30  
    31  	"github.com/gravitational/teleport/api/constants"
    32  )
    33  
    34  type server struct {
    35  	listener net.Listener
    36  	config   *ssh.ServerConfig
    37  	handler  func(*ssh.ServerConn)
    38  
    39  	cSigner ssh.Signer
    40  	hSigner ssh.Signer
    41  }
    42  
    43  func (s *server) Run(errC chan error) {
    44  	for {
    45  		conn, err := s.listener.Accept()
    46  		if err != nil {
    47  			errC <- err
    48  			return
    49  		}
    50  
    51  		go func() {
    52  			sconn, _, _, err := ssh.NewServerConn(conn, s.config)
    53  			if err != nil {
    54  				errC <- err
    55  				return
    56  			}
    57  			s.handler(sconn)
    58  		}()
    59  	}
    60  }
    61  
    62  func (s *server) Stop() error {
    63  	return s.listener.Close()
    64  }
    65  
    66  func generateSigner(t *testing.T) ssh.Signer {
    67  	private, err := rsa.GenerateKey(rand.Reader, 2048)
    68  	require.NoError(t, err)
    69  
    70  	block := &pem.Block{
    71  		Type:  "RSA PRIVATE KEY",
    72  		Bytes: x509.MarshalPKCS1PrivateKey(private),
    73  	}
    74  
    75  	privatePEM := pem.EncodeToMemory(block)
    76  	signer, err := ssh.ParsePrivateKey(privatePEM)
    77  	require.NoError(t, err)
    78  
    79  	return signer
    80  }
    81  
    82  func (s *server) GetClient(t *testing.T) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) {
    83  	conn, err := net.Dial("tcp", s.listener.Addr().String())
    84  	require.NoError(t, err)
    85  
    86  	sconn, nc, r, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{
    87  		Auth:            []ssh.AuthMethod{ssh.PublicKeys(s.cSigner)},
    88  		HostKeyCallback: ssh.FixedHostKey(s.hSigner.PublicKey()),
    89  	})
    90  	require.NoError(t, err)
    91  
    92  	return sconn, nc, r
    93  }
    94  
    95  func newServer(t *testing.T, handler func(*ssh.ServerConn)) *server {
    96  	listener, err := net.Listen("tcp", "localhost:0")
    97  	require.NoError(t, err)
    98  
    99  	cSigner := generateSigner(t)
   100  	hSigner := generateSigner(t)
   101  
   102  	config := &ssh.ServerConfig{
   103  		NoClientAuth: true,
   104  	}
   105  	config.AddHostKey(hSigner)
   106  
   107  	return &server{
   108  		listener: listener,
   109  		config:   config,
   110  		handler:  handler,
   111  		cSigner:  cSigner,
   112  		hSigner:  hSigner,
   113  	}
   114  }
   115  
   116  // TestTransportError ensures ConnectProxyTransport does not block forever
   117  // when an error occurs while opening the transport channel.
   118  func TestTransportError(t *testing.T) {
   119  	handlerErrC := make(chan error, 1)
   120  	serverErrC := make(chan error, 1)
   121  
   122  	server := newServer(t, func(sconn *ssh.ServerConn) {
   123  		_, _, err := ConnectProxyTransport(sconn, &DialReq{
   124  			Address: "test", ServerID: "test",
   125  		}, false)
   126  		handlerErrC <- err
   127  	})
   128  
   129  	go server.Run(serverErrC)
   130  	t.Cleanup(func() { require.NoError(t, server.Stop()) })
   131  
   132  	sconn1, nc, _ := server.GetClient(t)
   133  	t.Cleanup(func() { require.Error(t, sconn1.Close()) })
   134  
   135  	channel := <-nc
   136  	require.Equal(t, constants.ChanTransport, channel.ChannelType())
   137  
   138  	sconn1.Close()
   139  	err := timeoutErrC(t, handlerErrC, time.Second*5)
   140  	require.Error(t, err)
   141  
   142  	sconn2, nc, _ := server.GetClient(t)
   143  	t.Cleanup(func() { require.NoError(t, sconn2.Close()) })
   144  
   145  	channel = <-nc
   146  	require.Equal(t, constants.ChanTransport, channel.ChannelType())
   147  
   148  	err = channel.Reject(ssh.ConnectionFailed, "test reject")
   149  	require.NoError(t, err)
   150  
   151  	err = timeoutErrC(t, handlerErrC, time.Second*5)
   152  	require.Error(t, err)
   153  
   154  	select {
   155  	case err = <-serverErrC:
   156  		require.FailNow(t, err.Error())
   157  	default:
   158  	}
   159  }
   160  
   161  func timeoutErrC(t *testing.T, errC <-chan error, d time.Duration) error {
   162  	timeout := time.NewTimer(d)
   163  	select {
   164  	case err := <-errC:
   165  		return err
   166  	case <-timeout.C:
   167  		require.FailNow(t, "failed to receive on err channel in time")
   168  	}
   169  
   170  	return nil
   171  }