github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/ca/transport.go (about)

     1  package ca
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"crypto/x509/pkix"
     8  	"net"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/pkg/errors"
    13  	"google.golang.org/grpc/credentials"
    14  )
    15  
    16  var (
    17  	// alpnProtoStr is the specified application level protocols for gRPC.
    18  	alpnProtoStr = []string{"h2"}
    19  )
    20  
    21  // MutableTLSCreds is the credentials required for authenticating a connection using TLS.
    22  type MutableTLSCreds struct {
    23  	// Mutex for the tls config
    24  	sync.Mutex
    25  	// TLS configuration
    26  	config *tls.Config
    27  	// TLS Credentials
    28  	tlsCreds credentials.TransportCredentials
    29  	// store the subject for easy access
    30  	subject pkix.Name
    31  }
    32  
    33  // Info implements the credentials.TransportCredentials interface
    34  func (c *MutableTLSCreds) Info() credentials.ProtocolInfo {
    35  	return credentials.ProtocolInfo{
    36  		SecurityProtocol: "tls",
    37  		SecurityVersion:  "1.2",
    38  	}
    39  }
    40  
    41  // Clone returns new MutableTLSCreds created from underlying *tls.Config.
    42  // It panics if validation of underlying config fails.
    43  func (c *MutableTLSCreds) Clone() credentials.TransportCredentials {
    44  	c.Lock()
    45  	newCfg, err := NewMutableTLS(c.config.Clone())
    46  	if err != nil {
    47  		panic("validation error on Clone")
    48  	}
    49  	c.Unlock()
    50  	return newCfg
    51  }
    52  
    53  // OverrideServerName overrides *tls.Config.ServerName.
    54  func (c *MutableTLSCreds) OverrideServerName(name string) error {
    55  	c.Lock()
    56  	c.config.ServerName = name
    57  	c.Unlock()
    58  	return nil
    59  }
    60  
    61  // GetRequestMetadata implements the credentials.TransportCredentials interface
    62  func (c *MutableTLSCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
    63  	return nil, nil
    64  }
    65  
    66  // RequireTransportSecurity implements the credentials.TransportCredentials interface
    67  func (c *MutableTLSCreds) RequireTransportSecurity() bool {
    68  	return true
    69  }
    70  
    71  // ClientHandshake implements the credentials.TransportCredentials interface
    72  func (c *MutableTLSCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    73  	// borrow all the code from the original TLS credentials
    74  	c.Lock()
    75  	if c.config.ServerName == "" {
    76  		colonPos := strings.LastIndex(addr, ":")
    77  		if colonPos == -1 {
    78  			colonPos = len(addr)
    79  		}
    80  		c.config.ServerName = addr[:colonPos]
    81  	}
    82  
    83  	conn := tls.Client(rawConn, c.config)
    84  	// Need to allow conn.Handshake to have access to config,
    85  	// would create a deadlock otherwise
    86  	c.Unlock()
    87  	var err error
    88  	errChannel := make(chan error, 1)
    89  	go func() {
    90  		errChannel <- conn.Handshake()
    91  	}()
    92  	select {
    93  	case err = <-errChannel:
    94  	case <-ctx.Done():
    95  		err = ctx.Err()
    96  	}
    97  	if err != nil {
    98  		rawConn.Close()
    99  		return nil, nil, err
   100  	}
   101  	return conn, nil, nil
   102  }
   103  
   104  // ServerHandshake implements the credentials.TransportCredentials interface
   105  func (c *MutableTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   106  	c.Lock()
   107  	conn := tls.Server(rawConn, c.config)
   108  	c.Unlock()
   109  	if err := conn.Handshake(); err != nil {
   110  		rawConn.Close()
   111  		return nil, nil, err
   112  	}
   113  
   114  	return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil
   115  }
   116  
   117  // loadNewTLSConfig replaces the currently loaded TLS config with a new one
   118  func (c *MutableTLSCreds) loadNewTLSConfig(newConfig *tls.Config) error {
   119  	newSubject, err := GetAndValidateCertificateSubject(newConfig.Certificates)
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	c.Lock()
   125  	defer c.Unlock()
   126  	c.subject = newSubject
   127  	c.config = newConfig
   128  
   129  	return nil
   130  }
   131  
   132  // Config returns the current underlying TLS config.
   133  func (c *MutableTLSCreds) Config() *tls.Config {
   134  	c.Lock()
   135  	defer c.Unlock()
   136  
   137  	return c.config
   138  }
   139  
   140  // Role returns the OU for the certificate encapsulated in this TransportCredentials
   141  func (c *MutableTLSCreds) Role() string {
   142  	c.Lock()
   143  	defer c.Unlock()
   144  
   145  	return c.subject.OrganizationalUnit[0]
   146  }
   147  
   148  // Organization returns the O for the certificate encapsulated in this TransportCredentials
   149  func (c *MutableTLSCreds) Organization() string {
   150  	c.Lock()
   151  	defer c.Unlock()
   152  
   153  	return c.subject.Organization[0]
   154  }
   155  
   156  // NodeID returns the CN for the certificate encapsulated in this TransportCredentials
   157  func (c *MutableTLSCreds) NodeID() string {
   158  	c.Lock()
   159  	defer c.Unlock()
   160  
   161  	return c.subject.CommonName
   162  }
   163  
   164  // NewMutableTLS uses c to construct a mutable TransportCredentials based on TLS.
   165  func NewMutableTLS(c *tls.Config) (*MutableTLSCreds, error) {
   166  	originalTC := credentials.NewTLS(c)
   167  
   168  	if len(c.Certificates) < 1 {
   169  		return nil, errors.New("invalid configuration: needs at least one certificate")
   170  	}
   171  
   172  	subject, err := GetAndValidateCertificateSubject(c.Certificates)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  
   177  	tc := &MutableTLSCreds{config: c, tlsCreds: originalTC, subject: subject}
   178  	tc.config.NextProtos = alpnProtoStr
   179  
   180  	return tc, nil
   181  }
   182  
   183  // GetAndValidateCertificateSubject is a helper method to retrieve and validate the subject
   184  // from the x509 certificate underlying a tls.Certificate
   185  func GetAndValidateCertificateSubject(certs []tls.Certificate) (pkix.Name, error) {
   186  	for i := range certs {
   187  		cert := &certs[i]
   188  		x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
   189  		if err != nil {
   190  			continue
   191  		}
   192  		if len(x509Cert.Subject.OrganizationalUnit) < 1 {
   193  			return pkix.Name{}, errors.New("no OU found in certificate subject")
   194  		}
   195  
   196  		if len(x509Cert.Subject.Organization) < 1 {
   197  			return pkix.Name{}, errors.New("no organization found in certificate subject")
   198  		}
   199  		if x509Cert.Subject.CommonName == "" {
   200  			return pkix.Name{}, errors.New("no valid subject names found for TLS configuration")
   201  		}
   202  
   203  		return x509Cert.Subject, nil
   204  	}
   205  
   206  	return pkix.Name{}, errors.New("no valid certificates found for TLS configuration")
   207  }