hub.fastgit.org/hashicorp/consul.git@v1.4.5/connect/testing.go (about)

     1  package connect
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"fmt"
     7  	"io"
     8  	"log"
     9  	"net"
    10  	"net/http"
    11  	"os"
    12  	"sync/atomic"
    13  
    14  	"github.com/hashicorp/consul/agent/connect"
    15  	"github.com/hashicorp/consul/agent/structs"
    16  	"github.com/hashicorp/consul/lib/freeport"
    17  	testing "github.com/mitchellh/go-testing-interface"
    18  )
    19  
    20  // TestService returns a Service instance based on a static TLS Config.
    21  func TestService(t testing.T, service string, ca *structs.CARoot) *Service {
    22  	t.Helper()
    23  
    24  	// Don't need to talk to client since we are setting TLSConfig locally
    25  	svc, err := NewDevServiceWithTLSConfig(service,
    26  		log.New(os.Stderr, "", log.LstdFlags), TestTLSConfig(t, service, ca))
    27  	if err != nil {
    28  		t.Fatal(err)
    29  	}
    30  	return svc
    31  }
    32  
    33  // TestTLSConfig returns a *tls.Config suitable for use during tests.
    34  func TestTLSConfig(t testing.T, service string, ca *structs.CARoot) *tls.Config {
    35  	t.Helper()
    36  
    37  	cfg := defaultTLSConfig()
    38  	cfg.Certificates = []tls.Certificate{TestSvcKeyPair(t, service, ca)}
    39  	cfg.RootCAs = TestCAPool(t, ca)
    40  	cfg.ClientCAs = TestCAPool(t, ca)
    41  	return cfg
    42  }
    43  
    44  // TestCAPool returns an *x509.CertPool containing the passed CA's root(s)
    45  func TestCAPool(t testing.T, cas ...*structs.CARoot) *x509.CertPool {
    46  	t.Helper()
    47  	pool := x509.NewCertPool()
    48  	for _, ca := range cas {
    49  		pool.AppendCertsFromPEM([]byte(ca.RootCert))
    50  	}
    51  	return pool
    52  }
    53  
    54  // TestSvcKeyPair returns an tls.Certificate containing both cert and private
    55  // key for a given service under a given CA from the testdata dir.
    56  func TestSvcKeyPair(t testing.T, service string, ca *structs.CARoot) tls.Certificate {
    57  	t.Helper()
    58  	certPEM, keyPEM := connect.TestLeaf(t, service, ca)
    59  	cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	return cert
    64  }
    65  
    66  // TestPeerCertificates returns a []*x509.Certificate as you'd get from
    67  // tls.Conn.ConnectionState().PeerCertificates including the named certificate.
    68  func TestPeerCertificates(t testing.T, service string, ca *structs.CARoot) []*x509.Certificate {
    69  	t.Helper()
    70  	certPEM, _ := connect.TestLeaf(t, service, ca)
    71  	cert, err := connect.ParseCert(certPEM)
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  	return []*x509.Certificate{cert}
    76  }
    77  
    78  // TestServer runs a service listener that can be used to test clients. It's
    79  // behavior can be controlled by the struct members.
    80  type TestServer struct {
    81  	// The service name to serve.
    82  	Service string
    83  	// The (test) CA to use for generating certs.
    84  	CA *structs.CARoot
    85  	// TimeoutHandshake controls whether the listening server will complete a TLS
    86  	// handshake quickly enough.
    87  	TimeoutHandshake bool
    88  	// TLSCfg is the tls.Config that will be used. By default it's set up from the
    89  	// service and ca set.
    90  	TLSCfg *tls.Config
    91  	// Addr is the listen address. It is set to a random free port on `localhost`
    92  	// by default.
    93  	Addr string
    94  	// Listening is closed when the listener is run.
    95  	Listening chan struct{}
    96  
    97  	l        net.Listener
    98  	stopFlag int32
    99  	stopChan chan struct{}
   100  }
   101  
   102  // NewTestServer returns a TestServer. It should be closed when test is
   103  // complete.
   104  func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer {
   105  	ports := freeport.GetT(t, 1)
   106  	return &TestServer{
   107  		Service:   service,
   108  		CA:        ca,
   109  		stopChan:  make(chan struct{}),
   110  		TLSCfg:    TestTLSConfig(t, service, ca),
   111  		Addr:      fmt.Sprintf("127.0.0.1:%d", ports[0]),
   112  		Listening: make(chan struct{}),
   113  	}
   114  }
   115  
   116  // Serve runs a tcp echo server and blocks until it is closed or errors. If
   117  // TimeoutHandshake is set it won't start TLS handshake on new connections.
   118  func (s *TestServer) Serve() error {
   119  	// Just accept TCP conn but so we can control timing of accept/handshake
   120  	l, err := net.Listen("tcp", s.Addr)
   121  	if err != nil {
   122  		return err
   123  	}
   124  	close(s.Listening)
   125  	s.l = l
   126  	log.Printf("test connect service listening on %s", s.Addr)
   127  
   128  	for {
   129  		conn, err := s.l.Accept()
   130  		if err != nil {
   131  			if atomic.LoadInt32(&s.stopFlag) == 1 {
   132  				return nil
   133  			}
   134  			return err
   135  		}
   136  
   137  		// Ignore the conn if we are not actively handshaking
   138  		if !s.TimeoutHandshake {
   139  			// Upgrade conn to TLS
   140  			conn = tls.Server(conn, s.TLSCfg)
   141  
   142  			// Run an echo service
   143  			log.Printf("test connect service accepted conn from %s, "+
   144  				" running echo service", conn.RemoteAddr())
   145  			go io.Copy(conn, conn)
   146  		}
   147  
   148  		// Close this conn when we stop
   149  		go func(c net.Conn) {
   150  			<-s.stopChan
   151  			c.Close()
   152  		}(conn)
   153  	}
   154  }
   155  
   156  // ServeHTTPS runs an HTTPS server with the given config. It invokes the passed
   157  // Handler for all requests.
   158  func (s *TestServer) ServeHTTPS(h http.Handler) error {
   159  	srv := http.Server{
   160  		Addr:      s.Addr,
   161  		TLSConfig: s.TLSCfg,
   162  		Handler:   h,
   163  	}
   164  	log.Printf("starting test connect HTTPS server on %s", s.Addr)
   165  
   166  	// Use our own listener so we can signal when it's ready.
   167  	l, err := net.Listen("tcp", s.Addr)
   168  	if err != nil {
   169  		return err
   170  	}
   171  	close(s.Listening)
   172  	s.l = l
   173  	log.Printf("test connect service listening on %s", s.Addr)
   174  
   175  	err = srv.ServeTLS(l, "", "")
   176  	if atomic.LoadInt32(&s.stopFlag) == 1 {
   177  		return nil
   178  	}
   179  	return err
   180  }
   181  
   182  // Close stops a TestServer
   183  func (s *TestServer) Close() error {
   184  	old := atomic.SwapInt32(&s.stopFlag, 1)
   185  	if old == 0 {
   186  		if s.l != nil {
   187  			s.l.Close()
   188  		}
   189  		close(s.stopChan)
   190  	}
   191  	return nil
   192  }