github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/trees/trees_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package trees
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"crypto/ecdsa"
    21  	"crypto/elliptic"
    22  	"crypto/rand"
    23  	"crypto/rsa"
    24  	"errors"
    25  	"fmt"
    26  	"testing"
    27  
    28  	"github.com/golang/mock/gomock"
    29  	"github.com/golang/protobuf/proto"
    30  	"github.com/golang/protobuf/ptypes"
    31  	"github.com/google/trillian"
    32  	"github.com/google/trillian/crypto/keys"
    33  	"github.com/google/trillian/crypto/sigpb"
    34  	"github.com/google/trillian/storage"
    35  	"github.com/google/trillian/storage/testonly"
    36  	"github.com/kylelemons/godebug/pretty"
    37  	"google.golang.org/grpc/codes"
    38  	"google.golang.org/grpc/status"
    39  
    40  	tcrypto "github.com/google/trillian/crypto"
    41  )
    42  
    43  func TestFromContext(t *testing.T) {
    44  	tests := []struct {
    45  		desc string
    46  		tree *trillian.Tree
    47  	}{
    48  		{desc: "noTree"},
    49  		{desc: "hasTree", tree: testonly.LogTree},
    50  	}
    51  	for _, test := range tests {
    52  		ctx := NewContext(context.Background(), test.tree)
    53  
    54  		tree, ok := FromContext(ctx)
    55  		switch wantOK := test.tree != nil; {
    56  		case ok != wantOK:
    57  			t.Errorf("%v: FromContext(%v) = (_, %v), want = (_, %v)", test.desc, ctx, ok, wantOK)
    58  		case ok && !proto.Equal(tree, test.tree):
    59  			t.Errorf("%v: FromContext(%v) = (%v, nil), want = (%v, nil)", test.desc, ctx, tree, test.tree)
    60  		case !ok && tree != nil:
    61  			t.Errorf("%v: FromContext(%v) = (%v, %v), want = (nil, %v)", test.desc, ctx, tree, ok, wantOK)
    62  		}
    63  	}
    64  }
    65  
    66  func TestGetTree(t *testing.T) {
    67  	logTree := *testonly.LogTree
    68  	logTree.TreeId = 1
    69  
    70  	mapTree := *testonly.MapTree
    71  	mapTree.TreeId = 2
    72  
    73  	frozenTree := *testonly.LogTree
    74  	frozenTree.TreeId = 3
    75  	frozenTree.TreeState = trillian.TreeState_FROZEN
    76  
    77  	drainingTree := *testonly.LogTree
    78  	drainingTree.TreeId = 3
    79  	drainingTree.TreeState = trillian.TreeState_DRAINING
    80  
    81  	softDeletedTree := *testonly.LogTree
    82  	softDeletedTree.Deleted = true
    83  	softDeletedTree.DeleteTime = ptypes.TimestampNow()
    84  
    85  	tests := []struct {
    86  		desc                           string
    87  		treeID                         int64
    88  		opts                           GetOpts
    89  		ctxTree, storageTree, wantTree *trillian.Tree
    90  		beginErr, getErr, commitErr    error
    91  		wantErr                        bool
    92  		code                           codes.Code
    93  	}{
    94  		{
    95  			desc:        "anyTree",
    96  			treeID:      logTree.TreeId,
    97  			opts:        NewGetOpts(Query),
    98  			storageTree: &logTree,
    99  			wantTree:    &logTree,
   100  		},
   101  		{
   102  			desc:        "logTree",
   103  			treeID:      logTree.TreeId,
   104  			opts:        NewGetOpts(Query, trillian.TreeType_LOG),
   105  			storageTree: &logTree,
   106  			wantTree:    &logTree,
   107  		},
   108  		{
   109  			desc:        "mapTree",
   110  			treeID:      mapTree.TreeId,
   111  			opts:        NewGetOpts(Query, trillian.TreeType_MAP),
   112  			storageTree: &mapTree,
   113  			wantTree:    &mapTree,
   114  		},
   115  		{
   116  			desc:        "logTreeButMaybeMap",
   117  			treeID:      logTree.TreeId,
   118  			opts:        NewGetOpts(Query, trillian.TreeType_LOG, trillian.TreeType_MAP),
   119  			storageTree: &logTree,
   120  			wantTree:    &logTree,
   121  		},
   122  		{
   123  			desc:        "mapTreeButMaybeLog",
   124  			treeID:      mapTree.TreeId,
   125  			opts:        NewGetOpts(Query, trillian.TreeType_LOG, trillian.TreeType_MAP),
   126  			storageTree: &mapTree,
   127  			wantTree:    &mapTree,
   128  		},
   129  		{
   130  			desc:        "wrongType1",
   131  			treeID:      logTree.TreeId,
   132  			opts:        NewGetOpts(Query, trillian.TreeType_MAP),
   133  			storageTree: &logTree,
   134  			wantErr:     true,
   135  			code:        codes.InvalidArgument,
   136  		},
   137  		{
   138  			desc:        "wrongType2",
   139  			treeID:      mapTree.TreeId,
   140  			opts:        NewGetOpts(Query, trillian.TreeType_LOG),
   141  			storageTree: &mapTree,
   142  			wantErr:     true,
   143  			code:        codes.InvalidArgument,
   144  		},
   145  		{
   146  			desc:        "wrongType3",
   147  			treeID:      mapTree.TreeId,
   148  			opts:        NewGetOpts(Query, trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG),
   149  			storageTree: &mapTree,
   150  			wantErr:     true,
   151  			code:        codes.InvalidArgument,
   152  		},
   153  		{
   154  			desc:        "adminLog",
   155  			treeID:      logTree.TreeId,
   156  			opts:        NewGetOpts(Admin, trillian.TreeType_LOG),
   157  			storageTree: &logTree,
   158  			wantTree:    &logTree,
   159  		},
   160  		{
   161  			desc:        "adminPreordered",
   162  			treeID:      testonly.PreorderedLogTree.TreeId,
   163  			opts:        NewGetOpts(Admin, trillian.TreeType_PREORDERED_LOG),
   164  			storageTree: testonly.PreorderedLogTree,
   165  			wantTree:    testonly.PreorderedLogTree,
   166  		},
   167  		{
   168  			desc:        "adminFrozen",
   169  			treeID:      logTree.TreeId,
   170  			opts:        NewGetOpts(Admin, trillian.TreeType_LOG),
   171  			storageTree: &frozenTree,
   172  			wantTree:    &frozenTree,
   173  		},
   174  		{
   175  			desc:        "adminMap",
   176  			treeID:      mapTree.TreeId,
   177  			opts:        NewGetOpts(Admin, trillian.TreeType_MAP),
   178  			storageTree: &mapTree,
   179  			wantTree:    &mapTree,
   180  		},
   181  		{
   182  			desc:        "queryLog",
   183  			treeID:      logTree.TreeId,
   184  			opts:        NewGetOpts(Query, trillian.TreeType_LOG),
   185  			storageTree: &logTree,
   186  			wantTree:    &logTree,
   187  		},
   188  		{
   189  			desc:        "queryPreordered",
   190  			treeID:      testonly.PreorderedLogTree.TreeId,
   191  			opts:        NewGetOpts(Query, trillian.TreeType_PREORDERED_LOG),
   192  			storageTree: testonly.PreorderedLogTree,
   193  			wantTree:    testonly.PreorderedLogTree,
   194  		},
   195  		{
   196  			desc:        "queryMap",
   197  			treeID:      mapTree.TreeId,
   198  			opts:        NewGetOpts(Query, trillian.TreeType_MAP),
   199  			storageTree: &mapTree,
   200  			wantTree:    &mapTree,
   201  		},
   202  		{
   203  			desc:        "queryFrozen",
   204  			treeID:      frozenTree.TreeId,
   205  			opts:        NewGetOpts(Query, trillian.TreeType_LOG),
   206  			storageTree: &frozenTree,
   207  			wantTree:    &frozenTree,
   208  		},
   209  		{
   210  			desc:        "sequenceFrozen",
   211  			treeID:      frozenTree.TreeId,
   212  			opts:        NewGetOpts(SequenceLog, trillian.TreeType_LOG),
   213  			storageTree: &frozenTree,
   214  			wantTree:    &frozenTree,
   215  			wantErr:     true,
   216  			code:        codes.PermissionDenied,
   217  		},
   218  		{
   219  			desc:        "queueFrozen",
   220  			treeID:      frozenTree.TreeId,
   221  			opts:        NewGetOpts(QueueLog, trillian.TreeType_LOG),
   222  			storageTree: &frozenTree,
   223  			wantTree:    &frozenTree,
   224  			wantErr:     true,
   225  			code:        codes.PermissionDenied,
   226  		},
   227  		{
   228  			desc:        "queryDraining",
   229  			treeID:      drainingTree.TreeId,
   230  			opts:        NewGetOpts(Query, trillian.TreeType_LOG),
   231  			storageTree: &drainingTree,
   232  			wantTree:    &drainingTree,
   233  		},
   234  		{
   235  			desc:        "sequenceDraining",
   236  			treeID:      drainingTree.TreeId,
   237  			opts:        NewGetOpts(SequenceLog, trillian.TreeType_LOG),
   238  			storageTree: &drainingTree,
   239  			wantTree:    &drainingTree,
   240  		},
   241  		{
   242  			desc:        "queueDraining",
   243  			treeID:      drainingTree.TreeId,
   244  			opts:        NewGetOpts(QueueLog, trillian.TreeType_LOG),
   245  			storageTree: &drainingTree,
   246  			wantTree:    &drainingTree,
   247  			wantErr:     true,
   248  			code:        codes.PermissionDenied,
   249  		},
   250  		{
   251  			desc:        "softDeleted",
   252  			treeID:      softDeletedTree.TreeId,
   253  			opts:        NewGetOpts(Query, trillian.TreeType_LOG),
   254  			storageTree: &softDeletedTree,
   255  			wantErr:     true, // Deleted = true makes the tree "invisible" for most RPCs
   256  			code:        codes.NotFound,
   257  		},
   258  		{
   259  			desc:     "treeInCtx",
   260  			treeID:   logTree.TreeId,
   261  			opts:     NewGetOpts(Query, trillian.TreeType_LOG),
   262  			ctxTree:  &logTree,
   263  			wantTree: &logTree,
   264  		},
   265  		{
   266  			desc:        "wrongTreeInCtx",
   267  			treeID:      logTree.TreeId,
   268  			opts:        NewGetOpts(Query, trillian.TreeType_LOG),
   269  			ctxTree:     &mapTree,
   270  			storageTree: &logTree,
   271  			wantTree:    &logTree,
   272  		},
   273  		{
   274  			desc:     "beginErr",
   275  			treeID:   logTree.TreeId,
   276  			opts:     NewGetOpts(Query, trillian.TreeType_LOG),
   277  			beginErr: errors.New("begin err"),
   278  			wantErr:  true,
   279  			code:     codes.Unknown,
   280  		},
   281  		{
   282  			desc:    "getErr",
   283  			treeID:  logTree.TreeId,
   284  			opts:    NewGetOpts(Query, trillian.TreeType_LOG),
   285  			getErr:  errors.New("get err"),
   286  			wantErr: true,
   287  			code:    codes.Unknown,
   288  		},
   289  		{
   290  			desc:      "commitErr",
   291  			treeID:    logTree.TreeId,
   292  			opts:      NewGetOpts(Query, trillian.TreeType_LOG),
   293  			commitErr: errors.New("commit err"),
   294  			wantErr:   true,
   295  			code:      codes.Unknown,
   296  		},
   297  	}
   298  
   299  	ctrl := gomock.NewController(t)
   300  	defer ctrl.Finish()
   301  
   302  	for _, test := range tests {
   303  		ctx := NewContext(context.Background(), test.ctxTree)
   304  
   305  		admin := storage.NewMockAdminStorage(ctrl)
   306  		tx := storage.NewMockReadOnlyAdminTX(ctrl)
   307  		admin.EXPECT().Snapshot(gomock.Any()).MaxTimes(1).Return(tx, test.beginErr)
   308  		tx.EXPECT().GetTree(gomock.Any(), test.treeID).MaxTimes(1).Return(test.storageTree, test.getErr)
   309  		tx.EXPECT().Close().MaxTimes(1).Return(nil)
   310  		tx.EXPECT().Commit().MaxTimes(1).Return(test.commitErr)
   311  
   312  		tree, err := GetTree(ctx, admin, test.treeID, test.opts)
   313  		if hasErr := err != nil; hasErr != test.wantErr {
   314  			t.Errorf("%v: GetTree() = (_, %q), wantErr = %v", test.desc, err, test.wantErr)
   315  			continue
   316  		} else if hasErr {
   317  			if status.Code(err) != test.code {
   318  				t.Errorf("%v: GetTree() = (_, %q), got ErrorCode: %v, want: %v", test.desc, err, status.Code(err), test.code)
   319  			}
   320  			continue
   321  		}
   322  
   323  		if !proto.Equal(tree, test.wantTree) {
   324  			diff := pretty.Compare(tree, test.wantTree)
   325  			t.Errorf("%v: post-GetTree diff:\n%v", test.desc, diff)
   326  		}
   327  	}
   328  }
   329  
   330  func TestHash(t *testing.T) {
   331  	tests := []struct {
   332  		hashAlgo sigpb.DigitallySigned_HashAlgorithm
   333  		wantHash crypto.Hash
   334  		wantErr  bool
   335  	}{
   336  		{hashAlgo: sigpb.DigitallySigned_NONE, wantErr: true},
   337  		{hashAlgo: sigpb.DigitallySigned_SHA256, wantHash: crypto.SHA256},
   338  	}
   339  
   340  	for _, test := range tests {
   341  		tree := *testonly.LogTree
   342  		tree.HashAlgorithm = test.hashAlgo
   343  
   344  		hash, err := Hash(&tree)
   345  		if hasErr := err != nil; hasErr != test.wantErr {
   346  			t.Errorf("Hash(%s) = (_, %q), wantErr = %v", test.hashAlgo, err, test.wantErr)
   347  			continue
   348  		} else if hasErr {
   349  			continue
   350  		}
   351  
   352  		if hash != test.wantHash {
   353  			t.Errorf("Hash(%s) = (%v, nil), want = (%v, nil)", test.hashAlgo, hash, test.wantHash)
   354  		}
   355  	}
   356  }
   357  
   358  func TestSigner(t *testing.T) {
   359  	ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   360  	if err != nil {
   361  		t.Fatalf("Error generating test ECDSA key: %v", err)
   362  	}
   363  
   364  	rsaKey, err := rsa.GenerateKey(rand.Reader, 1024)
   365  	if err != nil {
   366  		t.Fatalf("Error generating test RSA key: %v", err)
   367  	}
   368  
   369  	ctrl := gomock.NewController(t)
   370  	defer ctrl.Finish()
   371  
   372  	tests := []struct {
   373  		desc         string
   374  		sigAlgo      sigpb.DigitallySigned_SignatureAlgorithm
   375  		signer       crypto.Signer
   376  		newSignerErr error
   377  		wantErr      bool
   378  	}{
   379  		{
   380  			desc:    "anonymous",
   381  			sigAlgo: sigpb.DigitallySigned_ANONYMOUS,
   382  			wantErr: true,
   383  		},
   384  		{
   385  			desc:    "ecdsa",
   386  			sigAlgo: sigpb.DigitallySigned_ECDSA,
   387  			signer:  ecdsaKey,
   388  		},
   389  		{
   390  			desc:    "rsa",
   391  			sigAlgo: sigpb.DigitallySigned_RSA,
   392  			signer:  rsaKey,
   393  		},
   394  		{
   395  			desc:    "keyMismatch1",
   396  			sigAlgo: sigpb.DigitallySigned_ECDSA,
   397  			signer:  rsaKey,
   398  			wantErr: true,
   399  		},
   400  		{
   401  			desc:    "keyMismatch2",
   402  			sigAlgo: sigpb.DigitallySigned_RSA,
   403  			signer:  ecdsaKey,
   404  			wantErr: true,
   405  		},
   406  		{
   407  			desc:         "newSignerErr",
   408  			sigAlgo:      sigpb.DigitallySigned_ECDSA,
   409  			newSignerErr: errors.New("NewSigner() error"),
   410  			wantErr:      true,
   411  		},
   412  	}
   413  
   414  	ctx := context.Background()
   415  	for _, test := range tests {
   416  		t.Run(test.desc, func(t *testing.T) {
   417  			tree := *testonly.LogTree
   418  			tree.HashAlgorithm = sigpb.DigitallySigned_SHA256
   419  			tree.HashStrategy = trillian.HashStrategy_RFC6962_SHA256
   420  			tree.SignatureAlgorithm = test.sigAlgo
   421  
   422  			var wantKeyProto ptypes.DynamicAny
   423  			if err := ptypes.UnmarshalAny(tree.PrivateKey, &wantKeyProto); err != nil {
   424  				t.Fatalf("failed to unmarshal tree.PrivateKey: %v", err)
   425  			}
   426  
   427  			keys.RegisterHandler(wantKeyProto.Message, func(ctx context.Context, gotKeyProto proto.Message) (crypto.Signer, error) {
   428  				if !proto.Equal(gotKeyProto, wantKeyProto.Message) {
   429  					return nil, fmt.Errorf("NewSigner(_, %#v) called, want NewSigner(_, %#v)", gotKeyProto, wantKeyProto.Message)
   430  				}
   431  				return test.signer, test.newSignerErr
   432  			})
   433  			defer keys.UnregisterHandler(wantKeyProto.Message)
   434  
   435  			signer, err := Signer(ctx, &tree)
   436  			if hasErr := err != nil; hasErr != test.wantErr {
   437  				t.Fatalf("Signer(_, %s) = (_, %q), wantErr = %v", test.sigAlgo, err, test.wantErr)
   438  			} else if hasErr {
   439  				return
   440  			}
   441  
   442  			want := tcrypto.NewSigner(0, test.signer, crypto.SHA256)
   443  			if diff := pretty.Compare(signer, want); diff != "" {
   444  				t.Fatalf("post-Signer(_, %s) diff:\n%v", test.sigAlgo, diff)
   445  			}
   446  		})
   447  	}
   448  }