github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/sshutils/checker.go (about)

     1  /*
     2  Copyright 2019-2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sshutils
    18  
    19  import (
    20  	"crypto/rsa"
    21  	"net"
    22  
    23  	"github.com/gravitational/trace"
    24  	"golang.org/x/crypto/ssh"
    25  
    26  	"github.com/gravitational/teleport/api/constants"
    27  )
    28  
    29  // CertChecker is a drop-in replacement for ssh.CertChecker. In FIPS mode,
    30  // checks if the certificate (or key) were generated with a supported algorithm.
    31  type CertChecker struct {
    32  	ssh.CertChecker
    33  
    34  	// FIPS means in addition to checking the validity of the key or
    35  	// certificate, also check that FIPS 140-2 algorithms were used.
    36  	FIPS bool
    37  
    38  	// OnCheckCert is called when validating host certificate.
    39  	OnCheckCert func(*ssh.Certificate) error
    40  }
    41  
    42  // Authenticate checks the validity of a user certificate.
    43  func (c *CertChecker) Authenticate(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
    44  	err := c.validateFIPS(key)
    45  	if err != nil {
    46  		return nil, trace.Wrap(err)
    47  	}
    48  
    49  	perms, err := c.CertChecker.Authenticate(conn, key)
    50  	if err != nil {
    51  		return nil, trace.Wrap(err)
    52  	}
    53  
    54  	return perms, nil
    55  }
    56  
    57  // CheckCert checks certificate metadata and signature.
    58  func (c *CertChecker) CheckCert(principal string, cert *ssh.Certificate) error {
    59  	err := c.validateFIPS(cert)
    60  	if err != nil {
    61  		return trace.Wrap(err)
    62  	}
    63  
    64  	err = c.CertChecker.CheckCert(principal, cert)
    65  	if err != nil {
    66  		return trace.Wrap(err)
    67  	}
    68  
    69  	if c.OnCheckCert != nil {
    70  		if err := c.OnCheckCert(cert); err != nil {
    71  			return trace.Wrap(err)
    72  		}
    73  	}
    74  
    75  	return nil
    76  }
    77  
    78  // CheckHostKey checks the validity of a host certificate.
    79  func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key ssh.PublicKey) error {
    80  	err := c.validateFIPS(key)
    81  	if err != nil {
    82  		return trace.Wrap(err)
    83  	}
    84  
    85  	err = c.CertChecker.CheckHostKey(addr, remote, key)
    86  	if err != nil {
    87  		return trace.Wrap(err)
    88  	}
    89  
    90  	if cert, ok := key.(*ssh.Certificate); ok && c.OnCheckCert != nil {
    91  		if err := c.OnCheckCert(cert); err != nil {
    92  			return trace.Wrap(err)
    93  		}
    94  	}
    95  
    96  	return nil
    97  }
    98  
    99  func (c *CertChecker) validateFIPS(key ssh.PublicKey) error {
   100  	// When not in FIPS mode, accept all algorithms and key sizes.
   101  	if !c.FIPS {
   102  		return nil
   103  	}
   104  
   105  	switch cert := key.(type) {
   106  	case *ssh.Certificate:
   107  		err := validateFIPSAlgorithm(cert.Key)
   108  		if err != nil {
   109  			return trace.Wrap(err)
   110  		}
   111  		err = validateFIPSAlgorithm(cert.SignatureKey)
   112  		if err != nil {
   113  			return trace.Wrap(err)
   114  		}
   115  		return nil
   116  	default:
   117  		return validateFIPSAlgorithm(key)
   118  	}
   119  }
   120  
   121  func validateFIPSAlgorithm(key ssh.PublicKey) error {
   122  	cryptoKey, ok := key.(ssh.CryptoPublicKey)
   123  	if !ok {
   124  		return trace.BadParameter("unable to determine underlying public key")
   125  	}
   126  	k, ok := cryptoKey.CryptoPublicKey().(*rsa.PublicKey)
   127  	if !ok {
   128  		return trace.BadParameter("only RSA keys supported")
   129  	}
   130  	if k.N.BitLen() != constants.RSAKeySize {
   131  		return trace.BadParameter("found %v-bit key, only %v-bit supported", k.N.BitLen(), constants.RSAKeySize)
   132  	}
   133  	return nil
   134  }