github.com/google/martian/v3@v3.3.3/mitm/mitm.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package mitm provides tooling for MITMing TLS connections. It provides
    16  // tooling to create CA certs and generate TLS configs that can be used to MITM
    17  // a TLS connection with a provided CA certificate.
    18  package mitm
    19  
    20  import (
    21  	"bytes"
    22  	"crypto/rand"
    23  	"crypto/rsa"
    24  	"crypto/sha1"
    25  	"crypto/tls"
    26  	"crypto/x509"
    27  	"crypto/x509/pkix"
    28  	"errors"
    29  	"math/big"
    30  	"net"
    31  	"net/http"
    32  	"sync"
    33  	"time"
    34  
    35  	"github.com/google/martian/v3/h2"
    36  	"github.com/google/martian/v3/log"
    37  )
    38  
    39  // MaxSerialNumber is the upper boundary that is used to create unique serial
    40  // numbers for the certificate. This can be any unsigned integer up to 20
    41  // bytes (2^(8*20)-1).
    42  var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20))
    43  
    44  // Config is a set of configuration values that are used to build TLS configs
    45  // capable of MITM.
    46  type Config struct {
    47  	ca                     *x509.Certificate
    48  	capriv                 interface{}
    49  	priv                   *rsa.PrivateKey
    50  	keyID                  []byte
    51  	validity               time.Duration
    52  	org                    string
    53  	h2Config               *h2.Config
    54  	getCertificate         func(*tls.ClientHelloInfo) (*tls.Certificate, error)
    55  	roots                  *x509.CertPool
    56  	skipVerify             bool
    57  	handshakeErrorCallback func(*http.Request, error)
    58  
    59  	certmu sync.RWMutex
    60  	certs  map[string]*tls.Certificate
    61  }
    62  
    63  // NewAuthority creates a new CA certificate and associated
    64  // private key.
    65  func NewAuthority(name, organization string, validity time.Duration) (*x509.Certificate, *rsa.PrivateKey, error) {
    66  	priv, err := rsa.GenerateKey(rand.Reader, 2048)
    67  	if err != nil {
    68  		return nil, nil, err
    69  	}
    70  	pub := priv.Public()
    71  
    72  	// Subject Key Identifier support for end entity certificate.
    73  	// https://www.ietf.org/rfc/rfc3280.txt (section 4.2.1.2)
    74  	pkixpub, err := x509.MarshalPKIXPublicKey(pub)
    75  	if err != nil {
    76  		return nil, nil, err
    77  	}
    78  	h := sha1.New()
    79  	h.Write(pkixpub)
    80  	keyID := h.Sum(nil)
    81  
    82  	// TODO: keep a map of used serial numbers to avoid potentially reusing a
    83  	// serial multiple times.
    84  	serial, err := rand.Int(rand.Reader, MaxSerialNumber)
    85  	if err != nil {
    86  		return nil, nil, err
    87  	}
    88  
    89  	tmpl := &x509.Certificate{
    90  		SerialNumber: serial,
    91  		Subject: pkix.Name{
    92  			CommonName:   name,
    93  			Organization: []string{organization},
    94  		},
    95  		SubjectKeyId:          keyID,
    96  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
    97  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
    98  		BasicConstraintsValid: true,
    99  		NotBefore:             time.Now().Add(-validity),
   100  		NotAfter:              time.Now().Add(validity),
   101  		DNSNames:              []string{name},
   102  		IsCA:                  true,
   103  	}
   104  
   105  	raw, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv)
   106  	if err != nil {
   107  		return nil, nil, err
   108  	}
   109  
   110  	// Parse certificate bytes so that we have a leaf certificate.
   111  	x509c, err := x509.ParseCertificate(raw)
   112  	if err != nil {
   113  		return nil, nil, err
   114  	}
   115  
   116  	return x509c, priv, nil
   117  }
   118  
   119  // NewConfig creates a MITM config using the CA certificate and
   120  // private key to generate on-the-fly certificates.
   121  func NewConfig(ca *x509.Certificate, privateKey interface{}) (*Config, error) {
   122  	roots := x509.NewCertPool()
   123  	roots.AddCert(ca)
   124  
   125  	priv, err := rsa.GenerateKey(rand.Reader, 2048)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	pub := priv.Public()
   130  
   131  	// Subject Key Identifier support for end entity certificate.
   132  	// https://www.ietf.org/rfc/rfc3280.txt (section 4.2.1.2)
   133  	pkixpub, err := x509.MarshalPKIXPublicKey(pub)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	h := sha1.New()
   138  	h.Write(pkixpub)
   139  	keyID := h.Sum(nil)
   140  
   141  	return &Config{
   142  		ca:       ca,
   143  		capriv:   privateKey,
   144  		priv:     priv,
   145  		keyID:    keyID,
   146  		validity: time.Hour,
   147  		org:      "Martian Proxy",
   148  		certs:    make(map[string]*tls.Certificate),
   149  		roots:    roots,
   150  	}, nil
   151  }
   152  
   153  // SetValidity sets the validity window around the current time that the
   154  // certificate is valid for.
   155  func (c *Config) SetValidity(validity time.Duration) {
   156  	c.validity = validity
   157  }
   158  
   159  // SkipTLSVerify skips the TLS certification verification check.
   160  func (c *Config) SkipTLSVerify(skip bool) {
   161  	c.skipVerify = skip
   162  }
   163  
   164  // SetOrganization sets the organization of the certificate.
   165  func (c *Config) SetOrganization(org string) {
   166  	c.org = org
   167  }
   168  
   169  // SetH2Config configures processing of HTTP/2 streams.
   170  func (c *Config) SetH2Config(h2Config *h2.Config) {
   171  	c.h2Config = h2Config
   172  }
   173  
   174  // H2Config returns the current HTTP/2 configuration.
   175  func (c *Config) H2Config() *h2.Config {
   176  	return c.h2Config
   177  }
   178  
   179  // SetHandshakeErrorCallback sets the handshakeErrorCallback function.
   180  func (c *Config) SetHandshakeErrorCallback(cb func(*http.Request, error)) {
   181  	c.handshakeErrorCallback = cb
   182  }
   183  
   184  // HandshakeErrorCallback calls the handshakeErrorCallback function in this
   185  // Config, if it is non-nil. Request is the connect request that this handshake
   186  // is being executed through.
   187  func (c *Config) HandshakeErrorCallback(r *http.Request, err error) {
   188  	if c.handshakeErrorCallback != nil {
   189  		c.handshakeErrorCallback(r, err)
   190  	}
   191  }
   192  
   193  // TLS returns a *tls.Config that will generate certificates on-the-fly using
   194  // the SNI extension in the TLS ClientHello.
   195  func (c *Config) TLS() *tls.Config {
   196  	return &tls.Config{
   197  		InsecureSkipVerify: c.skipVerify,
   198  		GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   199  			if clientHello.ServerName == "" {
   200  				return nil, errors.New("mitm: SNI not provided, failed to build certificate")
   201  			}
   202  
   203  			return c.cert(clientHello.ServerName)
   204  		},
   205  		NextProtos: []string{"http/1.1"},
   206  	}
   207  }
   208  
   209  // TLSForHost returns a *tls.Config that will generate certificates on-the-fly
   210  // using SNI from the connection, or fall back to the provided hostname.
   211  func (c *Config) TLSForHost(hostname string) *tls.Config {
   212  	nextProtos := []string{"http/1.1"}
   213  	if c.h2AllowedHost(hostname) {
   214  		nextProtos = []string{"h2", "http/1.1"}
   215  	}
   216  	return &tls.Config{
   217  		InsecureSkipVerify: c.skipVerify,
   218  		GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   219  			host := clientHello.ServerName
   220  			if host == "" {
   221  				host = hostname
   222  			}
   223  
   224  			return c.cert(host)
   225  		},
   226  		NextProtos: nextProtos,
   227  	}
   228  }
   229  
   230  func (c *Config) h2AllowedHost(host string) bool {
   231  	return c.h2Config != nil &&
   232  		c.h2Config.AllowedHostsFilter != nil &&
   233  		c.h2Config.AllowedHostsFilter(host)
   234  }
   235  
   236  func (c *Config) cert(hostname string) (*tls.Certificate, error) {
   237  	// Remove the port if it exists.
   238  	host, _, err := net.SplitHostPort(hostname)
   239  	if err == nil {
   240  		hostname = host
   241  	}
   242  
   243  	c.certmu.RLock()
   244  	tlsc, ok := c.certs[hostname]
   245  	c.certmu.RUnlock()
   246  
   247  	if ok {
   248  		log.Debugf("mitm: cache hit for %s", hostname)
   249  
   250  		// Check validity of the certificate for hostname match, expiry, etc. In
   251  		// particular, if the cached certificate has expired, create a new one.
   252  		if _, err := tlsc.Leaf.Verify(x509.VerifyOptions{
   253  			DNSName: hostname,
   254  			Roots:   c.roots,
   255  		}); err == nil {
   256  			return tlsc, nil
   257  		}
   258  
   259  		log.Debugf("mitm: invalid certificate in cache for %s", hostname)
   260  	}
   261  
   262  	log.Debugf("mitm: cache miss for %s", hostname)
   263  
   264  	serial, err := rand.Int(rand.Reader, MaxSerialNumber)
   265  	if err != nil {
   266  		return nil, err
   267  	}
   268  
   269  	tmpl := &x509.Certificate{
   270  		SerialNumber: serial,
   271  		Subject: pkix.Name{
   272  			CommonName:   hostname,
   273  			Organization: []string{c.org},
   274  		},
   275  		SubjectKeyId:          c.keyID,
   276  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
   277  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   278  		BasicConstraintsValid: true,
   279  		NotBefore:             time.Now().Add(-c.validity),
   280  		NotAfter:              time.Now().Add(c.validity),
   281  	}
   282  
   283  	if ip := net.ParseIP(hostname); ip != nil {
   284  		tmpl.IPAddresses = []net.IP{ip}
   285  	} else {
   286  		tmpl.DNSNames = []string{hostname}
   287  	}
   288  
   289  	raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.priv.Public(), c.capriv)
   290  	if err != nil {
   291  		return nil, err
   292  	}
   293  
   294  	// Parse certificate bytes so that we have a leaf certificate.
   295  	x509c, err := x509.ParseCertificate(raw)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  
   300  	tlsc = &tls.Certificate{
   301  		Certificate: [][]byte{raw, c.ca.Raw},
   302  		PrivateKey:  c.priv,
   303  		Leaf:        x509c,
   304  	}
   305  
   306  	c.certmu.Lock()
   307  	c.certs[hostname] = tlsc
   308  	c.certmu.Unlock()
   309  
   310  	return tlsc, nil
   311  }