github.com/google/martian/v3@v3.3.3/mitm/mitm_test.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mitm 16 17 import ( 18 "crypto/tls" 19 "crypto/x509" 20 "net" 21 "reflect" 22 "testing" 23 "time" 24 ) 25 26 func TestMITM(t *testing.T) { 27 ca, priv, err := NewAuthority("martian.proxy", "Martian Authority", 24*time.Hour) 28 if err != nil { 29 t.Fatalf("NewAuthority(): got %v, want no error", err) 30 } 31 32 c, err := NewConfig(ca, priv) 33 if err != nil { 34 t.Fatalf("NewConfig(): got %v, want no error", err) 35 } 36 37 c.SetValidity(20 * time.Hour) 38 c.SetOrganization("Test Organization") 39 40 protos := []string{"http/1.1"} 41 42 conf := c.TLS() 43 if got := conf.NextProtos; !reflect.DeepEqual(got, protos) { 44 t.Errorf("conf.NextProtos: got %v, want %v", got, protos) 45 } 46 if conf.InsecureSkipVerify { 47 t.Error("conf.InsecureSkipVerify: got true, want false") 48 } 49 50 // Simulate a TLS connection without SNI. 51 clientHello := &tls.ClientHelloInfo{ 52 ServerName: "", 53 } 54 55 if _, err := conf.GetCertificate(clientHello); err == nil { 56 t.Fatal("conf.GetCertificate(): got nil, want error") 57 } 58 59 // Simulate a TLS connection with SNI. 60 clientHello.ServerName = "example.com" 61 62 tlsc, err := conf.GetCertificate(clientHello) 63 if err != nil { 64 t.Fatalf("conf.GetCertificate(): got %v, want no error", err) 65 } 66 67 x509c := tlsc.Leaf 68 if got, want := x509c.Subject.CommonName, "example.com"; got != want { 69 t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want) 70 } 71 72 c.SkipTLSVerify(true) 73 74 conf = c.TLSForHost("example.com") 75 if got := conf.NextProtos; !reflect.DeepEqual(got, protos) { 76 t.Errorf("conf.NextProtos: got %v, want %v", got, protos) 77 } 78 if !conf.InsecureSkipVerify { 79 t.Error("conf.InsecureSkipVerify: got false, want true") 80 } 81 82 // Set SNI, takes precedence over host. 83 clientHello.ServerName = "google.com" 84 tlsc, err = conf.GetCertificate(clientHello) 85 if err != nil { 86 t.Fatalf("conf.GetCertificate(): got %v, want no error", err) 87 } 88 89 x509c = tlsc.Leaf 90 if got, want := x509c.Subject.CommonName, "google.com"; got != want { 91 t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want) 92 } 93 94 // Reset SNI to fallback to hostname. 95 clientHello.ServerName = "" 96 tlsc, err = conf.GetCertificate(clientHello) 97 if err != nil { 98 t.Fatalf("conf.GetCertificate(): got %v, want no error", err) 99 } 100 101 x509c = tlsc.Leaf 102 if got, want := x509c.Subject.CommonName, "example.com"; got != want { 103 t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want) 104 } 105 } 106 107 func TestCert(t *testing.T) { 108 ca, priv, err := NewAuthority("martian.proxy", "Martian Authority", 24*time.Hour) 109 if err != nil { 110 t.Fatalf("NewAuthority(): got %v, want no error", err) 111 } 112 113 c, err := NewConfig(ca, priv) 114 if err != nil { 115 t.Fatalf("NewConfig(): got %v, want no error", err) 116 } 117 118 tlsc, err := c.cert("example.com") 119 if err != nil { 120 t.Fatalf("c.cert(%q): got %v, want no error", "example.com:8080", err) 121 } 122 123 if tlsc.Certificate == nil { 124 t.Error("tlsc.Certificate: got nil, want certificate bytes") 125 } 126 if tlsc.PrivateKey == nil { 127 t.Error("tlsc.PrivateKey: got nil, want private key") 128 } 129 130 x509c := tlsc.Leaf 131 if x509c == nil { 132 t.Fatal("x509c: got nil, want *x509.Certificate") 133 } 134 135 if got := x509c.SerialNumber; got.Cmp(MaxSerialNumber) >= 0 { 136 t.Errorf("x509c.SerialNumber: got %v, want <= MaxSerialNumber", got) 137 } 138 if got, want := x509c.Subject.CommonName, "example.com"; got != want { 139 t.Errorf("X509c.Subject.CommonName: got %q, want %q", got, want) 140 } 141 if err := x509c.VerifyHostname("example.com"); err != nil { 142 t.Errorf("x509c.VerifyHostname(%q): got %v, want no error", "example.com", err) 143 } 144 145 if got, want := x509c.Subject.Organization, []string{"Martian Proxy"}; !reflect.DeepEqual(got, want) { 146 t.Errorf("x509c.Subject.Organization: got %v, want %v", got, want) 147 } 148 149 if got := x509c.SubjectKeyId; got == nil { 150 t.Error("x509c.SubjectKeyId: got nothing, want key ID") 151 } 152 if !x509c.BasicConstraintsValid { 153 t.Error("x509c.BasicConstraintsValid: got false, want true") 154 } 155 156 if got, want := x509c.KeyUsage, x509.KeyUsageKeyEncipherment; got&want == 0 { 157 t.Error("x509c.KeyUsage: got nothing, want to include x509.KeyUsageKeyEncipherment") 158 } 159 if got, want := x509c.KeyUsage, x509.KeyUsageDigitalSignature; got&want == 0 { 160 t.Error("x509c.KeyUsage: got nothing, want to include x509.KeyUsageDigitalSignature") 161 } 162 163 want := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} 164 if got := x509c.ExtKeyUsage; !reflect.DeepEqual(got, want) { 165 t.Errorf("x509c.ExtKeyUsage: got %v, want %v", got, want) 166 } 167 168 if got, want := x509c.DNSNames, []string{"example.com"}; !reflect.DeepEqual(got, want) { 169 t.Errorf("x509c.DNSNames: got %v, want %v", got, want) 170 } 171 172 before := time.Now().Add(-2 * time.Hour) 173 if got := x509c.NotBefore; before.After(got) { 174 t.Errorf("x509c.NotBefore: got %v, want after %v", got, before) 175 } 176 177 after := time.Now().Add(2 * time.Hour) 178 if got := x509c.NotAfter; !after.After(got) { 179 t.Errorf("x509c.NotAfter: got %v, want before %v", got, want) 180 } 181 182 // Retrieve cached certificate. 183 tlsc2, err := c.cert("example.com") 184 if err != nil { 185 t.Fatalf("c.cert(%q): got %v, want no error", "example.com", err) 186 } 187 if tlsc != tlsc2 { 188 t.Error("tlsc2: got new certificate, want cached certificate") 189 } 190 191 // TLS certificate for IP. 192 tlsc, err = c.cert("10.0.0.1:8227") 193 if err != nil { 194 t.Fatalf("c.cert(%q): got %v, want no error", "10.0.0.1:8227", err) 195 } 196 x509c = tlsc.Leaf 197 198 if got, want := len(x509c.IPAddresses), 1; got != want { 199 t.Fatalf("len(x509c.IPAddresses): got %d, want %d", got, want) 200 } 201 202 if got, want := x509c.IPAddresses[0], net.ParseIP("10.0.0.1"); !got.Equal(want) { 203 t.Fatalf("x509c.IPAddresses: got %v, want %v", got, want) 204 } 205 }