github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/sshutils/test.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sshutils
    18  
    19  import (
    20  	"crypto/rand"
    21  	"crypto/rsa"
    22  	"time"
    23  
    24  	"github.com/gravitational/trace"
    25  	"golang.org/x/crypto/ssh"
    26  
    27  	"github.com/gravitational/teleport/api/constants"
    28  )
    29  
    30  const defaultPrincipal = "127.0.0.1"
    31  
    32  // MakeTestSSHCA generates a new SSH certificate authority for tests.
    33  func MakeTestSSHCA() (ssh.Signer, error) {
    34  	privateKey, err := rsa.GenerateKey(rand.Reader, constants.RSAKeySize)
    35  	if err != nil {
    36  		return nil, trace.Wrap(err)
    37  	}
    38  	ca, err := ssh.NewSignerFromKey(privateKey)
    39  	if err != nil {
    40  		return nil, trace.Wrap(err)
    41  	}
    42  	return ca, nil
    43  }
    44  
    45  // MakeSpoofedHostCert makes an SSH host certificate that claims to be signed
    46  // by the provided CA but in fact is signed by a different CA.
    47  func MakeSpoofedHostCert(realCA ssh.Signer) (ssh.Signer, error) {
    48  	fakeCA, err := MakeTestSSHCA()
    49  	if err != nil {
    50  		return nil, trace.Wrap(err)
    51  	}
    52  	return makeHostCert(realCA.PublicKey(), fakeCA, defaultPrincipal)
    53  }
    54  
    55  // MakeRealHostCert makes an SSH host certificate that is signed by the
    56  // provided CA.
    57  func MakeRealHostCert(realCA ssh.Signer) (ssh.Signer, error) {
    58  	return makeHostCert(realCA.PublicKey(), realCA, defaultPrincipal)
    59  }
    60  
    61  // MakeRealHostCertWithPrincipals makes an SSH host certificate that is signed by the
    62  // provided CA for the provided principals.
    63  func MakeRealHostCertWithPrincipals(realCA ssh.Signer, principals ...string) (ssh.Signer, error) {
    64  	return makeHostCert(realCA.PublicKey(), realCA, principals...)
    65  }
    66  
    67  func makeHostCert(signKey ssh.PublicKey, signer ssh.Signer, principals ...string) (ssh.Signer, error) {
    68  	priv, err := rsa.GenerateKey(rand.Reader, constants.RSAKeySize)
    69  	if err != nil {
    70  		return nil, trace.Wrap(err)
    71  	}
    72  
    73  	privSigner, err := ssh.NewSignerFromKey(priv)
    74  	if err != nil {
    75  		return nil, trace.Wrap(err)
    76  	}
    77  
    78  	pub, err := ssh.NewPublicKey(priv.Public())
    79  	if err != nil {
    80  		return nil, trace.Wrap(err)
    81  	}
    82  
    83  	nonce := make([]byte, 32)
    84  	if _, err = rand.Read(nonce); err != nil {
    85  		return nil, trace.Wrap(err)
    86  	}
    87  
    88  	cert := &ssh.Certificate{
    89  		Nonce:           nonce,
    90  		Key:             pub,
    91  		CertType:        ssh.HostCert,
    92  		SignatureKey:    signKey,
    93  		ValidPrincipals: principals,
    94  		ValidBefore:     uint64(time.Now().Add(time.Hour).Unix()),
    95  	}
    96  
    97  	// We cannot use ssh.Certificate SignCert method since we're intentionally
    98  	// setting invalid signature key to make a spoofed cert in some tests.
    99  	//
   100  	// When marshaling cert for signing, last 4 bytes containing trailing
   101  	// signature length are dropped:
   102  	//
   103  	// https://cs.opensource.google/go/x/crypto/+/32db7946:ssh/certs.go;l=456-462
   104  	bytesForSigning := cert.Marshal()
   105  	bytesForSigning = bytesForSigning[:len(bytesForSigning)-4]
   106  
   107  	cert.Signature, err = signer.Sign(rand.Reader, bytesForSigning)
   108  	if err != nil {
   109  		return nil, trace.Wrap(err)
   110  	}
   111  
   112  	certSigner, err := ssh.NewCertSigner(cert, privSigner)
   113  	if err != nil {
   114  		return nil, trace.Wrap(err)
   115  	}
   116  
   117  	return certSigner, nil
   118  }