go.etcd.io/etcd@v3.3.27+incompatible/pkg/transport/listener_tls.go (about) 1 // Copyright 2017 The etcd Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package transport 16 17 import ( 18 "context" 19 "crypto/tls" 20 "crypto/x509" 21 "fmt" 22 "io/ioutil" 23 "net" 24 "strings" 25 "sync" 26 ) 27 28 // tlsListener overrides a TLS listener so it will reject client 29 // certificates with insufficient SAN credentials or CRL revoked 30 // certificates. 31 type tlsListener struct { 32 net.Listener 33 connc chan net.Conn 34 donec chan struct{} 35 err error 36 handshakeFailure func(*tls.Conn, error) 37 check tlsCheckFunc 38 } 39 40 type tlsCheckFunc func(context.Context, *tls.Conn) error 41 42 // NewTLSListener handshakes TLS connections and performs optional CRL checking. 43 func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) { 44 check := func(context.Context, *tls.Conn) error { return nil } 45 return newTLSListener(l, tlsinfo, check) 46 } 47 48 func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) { 49 if tlsinfo == nil || tlsinfo.Empty() { 50 l.Close() 51 return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String()) 52 } 53 tlscfg, err := tlsinfo.ServerConfig() 54 if err != nil { 55 return nil, err 56 } 57 58 hf := tlsinfo.HandshakeFailure 59 if hf == nil { 60 hf = func(*tls.Conn, error) {} 61 } 62 63 if len(tlsinfo.CRLFile) > 0 { 64 prevCheck := check 65 check = func(ctx context.Context, tlsConn *tls.Conn) error { 66 if err := prevCheck(ctx, tlsConn); err != nil { 67 return err 68 } 69 st := tlsConn.ConnectionState() 70 if certs := st.PeerCertificates; len(certs) > 0 { 71 return checkCRL(tlsinfo.CRLFile, certs) 72 } 73 return nil 74 } 75 } 76 77 tlsl := &tlsListener{ 78 Listener: tls.NewListener(l, tlscfg), 79 connc: make(chan net.Conn), 80 donec: make(chan struct{}), 81 handshakeFailure: hf, 82 check: check, 83 } 84 go tlsl.acceptLoop() 85 return tlsl, nil 86 } 87 88 func (l *tlsListener) Accept() (net.Conn, error) { 89 select { 90 case conn := <-l.connc: 91 return conn, nil 92 case <-l.donec: 93 return nil, l.err 94 } 95 } 96 97 func checkSAN(ctx context.Context, tlsConn *tls.Conn) error { 98 st := tlsConn.ConnectionState() 99 if certs := st.PeerCertificates; len(certs) > 0 { 100 addr := tlsConn.RemoteAddr().String() 101 return checkCertSAN(ctx, certs[0], addr) 102 } 103 return nil 104 } 105 106 // acceptLoop launches each TLS handshake in a separate goroutine 107 // to prevent a hanging TLS connection from blocking other connections. 108 func (l *tlsListener) acceptLoop() { 109 var wg sync.WaitGroup 110 var pendingMu sync.Mutex 111 112 pending := make(map[net.Conn]struct{}) 113 ctx, cancel := context.WithCancel(context.Background()) 114 defer func() { 115 cancel() 116 pendingMu.Lock() 117 for c := range pending { 118 c.Close() 119 } 120 pendingMu.Unlock() 121 wg.Wait() 122 close(l.donec) 123 }() 124 125 for { 126 conn, err := l.Listener.Accept() 127 if err != nil { 128 l.err = err 129 return 130 } 131 132 pendingMu.Lock() 133 pending[conn] = struct{}{} 134 pendingMu.Unlock() 135 136 wg.Add(1) 137 go func() { 138 defer func() { 139 if conn != nil { 140 conn.Close() 141 } 142 wg.Done() 143 }() 144 145 tlsConn := conn.(*tls.Conn) 146 herr := tlsConn.Handshake() 147 pendingMu.Lock() 148 delete(pending, conn) 149 pendingMu.Unlock() 150 151 if herr != nil { 152 l.handshakeFailure(tlsConn, herr) 153 return 154 } 155 if err := l.check(ctx, tlsConn); err != nil { 156 l.handshakeFailure(tlsConn, err) 157 return 158 } 159 160 select { 161 case l.connc <- tlsConn: 162 conn = nil 163 case <-ctx.Done(): 164 } 165 }() 166 } 167 } 168 169 func checkCRL(crlPath string, cert []*x509.Certificate) error { 170 // TODO: cache 171 crlBytes, err := ioutil.ReadFile(crlPath) 172 if err != nil { 173 return err 174 } 175 certList, err := x509.ParseCRL(crlBytes) 176 if err != nil { 177 return err 178 } 179 revokedSerials := make(map[string]struct{}) 180 for _, rc := range certList.TBSCertList.RevokedCertificates { 181 revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{} 182 } 183 for _, c := range cert { 184 serial := string(c.SerialNumber.Bytes()) 185 if _, ok := revokedSerials[serial]; ok { 186 return fmt.Errorf("transport: certificate serial %x revoked", serial) 187 } 188 } 189 return nil 190 } 191 192 func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error { 193 if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 { 194 return nil 195 } 196 h, _, herr := net.SplitHostPort(remoteAddr) 197 if herr != nil { 198 return herr 199 } 200 if len(cert.IPAddresses) > 0 { 201 cerr := cert.VerifyHostname(h) 202 if cerr == nil { 203 return nil 204 } 205 if len(cert.DNSNames) == 0 { 206 return cerr 207 } 208 } 209 if len(cert.DNSNames) > 0 { 210 ok, err := isHostInDNS(ctx, h, cert.DNSNames) 211 if ok { 212 return nil 213 } 214 errStr := "" 215 if err != nil { 216 errStr = " (" + err.Error() + ")" 217 } 218 return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames) 219 } 220 return nil 221 } 222 223 func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) { 224 // reverse lookup 225 wildcards, names := []string{}, []string{} 226 for _, dns := range dnsNames { 227 if strings.HasPrefix(dns, "*.") { 228 wildcards = append(wildcards, dns[1:]) 229 } else { 230 names = append(names, dns) 231 } 232 } 233 lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host) 234 for _, name := range lnames { 235 // strip trailing '.' from PTR record 236 if name[len(name)-1] == '.' { 237 name = name[:len(name)-1] 238 } 239 for _, wc := range wildcards { 240 if strings.HasSuffix(name, wc) { 241 return true, nil 242 } 243 } 244 for _, n := range names { 245 if n == name { 246 return true, nil 247 } 248 } 249 } 250 err = lerr 251 252 // forward lookup 253 for _, dns := range names { 254 addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns) 255 if lerr != nil { 256 err = lerr 257 continue 258 } 259 for _, addr := range addrs { 260 if addr == host { 261 return true, nil 262 } 263 } 264 } 265 return false, err 266 } 267 268 func (l *tlsListener) Close() error { 269 err := l.Listener.Close() 270 <-l.donec 271 return err 272 }