github.com/thanos-io/thanos@v0.32.5/test/e2e/tls_test.go (about)

     1  // Copyright (c) The Thanos Authors.
     2  // Licensed under the Apache License 2.0.
     3  
     4  package e2e_test
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/rand"
    10  	"crypto/rsa"
    11  	"crypto/x509"
    12  	"encoding/pem"
    13  	"fmt"
    14  	"math/big"
    15  	"net"
    16  	"os"
    17  	"path/filepath"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/fortytw2/leaktest"
    22  	"github.com/go-kit/log"
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/grpc/credentials"
    25  	"google.golang.org/grpc/keepalive"
    26  
    27  	"github.com/efficientgo/core/testutil"
    28  	"github.com/thanos-io/thanos/pkg/testutil/e2eutil"
    29  
    30  	pb "google.golang.org/grpc/examples/features/proto/echo"
    31  
    32  	thTLS "github.com/thanos-io/thanos/pkg/tls"
    33  )
    34  
    35  var serverName = "thanos"
    36  
    37  func TestGRPCServerCertAutoRotate(t *testing.T) {
    38  	defer leaktest.CheckTimeout(t, 10*time.Second)() // To see whether any goroutines leaked.
    39  
    40  	logger := log.NewLogfmtLogger(os.Stderr)
    41  	expMessage := "hello world"
    42  
    43  	tmpDirClt := t.TempDir()
    44  	caClt := filepath.Join(tmpDirClt, "ca")
    45  	certClt := filepath.Join(tmpDirClt, "cert")
    46  	keyClt := filepath.Join(tmpDirClt, "key")
    47  
    48  	tmpDirSrv := t.TempDir()
    49  	caSrv := filepath.Join(tmpDirSrv, "ca")
    50  	certSrv := filepath.Join(tmpDirSrv, "cert")
    51  	keySrv := filepath.Join(tmpDirSrv, "key")
    52  
    53  	genCerts(t, certSrv, keySrv, caClt)
    54  	genCerts(t, certClt, keyClt, caSrv)
    55  
    56  	configSrv, err := thTLS.NewServerConfig(logger, certSrv, keySrv, caSrv)
    57  	testutil.Ok(t, err)
    58  
    59  	srv := grpc.NewServer(grpc.KeepaliveParams(keepalive.ServerParameters{MaxConnectionAge: 1 * time.Millisecond}), grpc.Creds(credentials.NewTLS(configSrv)))
    60  
    61  	pb.RegisterEchoServer(srv, &ecServer{})
    62  	p, err := e2eutil.FreePort()
    63  	testutil.Ok(t, err)
    64  	addr := fmt.Sprint("localhost:", p)
    65  	lis, err := net.Listen("tcp", addr)
    66  	testutil.Ok(t, err)
    67  
    68  	go func() {
    69  		testutil.Ok(t, srv.Serve(lis))
    70  	}()
    71  	defer func() { srv.Stop() }()
    72  	time.Sleep(50 * time.Millisecond) // Wait for the server to start.
    73  
    74  	// Setup the connection and the client.
    75  	configClt, err := thTLS.NewClientConfig(logger, certClt, keyClt, caClt, serverName, false)
    76  	testutil.Ok(t, err)
    77  	conn, err := grpc.Dial(addr, grpc.WithConnectParams(grpc.ConnectParams{MinConnectTimeout: 1 * time.Minute}), grpc.WithTransportCredentials(credentials.NewTLS(configClt)))
    78  	testutil.Ok(t, err)
    79  	defer func() {
    80  		testutil.Ok(t, conn.Close())
    81  	}()
    82  	clt := pb.NewEchoClient(conn)
    83  
    84  	// Check a good state.
    85  	resp, err := clt.UnaryEcho(context.Background(), &pb.EchoRequest{Message: expMessage})
    86  	testutil.Ok(t, err)
    87  	testutil.Equals(t, expMessage, resp.Message)
    88  
    89  	// Reload certs and check for a good state.
    90  	genCerts(t, certSrv, keySrv, caClt)
    91  	genCerts(t, certClt, keyClt, caSrv)
    92  	time.Sleep(50 * time.Millisecond) // Wait for the server MaxConnectionAge to expire.
    93  	resp, err = clt.UnaryEcho(context.Background(), &pb.EchoRequest{Message: expMessage})
    94  	testutil.Ok(t, err)
    95  	testutil.Equals(t, expMessage, resp.Message)
    96  }
    97  
    98  var caRoot = &x509.Certificate{
    99  	SerialNumber:          big.NewInt(2019),
   100  	NotAfter:              time.Now().AddDate(10, 0, 0),
   101  	IsCA:                  true,
   102  	ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   103  	KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   104  	BasicConstraintsValid: true,
   105  }
   106  
   107  var cert = &x509.Certificate{
   108  	SerialNumber: big.NewInt(1658),
   109  	DNSNames:     []string{serverName},
   110  	NotAfter:     time.Now().AddDate(10, 0, 0),
   111  	SubjectKeyId: []byte{1, 2, 3},
   112  	ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   113  	KeyUsage:     x509.KeyUsageDigitalSignature,
   114  }
   115  
   116  // genCerts generates certificates and writes those to the provided paths.
   117  // When the CA file already exists it is not overwritten and
   118  // it is used to sign the certificates.
   119  func genCerts(t *testing.T, certPath, privkeyPath, caPath string) {
   120  	var (
   121  		err       error
   122  		caPrivKey *rsa.PrivateKey
   123  		caSrvPriv = caPath + ".priv"
   124  	)
   125  
   126  	// When the CA private file exists don't overwrite it but
   127  	// use it to extract the private key to be used for signing the certificate.
   128  	if _, err := os.Stat(caSrvPriv); !os.IsNotExist(err) {
   129  		d, err := os.ReadFile(caSrvPriv)
   130  		testutil.Ok(t, err)
   131  		caPrivKey, err = x509.ParsePKCS1PrivateKey(d)
   132  		testutil.Ok(t, err)
   133  	} else {
   134  		caPrivKey, err = rsa.GenerateKey(rand.Reader, 1024)
   135  		testutil.Ok(t, err)
   136  	}
   137  
   138  	certPrivKey, err := rsa.GenerateKey(rand.Reader, 1024)
   139  	testutil.Ok(t, err)
   140  
   141  	// Sign the cert with the CA private key.
   142  	certBytes, err := x509.CreateCertificate(rand.Reader, cert, caRoot, &certPrivKey.PublicKey, caPrivKey)
   143  	testutil.Ok(t, err)
   144  
   145  	if caPath != "" {
   146  		caBytes, err := x509.CreateCertificate(rand.Reader, caRoot, caRoot, &caPrivKey.PublicKey, caPrivKey)
   147  		testutil.Ok(t, err)
   148  		caPEM := pem.EncodeToMemory(&pem.Block{
   149  			Type:  "CERTIFICATE",
   150  			Bytes: caBytes,
   151  		})
   152  		testutil.Ok(t, os.WriteFile(caPath, caPEM, 0644))
   153  		testutil.Ok(t, os.WriteFile(caSrvPriv, x509.MarshalPKCS1PrivateKey(caPrivKey), 0644))
   154  	}
   155  
   156  	if certPath != "" {
   157  		certPEM := new(bytes.Buffer)
   158  		testutil.Ok(t, pem.Encode(certPEM, &pem.Block{
   159  			Type:  "CERTIFICATE",
   160  			Bytes: certBytes,
   161  		}))
   162  		testutil.Ok(t, os.WriteFile(certPath, certPEM.Bytes(), 0644))
   163  	}
   164  
   165  	if privkeyPath != "" {
   166  		certPrivKeyPEM := new(bytes.Buffer)
   167  		testutil.Ok(t, pem.Encode(certPrivKeyPEM, &pem.Block{
   168  			Type:  "RSA PRIVATE KEY",
   169  			Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
   170  		}))
   171  		testutil.Ok(t, os.WriteFile(privkeyPath, certPrivKeyPEM.Bytes(), 0644))
   172  	}
   173  }
   174  
   175  type ecServer struct {
   176  	pb.UnimplementedEchoServer
   177  }
   178  
   179  func (s *ecServer) UnaryEcho(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) {
   180  	return &pb.EchoResponse{Message: req.Message}, nil
   181  }
   182  
   183  func TestInvalidCertAndKey(t *testing.T) {
   184  	defer leaktest.CheckTimeout(t, 10*time.Second)()
   185  	logger := log.NewLogfmtLogger(os.Stderr)
   186  	tmpDirSrv := t.TempDir()
   187  	caSrv := filepath.Join(tmpDirSrv, "ca")
   188  	certSrv := filepath.Join(tmpDirSrv, "cert")
   189  	keySrv := filepath.Join(tmpDirSrv, "key")
   190  	// Certificate and key are not present in the above path
   191  	_, err := thTLS.NewServerConfig(logger, certSrv, keySrv, caSrv)
   192  	testutil.NotOk(t, err)
   193  }