github.com/letsencrypt/boulder@v0.20251208.0/cmd/admin/cert_test.go (about) 1 package main 2 3 import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rand" 8 "crypto/x509" 9 "encoding/pem" 10 "errors" 11 "os" 12 "path" 13 "reflect" 14 "slices" 15 "strings" 16 "sync" 17 "testing" 18 "time" 19 20 "github.com/jmhodges/clock" 21 "google.golang.org/grpc" 22 "google.golang.org/protobuf/types/known/emptypb" 23 24 "github.com/letsencrypt/boulder/core" 25 corepb "github.com/letsencrypt/boulder/core/proto" 26 berrors "github.com/letsencrypt/boulder/errors" 27 blog "github.com/letsencrypt/boulder/log" 28 "github.com/letsencrypt/boulder/mocks" 29 rapb "github.com/letsencrypt/boulder/ra/proto" 30 "github.com/letsencrypt/boulder/revocation" 31 sapb "github.com/letsencrypt/boulder/sa/proto" 32 "github.com/letsencrypt/boulder/test" 33 ) 34 35 // mockSAWithIncident is a mock which only implements the SerialsForIncident 36 // gRPC method. It can be initialized with a set of serials for that method 37 // to return. 38 type mockSAWithIncident struct { 39 sapb.StorageAuthorityReadOnlyClient 40 incidentSerials []string 41 } 42 43 // SerialsForIncident returns a fake gRPC stream client object which itself 44 // will return the mockSAWithIncident's serials in order. 45 func (msa *mockSAWithIncident) SerialsForIncident(_ context.Context, _ *sapb.SerialsForIncidentRequest, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.IncidentSerial], error) { 46 fakeResults := make([]*sapb.IncidentSerial, len(msa.incidentSerials)) 47 for i, serial := range msa.incidentSerials { 48 fakeResults[i] = &sapb.IncidentSerial{Serial: serial} 49 } 50 return &mocks.ServerStreamClient[sapb.IncidentSerial]{Results: fakeResults}, nil 51 } 52 53 func TestSerialsFromIncidentTable(t *testing.T) { 54 t.Parallel() 55 serials := []string{"foo", "bar", "baz"} 56 57 a := admin{ 58 saroc: &mockSAWithIncident{incidentSerials: serials}, 59 } 60 61 res, err := a.serialsFromIncidentTable(context.Background(), "tablename") 62 test.AssertNotError(t, err, "getting serials from mock SA") 63 test.AssertDeepEquals(t, res, serials) 64 } 65 66 func TestSerialsFromFile(t *testing.T) { 67 t.Parallel() 68 serials := []string{"foo", "bar", "baz"} 69 70 serialsFile := path.Join(t.TempDir(), "serials.txt") 71 err := os.WriteFile(serialsFile, []byte(strings.Join(serials, "\n")), os.ModeAppend) 72 test.AssertNotError(t, err, "writing temp serials file") 73 74 a := admin{} 75 76 res, err := a.serialsFromFile(context.Background(), serialsFile) 77 test.AssertNotError(t, err, "getting serials from file") 78 test.AssertDeepEquals(t, res, serials) 79 } 80 81 // mockSAWithKey is a mock which only implements the GetSerialsByKey 82 // gRPC method. It can be initialized with a set of serials for that method 83 // to return. 84 type mockSAWithKey struct { 85 sapb.StorageAuthorityReadOnlyClient 86 keyHash []byte 87 serials []string 88 } 89 90 // GetSerialsByKey returns a fake gRPC stream client object which itself 91 // will return the mockSAWithKey's serials in order. 92 func (msa *mockSAWithKey) GetSerialsByKey(_ context.Context, req *sapb.SPKIHash, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) { 93 if !slices.Equal(req.KeyHash, msa.keyHash) { 94 return &mocks.ServerStreamClient[sapb.Serial]{}, nil 95 } 96 fakeResults := make([]*sapb.Serial, len(msa.serials)) 97 for i, serial := range msa.serials { 98 fakeResults[i] = &sapb.Serial{Serial: serial} 99 } 100 return &mocks.ServerStreamClient[sapb.Serial]{Results: fakeResults}, nil 101 } 102 103 func TestSerialsFromPrivateKey(t *testing.T) { 104 serials := []string{"foo", "bar", "baz"} 105 fc := clock.NewFake() 106 fc.Set(time.Now()) 107 108 privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 109 test.AssertNotError(t, err, "creating test private key") 110 keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey) 111 test.AssertNotError(t, err, "marshalling test private key bytes") 112 113 keyFile := path.Join(t.TempDir(), "key.pem") 114 keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}) 115 err = os.WriteFile(keyFile, keyPEM, os.ModeAppend) 116 test.AssertNotError(t, err, "writing test private key file") 117 118 keyHash, err := core.KeyDigest(privKey.Public()) 119 test.AssertNotError(t, err, "computing test SPKI hash") 120 121 a := admin{saroc: &mockSAWithKey{keyHash: keyHash[:], serials: serials}} 122 123 res, err := a.serialsFromPrivateKey(context.Background(), keyFile) 124 test.AssertNotError(t, err, "getting serials from keyHashToSerial table") 125 test.AssertDeepEquals(t, res, serials) 126 } 127 128 // mockSAWithAccount is a mock which only implements the GetSerialsByAccount 129 // gRPC method. It can be initialized with a set of serials for that method 130 // to return. 131 type mockSAWithAccount struct { 132 sapb.StorageAuthorityReadOnlyClient 133 regID int64 134 serials []string 135 } 136 137 func (msa *mockSAWithAccount) GetRegistration(_ context.Context, req *sapb.RegistrationID, _ ...grpc.CallOption) (*corepb.Registration, error) { 138 if req.Id != msa.regID { 139 return nil, errors.New("no such reg") 140 } 141 return &corepb.Registration{}, nil 142 } 143 144 // GetSerialsByAccount returns a fake gRPC stream client object which itself 145 // will return the mockSAWithAccount's serials in order. 146 func (msa *mockSAWithAccount) GetSerialsByAccount(_ context.Context, req *sapb.RegistrationID, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) { 147 if req.Id != msa.regID { 148 return &mocks.ServerStreamClient[sapb.Serial]{}, nil 149 } 150 fakeResults := make([]*sapb.Serial, len(msa.serials)) 151 for i, serial := range msa.serials { 152 fakeResults[i] = &sapb.Serial{Serial: serial} 153 } 154 return &mocks.ServerStreamClient[sapb.Serial]{Results: fakeResults}, nil 155 } 156 157 func TestSerialsFromRegID(t *testing.T) { 158 serials := []string{"foo", "bar", "baz"} 159 a := admin{saroc: &mockSAWithAccount{regID: 123, serials: serials}} 160 161 res, err := a.serialsFromRegID(context.Background(), 123) 162 test.AssertNotError(t, err, "getting serials from serials table") 163 test.AssertDeepEquals(t, res, serials) 164 } 165 166 // mockRARecordingRevocations is a mock which only implements the 167 // AdministrativelyRevokeCertificate gRPC method. It can be initialized with 168 // serials to recognize as already revoked, or to fail. 169 type mockRARecordingRevocations struct { 170 rapb.RegistrationAuthorityClient 171 doomedToFail []string 172 alreadyRevoked []string 173 revocationRequests []*rapb.AdministrativelyRevokeCertificateRequest 174 sync.Mutex 175 } 176 177 // AdministrativelyRevokeCertificate records the request it received on the mock 178 // RA struct, and succeeds if it doesn't recognize the serial as one it should 179 // fail for. 180 func (mra *mockRARecordingRevocations) AdministrativelyRevokeCertificate(_ context.Context, req *rapb.AdministrativelyRevokeCertificateRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) { 181 mra.Lock() 182 defer mra.Unlock() 183 mra.revocationRequests = append(mra.revocationRequests, req) 184 if slices.Contains(mra.doomedToFail, req.Serial) { 185 return nil, errors.New("oops") 186 } 187 if slices.Contains(mra.alreadyRevoked, req.Serial) { 188 return nil, berrors.AlreadyRevokedError("too slow") 189 } 190 return &emptypb.Empty{}, nil 191 } 192 193 func (mra *mockRARecordingRevocations) reset() { 194 mra.doomedToFail = nil 195 mra.alreadyRevoked = nil 196 mra.revocationRequests = nil 197 } 198 199 func TestRevokeSerials(t *testing.T) { 200 t.Parallel() 201 serials := []string{ 202 "2a18592b7f4bf596fb1a1df135567acd825a", 203 "038c3f6388afb7695dd4d6bbe3d264f1e4e2", 204 "048c3f6388afb7695dd4d6bbe3d264f1e5e5", 205 } 206 mra := mockRARecordingRevocations{} 207 log := blog.NewMock() 208 a := admin{rac: &mra, log: log} 209 210 assertRequestsContain := func(reqs []*rapb.AdministrativelyRevokeCertificateRequest, code revocation.Reason, skipBlockKey bool) { 211 t.Helper() 212 for _, req := range reqs { 213 test.AssertEquals(t, len(req.Cert), 0) 214 test.AssertEquals(t, req.Code, int64(code)) 215 test.AssertEquals(t, req.SkipBlockKey, skipBlockKey) 216 } 217 } 218 219 // Revoking should result in 3 gRPC requests and quiet execution. 220 mra.reset() 221 log.Clear() 222 a.dryRun = false 223 err := a.revokeSerials(context.Background(), serials, 0, false, 1) 224 test.AssertEquals(t, len(log.GetAllMatching("invalid serial format")), 0) 225 test.AssertNotError(t, err, "") 226 test.AssertEquals(t, len(log.GetAll()), 0) 227 test.AssertEquals(t, len(mra.revocationRequests), 3) 228 assertRequestsContain(mra.revocationRequests, 0, false) 229 230 // Revoking an already-revoked serial should result in one log line. 231 mra.reset() 232 log.Clear() 233 mra.alreadyRevoked = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"} 234 err = a.revokeSerials(context.Background(), serials, 0, false, 1) 235 t.Logf("error: %s", err) 236 t.Logf("logs: %s", strings.Join(log.GetAll(), "")) 237 test.AssertError(t, err, "already-revoked should result in error") 238 test.AssertEquals(t, len(log.GetAllMatching("not revoking")), 1) 239 test.AssertEquals(t, len(mra.revocationRequests), 3) 240 assertRequestsContain(mra.revocationRequests, 0, false) 241 242 // Revoking a doomed-to-fail serial should also result in one log line. 243 mra.reset() 244 log.Clear() 245 mra.doomedToFail = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"} 246 err = a.revokeSerials(context.Background(), serials, 0, false, 1) 247 test.AssertError(t, err, "gRPC error should result in error") 248 test.AssertEquals(t, len(log.GetAllMatching("failed to revoke")), 1) 249 test.AssertEquals(t, len(mra.revocationRequests), 3) 250 assertRequestsContain(mra.revocationRequests, 0, false) 251 252 // Revoking with other parameters should get carried through. 253 mra.reset() 254 log.Clear() 255 err = a.revokeSerials(context.Background(), serials, 1, true, 3) 256 test.AssertNotError(t, err, "") 257 test.AssertEquals(t, len(mra.revocationRequests), 3) 258 assertRequestsContain(mra.revocationRequests, 1, true) 259 260 // Revoking in dry-run mode should result in no gRPC requests and three logs. 261 mra.reset() 262 log.Clear() 263 a.dryRun = true 264 a.rac = dryRunRAC{log: log} 265 err = a.revokeSerials(context.Background(), serials, 0, false, 1) 266 test.AssertNotError(t, err, "") 267 test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 3) 268 test.AssertEquals(t, len(mra.revocationRequests), 0) 269 assertRequestsContain(mra.revocationRequests, 0, false) 270 } 271 272 func TestRevokeMalformed(t *testing.T) { 273 t.Parallel() 274 mra := mockRARecordingRevocations{} 275 log := blog.NewMock() 276 a := &admin{ 277 rac: &mra, 278 log: log, 279 dryRun: false, 280 } 281 282 s := subcommandRevokeCert{ 283 crlShard: 623, 284 } 285 serial := "0379c3dfdd518be45948f2dbfa6ea3e9b209" 286 err := s.revokeMalformed(context.Background(), a, []string{serial}, 1) 287 if err != nil { 288 t.Errorf("revokedMalformed with crlShard 623: want success, got %s", err) 289 } 290 if len(mra.revocationRequests) != 1 { 291 t.Errorf("revokeMalformed: want 1 revocation request to SA, got %v", mra.revocationRequests) 292 } 293 if mra.revocationRequests[0].Serial != serial { 294 t.Errorf("revokeMalformed: want %s to be revoked, got %s", serial, mra.revocationRequests[0]) 295 } 296 297 s = subcommandRevokeCert{ 298 crlShard: 0, 299 } 300 err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2"}, 1) 301 if err == nil { 302 t.Errorf("revokedMalformed with crlShard 0: want error, got none") 303 } 304 305 s = subcommandRevokeCert{ 306 crlShard: 623, 307 } 308 err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2", "28a94f966eae14e525777188512ddf5a0a3b"}, 1) 309 if err == nil { 310 t.Errorf("revokedMalformed with multiple serials: want error, got none") 311 } 312 } 313 314 func TestCleanSerials(t *testing.T) { 315 input := []string{ 316 "2a:18:59:2b:7f:4b:f5:96:fb:1a:1d:f1:35:56:7a:cd:82:5a", 317 "03:8c:3f:63:88:af:b7:69:5d:d4:d6:bb:e3:d2:64:f1:e4:e2", 318 "038c3f6388afb7695dd4d6bbe3d264f1e4e2", 319 } 320 expected := []string{ 321 "2a18592b7f4bf596fb1a1df135567acd825a", 322 "038c3f6388afb7695dd4d6bbe3d264f1e4e2", 323 "038c3f6388afb7695dd4d6bbe3d264f1e4e2", 324 } 325 output, err := cleanSerials(input) 326 if err != nil { 327 t.Errorf("cleanSerials(%s): %s, want %s", input, err, expected) 328 } 329 if !reflect.DeepEqual(output, expected) { 330 t.Errorf("cleanSerials(%s)=%s, want %s", input, output, expected) 331 } 332 }