github.com/lmb/consul@v1.4.1/connect/service_test.go (about) 1 package connect 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "crypto/x509" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "net/http" 12 "strings" 13 "testing" 14 "time" 15 16 "github.com/stretchr/testify/assert" 17 18 "github.com/hashicorp/consul/agent" 19 "github.com/hashicorp/consul/agent/connect" 20 "github.com/hashicorp/consul/api" 21 "github.com/hashicorp/consul/testrpc" 22 "github.com/hashicorp/consul/testutil/retry" 23 "github.com/stretchr/testify/require" 24 ) 25 26 // Assert io.Closer implementation 27 var _ io.Closer = new(Service) 28 29 func TestService_Name(t *testing.T) { 30 ca := connect.TestCA(t, nil) 31 s := TestService(t, "web", ca) 32 assert.Equal(t, "web", s.Name()) 33 } 34 35 func TestService_Dial(t *testing.T) { 36 ca := connect.TestCA(t, nil) 37 38 tests := []struct { 39 name string 40 accept bool 41 handshake bool 42 presentService string 43 wantErr string 44 }{ 45 { 46 name: "working", 47 accept: true, 48 handshake: true, 49 presentService: "db", 50 wantErr: "", 51 }, 52 { 53 name: "tcp connect fail", 54 accept: false, 55 handshake: false, 56 presentService: "db", 57 wantErr: "connection refused", 58 }, 59 { 60 name: "handshake timeout", 61 accept: true, 62 handshake: false, 63 presentService: "db", 64 wantErr: "i/o timeout", 65 }, 66 { 67 name: "bad cert", 68 accept: true, 69 handshake: true, 70 presentService: "web", 71 wantErr: "peer certificate mismatch", 72 }, 73 } 74 for _, tt := range tests { 75 t.Run(tt.name, func(t *testing.T) { 76 require := require.New(t) 77 78 s := TestService(t, "web", ca) 79 80 ctx, cancel := context.WithTimeout(context.Background(), 81 100*time.Millisecond) 82 defer cancel() 83 84 testSvr := NewTestServer(t, tt.presentService, ca) 85 testSvr.TimeoutHandshake = !tt.handshake 86 87 if tt.accept { 88 go func() { 89 err := testSvr.Serve() 90 require.NoError(err) 91 }() 92 defer testSvr.Close() 93 <-testSvr.Listening 94 } 95 96 // Always expect to be connecting to a "DB" 97 resolver := &StaticResolver{ 98 Addr: testSvr.Addr, 99 CertURI: connect.TestSpiffeIDService(t, "db"), 100 } 101 102 // All test runs should complete in under 500ms due to the timeout about. 103 // Don't wait for whole test run to get stuck. 104 testTimeout := 500 * time.Millisecond 105 testTimer := time.AfterFunc(testTimeout, func() { 106 panic(fmt.Sprintf("test timed out after %s", testTimeout)) 107 }) 108 109 conn, err := s.Dial(ctx, resolver) 110 testTimer.Stop() 111 112 if tt.wantErr == "" { 113 require.NoError(err) 114 require.IsType(&tls.Conn{}, conn) 115 } else { 116 require.Error(err) 117 require.Contains(err.Error(), tt.wantErr) 118 } 119 120 if err == nil { 121 conn.Close() 122 } 123 }) 124 } 125 } 126 127 func TestService_ServerTLSConfig(t *testing.T) { 128 require := require.New(t) 129 130 a := agent.NewTestAgent("007", "") 131 defer a.Shutdown() 132 testrpc.WaitForTestAgent(t, a.RPC, "dc1") 133 client := a.Client() 134 agent := client.Agent() 135 136 // NewTestAgent setup a CA already by default 137 138 // Register a local agent service with a managed proxy 139 reg := &api.AgentServiceRegistration{ 140 Name: "web", 141 Port: 8080, 142 } 143 err := agent.ServiceRegister(reg) 144 require.NoError(err) 145 146 // Now we should be able to create a service that will eventually get it's TLS 147 // all by itself! 148 service, err := NewService("web", client) 149 require.NoError(err) 150 151 // Wait for it to be ready 152 select { 153 case <-service.ReadyWait(): 154 // continue with test case below 155 case <-time.After(1 * time.Second): 156 t.Fatalf("timeout waiting for Service.ReadyWait after 1s") 157 } 158 159 tlsCfg := service.ServerTLSConfig() 160 161 // Sanity check it has a leaf with the right ServiceID and that validates with 162 // the given roots. 163 require.NotNil(tlsCfg.GetCertificate) 164 leaf, err := tlsCfg.GetCertificate(&tls.ClientHelloInfo{}) 165 require.NoError(err) 166 cert, err := x509.ParseCertificate(leaf.Certificate[0]) 167 require.NoError(err) 168 require.Len(cert.URIs, 1) 169 require.True(strings.HasSuffix(cert.URIs[0].String(), "/svc/web")) 170 171 // Verify it as a client would 172 err = clientSideVerifier(tlsCfg, leaf.Certificate) 173 require.NoError(err) 174 175 // Now test that rotating the root updates 176 { 177 // Setup a new generated CA 178 connect.TestCAConfigSet(t, a, nil) 179 } 180 181 // After some time, both root and leaves should be different but both should 182 // still be correct. 183 oldRootSubjects := bytes.Join(tlsCfg.RootCAs.Subjects(), []byte(", ")) 184 oldLeafSerial := connect.HexString(cert.SerialNumber.Bytes()) 185 oldLeafKeyID := connect.HexString(cert.SubjectKeyId) 186 retry.Run(t, func(r *retry.R) { 187 updatedCfg := service.ServerTLSConfig() 188 189 // Wait until roots are different 190 rootSubjects := bytes.Join(updatedCfg.RootCAs.Subjects(), []byte(", ")) 191 if bytes.Equal(oldRootSubjects, rootSubjects) { 192 r.Fatalf("root certificates should have changed, got %s", 193 rootSubjects) 194 } 195 196 leaf, err := updatedCfg.GetCertificate(&tls.ClientHelloInfo{}) 197 r.Check(err) 198 cert, err := x509.ParseCertificate(leaf.Certificate[0]) 199 r.Check(err) 200 201 if oldLeafSerial == connect.HexString(cert.SerialNumber.Bytes()) { 202 r.Fatalf("leaf certificate should have changed, got serial %s", 203 oldLeafSerial) 204 } 205 if oldLeafKeyID == connect.HexString(cert.SubjectKeyId) { 206 r.Fatalf("leaf should have a different key, got matching SubjectKeyID = %s", 207 oldLeafKeyID) 208 } 209 }) 210 } 211 212 func TestService_HTTPClient(t *testing.T) { 213 ca := connect.TestCA(t, nil) 214 215 s := TestService(t, "web", ca) 216 217 // Run a test HTTP server 218 testSvr := NewTestServer(t, "backend", ca) 219 defer testSvr.Close() 220 go func() { 221 err := testSvr.ServeHTTPS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 222 w.Write([]byte("Hello, I am Backend")) 223 })) 224 require.NoError(t, err) 225 }() 226 <-testSvr.Listening 227 228 // Still get connection refused some times so retry on those 229 retry.Run(t, func(r *retry.R) { 230 // Hook the service resolver to avoid needing full agent setup. 231 s.httpResolverFromAddr = func(addr string) (Resolver, error) { 232 // Require in this goroutine seems to block causing a timeout on the Get. 233 //require.Equal("https://backend.service.consul:443", addr) 234 return &StaticResolver{ 235 Addr: testSvr.Addr, 236 CertURI: connect.TestSpiffeIDService(t, "backend"), 237 }, nil 238 } 239 240 client := s.HTTPClient() 241 client.Timeout = 1 * time.Second 242 243 resp, err := client.Get("https://backend.service.consul/foo") 244 r.Check(err) 245 defer resp.Body.Close() 246 247 bodyBytes, err := ioutil.ReadAll(resp.Body) 248 r.Check(err) 249 250 got := string(bodyBytes) 251 want := "Hello, I am Backend" 252 if got != want { 253 r.Fatalf("got %s, want %s", got, want) 254 } 255 }) 256 } 257 258 func TestService_HasDefaultHTTPResolverFromAddr(t *testing.T) { 259 260 client, err := api.NewClient(api.DefaultConfig()) 261 require.NoError(t, err) 262 263 s, err := NewService("foo", client) 264 require.NoError(t, err) 265 266 // Sanity check this is actually set in constructor since we always override 267 // it in tests. Full tests of the resolver func are in resolver_test.go 268 require.NotNil(t, s.httpResolverFromAddr) 269 270 fn := s.httpResolverFromAddr 271 272 expected := &ConsulResolver{ 273 Client: client, 274 Namespace: "default", 275 Name: "foo", 276 Type: ConsulResolverTypeService, 277 } 278 got, err := fn("foo.service.consul") 279 require.NoError(t, err) 280 require.Equal(t, expected, got) 281 }