github.com/zorawar87/trillian@v1.2.1/server/admin/admin_server_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package admin
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"crypto/ecdsa"
    21  	"crypto/elliptic"
    22  	"crypto/rand"
    23  	"crypto/rsa"
    24  	"errors"
    25  	"fmt"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/golang/mock/gomock"
    31  	"github.com/golang/protobuf/proto"
    32  	"github.com/golang/protobuf/ptypes"
    33  	"github.com/golang/protobuf/ptypes/empty"
    34  	"github.com/golang/protobuf/ptypes/timestamp"
    35  	"github.com/google/trillian"
    36  	"github.com/google/trillian/crypto/keys"
    37  	"github.com/google/trillian/crypto/keys/der"
    38  	"github.com/google/trillian/crypto/keyspb"
    39  	"github.com/google/trillian/crypto/sigpb"
    40  	"github.com/google/trillian/extension"
    41  	"github.com/google/trillian/storage"
    42  	"github.com/google/trillian/storage/testonly"
    43  	"github.com/kylelemons/godebug/pretty"
    44  	"google.golang.org/genproto/protobuf/field_mask"
    45  	"google.golang.org/grpc/codes"
    46  	"google.golang.org/grpc/status"
    47  
    48  	ttestonly "github.com/google/trillian/testonly"
    49  )
    50  
    51  func TestServer_BeginError(t *testing.T) {
    52  	ctrl := gomock.NewController(t)
    53  	defer ctrl.Finish()
    54  
    55  	// PEM on the testonly trees is ECDSA, so let's use an ECDSA key for tests.
    56  	privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    57  	if err != nil {
    58  		t.Fatalf("Error generating test key: %v", err)
    59  	}
    60  
    61  	validTree := *testonly.LogTree
    62  
    63  	// Need to remove the public key, as it won't correspond to the privateKey that was just generated.
    64  	validTree.PublicKey = nil
    65  
    66  	keyProto := &empty.Empty{}
    67  	validTree.PrivateKey = ttestonly.MustMarshalAny(t, keyProto)
    68  	keys.RegisterHandler(fakeKeyProtoHandler(keyProto, privateKey))
    69  	defer keys.UnregisterHandler(keyProto)
    70  
    71  	tests := []struct {
    72  		desc     string
    73  		fn       func(context.Context, *Server) error
    74  		snapshot bool
    75  	}{
    76  		{
    77  			desc: "ListTrees",
    78  			fn: func(ctx context.Context, s *Server) error {
    79  				_, err := s.ListTrees(ctx, &trillian.ListTreesRequest{})
    80  				return err
    81  			},
    82  			snapshot: true,
    83  		},
    84  		{
    85  			desc: "GetTree",
    86  			fn: func(ctx context.Context, s *Server) error {
    87  				_, err := s.GetTree(ctx, &trillian.GetTreeRequest{TreeId: 12345})
    88  				return err
    89  			},
    90  			snapshot: true,
    91  		},
    92  		{
    93  			desc: "CreateTree",
    94  			fn: func(ctx context.Context, s *Server) error {
    95  				_, err := s.CreateTree(ctx, &trillian.CreateTreeRequest{Tree: &validTree})
    96  				return err
    97  			},
    98  		},
    99  	}
   100  
   101  	ctx := context.Background()
   102  	for _, test := range tests {
   103  		as := storage.NewMockAdminStorage(ctrl)
   104  		if test.snapshot {
   105  			as.EXPECT().Snapshot(gomock.Any()).Return(nil, errors.New("snapshot error"))
   106  		} else {
   107  			as.EXPECT().ReadWriteTransaction(gomock.Any(), gomock.Any()).Return(errors.New("begin error"))
   108  		}
   109  
   110  		registry := extension.Registry{
   111  			AdminStorage: as,
   112  		}
   113  
   114  		s := &Server{registry: registry}
   115  		if err := test.fn(ctx, s); err == nil {
   116  			t.Errorf("%v: got = %v, want non-nil", test.desc, err)
   117  		}
   118  	}
   119  }
   120  
   121  func TestServer_ListTrees(t *testing.T) {
   122  	ctrl := gomock.NewController(t)
   123  	defer ctrl.Finish()
   124  
   125  	activeLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
   126  	frozenLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
   127  	frozenLog.TreeState = trillian.TreeState_FROZEN
   128  	deletedLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
   129  	activeMap := proto.Clone(testonly.MapTree).(*trillian.Tree)
   130  	deletedMap := proto.Clone(testonly.MapTree).(*trillian.Tree)
   131  
   132  	id := int64(17)
   133  	nowPB := ptypes.TimestampNow()
   134  	for _, tree := range []*trillian.Tree{activeLog, frozenLog, deletedLog, activeMap, deletedMap} {
   135  		tree.TreeId = id
   136  		tree.CreateTime = proto.Clone(nowPB).(*timestamp.Timestamp)
   137  		tree.UpdateTime = proto.Clone(nowPB).(*timestamp.Timestamp)
   138  		id++
   139  		nowPB.Seconds++
   140  	}
   141  	for _, tree := range []*trillian.Tree{deletedLog, deletedMap} {
   142  		tree.Deleted = true
   143  		tree.DeleteTime = proto.Clone(nowPB).(*timestamp.Timestamp)
   144  		nowPB.Seconds++
   145  	}
   146  	nonDeletedTrees := []*trillian.Tree{activeLog, frozenLog, activeMap}
   147  	allTrees := []*trillian.Tree{activeLog, frozenLog, deletedLog, activeMap, deletedMap}
   148  
   149  	tests := []struct {
   150  		desc  string
   151  		req   *trillian.ListTreesRequest
   152  		trees []*trillian.Tree
   153  	}{
   154  		{desc: "emptyNonDeleted", req: &trillian.ListTreesRequest{}},
   155  		{desc: "empty", req: &trillian.ListTreesRequest{ShowDeleted: true}},
   156  		{desc: "nonDeleted", req: &trillian.ListTreesRequest{}, trees: nonDeletedTrees},
   157  		{
   158  			desc:  "allTreesDeleted",
   159  			req:   &trillian.ListTreesRequest{ShowDeleted: true},
   160  			trees: allTrees,
   161  		},
   162  	}
   163  
   164  	ctx := context.Background()
   165  	for _, test := range tests {
   166  		setup := setupAdminServer(
   167  			ctrl,
   168  			nil,  /* keygen */
   169  			true, /* snapshot */
   170  			true, /* shouldCommit */
   171  			false /* commitErr */)
   172  
   173  		tx := setup.snapshotTX
   174  		tx.EXPECT().ListTrees(gomock.Any(), test.req.ShowDeleted).Return(test.trees, nil)
   175  
   176  		s := setup.server
   177  		resp, err := s.ListTrees(ctx, test.req)
   178  		if err != nil {
   179  			t.Errorf("%v: ListTrees() returned err = %v", test.desc, err)
   180  			continue
   181  		}
   182  		want := []*trillian.Tree{}
   183  		for _, tree := range test.trees {
   184  			wantTree := proto.Clone(tree).(*trillian.Tree)
   185  			wantTree.PrivateKey = nil // redacted
   186  			want = append(want, wantTree)
   187  		}
   188  		for i, wantTree := range want {
   189  			if !proto.Equal(resp.Tree[i], wantTree) {
   190  				t.Errorf("%v: post-ListTrees() diff (-got +want):\n%v", test.desc, pretty.Compare(resp.Tree, want))
   191  				break
   192  			}
   193  		}
   194  	}
   195  }
   196  
   197  func TestServer_ListTreesErrors(t *testing.T) {
   198  	ctrl := gomock.NewController(t)
   199  	defer ctrl.Finish()
   200  
   201  	tests := []struct {
   202  		desc      string
   203  		listErr   error
   204  		commitErr bool
   205  	}{
   206  		{desc: "listErr", listErr: errors.New("error listing trees")},
   207  		{desc: "commitErr", commitErr: true},
   208  	}
   209  
   210  	ctx := context.Background()
   211  	for _, test := range tests {
   212  		setup := setupAdminServer(
   213  			ctrl,
   214  			nil,                 /* keygen */
   215  			true,                /* snapshot */
   216  			test.listErr == nil, /* shouldCommit */
   217  			test.commitErr /* commitErr */)
   218  
   219  		tx := setup.snapshotTX
   220  		tx.EXPECT().ListTrees(gomock.Any(), false).Return(nil, test.listErr)
   221  
   222  		s := setup.server
   223  		if _, err := s.ListTrees(ctx, &trillian.ListTreesRequest{}); err == nil {
   224  			t.Errorf("%v: ListTrees() returned err = nil, want non-nil", test.desc)
   225  		}
   226  	}
   227  }
   228  
   229  func TestServer_GetTree(t *testing.T) {
   230  	ctrl := gomock.NewController(t)
   231  	defer ctrl.Finish()
   232  
   233  	tests := []struct {
   234  		desc              string
   235  		getErr, commitErr bool
   236  	}{
   237  		{
   238  			desc: "success",
   239  		},
   240  		{
   241  			desc:   "unknownTree",
   242  			getErr: true,
   243  		},
   244  		{
   245  			desc:      "commitError",
   246  			commitErr: true,
   247  		},
   248  	}
   249  
   250  	ctx := context.Background()
   251  	for _, test := range tests {
   252  		setup := setupAdminServer(
   253  			ctrl,
   254  			nil,          /* keygen */
   255  			true,         /* snapshot */
   256  			!test.getErr, /* shouldCommit */
   257  			test.commitErr)
   258  
   259  		tx := setup.snapshotTX
   260  		s := setup.server
   261  
   262  		storedTree := *testonly.LogTree
   263  		storedTree.TreeId = 12345
   264  		if test.getErr {
   265  			tx.EXPECT().GetTree(gomock.Any(), storedTree.TreeId).Return(nil, errors.New("GetTree failed"))
   266  		} else {
   267  			tx.EXPECT().GetTree(gomock.Any(), storedTree.TreeId).Return(&storedTree, nil)
   268  		}
   269  		wantErr := test.getErr || test.commitErr
   270  
   271  		tree, err := s.GetTree(ctx, &trillian.GetTreeRequest{TreeId: storedTree.TreeId})
   272  		if hasErr := err != nil; hasErr != wantErr {
   273  			t.Errorf("%v: GetTree() = (_, %v), wantErr = %v", test.desc, err, wantErr)
   274  			continue
   275  		} else if hasErr {
   276  			continue
   277  		}
   278  
   279  		wantTree := storedTree
   280  		wantTree.PrivateKey = nil // redacted
   281  		if diff := pretty.Compare(tree, &wantTree); diff != "" {
   282  			t.Errorf("%v: post-GetTree diff (-got +want):\n%v", test.desc, diff)
   283  		}
   284  	}
   285  }
   286  
   287  func TestServer_CreateTree(t *testing.T) {
   288  	// PEM on the testonly trees is ECDSA, so let's use an ECDSA key for tests.
   289  	ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   290  	if err != nil {
   291  		t.Fatalf("Error generating test ECDSA key: %v", err)
   292  	}
   293  
   294  	rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096)
   295  	if err != nil {
   296  		t.Fatalf("Error generating test RSA key: %v", err)
   297  	}
   298  
   299  	// Need to change the public key to correspond with the ECDSA private key generated above.
   300  	validTree := *testonly.LogTree
   301  	// Except in key generation test cases, a keys.ProtoHandler will be registered that
   302  	// returns ecdsaPrivateKey when passed an empty proto.
   303  	wantKeyProto := &empty.Empty{}
   304  	validTree.PrivateKey = ttestonly.MustMarshalAny(t, wantKeyProto)
   305  	validTree.PublicKey = func() *keyspb.PublicKey {
   306  		pb, err := der.ToPublicProto(ecdsaPrivateKey.Public())
   307  		if err != nil {
   308  			t.Fatalf("Error marshaling ECDSA public key: %v", err)
   309  		}
   310  		return pb
   311  	}()
   312  
   313  	mismatchedPublicKey := validTree
   314  	mismatchedPublicKey.PublicKey = testonly.LogTree.GetPublicKey()
   315  
   316  	omittedPublicKey := validTree
   317  	omittedPublicKey.PublicKey = nil
   318  
   319  	omittedPrivateKey := validTree
   320  	omittedPrivateKey.PrivateKey = nil
   321  
   322  	omittedKeys := omittedPublicKey
   323  	omittedKeys.PrivateKey = nil
   324  
   325  	invalidTree := validTree
   326  	invalidTree.TreeState = trillian.TreeState_UNKNOWN_TREE_STATE
   327  
   328  	invalidHashAlgo := validTree
   329  	invalidHashAlgo.HashAlgorithm = sigpb.DigitallySigned_NONE
   330  
   331  	invalidHashStrategy := validTree
   332  	invalidHashStrategy.HashStrategy = trillian.HashStrategy_UNKNOWN_HASH_STRATEGY
   333  
   334  	invalidSignatureAlgo := validTree
   335  	invalidSignatureAlgo.SignatureAlgorithm = sigpb.DigitallySigned_ANONYMOUS
   336  
   337  	keySignatureMismatch := validTree
   338  	keySignatureMismatch.SignatureAlgorithm = sigpb.DigitallySigned_RSA
   339  
   340  	tests := []struct {
   341  		desc                  string
   342  		req                   *trillian.CreateTreeRequest
   343  		wantKeyGenerator      bool
   344  		createErr             error
   345  		commitErr, wantCommit bool
   346  		wantErr               string
   347  	}{
   348  		{
   349  			desc:       "validTree",
   350  			req:        &trillian.CreateTreeRequest{Tree: &validTree},
   351  			wantCommit: true,
   352  		},
   353  		{
   354  			desc:    "nilTree",
   355  			req:     &trillian.CreateTreeRequest{},
   356  			wantErr: "tree is required",
   357  		},
   358  		{
   359  			desc:    "mismatchedPublicKey",
   360  			req:     &trillian.CreateTreeRequest{Tree: &mismatchedPublicKey},
   361  			wantErr: "public and private keys are not a pair",
   362  		},
   363  		{
   364  			desc:    "omittedPrivateKey",
   365  			req:     &trillian.CreateTreeRequest{Tree: &omittedPrivateKey},
   366  			wantErr: "private_key or key_spec is required",
   367  		},
   368  		{
   369  			desc: "privateKeySpec",
   370  			req: &trillian.CreateTreeRequest{
   371  				Tree: &omittedKeys,
   372  				KeySpec: &keyspb.Specification{
   373  					Params: &keyspb.Specification_EcdsaParams{},
   374  				},
   375  			},
   376  			wantKeyGenerator: true,
   377  			wantCommit:       true,
   378  		},
   379  		{
   380  			desc: "privateKeySpecButNoKeyGenerator",
   381  			req: &trillian.CreateTreeRequest{
   382  				Tree: &omittedKeys,
   383  				KeySpec: &keyspb.Specification{
   384  					Params: &keyspb.Specification_EcdsaParams{},
   385  				},
   386  			},
   387  			wantErr: "key generation is not enabled",
   388  		},
   389  		{
   390  			// Tree specifies ECDSA signatures, but key specification provides RSA parameters.
   391  			desc: "privateKeySpecWithMismatchedAlgorithm",
   392  			req: &trillian.CreateTreeRequest{
   393  				Tree: &omittedKeys,
   394  				KeySpec: &keyspb.Specification{
   395  					Params: &keyspb.Specification_RsaParams{},
   396  				},
   397  			},
   398  			wantKeyGenerator: true,
   399  			wantErr:          "signature not supported by signer",
   400  		},
   401  		{
   402  			desc: "privateKeySpecAndPrivateKeyProvided",
   403  			req: &trillian.CreateTreeRequest{
   404  				Tree: &validTree,
   405  				KeySpec: &keyspb.Specification{
   406  					Params: &keyspb.Specification_EcdsaParams{},
   407  				},
   408  			},
   409  			wantKeyGenerator: true,
   410  			wantErr:          "private_key and key_spec fields are mutually exclusive",
   411  		},
   412  		{
   413  			desc: "privateKeySpecAndPublicKeyProvided",
   414  			req: &trillian.CreateTreeRequest{
   415  				Tree: &omittedPrivateKey,
   416  				KeySpec: &keyspb.Specification{
   417  					Params: &keyspb.Specification_EcdsaParams{},
   418  				},
   419  			},
   420  			wantKeyGenerator: true,
   421  			wantErr:          "public_key and key_spec fields are mutually exclusive",
   422  		},
   423  		{
   424  			desc:       "omittedPublicKey",
   425  			req:        &trillian.CreateTreeRequest{Tree: &omittedPublicKey},
   426  			wantCommit: true,
   427  		},
   428  		{
   429  			desc:    "invalidHashAlgo",
   430  			req:     &trillian.CreateTreeRequest{Tree: &invalidHashAlgo},
   431  			wantErr: "unexpected hash algorithm",
   432  		},
   433  		{
   434  			desc:    "invalidHashStrategy",
   435  			req:     &trillian.CreateTreeRequest{Tree: &invalidHashStrategy},
   436  			wantErr: "unknown hasher",
   437  		},
   438  		{
   439  			desc:    "invalidSignatureAlgo",
   440  			req:     &trillian.CreateTreeRequest{Tree: &invalidSignatureAlgo},
   441  			wantErr: "signature algorithm not supported",
   442  		},
   443  		{
   444  			desc:    "keySignatureMismatch",
   445  			req:     &trillian.CreateTreeRequest{Tree: &keySignatureMismatch},
   446  			wantErr: "signature not supported by signer",
   447  		},
   448  		{
   449  			desc:      "createErr",
   450  			req:       &trillian.CreateTreeRequest{Tree: &invalidTree},
   451  			createErr: errors.New("storage CreateTree failed"),
   452  			wantErr:   "storage CreateTree failed",
   453  		},
   454  		{
   455  			desc:       "commitError",
   456  			req:        &trillian.CreateTreeRequest{Tree: &validTree},
   457  			commitErr:  true,
   458  			wantCommit: true,
   459  			wantErr:    "commit error",
   460  		},
   461  	}
   462  
   463  	ctx := context.Background()
   464  	for _, test := range tests {
   465  		t.Run(test.desc, func(t *testing.T) {
   466  			ctrl := gomock.NewController(t)
   467  			defer ctrl.Finish()
   468  
   469  			var privateKey crypto.Signer = ecdsaPrivateKey
   470  			var keygen keys.ProtoGenerator
   471  			// If KeySpec is set, select the correct type of key to "generate".
   472  			if test.req.GetKeySpec() != nil {
   473  				switch keySpec := test.req.GetKeySpec().GetParams().(type) {
   474  				case *keyspb.Specification_EcdsaParams:
   475  					privateKey = ecdsaPrivateKey
   476  				case *keyspb.Specification_RsaParams:
   477  					privateKey = rsaPrivateKey
   478  				default:
   479  					t.Fatalf("unexpected KeySpec.Params type: %T", keySpec)
   480  				}
   481  
   482  				if test.wantKeyGenerator {
   483  					// Setup a fake key generator. If it receives the expected KeySpec, it returns wantKeyProto,
   484  					// which a keys.ProtoHandler will expect to receive later on.
   485  					keygen = fakeKeyProtoGenerator(test.req.GetKeySpec(), wantKeyProto)
   486  				}
   487  			}
   488  
   489  			keys.RegisterHandler(fakeKeyProtoHandler(wantKeyProto, privateKey))
   490  			defer keys.UnregisterHandler(wantKeyProto)
   491  
   492  			setup := setupAdminServer(ctrl, keygen, false /* snapshot */, test.wantCommit, test.commitErr)
   493  			tx := setup.tx
   494  			s := setup.server
   495  			nowPB := ptypes.TimestampNow()
   496  
   497  			if test.req.Tree != nil {
   498  				var newTree trillian.Tree
   499  				tx.EXPECT().CreateTree(gomock.Any(), gomock.Any()).MaxTimes(1).Do(func(ctx context.Context, tree *trillian.Tree) {
   500  					newTree = *tree
   501  					newTree.TreeId = 12345
   502  					newTree.CreateTime = nowPB
   503  					newTree.UpdateTime = nowPB
   504  				}).Return(&newTree, test.createErr)
   505  			}
   506  
   507  			// Copy test.req so that any changes CreateTree makes don't affect the original, which may be shared between tests.
   508  			reqCopy := proto.Clone(test.req).(*trillian.CreateTreeRequest)
   509  			tree, err := s.CreateTree(ctx, reqCopy)
   510  			switch gotErr := err != nil; {
   511  			case gotErr && !strings.Contains(err.Error(), test.wantErr):
   512  				t.Fatalf("CreateTree() = (_, %q), want (_, %q)", err, test.wantErr)
   513  			case gotErr:
   514  				return
   515  			case test.wantErr != "":
   516  				t.Fatalf("CreateTree() = (_, nil), want (_, %q)", test.wantErr)
   517  			}
   518  
   519  			wantTree := *test.req.Tree
   520  			wantTree.TreeId = 12345
   521  			wantTree.CreateTime = nowPB
   522  			wantTree.UpdateTime = nowPB
   523  			wantTree.PrivateKey = nil // redacted
   524  			wantTree.PublicKey, err = der.ToPublicProto(privateKey.Public())
   525  			if err != nil {
   526  				t.Fatalf("failed to marshal test public key as protobuf: %v", err)
   527  			}
   528  			if diff := pretty.Compare(tree, &wantTree); diff != "" {
   529  				t.Fatalf("post-CreateTree diff (-got +want):\n%v", diff)
   530  			}
   531  		})
   532  	}
   533  }
   534  
   535  func TestServer_CreateTree_AllowedTreeTypes(t *testing.T) {
   536  	ctrl := gomock.NewController(t)
   537  	defer ctrl.Finish()
   538  
   539  	tests := []struct {
   540  		desc      string
   541  		treeTypes []trillian.TreeType
   542  		req       *trillian.CreateTreeRequest
   543  		wantCode  codes.Code
   544  		wantMsg   string
   545  	}{
   546  		{
   547  			desc:      "mapOnLogServer",
   548  			treeTypes: []trillian.TreeType{trillian.TreeType_LOG},
   549  			req:       &trillian.CreateTreeRequest{Tree: testonly.MapTree},
   550  			wantCode:  codes.InvalidArgument,
   551  			wantMsg:   "tree type MAP not allowed",
   552  		},
   553  		{
   554  			desc:      "logOnMapServer",
   555  			treeTypes: []trillian.TreeType{trillian.TreeType_MAP},
   556  			req:       &trillian.CreateTreeRequest{Tree: testonly.LogTree},
   557  			wantCode:  codes.InvalidArgument,
   558  			wantMsg:   "tree type LOG not allowed",
   559  		},
   560  		{
   561  			desc:      "preorderedLogOnLogServer",
   562  			treeTypes: []trillian.TreeType{trillian.TreeType_LOG},
   563  			req:       &trillian.CreateTreeRequest{Tree: testonly.PreorderedLogTree},
   564  			wantCode:  codes.InvalidArgument,
   565  			wantMsg:   "tree type PREORDERED_LOG not allowed",
   566  		},
   567  		{
   568  			desc:      "preorderedLogOnMapServer",
   569  			treeTypes: []trillian.TreeType{trillian.TreeType_MAP},
   570  			req:       &trillian.CreateTreeRequest{Tree: testonly.PreorderedLogTree},
   571  			wantCode:  codes.InvalidArgument,
   572  			wantMsg:   "tree type PREORDERED_LOG not allowed",
   573  		},
   574  		{
   575  			desc:      "logOnLogServer",
   576  			treeTypes: []trillian.TreeType{trillian.TreeType_LOG},
   577  			req:       &trillian.CreateTreeRequest{Tree: testonly.LogTree},
   578  			wantCode:  codes.OK,
   579  		},
   580  		{
   581  			desc:      "preorderedLogAllowed",
   582  			treeTypes: []trillian.TreeType{trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG},
   583  			req:       &trillian.CreateTreeRequest{Tree: testonly.PreorderedLogTree},
   584  			wantCode:  codes.OK,
   585  		},
   586  		{
   587  			desc:      "mapOnMapServer",
   588  			treeTypes: []trillian.TreeType{trillian.TreeType_MAP},
   589  			req:       &trillian.CreateTreeRequest{Tree: testonly.MapTree},
   590  			wantCode:  codes.OK,
   591  		},
   592  		// treeTypes = nil is exercised by all other tests.
   593  	}
   594  
   595  	ctx := context.Background()
   596  	for _, test := range tests {
   597  		setup := setupAdminServer(
   598  			ctrl,
   599  			nil,   /* keygen */
   600  			false, /* snapshot */
   601  			test.wantCode == codes.OK, /* shouldCommit */
   602  			false /* commitErr */)
   603  		s := setup.server
   604  		tx := setup.tx
   605  		s.allowedTreeTypes = test.treeTypes
   606  
   607  		// Storage interactions aren't the focus of this test, so mocks are configured in a rather
   608  		// permissive way.
   609  		tx.EXPECT().CreateTree(gomock.Any(), gomock.Any()).AnyTimes().Return(&trillian.Tree{}, nil)
   610  
   611  		_, err := s.CreateTree(ctx, test.req)
   612  		switch s, ok := status.FromError(err); {
   613  		case !ok || s.Code() != test.wantCode:
   614  			t.Errorf("%v: CreateTree() returned err = %v, wantCode = %s", test.desc, err, test.wantCode)
   615  		case err != nil && !strings.Contains(err.Error(), test.wantMsg):
   616  			t.Errorf("%v: CreateTree() returned err = %q, wantMsg = %q", test.desc, err, test.wantMsg)
   617  		}
   618  	}
   619  }
   620  
   621  func TestServer_UpdateTree(t *testing.T) {
   622  	ctrl := gomock.NewController(t)
   623  	defer ctrl.Finish()
   624  
   625  	nowPB := ptypes.TimestampNow()
   626  	existingTree := *testonly.LogTree
   627  	existingTree.TreeId = 12345
   628  	existingTree.CreateTime = nowPB
   629  	existingTree.UpdateTime = nowPB
   630  	existingTree.MaxRootDuration = ptypes.DurationProto(1 * time.Nanosecond)
   631  
   632  	// Any valid proto works here, the type doesn't matter for this test.
   633  	settings := ttestonly.MustMarshalAny(t, &empty.Empty{})
   634  
   635  	// successTree specifies changes in all rw fields
   636  	successTree := &trillian.Tree{
   637  		TreeState:       trillian.TreeState_FROZEN,
   638  		DisplayName:     "Brand New Tree Name",
   639  		Description:     "Brand New Tree Desc",
   640  		StorageSettings: settings,
   641  		MaxRootDuration: ptypes.DurationProto(2 * time.Nanosecond),
   642  		PrivateKey:      ttestonly.MustMarshalAny(t, &empty.Empty{}),
   643  	}
   644  	successMask := &field_mask.FieldMask{
   645  		Paths: []string{"tree_state", "display_name", "description", "storage_settings", "max_root_duration", "private_key"},
   646  	}
   647  
   648  	successWant := existingTree
   649  	successWant.TreeState = successTree.TreeState
   650  	successWant.DisplayName = successTree.DisplayName
   651  	successWant.Description = successTree.Description
   652  	successWant.StorageSettings = successTree.StorageSettings
   653  	successWant.PrivateKey = nil // redacted on responses
   654  	successWant.MaxRootDuration = successTree.MaxRootDuration
   655  
   656  	tests := []struct {
   657  		desc                           string
   658  		req                            *trillian.UpdateTreeRequest
   659  		currentTree, wantTree          *trillian.Tree
   660  		updateErr                      error
   661  		commitErr, wantErr, wantCommit bool
   662  	}{
   663  		{
   664  			desc:        "success",
   665  			req:         &trillian.UpdateTreeRequest{Tree: successTree, UpdateMask: successMask},
   666  			currentTree: &existingTree,
   667  			wantTree:    &successWant,
   668  			wantCommit:  true,
   669  		},
   670  		{
   671  			desc:    "nilTree",
   672  			req:     &trillian.UpdateTreeRequest{},
   673  			wantErr: true,
   674  		},
   675  		{
   676  			desc:        "nilUpdateMask",
   677  			req:         &trillian.UpdateTreeRequest{Tree: successTree},
   678  			currentTree: &existingTree,
   679  			wantErr:     true,
   680  		},
   681  		{
   682  			desc:        "emptyUpdateMask",
   683  			req:         &trillian.UpdateTreeRequest{Tree: successTree, UpdateMask: &field_mask.FieldMask{}},
   684  			currentTree: &existingTree,
   685  			wantErr:     true,
   686  		},
   687  		{
   688  			desc: "readonlyField",
   689  			req: &trillian.UpdateTreeRequest{
   690  				Tree:       successTree,
   691  				UpdateMask: &field_mask.FieldMask{Paths: []string{"tree_id"}},
   692  			},
   693  			currentTree: &existingTree,
   694  			wantErr:     true,
   695  		},
   696  		{
   697  			desc:        "updateErr",
   698  			req:         &trillian.UpdateTreeRequest{Tree: successTree, UpdateMask: successMask},
   699  			updateErr:   errors.New("error updating tree"),
   700  			currentTree: &existingTree,
   701  			wantErr:     true,
   702  		},
   703  		{
   704  			desc:        "commitErr",
   705  			req:         &trillian.UpdateTreeRequest{Tree: successTree, UpdateMask: successMask},
   706  			currentTree: &existingTree,
   707  			commitErr:   true,
   708  			wantErr:     true,
   709  			wantCommit:  true,
   710  		},
   711  	}
   712  
   713  	ctx := context.Background()
   714  	for _, test := range tests {
   715  		setup := setupAdminServer(
   716  			ctrl,
   717  			nil,   /* keygen */
   718  			false, /* snapshot */
   719  			test.wantCommit,
   720  			test.commitErr)
   721  
   722  		tx := setup.tx
   723  		s := setup.server
   724  
   725  		if test.req.Tree != nil {
   726  			tx.EXPECT().UpdateTree(gomock.Any(), test.req.Tree.TreeId, gomock.Any()).MaxTimes(1).Do(func(ctx context.Context, treeID int64, updateFn func(*trillian.Tree)) {
   727  				// This step should be done by the storage layer, but since we're mocking it we have to trigger it ourselves.
   728  				updateFn(test.currentTree)
   729  			}).Return(test.currentTree, test.updateErr)
   730  		}
   731  
   732  		tree, err := s.UpdateTree(ctx, test.req)
   733  		if hasErr := err != nil; hasErr != test.wantErr {
   734  			t.Errorf("%v: UpdateTree() returned err = %q, wantErr = %v", test.desc, err, test.wantErr)
   735  			continue
   736  		} else if hasErr {
   737  			continue
   738  		}
   739  
   740  		if !proto.Equal(tree, test.wantTree) {
   741  			diff := pretty.Compare(tree, test.wantTree)
   742  			t.Errorf("%v: post-UpdateTree diff:\n%v", test.desc, diff)
   743  		}
   744  	}
   745  }
   746  
   747  func TestServer_DeleteTree(t *testing.T) {
   748  	ctrl := gomock.NewController(t)
   749  	defer ctrl.Finish()
   750  
   751  	logTree := proto.Clone(testonly.LogTree).(*trillian.Tree)
   752  	mapTree := proto.Clone(testonly.MapTree).(*trillian.Tree)
   753  	for i, tree := range []*trillian.Tree{logTree, mapTree} {
   754  		tree.TreeId = int64(i) + 10
   755  		tree.CreateTime, _ = ptypes.TimestampProto(time.Unix(int64(i)*3600, 0))
   756  		tree.UpdateTime = tree.CreateTime
   757  	}
   758  
   759  	tests := []struct {
   760  		desc string
   761  		tree *trillian.Tree
   762  	}{
   763  		{desc: "logTree", tree: logTree},
   764  		{desc: "mapTree", tree: mapTree},
   765  	}
   766  
   767  	ctx := context.Background()
   768  	for _, test := range tests {
   769  		setup := setupAdminServer(
   770  			ctrl,
   771  			nil,   /* keygen */
   772  			false, /* snapshot */
   773  			true,  /* shouldCommit */
   774  			false /* commitErr */)
   775  		req := &trillian.DeleteTreeRequest{TreeId: test.tree.TreeId}
   776  
   777  		tx := setup.tx
   778  		tx.EXPECT().SoftDeleteTree(gomock.Any(), req.TreeId).Return(test.tree, nil)
   779  
   780  		s := setup.server
   781  		got, err := s.DeleteTree(ctx, req)
   782  		if err != nil {
   783  			t.Errorf("%v: DeleteTree() returned err = %v", test.desc, err)
   784  			continue
   785  		}
   786  
   787  		want := proto.Clone(test.tree).(*trillian.Tree)
   788  		want.PrivateKey = nil // redacted
   789  		if !proto.Equal(got, want) {
   790  			diff := pretty.Compare(got, want)
   791  			t.Errorf("%v: post-DeleteTree() diff (-got +want):\n%v", test.desc, diff)
   792  		}
   793  	}
   794  }
   795  
   796  func TestServer_DeleteTreeErrors(t *testing.T) {
   797  	ctrl := gomock.NewController(t)
   798  	defer ctrl.Finish()
   799  
   800  	tests := []struct {
   801  		desc      string
   802  		deleteErr error
   803  		commitErr bool
   804  	}{
   805  		{desc: "deleteErr", deleteErr: errors.New("unknown tree")},
   806  		{desc: "commitErr", commitErr: true},
   807  	}
   808  
   809  	ctx := context.Background()
   810  	for _, test := range tests {
   811  		setup := setupAdminServer(
   812  			ctrl,
   813  			nil,   /* keygen */
   814  			false, /* snapshot */
   815  			test.deleteErr == nil, /* shouldCommit */
   816  			test.commitErr /* commitErr */)
   817  		req := &trillian.DeleteTreeRequest{TreeId: 10}
   818  
   819  		tx := setup.tx
   820  		tx.EXPECT().SoftDeleteTree(gomock.Any(), req.TreeId).Return(&trillian.Tree{}, test.deleteErr)
   821  
   822  		s := setup.server
   823  		if _, err := s.DeleteTree(ctx, req); err == nil {
   824  			t.Errorf("%v: DeleteTree() returned err = nil, want non-nil", test.desc)
   825  		}
   826  	}
   827  }
   828  
   829  func TestServer_UndeleteTree(t *testing.T) {
   830  	ctrl := gomock.NewController(t)
   831  	defer ctrl.Finish()
   832  
   833  	activeLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
   834  	frozenLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
   835  	frozenLog.TreeState = trillian.TreeState_FROZEN
   836  	activeMap := proto.Clone(testonly.MapTree).(*trillian.Tree)
   837  	for i, tree := range []*trillian.Tree{activeLog, frozenLog, activeMap} {
   838  		tree.TreeId = int64(i) + 10
   839  		tree.CreateTime, _ = ptypes.TimestampProto(time.Unix(int64(i)*3600, 0))
   840  		tree.UpdateTime = tree.CreateTime
   841  		tree.Deleted = true
   842  		tree.DeleteTime, _ = ptypes.TimestampProto(time.Unix(int64(i)*3600+10, 0))
   843  	}
   844  
   845  	tests := []struct {
   846  		desc string
   847  		tree *trillian.Tree
   848  	}{
   849  		{desc: "activeLog", tree: activeLog},
   850  		{desc: "frozenLog", tree: frozenLog},
   851  		{desc: "activeMap", tree: activeMap},
   852  	}
   853  
   854  	ctx := context.Background()
   855  	for _, test := range tests {
   856  		setup := setupAdminServer(
   857  			ctrl,
   858  			nil,   /* keygen */
   859  			false, /* snapshot */
   860  			true,  /* shouldCommit */
   861  			false /* commitErr */)
   862  		req := &trillian.UndeleteTreeRequest{TreeId: test.tree.TreeId}
   863  
   864  		tx := setup.tx
   865  		tx.EXPECT().UndeleteTree(gomock.Any(), req.TreeId).Return(test.tree, nil)
   866  
   867  		s := setup.server
   868  		got, err := s.UndeleteTree(ctx, req)
   869  		if err != nil {
   870  			t.Errorf("%v: UndeleteTree() returned err = %v", test.desc, err)
   871  			continue
   872  		}
   873  
   874  		want := proto.Clone(test.tree).(*trillian.Tree)
   875  		want.PrivateKey = nil // redacted
   876  		if !proto.Equal(got, want) {
   877  			diff := pretty.Compare(got, want)
   878  			t.Errorf("%v: post-UneleteTree() diff (-got +want):\n%v", test.desc, diff)
   879  		}
   880  	}
   881  }
   882  
   883  func TestServer_UndeleteTreeErrors(t *testing.T) {
   884  	ctrl := gomock.NewController(t)
   885  	defer ctrl.Finish()
   886  
   887  	tests := []struct {
   888  		desc        string
   889  		undeleteErr error
   890  		commitErr   bool
   891  	}{
   892  		{desc: "undeleteErr", undeleteErr: errors.New("unknown tree")},
   893  		{desc: "commitErr", commitErr: true},
   894  	}
   895  
   896  	ctx := context.Background()
   897  	for _, test := range tests {
   898  		setup := setupAdminServer(
   899  			ctrl,
   900  			nil,   /* keygen */
   901  			false, /* snapshot */
   902  			test.undeleteErr == nil, /* shouldCommit */
   903  			test.commitErr /* commitErr */)
   904  		req := &trillian.UndeleteTreeRequest{TreeId: 10}
   905  
   906  		tx := setup.tx
   907  		tx.EXPECT().UndeleteTree(gomock.Any(), req.TreeId).Return(&trillian.Tree{}, test.undeleteErr)
   908  
   909  		s := setup.server
   910  		if _, err := s.UndeleteTree(ctx, req); err == nil {
   911  			t.Errorf("%v: UndeleteTree() returned err = nil, want non-nil", test.desc)
   912  		}
   913  	}
   914  }
   915  
   916  // adminTestSetup contains an operational Server and required dependencies.
   917  // It's created via setupAdminServer.
   918  type adminTestSetup struct {
   919  	registry   extension.Registry
   920  	as         storage.AdminStorage
   921  	tx         *storage.MockAdminTX
   922  	snapshotTX *storage.MockReadOnlyAdminTX
   923  	server     *Server
   924  }
   925  
   926  // setupAdminServer configures mocks according to input parameters.
   927  // Storage will be set to use either snapshots or regular TXs via snapshot parameter.
   928  // Whether the snapshot/TX is expected to be committed (and if it should error doing so) is
   929  // controlled via shouldCommit and commitErr parameters.
   930  func setupAdminServer(ctrl *gomock.Controller, keygen keys.ProtoGenerator, snapshot, shouldCommit, commitErr bool) adminTestSetup {
   931  	as := &testonly.FakeAdminStorage{}
   932  
   933  	var snapshotTX *storage.MockReadOnlyAdminTX
   934  	var tx *storage.MockAdminTX
   935  	if snapshot {
   936  		snapshotTX = storage.NewMockReadOnlyAdminTX(ctrl)
   937  		snapshotTX.EXPECT().Close().MaxTimes(1).Return(nil)
   938  		as.ReadOnlyTX = append(as.ReadOnlyTX, snapshotTX)
   939  		if shouldCommit {
   940  			if commitErr {
   941  				snapshotTX.EXPECT().Commit().Return(errors.New("commit error"))
   942  			} else {
   943  				snapshotTX.EXPECT().Commit().Return(nil)
   944  			}
   945  		}
   946  	} else {
   947  		tx = storage.NewMockAdminTX(ctrl)
   948  		tx.EXPECT().Close().MaxTimes(1).Return(nil)
   949  		as.TX = append(as.TX, tx)
   950  		if shouldCommit {
   951  			if commitErr {
   952  				tx.EXPECT().Commit().Return(errors.New("commit error"))
   953  			} else {
   954  				tx.EXPECT().Commit().Return(nil)
   955  			}
   956  		}
   957  	}
   958  
   959  	registry := extension.Registry{
   960  		AdminStorage: as,
   961  		NewKeyProto:  keygen,
   962  	}
   963  
   964  	s := &Server{registry: registry}
   965  
   966  	return adminTestSetup{registry, as, tx, snapshotTX, s}
   967  }
   968  
   969  func fakeKeyProtoHandler(wantKeyProto proto.Message, key crypto.Signer) (proto.Message, keys.ProtoHandler) {
   970  	return wantKeyProto, func(ctx context.Context, gotKeyProto proto.Message) (crypto.Signer, error) {
   971  		if !proto.Equal(gotKeyProto, wantKeyProto) {
   972  			return nil, fmt.Errorf("NewSigner(_, %#v) called, want NewSigner(_, %#v)", gotKeyProto, wantKeyProto)
   973  		}
   974  		return key, nil
   975  	}
   976  }
   977  
   978  func fakeKeyProtoGenerator(wantKeySpec *keyspb.Specification, keyProto proto.Message) keys.ProtoGenerator {
   979  	return func(ctx context.Context, gotKeySpec *keyspb.Specification) (proto.Message, error) {
   980  		if !proto.Equal(gotKeySpec, wantKeySpec) {
   981  			return nil, fmt.Errorf("NewKeyProto(_, %#v) called, want NewKeyProto(_, %#v)", gotKeySpec, wantKeySpec)
   982  		}
   983  		return keyProto, nil
   984  	}
   985  }