go.etcd.io/etcd@v3.3.27+incompatible/pkg/transport/listener_tls.go (about)

     1  // Copyright 2017 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transport
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"crypto/x509"
    21  	"fmt"
    22  	"io/ioutil"
    23  	"net"
    24  	"strings"
    25  	"sync"
    26  )
    27  
    28  // tlsListener overrides a TLS listener so it will reject client
    29  // certificates with insufficient SAN credentials or CRL revoked
    30  // certificates.
    31  type tlsListener struct {
    32  	net.Listener
    33  	connc            chan net.Conn
    34  	donec            chan struct{}
    35  	err              error
    36  	handshakeFailure func(*tls.Conn, error)
    37  	check            tlsCheckFunc
    38  }
    39  
    40  type tlsCheckFunc func(context.Context, *tls.Conn) error
    41  
    42  // NewTLSListener handshakes TLS connections and performs optional CRL checking.
    43  func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
    44  	check := func(context.Context, *tls.Conn) error { return nil }
    45  	return newTLSListener(l, tlsinfo, check)
    46  }
    47  
    48  func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) {
    49  	if tlsinfo == nil || tlsinfo.Empty() {
    50  		l.Close()
    51  		return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
    52  	}
    53  	tlscfg, err := tlsinfo.ServerConfig()
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	hf := tlsinfo.HandshakeFailure
    59  	if hf == nil {
    60  		hf = func(*tls.Conn, error) {}
    61  	}
    62  
    63  	if len(tlsinfo.CRLFile) > 0 {
    64  		prevCheck := check
    65  		check = func(ctx context.Context, tlsConn *tls.Conn) error {
    66  			if err := prevCheck(ctx, tlsConn); err != nil {
    67  				return err
    68  			}
    69  			st := tlsConn.ConnectionState()
    70  			if certs := st.PeerCertificates; len(certs) > 0 {
    71  				return checkCRL(tlsinfo.CRLFile, certs)
    72  			}
    73  			return nil
    74  		}
    75  	}
    76  
    77  	tlsl := &tlsListener{
    78  		Listener:         tls.NewListener(l, tlscfg),
    79  		connc:            make(chan net.Conn),
    80  		donec:            make(chan struct{}),
    81  		handshakeFailure: hf,
    82  		check:            check,
    83  	}
    84  	go tlsl.acceptLoop()
    85  	return tlsl, nil
    86  }
    87  
    88  func (l *tlsListener) Accept() (net.Conn, error) {
    89  	select {
    90  	case conn := <-l.connc:
    91  		return conn, nil
    92  	case <-l.donec:
    93  		return nil, l.err
    94  	}
    95  }
    96  
    97  func checkSAN(ctx context.Context, tlsConn *tls.Conn) error {
    98  	st := tlsConn.ConnectionState()
    99  	if certs := st.PeerCertificates; len(certs) > 0 {
   100  		addr := tlsConn.RemoteAddr().String()
   101  		return checkCertSAN(ctx, certs[0], addr)
   102  	}
   103  	return nil
   104  }
   105  
   106  // acceptLoop launches each TLS handshake in a separate goroutine
   107  // to prevent a hanging TLS connection from blocking other connections.
   108  func (l *tlsListener) acceptLoop() {
   109  	var wg sync.WaitGroup
   110  	var pendingMu sync.Mutex
   111  
   112  	pending := make(map[net.Conn]struct{})
   113  	ctx, cancel := context.WithCancel(context.Background())
   114  	defer func() {
   115  		cancel()
   116  		pendingMu.Lock()
   117  		for c := range pending {
   118  			c.Close()
   119  		}
   120  		pendingMu.Unlock()
   121  		wg.Wait()
   122  		close(l.donec)
   123  	}()
   124  
   125  	for {
   126  		conn, err := l.Listener.Accept()
   127  		if err != nil {
   128  			l.err = err
   129  			return
   130  		}
   131  
   132  		pendingMu.Lock()
   133  		pending[conn] = struct{}{}
   134  		pendingMu.Unlock()
   135  
   136  		wg.Add(1)
   137  		go func() {
   138  			defer func() {
   139  				if conn != nil {
   140  					conn.Close()
   141  				}
   142  				wg.Done()
   143  			}()
   144  
   145  			tlsConn := conn.(*tls.Conn)
   146  			herr := tlsConn.Handshake()
   147  			pendingMu.Lock()
   148  			delete(pending, conn)
   149  			pendingMu.Unlock()
   150  
   151  			if herr != nil {
   152  				l.handshakeFailure(tlsConn, herr)
   153  				return
   154  			}
   155  			if err := l.check(ctx, tlsConn); err != nil {
   156  				l.handshakeFailure(tlsConn, err)
   157  				return
   158  			}
   159  
   160  			select {
   161  			case l.connc <- tlsConn:
   162  				conn = nil
   163  			case <-ctx.Done():
   164  			}
   165  		}()
   166  	}
   167  }
   168  
   169  func checkCRL(crlPath string, cert []*x509.Certificate) error {
   170  	// TODO: cache
   171  	crlBytes, err := ioutil.ReadFile(crlPath)
   172  	if err != nil {
   173  		return err
   174  	}
   175  	certList, err := x509.ParseCRL(crlBytes)
   176  	if err != nil {
   177  		return err
   178  	}
   179  	revokedSerials := make(map[string]struct{})
   180  	for _, rc := range certList.TBSCertList.RevokedCertificates {
   181  		revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{}
   182  	}
   183  	for _, c := range cert {
   184  		serial := string(c.SerialNumber.Bytes())
   185  		if _, ok := revokedSerials[serial]; ok {
   186  			return fmt.Errorf("transport: certificate serial %x revoked", serial)
   187  		}
   188  	}
   189  	return nil
   190  }
   191  
   192  func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
   193  	if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
   194  		return nil
   195  	}
   196  	h, _, herr := net.SplitHostPort(remoteAddr)
   197  	if herr != nil {
   198  		return herr
   199  	}
   200  	if len(cert.IPAddresses) > 0 {
   201  		cerr := cert.VerifyHostname(h)
   202  		if cerr == nil {
   203  			return nil
   204  		}
   205  		if len(cert.DNSNames) == 0 {
   206  			return cerr
   207  		}
   208  	}
   209  	if len(cert.DNSNames) > 0 {
   210  		ok, err := isHostInDNS(ctx, h, cert.DNSNames)
   211  		if ok {
   212  			return nil
   213  		}
   214  		errStr := ""
   215  		if err != nil {
   216  			errStr = " (" + err.Error() + ")"
   217  		}
   218  		return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames)
   219  	}
   220  	return nil
   221  }
   222  
   223  func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) {
   224  	// reverse lookup
   225  	wildcards, names := []string{}, []string{}
   226  	for _, dns := range dnsNames {
   227  		if strings.HasPrefix(dns, "*.") {
   228  			wildcards = append(wildcards, dns[1:])
   229  		} else {
   230  			names = append(names, dns)
   231  		}
   232  	}
   233  	lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host)
   234  	for _, name := range lnames {
   235  		// strip trailing '.' from PTR record
   236  		if name[len(name)-1] == '.' {
   237  			name = name[:len(name)-1]
   238  		}
   239  		for _, wc := range wildcards {
   240  			if strings.HasSuffix(name, wc) {
   241  				return true, nil
   242  			}
   243  		}
   244  		for _, n := range names {
   245  			if n == name {
   246  				return true, nil
   247  			}
   248  		}
   249  	}
   250  	err = lerr
   251  
   252  	// forward lookup
   253  	for _, dns := range names {
   254  		addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
   255  		if lerr != nil {
   256  			err = lerr
   257  			continue
   258  		}
   259  		for _, addr := range addrs {
   260  			if addr == host {
   261  				return true, nil
   262  			}
   263  		}
   264  	}
   265  	return false, err
   266  }
   267  
   268  func (l *tlsListener) Close() error {
   269  	err := l.Listener.Close()
   270  	<-l.donec
   271  	return err
   272  }