github.com/google/martian/v3@v3.3.3/mitm/mitm_test.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mitm
    16  
    17  import (
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"net"
    21  	"reflect"
    22  	"testing"
    23  	"time"
    24  )
    25  
    26  func TestMITM(t *testing.T) {
    27  	ca, priv, err := NewAuthority("martian.proxy", "Martian Authority", 24*time.Hour)
    28  	if err != nil {
    29  		t.Fatalf("NewAuthority(): got %v, want no error", err)
    30  	}
    31  
    32  	c, err := NewConfig(ca, priv)
    33  	if err != nil {
    34  		t.Fatalf("NewConfig(): got %v, want no error", err)
    35  	}
    36  
    37  	c.SetValidity(20 * time.Hour)
    38  	c.SetOrganization("Test Organization")
    39  
    40  	protos := []string{"http/1.1"}
    41  
    42  	conf := c.TLS()
    43  	if got := conf.NextProtos; !reflect.DeepEqual(got, protos) {
    44  		t.Errorf("conf.NextProtos: got %v, want %v", got, protos)
    45  	}
    46  	if conf.InsecureSkipVerify {
    47  		t.Error("conf.InsecureSkipVerify: got true, want false")
    48  	}
    49  
    50  	// Simulate a TLS connection without SNI.
    51  	clientHello := &tls.ClientHelloInfo{
    52  		ServerName: "",
    53  	}
    54  
    55  	if _, err := conf.GetCertificate(clientHello); err == nil {
    56  		t.Fatal("conf.GetCertificate(): got nil, want error")
    57  	}
    58  
    59  	// Simulate a TLS connection with SNI.
    60  	clientHello.ServerName = "example.com"
    61  
    62  	tlsc, err := conf.GetCertificate(clientHello)
    63  	if err != nil {
    64  		t.Fatalf("conf.GetCertificate(): got %v, want no error", err)
    65  	}
    66  
    67  	x509c := tlsc.Leaf
    68  	if got, want := x509c.Subject.CommonName, "example.com"; got != want {
    69  		t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want)
    70  	}
    71  
    72  	c.SkipTLSVerify(true)
    73  
    74  	conf = c.TLSForHost("example.com")
    75  	if got := conf.NextProtos; !reflect.DeepEqual(got, protos) {
    76  		t.Errorf("conf.NextProtos: got %v, want %v", got, protos)
    77  	}
    78  	if !conf.InsecureSkipVerify {
    79  		t.Error("conf.InsecureSkipVerify: got false, want true")
    80  	}
    81  
    82  	// Set SNI, takes precedence over host.
    83  	clientHello.ServerName = "google.com"
    84  	tlsc, err = conf.GetCertificate(clientHello)
    85  	if err != nil {
    86  		t.Fatalf("conf.GetCertificate(): got %v, want no error", err)
    87  	}
    88  
    89  	x509c = tlsc.Leaf
    90  	if got, want := x509c.Subject.CommonName, "google.com"; got != want {
    91  		t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want)
    92  	}
    93  
    94  	// Reset SNI to fallback to hostname.
    95  	clientHello.ServerName = ""
    96  	tlsc, err = conf.GetCertificate(clientHello)
    97  	if err != nil {
    98  		t.Fatalf("conf.GetCertificate(): got %v, want no error", err)
    99  	}
   100  
   101  	x509c = tlsc.Leaf
   102  	if got, want := x509c.Subject.CommonName, "example.com"; got != want {
   103  		t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want)
   104  	}
   105  }
   106  
   107  func TestCert(t *testing.T) {
   108  	ca, priv, err := NewAuthority("martian.proxy", "Martian Authority", 24*time.Hour)
   109  	if err != nil {
   110  		t.Fatalf("NewAuthority(): got %v, want no error", err)
   111  	}
   112  
   113  	c, err := NewConfig(ca, priv)
   114  	if err != nil {
   115  		t.Fatalf("NewConfig(): got %v, want no error", err)
   116  	}
   117  
   118  	tlsc, err := c.cert("example.com")
   119  	if err != nil {
   120  		t.Fatalf("c.cert(%q): got %v, want no error", "example.com:8080", err)
   121  	}
   122  
   123  	if tlsc.Certificate == nil {
   124  		t.Error("tlsc.Certificate: got nil, want certificate bytes")
   125  	}
   126  	if tlsc.PrivateKey == nil {
   127  		t.Error("tlsc.PrivateKey: got nil, want private key")
   128  	}
   129  
   130  	x509c := tlsc.Leaf
   131  	if x509c == nil {
   132  		t.Fatal("x509c: got nil, want *x509.Certificate")
   133  	}
   134  
   135  	if got := x509c.SerialNumber; got.Cmp(MaxSerialNumber) >= 0 {
   136  		t.Errorf("x509c.SerialNumber: got %v, want <= MaxSerialNumber", got)
   137  	}
   138  	if got, want := x509c.Subject.CommonName, "example.com"; got != want {
   139  		t.Errorf("X509c.Subject.CommonName: got %q, want %q", got, want)
   140  	}
   141  	if err := x509c.VerifyHostname("example.com"); err != nil {
   142  		t.Errorf("x509c.VerifyHostname(%q): got %v, want no error", "example.com", err)
   143  	}
   144  
   145  	if got, want := x509c.Subject.Organization, []string{"Martian Proxy"}; !reflect.DeepEqual(got, want) {
   146  		t.Errorf("x509c.Subject.Organization: got %v, want %v", got, want)
   147  	}
   148  
   149  	if got := x509c.SubjectKeyId; got == nil {
   150  		t.Error("x509c.SubjectKeyId: got nothing, want key ID")
   151  	}
   152  	if !x509c.BasicConstraintsValid {
   153  		t.Error("x509c.BasicConstraintsValid: got false, want true")
   154  	}
   155  
   156  	if got, want := x509c.KeyUsage, x509.KeyUsageKeyEncipherment; got&want == 0 {
   157  		t.Error("x509c.KeyUsage: got nothing, want to include x509.KeyUsageKeyEncipherment")
   158  	}
   159  	if got, want := x509c.KeyUsage, x509.KeyUsageDigitalSignature; got&want == 0 {
   160  		t.Error("x509c.KeyUsage: got nothing, want to include x509.KeyUsageDigitalSignature")
   161  	}
   162  
   163  	want := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
   164  	if got := x509c.ExtKeyUsage; !reflect.DeepEqual(got, want) {
   165  		t.Errorf("x509c.ExtKeyUsage: got %v, want %v", got, want)
   166  	}
   167  
   168  	if got, want := x509c.DNSNames, []string{"example.com"}; !reflect.DeepEqual(got, want) {
   169  		t.Errorf("x509c.DNSNames: got %v, want %v", got, want)
   170  	}
   171  
   172  	before := time.Now().Add(-2 * time.Hour)
   173  	if got := x509c.NotBefore; before.After(got) {
   174  		t.Errorf("x509c.NotBefore: got %v, want after %v", got, before)
   175  	}
   176  
   177  	after := time.Now().Add(2 * time.Hour)
   178  	if got := x509c.NotAfter; !after.After(got) {
   179  		t.Errorf("x509c.NotAfter: got %v, want before %v", got, want)
   180  	}
   181  
   182  	// Retrieve cached certificate.
   183  	tlsc2, err := c.cert("example.com")
   184  	if err != nil {
   185  		t.Fatalf("c.cert(%q): got %v, want no error", "example.com", err)
   186  	}
   187  	if tlsc != tlsc2 {
   188  		t.Error("tlsc2: got new certificate, want cached certificate")
   189  	}
   190  
   191  	// TLS certificate for IP.
   192  	tlsc, err = c.cert("10.0.0.1:8227")
   193  	if err != nil {
   194  		t.Fatalf("c.cert(%q): got %v, want no error", "10.0.0.1:8227", err)
   195  	}
   196  	x509c = tlsc.Leaf
   197  
   198  	if got, want := len(x509c.IPAddresses), 1; got != want {
   199  		t.Fatalf("len(x509c.IPAddresses): got %d, want %d", got, want)
   200  	}
   201  
   202  	if got, want := x509c.IPAddresses[0], net.ParseIP("10.0.0.1"); !got.Equal(want) {
   203  		t.Fatalf("x509c.IPAddresses: got %v, want %v", got, want)
   204  	}
   205  }