github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/trees/trees_test.go (about) 1 // Copyright 2017 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package trees 16 17 import ( 18 "context" 19 "crypto" 20 "crypto/ecdsa" 21 "crypto/elliptic" 22 "crypto/rand" 23 "crypto/rsa" 24 "errors" 25 "fmt" 26 "testing" 27 28 "github.com/golang/mock/gomock" 29 "github.com/golang/protobuf/proto" 30 "github.com/golang/protobuf/ptypes" 31 "github.com/google/trillian" 32 "github.com/google/trillian/crypto/keys" 33 "github.com/google/trillian/crypto/sigpb" 34 "github.com/google/trillian/storage" 35 "github.com/google/trillian/storage/testonly" 36 "github.com/kylelemons/godebug/pretty" 37 "google.golang.org/grpc/codes" 38 "google.golang.org/grpc/status" 39 40 tcrypto "github.com/google/trillian/crypto" 41 ) 42 43 func TestFromContext(t *testing.T) { 44 tests := []struct { 45 desc string 46 tree *trillian.Tree 47 }{ 48 {desc: "noTree"}, 49 {desc: "hasTree", tree: testonly.LogTree}, 50 } 51 for _, test := range tests { 52 ctx := NewContext(context.Background(), test.tree) 53 54 tree, ok := FromContext(ctx) 55 switch wantOK := test.tree != nil; { 56 case ok != wantOK: 57 t.Errorf("%v: FromContext(%v) = (_, %v), want = (_, %v)", test.desc, ctx, ok, wantOK) 58 case ok && !proto.Equal(tree, test.tree): 59 t.Errorf("%v: FromContext(%v) = (%v, nil), want = (%v, nil)", test.desc, ctx, tree, test.tree) 60 case !ok && tree != nil: 61 t.Errorf("%v: FromContext(%v) = (%v, %v), want = (nil, %v)", test.desc, ctx, tree, ok, wantOK) 62 } 63 } 64 } 65 66 func TestGetTree(t *testing.T) { 67 logTree := *testonly.LogTree 68 logTree.TreeId = 1 69 70 mapTree := *testonly.MapTree 71 mapTree.TreeId = 2 72 73 frozenTree := *testonly.LogTree 74 frozenTree.TreeId = 3 75 frozenTree.TreeState = trillian.TreeState_FROZEN 76 77 drainingTree := *testonly.LogTree 78 drainingTree.TreeId = 3 79 drainingTree.TreeState = trillian.TreeState_DRAINING 80 81 softDeletedTree := *testonly.LogTree 82 softDeletedTree.Deleted = true 83 softDeletedTree.DeleteTime = ptypes.TimestampNow() 84 85 tests := []struct { 86 desc string 87 treeID int64 88 opts GetOpts 89 ctxTree, storageTree, wantTree *trillian.Tree 90 beginErr, getErr, commitErr error 91 wantErr bool 92 code codes.Code 93 }{ 94 { 95 desc: "anyTree", 96 treeID: logTree.TreeId, 97 opts: NewGetOpts(Query), 98 storageTree: &logTree, 99 wantTree: &logTree, 100 }, 101 { 102 desc: "logTree", 103 treeID: logTree.TreeId, 104 opts: NewGetOpts(Query, trillian.TreeType_LOG), 105 storageTree: &logTree, 106 wantTree: &logTree, 107 }, 108 { 109 desc: "mapTree", 110 treeID: mapTree.TreeId, 111 opts: NewGetOpts(Query, trillian.TreeType_MAP), 112 storageTree: &mapTree, 113 wantTree: &mapTree, 114 }, 115 { 116 desc: "logTreeButMaybeMap", 117 treeID: logTree.TreeId, 118 opts: NewGetOpts(Query, trillian.TreeType_LOG, trillian.TreeType_MAP), 119 storageTree: &logTree, 120 wantTree: &logTree, 121 }, 122 { 123 desc: "mapTreeButMaybeLog", 124 treeID: mapTree.TreeId, 125 opts: NewGetOpts(Query, trillian.TreeType_LOG, trillian.TreeType_MAP), 126 storageTree: &mapTree, 127 wantTree: &mapTree, 128 }, 129 { 130 desc: "wrongType1", 131 treeID: logTree.TreeId, 132 opts: NewGetOpts(Query, trillian.TreeType_MAP), 133 storageTree: &logTree, 134 wantErr: true, 135 code: codes.InvalidArgument, 136 }, 137 { 138 desc: "wrongType2", 139 treeID: mapTree.TreeId, 140 opts: NewGetOpts(Query, trillian.TreeType_LOG), 141 storageTree: &mapTree, 142 wantErr: true, 143 code: codes.InvalidArgument, 144 }, 145 { 146 desc: "wrongType3", 147 treeID: mapTree.TreeId, 148 opts: NewGetOpts(Query, trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG), 149 storageTree: &mapTree, 150 wantErr: true, 151 code: codes.InvalidArgument, 152 }, 153 { 154 desc: "adminLog", 155 treeID: logTree.TreeId, 156 opts: NewGetOpts(Admin, trillian.TreeType_LOG), 157 storageTree: &logTree, 158 wantTree: &logTree, 159 }, 160 { 161 desc: "adminPreordered", 162 treeID: testonly.PreorderedLogTree.TreeId, 163 opts: NewGetOpts(Admin, trillian.TreeType_PREORDERED_LOG), 164 storageTree: testonly.PreorderedLogTree, 165 wantTree: testonly.PreorderedLogTree, 166 }, 167 { 168 desc: "adminFrozen", 169 treeID: logTree.TreeId, 170 opts: NewGetOpts(Admin, trillian.TreeType_LOG), 171 storageTree: &frozenTree, 172 wantTree: &frozenTree, 173 }, 174 { 175 desc: "adminMap", 176 treeID: mapTree.TreeId, 177 opts: NewGetOpts(Admin, trillian.TreeType_MAP), 178 storageTree: &mapTree, 179 wantTree: &mapTree, 180 }, 181 { 182 desc: "queryLog", 183 treeID: logTree.TreeId, 184 opts: NewGetOpts(Query, trillian.TreeType_LOG), 185 storageTree: &logTree, 186 wantTree: &logTree, 187 }, 188 { 189 desc: "queryPreordered", 190 treeID: testonly.PreorderedLogTree.TreeId, 191 opts: NewGetOpts(Query, trillian.TreeType_PREORDERED_LOG), 192 storageTree: testonly.PreorderedLogTree, 193 wantTree: testonly.PreorderedLogTree, 194 }, 195 { 196 desc: "queryMap", 197 treeID: mapTree.TreeId, 198 opts: NewGetOpts(Query, trillian.TreeType_MAP), 199 storageTree: &mapTree, 200 wantTree: &mapTree, 201 }, 202 { 203 desc: "queryFrozen", 204 treeID: frozenTree.TreeId, 205 opts: NewGetOpts(Query, trillian.TreeType_LOG), 206 storageTree: &frozenTree, 207 wantTree: &frozenTree, 208 }, 209 { 210 desc: "sequenceFrozen", 211 treeID: frozenTree.TreeId, 212 opts: NewGetOpts(SequenceLog, trillian.TreeType_LOG), 213 storageTree: &frozenTree, 214 wantTree: &frozenTree, 215 wantErr: true, 216 code: codes.PermissionDenied, 217 }, 218 { 219 desc: "queueFrozen", 220 treeID: frozenTree.TreeId, 221 opts: NewGetOpts(QueueLog, trillian.TreeType_LOG), 222 storageTree: &frozenTree, 223 wantTree: &frozenTree, 224 wantErr: true, 225 code: codes.PermissionDenied, 226 }, 227 { 228 desc: "queryDraining", 229 treeID: drainingTree.TreeId, 230 opts: NewGetOpts(Query, trillian.TreeType_LOG), 231 storageTree: &drainingTree, 232 wantTree: &drainingTree, 233 }, 234 { 235 desc: "sequenceDraining", 236 treeID: drainingTree.TreeId, 237 opts: NewGetOpts(SequenceLog, trillian.TreeType_LOG), 238 storageTree: &drainingTree, 239 wantTree: &drainingTree, 240 }, 241 { 242 desc: "queueDraining", 243 treeID: drainingTree.TreeId, 244 opts: NewGetOpts(QueueLog, trillian.TreeType_LOG), 245 storageTree: &drainingTree, 246 wantTree: &drainingTree, 247 wantErr: true, 248 code: codes.PermissionDenied, 249 }, 250 { 251 desc: "softDeleted", 252 treeID: softDeletedTree.TreeId, 253 opts: NewGetOpts(Query, trillian.TreeType_LOG), 254 storageTree: &softDeletedTree, 255 wantErr: true, // Deleted = true makes the tree "invisible" for most RPCs 256 code: codes.NotFound, 257 }, 258 { 259 desc: "treeInCtx", 260 treeID: logTree.TreeId, 261 opts: NewGetOpts(Query, trillian.TreeType_LOG), 262 ctxTree: &logTree, 263 wantTree: &logTree, 264 }, 265 { 266 desc: "wrongTreeInCtx", 267 treeID: logTree.TreeId, 268 opts: NewGetOpts(Query, trillian.TreeType_LOG), 269 ctxTree: &mapTree, 270 storageTree: &logTree, 271 wantTree: &logTree, 272 }, 273 { 274 desc: "beginErr", 275 treeID: logTree.TreeId, 276 opts: NewGetOpts(Query, trillian.TreeType_LOG), 277 beginErr: errors.New("begin err"), 278 wantErr: true, 279 code: codes.Unknown, 280 }, 281 { 282 desc: "getErr", 283 treeID: logTree.TreeId, 284 opts: NewGetOpts(Query, trillian.TreeType_LOG), 285 getErr: errors.New("get err"), 286 wantErr: true, 287 code: codes.Unknown, 288 }, 289 { 290 desc: "commitErr", 291 treeID: logTree.TreeId, 292 opts: NewGetOpts(Query, trillian.TreeType_LOG), 293 commitErr: errors.New("commit err"), 294 wantErr: true, 295 code: codes.Unknown, 296 }, 297 } 298 299 ctrl := gomock.NewController(t) 300 defer ctrl.Finish() 301 302 for _, test := range tests { 303 ctx := NewContext(context.Background(), test.ctxTree) 304 305 admin := storage.NewMockAdminStorage(ctrl) 306 tx := storage.NewMockReadOnlyAdminTX(ctrl) 307 admin.EXPECT().Snapshot(gomock.Any()).MaxTimes(1).Return(tx, test.beginErr) 308 tx.EXPECT().GetTree(gomock.Any(), test.treeID).MaxTimes(1).Return(test.storageTree, test.getErr) 309 tx.EXPECT().Close().MaxTimes(1).Return(nil) 310 tx.EXPECT().Commit().MaxTimes(1).Return(test.commitErr) 311 312 tree, err := GetTree(ctx, admin, test.treeID, test.opts) 313 if hasErr := err != nil; hasErr != test.wantErr { 314 t.Errorf("%v: GetTree() = (_, %q), wantErr = %v", test.desc, err, test.wantErr) 315 continue 316 } else if hasErr { 317 if status.Code(err) != test.code { 318 t.Errorf("%v: GetTree() = (_, %q), got ErrorCode: %v, want: %v", test.desc, err, status.Code(err), test.code) 319 } 320 continue 321 } 322 323 if !proto.Equal(tree, test.wantTree) { 324 diff := pretty.Compare(tree, test.wantTree) 325 t.Errorf("%v: post-GetTree diff:\n%v", test.desc, diff) 326 } 327 } 328 } 329 330 func TestHash(t *testing.T) { 331 tests := []struct { 332 hashAlgo sigpb.DigitallySigned_HashAlgorithm 333 wantHash crypto.Hash 334 wantErr bool 335 }{ 336 {hashAlgo: sigpb.DigitallySigned_NONE, wantErr: true}, 337 {hashAlgo: sigpb.DigitallySigned_SHA256, wantHash: crypto.SHA256}, 338 } 339 340 for _, test := range tests { 341 tree := *testonly.LogTree 342 tree.HashAlgorithm = test.hashAlgo 343 344 hash, err := Hash(&tree) 345 if hasErr := err != nil; hasErr != test.wantErr { 346 t.Errorf("Hash(%s) = (_, %q), wantErr = %v", test.hashAlgo, err, test.wantErr) 347 continue 348 } else if hasErr { 349 continue 350 } 351 352 if hash != test.wantHash { 353 t.Errorf("Hash(%s) = (%v, nil), want = (%v, nil)", test.hashAlgo, hash, test.wantHash) 354 } 355 } 356 } 357 358 func TestSigner(t *testing.T) { 359 ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 360 if err != nil { 361 t.Fatalf("Error generating test ECDSA key: %v", err) 362 } 363 364 rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) 365 if err != nil { 366 t.Fatalf("Error generating test RSA key: %v", err) 367 } 368 369 ctrl := gomock.NewController(t) 370 defer ctrl.Finish() 371 372 tests := []struct { 373 desc string 374 sigAlgo sigpb.DigitallySigned_SignatureAlgorithm 375 signer crypto.Signer 376 newSignerErr error 377 wantErr bool 378 }{ 379 { 380 desc: "anonymous", 381 sigAlgo: sigpb.DigitallySigned_ANONYMOUS, 382 wantErr: true, 383 }, 384 { 385 desc: "ecdsa", 386 sigAlgo: sigpb.DigitallySigned_ECDSA, 387 signer: ecdsaKey, 388 }, 389 { 390 desc: "rsa", 391 sigAlgo: sigpb.DigitallySigned_RSA, 392 signer: rsaKey, 393 }, 394 { 395 desc: "keyMismatch1", 396 sigAlgo: sigpb.DigitallySigned_ECDSA, 397 signer: rsaKey, 398 wantErr: true, 399 }, 400 { 401 desc: "keyMismatch2", 402 sigAlgo: sigpb.DigitallySigned_RSA, 403 signer: ecdsaKey, 404 wantErr: true, 405 }, 406 { 407 desc: "newSignerErr", 408 sigAlgo: sigpb.DigitallySigned_ECDSA, 409 newSignerErr: errors.New("NewSigner() error"), 410 wantErr: true, 411 }, 412 } 413 414 ctx := context.Background() 415 for _, test := range tests { 416 t.Run(test.desc, func(t *testing.T) { 417 tree := *testonly.LogTree 418 tree.HashAlgorithm = sigpb.DigitallySigned_SHA256 419 tree.HashStrategy = trillian.HashStrategy_RFC6962_SHA256 420 tree.SignatureAlgorithm = test.sigAlgo 421 422 var wantKeyProto ptypes.DynamicAny 423 if err := ptypes.UnmarshalAny(tree.PrivateKey, &wantKeyProto); err != nil { 424 t.Fatalf("failed to unmarshal tree.PrivateKey: %v", err) 425 } 426 427 keys.RegisterHandler(wantKeyProto.Message, func(ctx context.Context, gotKeyProto proto.Message) (crypto.Signer, error) { 428 if !proto.Equal(gotKeyProto, wantKeyProto.Message) { 429 return nil, fmt.Errorf("NewSigner(_, %#v) called, want NewSigner(_, %#v)", gotKeyProto, wantKeyProto.Message) 430 } 431 return test.signer, test.newSignerErr 432 }) 433 defer keys.UnregisterHandler(wantKeyProto.Message) 434 435 signer, err := Signer(ctx, &tree) 436 if hasErr := err != nil; hasErr != test.wantErr { 437 t.Fatalf("Signer(_, %s) = (_, %q), wantErr = %v", test.sigAlgo, err, test.wantErr) 438 } else if hasErr { 439 return 440 } 441 442 want := tcrypto.NewSigner(0, test.signer, crypto.SHA256) 443 if diff := pretty.Compare(signer, want); diff != "" { 444 t.Fatalf("post-Signer(_, %s) diff:\n%v", test.sigAlgo, diff) 445 } 446 }) 447 } 448 }