github.com/letsencrypt/boulder@v0.20251208.0/cmd/admin/cert_test.go (about)

     1  package main
     2  
     3  import (
     4  	"context"
     5  	"crypto/ecdsa"
     6  	"crypto/elliptic"
     7  	"crypto/rand"
     8  	"crypto/x509"
     9  	"encoding/pem"
    10  	"errors"
    11  	"os"
    12  	"path"
    13  	"reflect"
    14  	"slices"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/jmhodges/clock"
    21  	"google.golang.org/grpc"
    22  	"google.golang.org/protobuf/types/known/emptypb"
    23  
    24  	"github.com/letsencrypt/boulder/core"
    25  	corepb "github.com/letsencrypt/boulder/core/proto"
    26  	berrors "github.com/letsencrypt/boulder/errors"
    27  	blog "github.com/letsencrypt/boulder/log"
    28  	"github.com/letsencrypt/boulder/mocks"
    29  	rapb "github.com/letsencrypt/boulder/ra/proto"
    30  	"github.com/letsencrypt/boulder/revocation"
    31  	sapb "github.com/letsencrypt/boulder/sa/proto"
    32  	"github.com/letsencrypt/boulder/test"
    33  )
    34  
    35  // mockSAWithIncident is a mock which only implements the SerialsForIncident
    36  // gRPC method. It can be initialized with a set of serials for that method
    37  // to return.
    38  type mockSAWithIncident struct {
    39  	sapb.StorageAuthorityReadOnlyClient
    40  	incidentSerials []string
    41  }
    42  
    43  // SerialsForIncident returns a fake gRPC stream client object which itself
    44  // will return the mockSAWithIncident's serials in order.
    45  func (msa *mockSAWithIncident) SerialsForIncident(_ context.Context, _ *sapb.SerialsForIncidentRequest, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.IncidentSerial], error) {
    46  	fakeResults := make([]*sapb.IncidentSerial, len(msa.incidentSerials))
    47  	for i, serial := range msa.incidentSerials {
    48  		fakeResults[i] = &sapb.IncidentSerial{Serial: serial}
    49  	}
    50  	return &mocks.ServerStreamClient[sapb.IncidentSerial]{Results: fakeResults}, nil
    51  }
    52  
    53  func TestSerialsFromIncidentTable(t *testing.T) {
    54  	t.Parallel()
    55  	serials := []string{"foo", "bar", "baz"}
    56  
    57  	a := admin{
    58  		saroc: &mockSAWithIncident{incidentSerials: serials},
    59  	}
    60  
    61  	res, err := a.serialsFromIncidentTable(context.Background(), "tablename")
    62  	test.AssertNotError(t, err, "getting serials from mock SA")
    63  	test.AssertDeepEquals(t, res, serials)
    64  }
    65  
    66  func TestSerialsFromFile(t *testing.T) {
    67  	t.Parallel()
    68  	serials := []string{"foo", "bar", "baz"}
    69  
    70  	serialsFile := path.Join(t.TempDir(), "serials.txt")
    71  	err := os.WriteFile(serialsFile, []byte(strings.Join(serials, "\n")), os.ModeAppend)
    72  	test.AssertNotError(t, err, "writing temp serials file")
    73  
    74  	a := admin{}
    75  
    76  	res, err := a.serialsFromFile(context.Background(), serialsFile)
    77  	test.AssertNotError(t, err, "getting serials from file")
    78  	test.AssertDeepEquals(t, res, serials)
    79  }
    80  
    81  // mockSAWithKey is a mock which only implements the GetSerialsByKey
    82  // gRPC method. It can be initialized with a set of serials for that method
    83  // to return.
    84  type mockSAWithKey struct {
    85  	sapb.StorageAuthorityReadOnlyClient
    86  	keyHash []byte
    87  	serials []string
    88  }
    89  
    90  // GetSerialsByKey returns a fake gRPC stream client object which itself
    91  // will return the mockSAWithKey's serials in order.
    92  func (msa *mockSAWithKey) GetSerialsByKey(_ context.Context, req *sapb.SPKIHash, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) {
    93  	if !slices.Equal(req.KeyHash, msa.keyHash) {
    94  		return &mocks.ServerStreamClient[sapb.Serial]{}, nil
    95  	}
    96  	fakeResults := make([]*sapb.Serial, len(msa.serials))
    97  	for i, serial := range msa.serials {
    98  		fakeResults[i] = &sapb.Serial{Serial: serial}
    99  	}
   100  	return &mocks.ServerStreamClient[sapb.Serial]{Results: fakeResults}, nil
   101  }
   102  
   103  func TestSerialsFromPrivateKey(t *testing.T) {
   104  	serials := []string{"foo", "bar", "baz"}
   105  	fc := clock.NewFake()
   106  	fc.Set(time.Now())
   107  
   108  	privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   109  	test.AssertNotError(t, err, "creating test private key")
   110  	keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
   111  	test.AssertNotError(t, err, "marshalling test private key bytes")
   112  
   113  	keyFile := path.Join(t.TempDir(), "key.pem")
   114  	keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
   115  	err = os.WriteFile(keyFile, keyPEM, os.ModeAppend)
   116  	test.AssertNotError(t, err, "writing test private key file")
   117  
   118  	keyHash, err := core.KeyDigest(privKey.Public())
   119  	test.AssertNotError(t, err, "computing test SPKI hash")
   120  
   121  	a := admin{saroc: &mockSAWithKey{keyHash: keyHash[:], serials: serials}}
   122  
   123  	res, err := a.serialsFromPrivateKey(context.Background(), keyFile)
   124  	test.AssertNotError(t, err, "getting serials from keyHashToSerial table")
   125  	test.AssertDeepEquals(t, res, serials)
   126  }
   127  
   128  // mockSAWithAccount is a mock which only implements the GetSerialsByAccount
   129  // gRPC method. It can be initialized with a set of serials for that method
   130  // to return.
   131  type mockSAWithAccount struct {
   132  	sapb.StorageAuthorityReadOnlyClient
   133  	regID   int64
   134  	serials []string
   135  }
   136  
   137  func (msa *mockSAWithAccount) GetRegistration(_ context.Context, req *sapb.RegistrationID, _ ...grpc.CallOption) (*corepb.Registration, error) {
   138  	if req.Id != msa.regID {
   139  		return nil, errors.New("no such reg")
   140  	}
   141  	return &corepb.Registration{}, nil
   142  }
   143  
   144  // GetSerialsByAccount returns a fake gRPC stream client object which itself
   145  // will return the mockSAWithAccount's serials in order.
   146  func (msa *mockSAWithAccount) GetSerialsByAccount(_ context.Context, req *sapb.RegistrationID, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) {
   147  	if req.Id != msa.regID {
   148  		return &mocks.ServerStreamClient[sapb.Serial]{}, nil
   149  	}
   150  	fakeResults := make([]*sapb.Serial, len(msa.serials))
   151  	for i, serial := range msa.serials {
   152  		fakeResults[i] = &sapb.Serial{Serial: serial}
   153  	}
   154  	return &mocks.ServerStreamClient[sapb.Serial]{Results: fakeResults}, nil
   155  }
   156  
   157  func TestSerialsFromRegID(t *testing.T) {
   158  	serials := []string{"foo", "bar", "baz"}
   159  	a := admin{saroc: &mockSAWithAccount{regID: 123, serials: serials}}
   160  
   161  	res, err := a.serialsFromRegID(context.Background(), 123)
   162  	test.AssertNotError(t, err, "getting serials from serials table")
   163  	test.AssertDeepEquals(t, res, serials)
   164  }
   165  
   166  // mockRARecordingRevocations is a mock which only implements the
   167  // AdministrativelyRevokeCertificate gRPC method. It can be initialized with
   168  // serials to recognize as already revoked, or to fail.
   169  type mockRARecordingRevocations struct {
   170  	rapb.RegistrationAuthorityClient
   171  	doomedToFail       []string
   172  	alreadyRevoked     []string
   173  	revocationRequests []*rapb.AdministrativelyRevokeCertificateRequest
   174  	sync.Mutex
   175  }
   176  
   177  // AdministrativelyRevokeCertificate records the request it received on the mock
   178  // RA struct, and succeeds if it doesn't recognize the serial as one it should
   179  // fail for.
   180  func (mra *mockRARecordingRevocations) AdministrativelyRevokeCertificate(_ context.Context, req *rapb.AdministrativelyRevokeCertificateRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) {
   181  	mra.Lock()
   182  	defer mra.Unlock()
   183  	mra.revocationRequests = append(mra.revocationRequests, req)
   184  	if slices.Contains(mra.doomedToFail, req.Serial) {
   185  		return nil, errors.New("oops")
   186  	}
   187  	if slices.Contains(mra.alreadyRevoked, req.Serial) {
   188  		return nil, berrors.AlreadyRevokedError("too slow")
   189  	}
   190  	return &emptypb.Empty{}, nil
   191  }
   192  
   193  func (mra *mockRARecordingRevocations) reset() {
   194  	mra.doomedToFail = nil
   195  	mra.alreadyRevoked = nil
   196  	mra.revocationRequests = nil
   197  }
   198  
   199  func TestRevokeSerials(t *testing.T) {
   200  	t.Parallel()
   201  	serials := []string{
   202  		"2a18592b7f4bf596fb1a1df135567acd825a",
   203  		"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
   204  		"048c3f6388afb7695dd4d6bbe3d264f1e5e5",
   205  	}
   206  	mra := mockRARecordingRevocations{}
   207  	log := blog.NewMock()
   208  	a := admin{rac: &mra, log: log}
   209  
   210  	assertRequestsContain := func(reqs []*rapb.AdministrativelyRevokeCertificateRequest, code revocation.Reason, skipBlockKey bool) {
   211  		t.Helper()
   212  		for _, req := range reqs {
   213  			test.AssertEquals(t, len(req.Cert), 0)
   214  			test.AssertEquals(t, req.Code, int64(code))
   215  			test.AssertEquals(t, req.SkipBlockKey, skipBlockKey)
   216  		}
   217  	}
   218  
   219  	// Revoking should result in 3 gRPC requests and quiet execution.
   220  	mra.reset()
   221  	log.Clear()
   222  	a.dryRun = false
   223  	err := a.revokeSerials(context.Background(), serials, 0, false, 1)
   224  	test.AssertEquals(t, len(log.GetAllMatching("invalid serial format")), 0)
   225  	test.AssertNotError(t, err, "")
   226  	test.AssertEquals(t, len(log.GetAll()), 0)
   227  	test.AssertEquals(t, len(mra.revocationRequests), 3)
   228  	assertRequestsContain(mra.revocationRequests, 0, false)
   229  
   230  	// Revoking an already-revoked serial should result in one log line.
   231  	mra.reset()
   232  	log.Clear()
   233  	mra.alreadyRevoked = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"}
   234  	err = a.revokeSerials(context.Background(), serials, 0, false, 1)
   235  	t.Logf("error: %s", err)
   236  	t.Logf("logs: %s", strings.Join(log.GetAll(), ""))
   237  	test.AssertError(t, err, "already-revoked should result in error")
   238  	test.AssertEquals(t, len(log.GetAllMatching("not revoking")), 1)
   239  	test.AssertEquals(t, len(mra.revocationRequests), 3)
   240  	assertRequestsContain(mra.revocationRequests, 0, false)
   241  
   242  	// Revoking a doomed-to-fail serial should also result in one log line.
   243  	mra.reset()
   244  	log.Clear()
   245  	mra.doomedToFail = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"}
   246  	err = a.revokeSerials(context.Background(), serials, 0, false, 1)
   247  	test.AssertError(t, err, "gRPC error should result in error")
   248  	test.AssertEquals(t, len(log.GetAllMatching("failed to revoke")), 1)
   249  	test.AssertEquals(t, len(mra.revocationRequests), 3)
   250  	assertRequestsContain(mra.revocationRequests, 0, false)
   251  
   252  	// Revoking with other parameters should get carried through.
   253  	mra.reset()
   254  	log.Clear()
   255  	err = a.revokeSerials(context.Background(), serials, 1, true, 3)
   256  	test.AssertNotError(t, err, "")
   257  	test.AssertEquals(t, len(mra.revocationRequests), 3)
   258  	assertRequestsContain(mra.revocationRequests, 1, true)
   259  
   260  	// Revoking in dry-run mode should result in no gRPC requests and three logs.
   261  	mra.reset()
   262  	log.Clear()
   263  	a.dryRun = true
   264  	a.rac = dryRunRAC{log: log}
   265  	err = a.revokeSerials(context.Background(), serials, 0, false, 1)
   266  	test.AssertNotError(t, err, "")
   267  	test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 3)
   268  	test.AssertEquals(t, len(mra.revocationRequests), 0)
   269  	assertRequestsContain(mra.revocationRequests, 0, false)
   270  }
   271  
   272  func TestRevokeMalformed(t *testing.T) {
   273  	t.Parallel()
   274  	mra := mockRARecordingRevocations{}
   275  	log := blog.NewMock()
   276  	a := &admin{
   277  		rac:    &mra,
   278  		log:    log,
   279  		dryRun: false,
   280  	}
   281  
   282  	s := subcommandRevokeCert{
   283  		crlShard: 623,
   284  	}
   285  	serial := "0379c3dfdd518be45948f2dbfa6ea3e9b209"
   286  	err := s.revokeMalformed(context.Background(), a, []string{serial}, 1)
   287  	if err != nil {
   288  		t.Errorf("revokedMalformed with crlShard 623: want success, got %s", err)
   289  	}
   290  	if len(mra.revocationRequests) != 1 {
   291  		t.Errorf("revokeMalformed: want 1 revocation request to SA, got %v", mra.revocationRequests)
   292  	}
   293  	if mra.revocationRequests[0].Serial != serial {
   294  		t.Errorf("revokeMalformed: want %s to be revoked, got %s", serial, mra.revocationRequests[0])
   295  	}
   296  
   297  	s = subcommandRevokeCert{
   298  		crlShard: 0,
   299  	}
   300  	err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2"}, 1)
   301  	if err == nil {
   302  		t.Errorf("revokedMalformed with crlShard 0: want error, got none")
   303  	}
   304  
   305  	s = subcommandRevokeCert{
   306  		crlShard: 623,
   307  	}
   308  	err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2", "28a94f966eae14e525777188512ddf5a0a3b"}, 1)
   309  	if err == nil {
   310  		t.Errorf("revokedMalformed with multiple serials: want error, got none")
   311  	}
   312  }
   313  
   314  func TestCleanSerials(t *testing.T) {
   315  	input := []string{
   316  		"2a:18:59:2b:7f:4b:f5:96:fb:1a:1d:f1:35:56:7a:cd:82:5a",
   317  		"03:8c:3f:63:88:af:b7:69:5d:d4:d6:bb:e3:d2:64:f1:e4:e2",
   318  		"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
   319  	}
   320  	expected := []string{
   321  		"2a18592b7f4bf596fb1a1df135567acd825a",
   322  		"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
   323  		"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
   324  	}
   325  	output, err := cleanSerials(input)
   326  	if err != nil {
   327  		t.Errorf("cleanSerials(%s): %s, want %s", input, err, expected)
   328  	}
   329  	if !reflect.DeepEqual(output, expected) {
   330  		t.Errorf("cleanSerials(%s)=%s, want %s", input, output, expected)
   331  	}
   332  }