github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/services/integrationtesting/cert_test.go (about)

     1  //go:build !skipintegrationtests
     2  // +build !skipintegrationtests
     3  
     4  package integrationtesting_test
     5  
     6  import (
     7  	"context"
     8  	"crypto/ecdsa"
     9  	"crypto/elliptic"
    10  	"crypto/rand"
    11  	"crypto/x509"
    12  	"crypto/x509/pkix"
    13  	"encoding/pem"
    14  	"math/big"
    15  	"os"
    16  	"path/filepath"
    17  	"testing"
    18  	"time"
    19  
    20  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
    21  	"github.com/stretchr/testify/require"
    22  	"go.uber.org/goleak"
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/grpc/backoff"
    25  
    26  	"github.com/authzed/spicedb/internal/datastore/memdb"
    27  	"github.com/authzed/spicedb/internal/dispatch/graph"
    28  	"github.com/authzed/spicedb/internal/middleware/consistency"
    29  	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
    30  	"github.com/authzed/spicedb/internal/middleware/servicespecific"
    31  	tf "github.com/authzed/spicedb/internal/testfixtures"
    32  	"github.com/authzed/spicedb/pkg/cmd/server"
    33  	"github.com/authzed/spicedb/pkg/cmd/util"
    34  	"github.com/authzed/spicedb/pkg/tuple"
    35  	"github.com/authzed/spicedb/pkg/zedtoken"
    36  )
    37  
    38  func TestCertRotation(t *testing.T) {
    39  	const (
    40  		// length of time the initial cert is valid
    41  		initialValidDuration = 3 * time.Second
    42  
    43  		// continue making requests for waitFactor*initialValidDuration
    44  		waitFactor = 2
    45  	)
    46  
    47  	certDir, err := os.MkdirTemp("", "test-certs-")
    48  	require.NoError(t, err)
    49  
    50  	ca := &x509.Certificate{
    51  		NotBefore:             time.Now(),
    52  		NotAfter:              time.Now().Add(5 * time.Minute),
    53  		SerialNumber:          big.NewInt(0),
    54  		Subject:               pkix.Name{Organization: []string{"testCA"}},
    55  		IsCA:                  true,
    56  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
    57  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
    58  		BasicConstraintsValid: true,
    59  	}
    60  	caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    61  	require.NoError(t, err)
    62  	caPublicKey := &caPrivateKey.PublicKey
    63  	caCertBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, caPublicKey, caPrivateKey)
    64  	require.NoError(t, err)
    65  	caCert, err := x509.ParseCertificate(caCertBytes)
    66  	require.NoError(t, err)
    67  	caFile, err := os.Create(filepath.Join(certDir, "ca.crt"))
    68  	require.NoError(t, err)
    69  	t.Cleanup(func() {
    70  		require.NoError(t, caFile.Close())
    71  	})
    72  	require.NoError(t, pem.Encode(caFile, &pem.Block{
    73  		Type:  "CERTIFICATE",
    74  		Bytes: caCert.Raw,
    75  	}))
    76  
    77  	old := &x509.Certificate{
    78  		SerialNumber: big.NewInt(1),
    79  		Subject: pkix.Name{
    80  			Organization: []string{"initialTestCert"},
    81  		},
    82  		NotBefore:             time.Now(),
    83  		NotAfter:              time.Now().Add(initialValidDuration),
    84  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
    85  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
    86  		BasicConstraintsValid: true,
    87  		DNSNames:              []string{"buffnet"},
    88  	}
    89  	oldCertPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    90  	require.NoError(t, err)
    91  	oldCertPublicKey := &oldCertPrivateKey.PublicKey
    92  	oldCertBytes, err := x509.CreateCertificate(rand.Reader, old, caCert, oldCertPublicKey, caPrivateKey)
    93  	require.NoError(t, err)
    94  	oldCert, err := x509.ParseCertificate(oldCertBytes)
    95  	require.NoError(t, err)
    96  
    97  	keyFile, err := os.Create(filepath.Join(certDir, "tls.key"))
    98  	require.NoError(t, err)
    99  	oldKeyBytes, err := x509.MarshalECPrivateKey(oldCertPrivateKey)
   100  	require.NoError(t, err)
   101  	require.NoError(t, pem.Encode(keyFile, &pem.Block{
   102  		Type:  "EC PRIVATE KEY",
   103  		Bytes: oldKeyBytes,
   104  	}))
   105  	require.NoError(t, keyFile.Close())
   106  
   107  	certFile, err := os.Create(filepath.Join(certDir, "tls.crt"))
   108  	require.NoError(t, err)
   109  	require.NoError(t, pem.Encode(certFile, &pem.Block{
   110  		Type:  "CERTIFICATE",
   111  		Bytes: oldCert.Raw,
   112  	}))
   113  	require.NoError(t, certFile.Close())
   114  
   115  	// start a server with an initial set of certs
   116  	emptyDS, err := memdb.NewMemdbDatastore(0, 10, time.Duration(90_000_000_000_000))
   117  	require.NoError(t, err)
   118  	ds, revision := tf.StandardDatastoreWithData(emptyDS, require.New(t))
   119  	ctx, cancel := context.WithCancel(context.Background())
   120  	srv, err := server.NewConfigWithOptionsAndDefaults(
   121  		server.WithDatastore(ds),
   122  		server.WithDispatcher(graph.NewLocalOnlyDispatcher(1)),
   123  		server.WithDispatchMaxDepth(50),
   124  		server.WithMaximumPreconditionCount(1000),
   125  		server.WithMaximumUpdatesPerWrite(1000),
   126  		server.WithGRPCServer(util.GRPCServerConfig{
   127  			Network:      util.BufferedNetwork,
   128  			Enabled:      true,
   129  			TLSCertPath:  certFile.Name(),
   130  			TLSKeyPath:   keyFile.Name(),
   131  			ClientCAPath: caFile.Name(),
   132  		}),
   133  		server.WithGRPCAuthFunc(func(ctx context.Context) (context.Context, error) {
   134  			return ctx, nil
   135  		}),
   136  		server.WithHTTPGateway(util.HTTPServerConfig{HTTPEnabled: false}),
   137  		server.WithMetricsAPI(util.HTTPServerConfig{HTTPEnabled: false}),
   138  		server.WithDispatchServer(util.GRPCServerConfig{Enabled: false}),
   139  		server.SetUnaryMiddlewareModification([]server.MiddlewareModification[grpc.UnaryServerInterceptor]{
   140  			{
   141  				Operation: server.OperationReplaceAllUnsafe,
   142  				Middlewares: []server.ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   143  					{
   144  						Name:       "datastore",
   145  						Middleware: datastoremw.UnaryServerInterceptor(ds),
   146  					},
   147  					{
   148  						Name:       "consistency",
   149  						Middleware: consistency.UnaryServerInterceptor(),
   150  					},
   151  					{
   152  						Name:       "servicespecific",
   153  						Middleware: servicespecific.UnaryServerInterceptor,
   154  					},
   155  				},
   156  			},
   157  		}),
   158  		server.SetStreamingMiddlewareModification([]server.MiddlewareModification[grpc.StreamServerInterceptor]{
   159  			{
   160  				Operation: server.OperationReplaceAllUnsafe,
   161  				Middlewares: []server.ReferenceableMiddleware[grpc.StreamServerInterceptor]{
   162  					{
   163  						Name:       "datastore",
   164  						Middleware: datastoremw.StreamServerInterceptor(ds),
   165  					},
   166  					{
   167  						Name:       "consistency",
   168  						Middleware: consistency.StreamServerInterceptor(),
   169  					},
   170  					{
   171  						Name:       "servicespecific",
   172  						Middleware: servicespecific.StreamServerInterceptor,
   173  					},
   174  				},
   175  			},
   176  		}),
   177  	).Complete(ctx)
   178  	require.NoError(t, err)
   179  
   180  	wait := make(chan struct{}, 1)
   181  	go func() {
   182  		require.NoError(t, srv.Run(ctx))
   183  		wait <- struct{}{}
   184  	}()
   185  
   186  	// If previous code takes more than initialValidDuration*2 to execute, the cert
   187  	// would have expired, and Dial would retry indefinitely, hence the context timeout
   188  	dialCtx, cancelDial := context.WithTimeout(ctx, initialValidDuration*2)
   189  	conn, err := srv.GRPCDialContext(dialCtx,
   190  		grpc.WithReturnConnectionError(),
   191  		grpc.WithConnectParams(grpc.ConnectParams{
   192  			Backoff: backoff.Config{
   193  				BaseDelay:  1 * time.Second,
   194  				Multiplier: 2,
   195  				MaxDelay:   15 * time.Second,
   196  			},
   197  		}),
   198  	)
   199  
   200  	require.NoError(t, err)
   201  	defer func() {
   202  		if conn != nil {
   203  			require.NoError(t, conn.Close())
   204  		}
   205  	}()
   206  	// requests work with the old key
   207  	client := v1.NewPermissionsServiceClient(conn)
   208  	rel := tuple.MustToRelationship(tuple.Parse(tf.StandardTuples[0]))
   209  	_, err = client.CheckPermission(ctx, &v1.CheckPermissionRequest{
   210  		Consistency: &v1.Consistency{
   211  			Requirement: &v1.Consistency_AtLeastAsFresh{
   212  				AtLeastAsFresh: zedtoken.MustNewFromRevision(revision),
   213  			},
   214  		},
   215  		Resource:   rel.Resource,
   216  		Permission: "viewer",
   217  		Subject:    rel.Subject,
   218  	})
   219  	require.NoError(t, err)
   220  
   221  	// rotate the key
   222  	newCert := &x509.Certificate{
   223  		SerialNumber: big.NewInt(2),
   224  		Subject: pkix.Name{
   225  			Organization: []string{"rotatedTestCert"},
   226  		},
   227  		NotBefore:             time.Now(),
   228  		NotAfter:              time.Now().Add(5 * time.Minute),
   229  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   230  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   231  		BasicConstraintsValid: true,
   232  		DNSNames:              []string{"buffnet"},
   233  	}
   234  	newCertPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   235  	require.NoError(t, err)
   236  	newCertPublicKey := &newCertPrivateKey.PublicKey
   237  	newCertBytes, err := x509.CreateCertificate(rand.Reader, newCert, caCert, newCertPublicKey, caPrivateKey)
   238  	require.NoError(t, err)
   239  	newCertParsed, err := x509.ParseCertificate(newCertBytes)
   240  	require.NoError(t, err)
   241  
   242  	keyFile, err = os.OpenFile(keyFile.Name(), os.O_WRONLY|os.O_TRUNC, 0o755)
   243  	require.NoError(t, err)
   244  	newKeyBytes, err := x509.MarshalECPrivateKey(newCertPrivateKey)
   245  	require.NoError(t, err)
   246  	require.NoError(t, pem.Encode(keyFile, &pem.Block{
   247  		Type:  "EC PRIVATE KEY",
   248  		Bytes: newKeyBytes,
   249  	}))
   250  	require.NoError(t, keyFile.Close())
   251  
   252  	certFile, err = os.OpenFile(certFile.Name(), os.O_WRONLY|os.O_TRUNC, 0o755)
   253  	require.NoError(t, err)
   254  	require.NoError(t, pem.Encode(certFile, &pem.Block{
   255  		Type:  "CERTIFICATE",
   256  		Bytes: newCertParsed.Raw,
   257  	}))
   258  	require.NoError(t, certFile.Close())
   259  
   260  	// check for waitFactor*initialValidDuration seconds
   261  	for i := 0; i < waitFactor; i++ {
   262  		_, err = client.CheckPermission(ctx, &v1.CheckPermissionRequest{
   263  			Consistency: &v1.Consistency{
   264  				Requirement: &v1.Consistency_AtLeastAsFresh{
   265  					AtLeastAsFresh: zedtoken.MustNewFromRevision(revision),
   266  				},
   267  			},
   268  			Resource:   rel.Resource,
   269  			Permission: "viewer",
   270  			Subject:    rel.Subject,
   271  		})
   272  		require.NoError(t, err)
   273  		time.Sleep(initialValidDuration)
   274  	}
   275  
   276  	cancel()
   277  	cancelDial()
   278  	select {
   279  	case <-wait:
   280  		return
   281  	case <-time.After(30 * time.Second):
   282  		require.Fail(t, "ungraceful server termination")
   283  	}
   284  	goleak.VerifyNone(t, goleak.IgnoreCurrent())
   285  }