github.com/outbrain/consul@v1.4.5/connect/service_test.go (about)

     1  package connect
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net/http"
    12  	"strings"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/stretchr/testify/assert"
    17  
    18  	"github.com/hashicorp/consul/agent"
    19  	"github.com/hashicorp/consul/agent/connect"
    20  	"github.com/hashicorp/consul/api"
    21  	"github.com/hashicorp/consul/testrpc"
    22  	"github.com/hashicorp/consul/testutil/retry"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  // Assert io.Closer implementation
    27  var _ io.Closer = new(Service)
    28  
    29  func TestService_Name(t *testing.T) {
    30  	ca := connect.TestCA(t, nil)
    31  	s := TestService(t, "web", ca)
    32  	assert.Equal(t, "web", s.Name())
    33  }
    34  
    35  func TestService_Dial(t *testing.T) {
    36  	ca := connect.TestCA(t, nil)
    37  
    38  	tests := []struct {
    39  		name           string
    40  		accept         bool
    41  		handshake      bool
    42  		presentService string
    43  		wantErr        string
    44  	}{
    45  		{
    46  			name:           "working",
    47  			accept:         true,
    48  			handshake:      true,
    49  			presentService: "db",
    50  			wantErr:        "",
    51  		},
    52  		{
    53  			name:           "tcp connect fail",
    54  			accept:         false,
    55  			handshake:      false,
    56  			presentService: "db",
    57  			wantErr:        "connection refused",
    58  		},
    59  		{
    60  			name:           "handshake timeout",
    61  			accept:         true,
    62  			handshake:      false,
    63  			presentService: "db",
    64  			wantErr:        "i/o timeout",
    65  		},
    66  		{
    67  			name:           "bad cert",
    68  			accept:         true,
    69  			handshake:      true,
    70  			presentService: "web",
    71  			wantErr:        "peer certificate mismatch",
    72  		},
    73  	}
    74  	for _, tt := range tests {
    75  		t.Run(tt.name, func(t *testing.T) {
    76  			require := require.New(t)
    77  
    78  			s := TestService(t, "web", ca)
    79  
    80  			ctx, cancel := context.WithTimeout(context.Background(),
    81  				100*time.Millisecond)
    82  			defer cancel()
    83  
    84  			testSvr := NewTestServer(t, tt.presentService, ca)
    85  			testSvr.TimeoutHandshake = !tt.handshake
    86  
    87  			if tt.accept {
    88  				go func() {
    89  					err := testSvr.Serve()
    90  					require.NoError(err)
    91  				}()
    92  				defer testSvr.Close()
    93  				<-testSvr.Listening
    94  			}
    95  
    96  			// Always expect to be connecting to a "DB"
    97  			resolver := &StaticResolver{
    98  				Addr:    testSvr.Addr,
    99  				CertURI: connect.TestSpiffeIDService(t, "db"),
   100  			}
   101  
   102  			// All test runs should complete in under 500ms due to the timeout about.
   103  			// Don't wait for whole test run to get stuck.
   104  			testTimeout := 500 * time.Millisecond
   105  			testTimer := time.AfterFunc(testTimeout, func() {
   106  				panic(fmt.Sprintf("test timed out after %s", testTimeout))
   107  			})
   108  
   109  			conn, err := s.Dial(ctx, resolver)
   110  			testTimer.Stop()
   111  
   112  			if tt.wantErr == "" {
   113  				require.NoError(err)
   114  				require.IsType(&tls.Conn{}, conn)
   115  			} else {
   116  				require.Error(err)
   117  				require.Contains(err.Error(), tt.wantErr)
   118  			}
   119  
   120  			if err == nil {
   121  				conn.Close()
   122  			}
   123  		})
   124  	}
   125  }
   126  
   127  func TestService_ServerTLSConfig(t *testing.T) {
   128  	require := require.New(t)
   129  
   130  	a := agent.NewTestAgent(t, "007", "")
   131  	defer a.Shutdown()
   132  	testrpc.WaitForTestAgent(t, a.RPC, "dc1")
   133  	client := a.Client()
   134  	agent := client.Agent()
   135  
   136  	// NewTestAgent setup a CA already by default
   137  
   138  	// Register a local agent service with a managed proxy
   139  	reg := &api.AgentServiceRegistration{
   140  		Name: "web",
   141  		Port: 8080,
   142  	}
   143  	err := agent.ServiceRegister(reg)
   144  	require.NoError(err)
   145  
   146  	// Now we should be able to create a service that will eventually get it's TLS
   147  	// all by itself!
   148  	service, err := NewService("web", client)
   149  	require.NoError(err)
   150  
   151  	// Wait for it to be ready
   152  	select {
   153  	case <-service.ReadyWait():
   154  		// continue with test case below
   155  	case <-time.After(1 * time.Second):
   156  		t.Fatalf("timeout waiting for Service.ReadyWait after 1s")
   157  	}
   158  
   159  	tlsCfg := service.ServerTLSConfig()
   160  
   161  	// Sanity check it has a leaf with the right ServiceID and that validates with
   162  	// the given roots.
   163  	require.NotNil(tlsCfg.GetCertificate)
   164  	leaf, err := tlsCfg.GetCertificate(&tls.ClientHelloInfo{})
   165  	require.NoError(err)
   166  	cert, err := x509.ParseCertificate(leaf.Certificate[0])
   167  	require.NoError(err)
   168  	require.Len(cert.URIs, 1)
   169  	require.True(strings.HasSuffix(cert.URIs[0].String(), "/svc/web"))
   170  
   171  	// Verify it as a client would
   172  	err = clientSideVerifier(tlsCfg, leaf.Certificate)
   173  	require.NoError(err)
   174  
   175  	// Now test that rotating the root updates
   176  	{
   177  		// Setup a new generated CA
   178  		connect.TestCAConfigSet(t, a, nil)
   179  	}
   180  
   181  	// After some time, both root and leaves should be different but both should
   182  	// still be correct.
   183  	oldRootSubjects := bytes.Join(tlsCfg.RootCAs.Subjects(), []byte(", "))
   184  	oldLeafSerial := connect.HexString(cert.SerialNumber.Bytes())
   185  	oldLeafKeyID := connect.HexString(cert.SubjectKeyId)
   186  	retry.Run(t, func(r *retry.R) {
   187  		updatedCfg := service.ServerTLSConfig()
   188  
   189  		// Wait until roots are different
   190  		rootSubjects := bytes.Join(updatedCfg.RootCAs.Subjects(), []byte(", "))
   191  		if bytes.Equal(oldRootSubjects, rootSubjects) {
   192  			r.Fatalf("root certificates should have changed, got %s",
   193  				rootSubjects)
   194  		}
   195  
   196  		leaf, err := updatedCfg.GetCertificate(&tls.ClientHelloInfo{})
   197  		r.Check(err)
   198  		cert, err := x509.ParseCertificate(leaf.Certificate[0])
   199  		r.Check(err)
   200  
   201  		if oldLeafSerial == connect.HexString(cert.SerialNumber.Bytes()) {
   202  			r.Fatalf("leaf certificate should have changed, got serial %s",
   203  				oldLeafSerial)
   204  		}
   205  		if oldLeafKeyID == connect.HexString(cert.SubjectKeyId) {
   206  			r.Fatalf("leaf should have a different key, got matching SubjectKeyID = %s",
   207  				oldLeafKeyID)
   208  		}
   209  	})
   210  }
   211  
   212  func TestService_HTTPClient(t *testing.T) {
   213  	ca := connect.TestCA(t, nil)
   214  
   215  	s := TestService(t, "web", ca)
   216  
   217  	// Run a test HTTP server
   218  	testSvr := NewTestServer(t, "backend", ca)
   219  	defer testSvr.Close()
   220  	go func() {
   221  		err := testSvr.ServeHTTPS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   222  			w.Write([]byte("Hello, I am Backend"))
   223  		}))
   224  		require.NoError(t, err)
   225  	}()
   226  	<-testSvr.Listening
   227  
   228  	// Still get connection refused some times so retry on those
   229  	retry.Run(t, func(r *retry.R) {
   230  		// Hook the service resolver to avoid needing full agent setup.
   231  		s.httpResolverFromAddr = func(addr string) (Resolver, error) {
   232  			// Require in this goroutine seems to block causing a timeout on the Get.
   233  			//require.Equal("https://backend.service.consul:443", addr)
   234  			return &StaticResolver{
   235  				Addr:    testSvr.Addr,
   236  				CertURI: connect.TestSpiffeIDService(t, "backend"),
   237  			}, nil
   238  		}
   239  
   240  		client := s.HTTPClient()
   241  		client.Timeout = 1 * time.Second
   242  
   243  		resp, err := client.Get("https://backend.service.consul/foo")
   244  		r.Check(err)
   245  		defer resp.Body.Close()
   246  
   247  		bodyBytes, err := ioutil.ReadAll(resp.Body)
   248  		r.Check(err)
   249  
   250  		got := string(bodyBytes)
   251  		want := "Hello, I am Backend"
   252  		if got != want {
   253  			r.Fatalf("got %s, want %s", got, want)
   254  		}
   255  	})
   256  }
   257  
   258  func TestService_HasDefaultHTTPResolverFromAddr(t *testing.T) {
   259  
   260  	client, err := api.NewClient(api.DefaultConfig())
   261  	require.NoError(t, err)
   262  
   263  	s, err := NewService("foo", client)
   264  	require.NoError(t, err)
   265  
   266  	// Sanity check this is actually set in constructor since we always override
   267  	// it in tests. Full tests of the resolver func are in resolver_test.go
   268  	require.NotNil(t, s.httpResolverFromAddr)
   269  
   270  	fn := s.httpResolverFromAddr
   271  
   272  	expected := &ConsulResolver{
   273  		Client:    client,
   274  		Namespace: "default",
   275  		Name:      "foo",
   276  		Type:      ConsulResolverTypeService,
   277  	}
   278  	got, err := fn("foo.service.consul")
   279  	require.NoError(t, err)
   280  	require.Equal(t, expected, got)
   281  }