github.com/operator-framework/operator-lifecycle-manager@v0.30.0/pkg/lib/filemonitor/cert_updater_test.go (about)

     1  package filemonitor
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"fmt"
     8  	"html"
     9  	"net"
    10  	"net/http"
    11  	"os"
    12  	"path/filepath"
    13  	"strconv"
    14  	"testing"
    15  	"time"
    16  
    17  	"k8s.io/apimachinery/pkg/util/wait"
    18  
    19  	"github.com/sirupsen/logrus"
    20  
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  func TestOLMGetCertRotationFn(t *testing.T) {
    26  	logger := logrus.New()
    27  	logger.SetLevel(logrus.DebugLevel)
    28  	logger.SetFormatter(&logrus.TextFormatter{
    29  		TimestampFormat: time.RFC3339Nano,
    30  	})
    31  
    32  	testData := "testdata"
    33  	monitorDir := "monitor"
    34  	caCrt := filepath.Join(testData, "ca.crt")
    35  	oldCrt := filepath.Join(testData, "server-old.crt")
    36  	oldKey := filepath.Join(testData, "server-old.key")
    37  	newCrt := filepath.Join(testData, "server-new.crt")
    38  	newKey := filepath.Join(testData, "server-new.key")
    39  	loadCrt := filepath.Join(monitorDir, "loaded.crt")
    40  	loadKey := filepath.Join(monitorDir, "loaded.key")
    41  
    42  	// these values must match values specified in the testdata generation script
    43  	expectedOldCN := "CN=127.0.0.1,OU=OpenShift,O=Red Hat,L=Columbia,ST=SC,C=US"
    44  	expectedNewCN := "CN=127.0.0.1,OU=OpenShift,O=Red Hat,L=New York City,ST=NY,C=US"
    45  
    46  	// the directory is expected to contain exactly one keypair, so create an empty directory to swap the keys in
    47  	err := os.RemoveAll(monitorDir) // this is for test development, shouldn't ever exist beforehand otherwise
    48  	require.NoError(t, err)
    49  	err = os.Mkdir(monitorDir, 0777)
    50  	require.NoError(t, err)
    51  
    52  	// symlink old files to loading files
    53  	err = os.Symlink(filepath.Join("..", oldCrt), loadCrt)
    54  	require.NoError(t, err)
    55  	err = os.Symlink(filepath.Join("..", oldKey), loadKey)
    56  	require.NoError(t, err)
    57  
    58  	certStore, err := NewCertStore(loadCrt, loadKey)
    59  	if err != nil {
    60  		require.NoError(t, err)
    61  	}
    62  
    63  	tlsGetCertFn := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
    64  		return certStore.GetCertificate(), nil
    65  	}
    66  
    67  	csw, err := NewWatch(logger, []string{filepath.Dir(loadCrt), filepath.Dir(loadKey)}, certStore.HandleFilesystemUpdate)
    68  	require.NoError(t, err)
    69  	csw.Run(context.Background())
    70  
    71  	// find a free port to listen on and start server
    72  	listener, err := net.Listen("tcp", "localhost:0")
    73  	require.NoError(t, err)
    74  	freePort := listener.Addr().(*net.TCPAddr).Port
    75  	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
    76  		fmt.Fprintf(w, "Path: %q", html.EscapeString(r.URL.Path))
    77  	})
    78  	httpsServer := &http.Server{
    79  		Addr: ":" + strconv.Itoa(freePort),
    80  		TLSConfig: &tls.Config{
    81  			GetCertificate: tlsGetCertFn,
    82  		},
    83  	}
    84  	go func() {
    85  		if err := httpsServer.ServeTLS(listener, "", ""); err != nil {
    86  			panic(err)
    87  		}
    88  	}()
    89  
    90  	caCert, err := os.ReadFile(caCrt)
    91  	require.NoError(t, err)
    92  	caCertPool := x509.NewCertPool()
    93  	caCertPool.AppendCertsFromPEM(caCert)
    94  
    95  	client := &http.Client{
    96  		Transport: &http.Transport{
    97  			TLSClientConfig: &tls.Config{
    98  				RootCAs: caCertPool,
    99  			},
   100  		},
   101  	}
   102  
   103  	resp, err := client.Get(fmt.Sprintf("https://localhost:%v", freePort))
   104  	require.NoError(t, err)
   105  	assert.Equal(t, resp.StatusCode, http.StatusOK)
   106  	assert.Equal(t, expectedOldCN, resp.TLS.PeerCertificates[0].Subject.String())
   107  	resp.Body.Close()
   108  	client.CloseIdleConnections()
   109  
   110  	// atomically switch out the symlink so the file contents are always seen in a consistent state
   111  	// (the same idea is used in the atomic writer in kubernetes)
   112  	atomicCrt := loadCrt + ".atomic-op"
   113  	atomicKey := loadKey + ".atomic-op"
   114  	err = os.Symlink(filepath.Join("..", newCrt), atomicCrt)
   115  	require.NoError(t, err)
   116  	err = os.Symlink(filepath.Join("..", newKey), atomicKey)
   117  	require.NoError(t, err)
   118  
   119  	err = os.Rename(atomicCrt, loadCrt)
   120  	require.NoError(t, err)
   121  	err = os.Rename(atomicKey, loadKey)
   122  	require.NoError(t, err)
   123  
   124  	// sometimes the the filesystem operations need time to catch up so the server cert is updated
   125  	err = wait.PollImmediate(500*time.Millisecond, 10*time.Second, func() (bool, error) {
   126  		currentCert, err := tlsGetCertFn(nil)
   127  		require.NoError(t, err)
   128  		info, err := x509.ParseCertificate(currentCert.Certificate[0])
   129  		if err != nil {
   130  			return false, err
   131  		}
   132  		if info.Subject.String() == expectedNewCN {
   133  			return true, nil
   134  		}
   135  
   136  		return false, nil
   137  	})
   138  	require.NoError(t, err)
   139  
   140  	resp, err = client.Get(fmt.Sprintf("https://localhost:%v", freePort))
   141  	require.NoError(t, err)
   142  	assert.Equal(t, resp.StatusCode, http.StatusOK)
   143  	assert.Equal(t, expectedNewCN, resp.TLS.PeerCertificates[0].Subject.String())
   144  
   145  	os.RemoveAll(monitorDir)
   146  }