github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/storage/testonly/admin_storage_tester.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testonly
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"fmt"
    21  	"reflect"
    22  	"sort"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/golang/protobuf/proto"
    27  	"github.com/golang/protobuf/ptypes"
    28  	"github.com/golang/protobuf/ptypes/any"
    29  	"github.com/golang/protobuf/ptypes/empty"
    30  	"github.com/google/trillian"
    31  	"github.com/google/trillian/crypto/keys"
    32  	"github.com/google/trillian/crypto/keys/pem"
    33  	"github.com/google/trillian/crypto/keyspb"
    34  	"github.com/google/trillian/crypto/sigpb"
    35  	"github.com/google/trillian/merkle/maphasher"
    36  	"github.com/google/trillian/storage"
    37  	"github.com/google/trillian/testonly"
    38  	"github.com/kylelemons/godebug/pretty"
    39  	"google.golang.org/grpc/codes"
    40  	"google.golang.org/grpc/status"
    41  
    42  	ktestonly "github.com/google/trillian/crypto/keys/testonly"
    43  	spb "github.com/google/trillian/crypto/sigpb"
    44  
    45  	_ "github.com/google/trillian/crypto/keys/der/proto" // PrivateKey proto handler
    46  	_ "github.com/google/trillian/crypto/keys/pem/proto" // PEMKeyFile proto handler
    47  	_ "github.com/google/trillian/merkle/maphasher"      // TEST_MAP_HASHER
    48  )
    49  
    50  const (
    51  	privateKeyPass = "towel"
    52  	privateKeyPEM  = `
    53  -----BEGIN EC PRIVATE KEY-----
    54  Proc-Type: 4,ENCRYPTED
    55  DEK-Info: DES-CBC,D95ECC664FF4BDEC
    56  
    57  Xy3zzHFwlFwjE8L1NCngJAFbu3zFf4IbBOCsz6Fa790utVNdulZncNCl2FMK3U2T
    58  sdoiTW8ymO+qgwcNrqvPVmjFRBtkN0Pn5lgbWhN/aK3TlS9IYJ/EShbMUzjgVzie
    59  S9+/31whWcH/FLeLJx4cBzvhgCtfquwA+s5ojeLYYsk=
    60  -----END EC PRIVATE KEY-----`
    61  	publicKeyPEM = `
    62  -----BEGIN PUBLIC KEY-----
    63  MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEywnWicNEQ8bn3GXcGpA+tiU4VL70
    64  Ws9xezgQPrg96YGsFrF6KYG68iqyHDlQ+4FWuKfGKXHn3ooVtB/pfawb5Q==
    65  -----END PUBLIC KEY-----`
    66  )
    67  
    68  // mustMarshalAny panics if ptypes.MarshalAny fails.
    69  func mustMarshalAny(pb proto.Message) *any.Any {
    70  	value, err := ptypes.MarshalAny(pb)
    71  	if err != nil {
    72  		panic(err)
    73  	}
    74  	return value
    75  }
    76  
    77  // TODO(phad): consider how to better break the import loop between trees and
    78  // trees/testonly (which is due to trees.Hash) than this.
    79  
    80  // hash returns the crypto.Hash configured by the tree.
    81  func hash(tree *trillian.Tree) (crypto.Hash, error) {
    82  	switch tree.HashAlgorithm {
    83  	case sigpb.DigitallySigned_SHA256:
    84  		return crypto.SHA256, nil
    85  	}
    86  	// There's no nil-like value for crypto.Hash, something has to be returned.
    87  	return crypto.SHA256, fmt.Errorf("unexpected hash algorithm: %s", tree.HashAlgorithm)
    88  }
    89  
    90  var (
    91  	// LogTree is a valid, LOG-type trillian.Tree for tests.
    92  	LogTree = &trillian.Tree{
    93  		TreeState:          trillian.TreeState_ACTIVE,
    94  		TreeType:           trillian.TreeType_LOG,
    95  		HashStrategy:       trillian.HashStrategy_RFC6962_SHA256,
    96  		HashAlgorithm:      spb.DigitallySigned_SHA256,
    97  		SignatureAlgorithm: spb.DigitallySigned_ECDSA,
    98  		DisplayName:        "Llamas Log",
    99  		Description:        "Registry of publicly-owned llamas",
   100  		PrivateKey: mustMarshalAny(&keyspb.PrivateKey{
   101  			Der: ktestonly.MustMarshalPrivatePEMToDER(privateKeyPEM, privateKeyPass),
   102  		}),
   103  		PublicKey: &keyspb.PublicKey{
   104  			Der: ktestonly.MustMarshalPublicPEMToDER(publicKeyPEM),
   105  		},
   106  		MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond),
   107  	}
   108  	// LogTreeEmptyRootHash is the root hash of LogTree when empty.
   109  	LogTreeEmptyRootHash = func() []byte {
   110  		hasher, err := hash(LogTree)
   111  		if err != nil {
   112  			panic(err)
   113  		}
   114  		return hasher.New().Sum(nil)
   115  	}()
   116  
   117  	// PreorderedLogTree is a valid, PREORDERED_LOG-type trillian.Tree for tests.
   118  	PreorderedLogTree = &trillian.Tree{
   119  		TreeState:          trillian.TreeState_ACTIVE,
   120  		TreeType:           trillian.TreeType_PREORDERED_LOG,
   121  		HashStrategy:       trillian.HashStrategy_RFC6962_SHA256,
   122  		HashAlgorithm:      spb.DigitallySigned_SHA256,
   123  		SignatureAlgorithm: spb.DigitallySigned_ECDSA,
   124  		DisplayName:        "Pre-ordered Log",
   125  		Description:        "Mirror registry of publicly-owned llamas",
   126  		PrivateKey: mustMarshalAny(&keyspb.PrivateKey{
   127  			Der: ktestonly.MustMarshalPrivatePEMToDER(privateKeyPEM, privateKeyPass),
   128  		}),
   129  		PublicKey: &keyspb.PublicKey{
   130  			Der: ktestonly.MustMarshalPublicPEMToDER(publicKeyPEM),
   131  		},
   132  		MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond),
   133  	}
   134  
   135  	// MapTree is a valid, MAP-type trillian.Tree for tests.
   136  	MapTree = &trillian.Tree{
   137  		TreeState:          trillian.TreeState_ACTIVE,
   138  		TreeType:           trillian.TreeType_MAP,
   139  		HashStrategy:       trillian.HashStrategy_TEST_MAP_HASHER,
   140  		HashAlgorithm:      spb.DigitallySigned_SHA256,
   141  		SignatureAlgorithm: spb.DigitallySigned_ECDSA,
   142  		DisplayName:        "Llamas Map",
   143  		Description:        "Key Transparency map for all your digital llama needs.",
   144  		PrivateKey: mustMarshalAny(&keyspb.PrivateKey{
   145  			Der: ktestonly.MustMarshalPrivatePEMToDER(testonly.DemoPrivateKey, testonly.DemoPrivateKeyPass),
   146  		}),
   147  		PublicKey: &keyspb.PublicKey{
   148  			Der: ktestonly.MustMarshalPublicPEMToDER(testonly.DemoPublicKey),
   149  		},
   150  		MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond),
   151  	}
   152  
   153  	// MapTreeEmptyRootHash is the root hash of MapTree when 'empty' (i.e. no leaves are set).
   154  	MapTreeEmptyRootHash = func() []byte {
   155  		hasher, err := hash(MapTree)
   156  		if err != nil {
   157  			panic(err)
   158  		}
   159  		mh := maphasher.New(hasher)
   160  		return mh.HashEmpty(0 /*treeID - unused*/, nil /*index - unused*/, mh.BitLen())
   161  	}()
   162  )
   163  
   164  // AdminStorageTester runs a suite of tests against AdminStorage implementations.
   165  type AdminStorageTester struct {
   166  	// NewAdminStorage returns an AdminStorage instance pointing to a clean
   167  	// test database.
   168  	NewAdminStorage func() storage.AdminStorage
   169  }
   170  
   171  // RunAllTests runs all AdminStorage tests.
   172  func (tester *AdminStorageTester) RunAllTests(t *testing.T) {
   173  	t.Run("TestCreateTree", tester.TestCreateTree)
   174  	t.Run("TestUpdateTree", tester.TestUpdateTree)
   175  	t.Run("TestListTrees", tester.TestListTrees)
   176  	t.Run("TestSoftDeleteTree", tester.TestSoftDeleteTree)
   177  	t.Run("TestSoftDeleteTreeErrors", tester.TestSoftDeleteTreeErrors)
   178  	t.Run("TestHardDeleteTree", tester.TestHardDeleteTree)
   179  	t.Run("TestHardDeleteTreeErrors", tester.TestHardDeleteTreeErrors)
   180  	t.Run("TestUndeleteTree", tester.TestUndeleteTree)
   181  	t.Run("TestUndeleteTreeErrors", tester.TestUndeleteTreeErrors)
   182  	t.Run("TestAdminTXReadWriteTransaction", tester.TestAdminTXReadWriteTransaction)
   183  }
   184  
   185  // TestCreateTree tests AdminStorage Tree creation.
   186  func (tester *AdminStorageTester) TestCreateTree(t *testing.T) {
   187  	// Check that validation runs, but leave details to the validation
   188  	// tests.
   189  	invalidTree := *LogTree
   190  	invalidTree.TreeType = trillian.TreeType_UNKNOWN_TREE_TYPE
   191  
   192  	validTree1 := *LogTree
   193  	validTree2 := *MapTree
   194  	validTree3 := *PreorderedLogTree
   195  
   196  	validTreeWithoutOptionals := *LogTree
   197  	validTreeWithoutOptionals.DisplayName = ""
   198  	validTreeWithoutOptionals.Description = ""
   199  
   200  	tests := []struct {
   201  		desc    string
   202  		tree    *trillian.Tree
   203  		wantErr bool
   204  	}{
   205  		{
   206  			desc:    "invalidTree",
   207  			tree:    &invalidTree,
   208  			wantErr: true,
   209  		},
   210  		{
   211  			desc: "validTree1",
   212  			tree: &validTree1,
   213  		},
   214  		{
   215  			desc: "validTree2",
   216  			tree: &validTree2,
   217  		},
   218  		{
   219  			desc: "validTree3",
   220  			tree: &validTree3,
   221  		},
   222  		{
   223  			desc: "validTreeWithoutOptionals",
   224  			tree: &validTreeWithoutOptionals,
   225  		},
   226  	}
   227  
   228  	ctx := context.Background()
   229  	s := tester.NewAdminStorage()
   230  	for _, test := range tests {
   231  		func() {
   232  			// Test CreateTree up to the tx commit
   233  			newTree, err := storage.CreateTree(ctx, s, test.tree)
   234  			if hasErr := err != nil; hasErr != test.wantErr {
   235  				t.Errorf("%v: CreateTree() = (_, %v), wantErr = %v", test.desc, err, test.wantErr)
   236  				return
   237  			} else if hasErr {
   238  				// Tested above
   239  				return
   240  			}
   241  
   242  			createTime := newTree.CreateTime
   243  			updateTime := newTree.UpdateTime
   244  			if _, err := ptypes.Timestamp(createTime); err != nil {
   245  				t.Errorf("%v: CreateTime malformed after creation: %v", test.desc, newTree)
   246  				return
   247  			}
   248  
   249  			switch {
   250  			case newTree.TreeId == 0:
   251  				t.Errorf("%v: TreeID not returned from creation: %v", test.desc, newTree)
   252  				return
   253  			case !reflect.DeepEqual(createTime, updateTime):
   254  				t.Errorf("%v: CreateTime != UpdateTime: %v", test.desc, newTree)
   255  				return
   256  			}
   257  
   258  			wantTree := *test.tree
   259  			wantTree.TreeId = newTree.TreeId
   260  			wantTree.CreateTime = createTime
   261  			wantTree.UpdateTime = updateTime
   262  			// Ignore storage_settings changes (OK to vary between implementations)
   263  			wantTree.StorageSettings = newTree.StorageSettings
   264  			if !proto.Equal(newTree, &wantTree) {
   265  				diff := pretty.Compare(newTree, &wantTree)
   266  				t.Errorf("%v: post-CreateTree diff:\n%v", test.desc, diff)
   267  				return
   268  			}
   269  
   270  			if err := assertStoredTree(ctx, s, newTree); err != nil {
   271  				t.Errorf("%v: %v", test.desc, err)
   272  			}
   273  		}()
   274  	}
   275  }
   276  
   277  // TestUpdateTree tests AdminStorage Tree updates.
   278  func (tester *AdminStorageTester) TestUpdateTree(t *testing.T) {
   279  	ctx := context.Background()
   280  	s := tester.NewAdminStorage()
   281  
   282  	unrelatedTree := makeTreeOrFail(ctx, s, spec{Tree: MapTree}, t.Fatalf)
   283  
   284  	referenceLog := *LogTree
   285  	validLog := referenceLog
   286  	validLog.TreeState = trillian.TreeState_FROZEN
   287  	validLog.DisplayName = "Frozen Tree"
   288  	validLog.Description = "A Frozen Tree"
   289  	validLogFunc := func(tree *trillian.Tree) {
   290  		tree.TreeState = validLog.TreeState
   291  		tree.DisplayName = validLog.DisplayName
   292  		tree.Description = validLog.Description
   293  	}
   294  
   295  	validLogWithoutOptionalsFunc := func(tree *trillian.Tree) {
   296  		tree.DisplayName = ""
   297  		tree.Description = ""
   298  	}
   299  	validLogWithoutOptionals := referenceLog
   300  	validLogWithoutOptionalsFunc(&validLogWithoutOptionals)
   301  
   302  	invalidLogFunc := func(tree *trillian.Tree) {
   303  		tree.TreeState = trillian.TreeState_UNKNOWN_TREE_STATE
   304  	}
   305  
   306  	readonlyChangedFunc := func(tree *trillian.Tree) {
   307  		tree.TreeType = trillian.TreeType_MAP
   308  	}
   309  
   310  	referenceMap := *MapTree
   311  	validMap := referenceMap
   312  	validMap.DisplayName = "Updated Map"
   313  	validMapFunc := func(tree *trillian.Tree) {
   314  		tree.DisplayName = validMap.DisplayName
   315  	}
   316  
   317  	newPrivateKey := &empty.Empty{}
   318  	privateKeyChangedButKeyMaterialSameTree := *LogTree
   319  	privateKeyChangedButKeyMaterialSameTree.PrivateKey = testonly.MustMarshalAny(t, newPrivateKey)
   320  	keys.RegisterHandler(newPrivateKey, func(ctx context.Context, pb proto.Message) (crypto.Signer, error) {
   321  		return pem.UnmarshalPrivateKey(privateKeyPEM, privateKeyPass)
   322  	})
   323  	defer keys.UnregisterHandler(newPrivateKey)
   324  
   325  	privateKeyChangedButKeyMaterialSameFunc := func(tree *trillian.Tree) {
   326  		tree.PrivateKey = privateKeyChangedButKeyMaterialSameTree.PrivateKey
   327  	}
   328  
   329  	privateKeyChangedAndKeyMaterialDifferentFunc := func(tree *trillian.Tree) {
   330  		tree.PrivateKey = testonly.MustMarshalAny(t, &keyspb.PrivateKey{
   331  			Der: ktestonly.MustMarshalPrivatePEMToDER(testonly.DemoPrivateKey, testonly.DemoPrivateKeyPass),
   332  		})
   333  	}
   334  
   335  	// Test for an unknown tree outside the loop: it makes the test logic simpler
   336  	if _, err := storage.UpdateTree(ctx, s, -1, func(tree *trillian.Tree) {}); err == nil {
   337  		t.Error("UpdateTree() for treeID -1 returned nil err")
   338  	}
   339  
   340  	tests := []struct {
   341  		desc         string
   342  		create, want *trillian.Tree
   343  		updateFunc   func(*trillian.Tree)
   344  		wantErr      bool
   345  	}{
   346  		{
   347  			desc:       "validLog",
   348  			create:     &referenceLog,
   349  			updateFunc: validLogFunc,
   350  			want:       &validLog,
   351  		},
   352  		{
   353  			desc:       "validLogWithoutOptionals",
   354  			create:     &referenceLog,
   355  			updateFunc: validLogWithoutOptionalsFunc,
   356  			want:       &validLogWithoutOptionals,
   357  		},
   358  		{
   359  			desc:       "invalidLog",
   360  			create:     &referenceLog,
   361  			updateFunc: invalidLogFunc,
   362  			wantErr:    true,
   363  		},
   364  		{
   365  			desc:       "readonlyChanged",
   366  			create:     &referenceLog,
   367  			updateFunc: readonlyChangedFunc,
   368  			wantErr:    true,
   369  		},
   370  		{
   371  			desc:       "validMap",
   372  			create:     &referenceMap,
   373  			updateFunc: validMapFunc,
   374  			want:       &validMap,
   375  		},
   376  		{
   377  			desc:       "privateKeyChangedButKeyMaterialSame",
   378  			create:     &referenceLog,
   379  			updateFunc: privateKeyChangedButKeyMaterialSameFunc,
   380  			want:       &privateKeyChangedButKeyMaterialSameTree,
   381  		},
   382  		{
   383  			desc:       "privateKeyChangedAndKeyMaterialDifferent",
   384  			create:     &referenceLog,
   385  			updateFunc: privateKeyChangedAndKeyMaterialDifferentFunc,
   386  			wantErr:    true,
   387  		},
   388  	}
   389  	for _, test := range tests {
   390  		createdTree, err := storage.CreateTree(ctx, s, test.create)
   391  		if err != nil {
   392  			t.Errorf("CreateTree() = (_, %v), want = (_, nil)", err)
   393  			continue
   394  		}
   395  
   396  		updatedTree, err := storage.UpdateTree(ctx, s, createdTree.TreeId, test.updateFunc)
   397  		if hasErr := err != nil; hasErr != test.wantErr {
   398  			t.Errorf("%v: UpdateTree() = (_, %v), wantErr = %v", test.desc, err, test.wantErr)
   399  			continue
   400  		} else if hasErr {
   401  			continue
   402  		}
   403  
   404  		if createdTree.TreeId != updatedTree.TreeId {
   405  			t.Errorf("%v: TreeId = %v, want = %v", test.desc, updatedTree.TreeId, createdTree.TreeId)
   406  		}
   407  		if !reflect.DeepEqual(createdTree.CreateTime, updatedTree.CreateTime) {
   408  			t.Errorf("%v: CreateTime = %v, want = %v", test.desc, updatedTree.CreateTime, createdTree.CreateTime)
   409  		}
   410  		createUpdateTime, err := ptypes.Timestamp(createdTree.UpdateTime)
   411  		if err != nil {
   412  			t.Errorf("%v: createdTree.UpdateTime malformed: %v", test.desc, err)
   413  		}
   414  		updatedUpdateTime, err := ptypes.Timestamp(updatedTree.UpdateTime)
   415  		if err != nil {
   416  			t.Errorf("%v: updatedTree.UpdateTime malformed: %v", test.desc, err)
   417  		}
   418  		if createUpdateTime.After(updatedUpdateTime) {
   419  			t.Errorf("%v: UpdateTime = %v, want >= %v", test.desc, updatedTree.UpdateTime, createdTree.UpdateTime)
   420  		}
   421  		// Copy storage-generated values to want before comparing
   422  		wantTree := *test.want
   423  		wantTree.TreeId = updatedTree.TreeId
   424  		wantTree.CreateTime = updatedTree.CreateTime
   425  		wantTree.UpdateTime = updatedTree.UpdateTime
   426  		// Ignore storage_settings changes (OK to vary between implementations)
   427  		wantTree.StorageSettings = updatedTree.StorageSettings
   428  		if !proto.Equal(updatedTree, &wantTree) {
   429  			diff := pretty.Compare(updatedTree, &wantTree)
   430  			t.Errorf("%v: updatedTree doesn't match wantTree:\n%s", test.desc, diff)
   431  		}
   432  
   433  		if err := assertStoredTree(ctx, s, updatedTree); err != nil {
   434  			t.Errorf("%v: %v", test.desc, err)
   435  		}
   436  
   437  		if err := assertStoredTree(ctx, s, unrelatedTree); err != nil {
   438  			t.Errorf("%v: %v", test.desc, err)
   439  		}
   440  	}
   441  }
   442  
   443  // TestListTrees tests both ListTreeIDs and ListTrees.
   444  func (tester *AdminStorageTester) TestListTrees(t *testing.T) {
   445  	ctx := context.Background()
   446  	s := tester.NewAdminStorage()
   447  
   448  	run := func(desc string, includeDeleted bool, wantTrees []*trillian.Tree) {
   449  		if err := storage.RunInAdminSnapshot(ctx, s, func(tx storage.ReadOnlyAdminTX) error {
   450  			if err := runListTreeIDsTest(ctx, tx, includeDeleted, wantTrees); err != nil {
   451  				t.Errorf("%v: %v", desc, err)
   452  			}
   453  			if err := runListTreesTest(ctx, tx, includeDeleted, wantTrees); err != nil {
   454  				t.Errorf("%v: %v", desc, err)
   455  			}
   456  			// Always return nil, as we're reporting errors independently above.
   457  			return nil
   458  		}); err != nil {
   459  			// Capture Begin() / Commit() errors
   460  			t.Errorf("%v: RunInAdminSnapshot() returned err = %v", desc, err)
   461  		}
   462  	}
   463  
   464  	// Do a first pass with an empty DB
   465  	run("empty", false /* includeDeleted */, nil /* wantTrees */)
   466  	run("emptyDeleted", true /* includeDeleted */, nil /* wantTrees */)
   467  
   468  	// Add some trees and do another pass
   469  	activeLog := makeTreeOrFail(ctx, s, spec{Tree: LogTree}, t.Fatalf)
   470  	frozenLog := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Frozen: true}, t.Fatalf)
   471  	deletedLog := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true}, t.Fatalf)
   472  	activeMap := makeTreeOrFail(ctx, s, spec{Tree: MapTree}, t.Fatalf)
   473  	run("multipleTrees", false /* includeDeleted */, []*trillian.Tree{activeLog, frozenLog, activeMap})
   474  	run("multipleTreesDeleted", true /* includeDeleted */, []*trillian.Tree{activeLog, frozenLog, deletedLog, activeMap})
   475  }
   476  
   477  func runListTreeIDsTest(ctx context.Context, tx storage.ReadOnlyAdminTX, includeDeleted bool, wantTrees []*trillian.Tree) error {
   478  	got, err := tx.ListTreeIDs(ctx, includeDeleted)
   479  	if err != nil {
   480  		return fmt.Errorf("ListTreeIDs() returned err = %v", err)
   481  	}
   482  
   483  	want := make([]int64, 0, len(wantTrees))
   484  	for _, tree := range wantTrees {
   485  		want = append(want, tree.TreeId)
   486  	}
   487  
   488  	sort.Slice(got, func(i, j int) bool { return got[i] < got[j] })
   489  	sort.Slice(want, func(i, j int) bool { return want[i] < want[j] })
   490  	if diff := pretty.Compare(got, want); diff != "" {
   491  		return fmt.Errorf("post-ListTreeIDs() diff (-got +want):\n%v", diff)
   492  	}
   493  	return nil
   494  }
   495  
   496  func runListTreesTest(ctx context.Context, tx storage.ReadOnlyAdminTX, includeDeleted bool, wantTrees []*trillian.Tree) error {
   497  	got, err := tx.ListTrees(ctx, includeDeleted)
   498  	if err != nil {
   499  		return fmt.Errorf("ListTrees() returned err = %v", err)
   500  	}
   501  
   502  	if len(got) != len(wantTrees) {
   503  		return fmt.Errorf("ListTrees() returned %v trees, want = %v", len(got), len(wantTrees))
   504  	}
   505  
   506  	want := wantTrees
   507  	sort.Slice(got, func(i, j int) bool { return got[i].TreeId < got[j].TreeId })
   508  	sort.Slice(want, func(i, j int) bool { return want[i].TreeId < want[j].TreeId })
   509  
   510  	for i, wantTree := range want {
   511  		if !proto.Equal(got[i], wantTree) {
   512  			return fmt.Errorf("post-ListTrees() diff (-got +want):\n%v", pretty.Compare(got, want))
   513  		}
   514  	}
   515  	return nil
   516  }
   517  
   518  // TestSoftDeleteTree tests success scenarios of SoftDeleteTree.
   519  func (tester *AdminStorageTester) TestSoftDeleteTree(t *testing.T) {
   520  	ctx := context.Background()
   521  	s := tester.NewAdminStorage()
   522  
   523  	logTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree}, t.Fatalf)
   524  	mapTree := makeTreeOrFail(ctx, s, spec{Tree: MapTree}, t.Fatalf)
   525  
   526  	tests := []struct {
   527  		desc string
   528  		tree *trillian.Tree
   529  	}{
   530  		{desc: "logTree", tree: logTree},
   531  		{desc: "mapTree", tree: mapTree},
   532  	}
   533  	for _, test := range tests {
   534  		deletedTree, err := storage.SoftDeleteTree(ctx, s, test.tree.TreeId)
   535  		if err != nil {
   536  			t.Errorf("%v: SoftDeleteTree() returned err = %v", test.desc, err)
   537  			continue
   538  		}
   539  
   540  		if deletedTree.GetDeleteTime().GetSeconds() == 0 {
   541  			t.Errorf("%v: tree.DeleteTime = %v, want > 0", test.desc, deletedTree.DeleteTime)
   542  		}
   543  
   544  		wantTree := proto.Clone(test.tree).(*trillian.Tree)
   545  		wantTree.Deleted = true
   546  		wantTree.DeleteTime = deletedTree.DeleteTime
   547  		if got, want := deletedTree, wantTree; !proto.Equal(got, want) {
   548  			t.Errorf("%v: post-SoftDeleteTree diff (-got +want):\n%v", test.desc, pretty.Compare(got, want))
   549  		}
   550  
   551  		if err := assertStoredTree(ctx, s, deletedTree); err != nil {
   552  			t.Errorf("%v: %v", test.desc, err)
   553  		}
   554  	}
   555  }
   556  
   557  // TestSoftDeleteTreeErrors tests error scenarios of SoftDeleteTree.
   558  func (tester *AdminStorageTester) TestSoftDeleteTreeErrors(t *testing.T) {
   559  	ctx := context.Background()
   560  	s := tester.NewAdminStorage()
   561  
   562  	softDeleted := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true}, t.Fatalf)
   563  
   564  	tests := []struct {
   565  		desc     string
   566  		treeID   int64
   567  		wantCode codes.Code
   568  	}{
   569  		{desc: "unknownTree", treeID: 12345, wantCode: codes.NotFound},
   570  		{desc: "alreadyDeleted", treeID: softDeleted.TreeId, wantCode: codes.FailedPrecondition},
   571  	}
   572  	for _, test := range tests {
   573  		if _, err := storage.SoftDeleteTree(ctx, s, test.treeID); status.Code(err) != test.wantCode {
   574  			t.Errorf("%v: SoftDeleteTree() returned err = %v, wantCode = %s", test.desc, err, test.wantCode)
   575  		}
   576  	}
   577  }
   578  
   579  // TestHardDeleteTree tests success scenarios of HardDeleteTree.
   580  func (tester *AdminStorageTester) TestHardDeleteTree(t *testing.T) {
   581  	ctx := context.Background()
   582  	s := tester.NewAdminStorage()
   583  
   584  	logTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true}, t.Fatalf)
   585  	frozenTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true, Frozen: true}, t.Fatalf)
   586  	mapTree := makeTreeOrFail(ctx, s, spec{Tree: MapTree, Deleted: true}, t.Fatalf)
   587  
   588  	tests := []struct {
   589  		desc   string
   590  		treeID int64
   591  	}{
   592  		{desc: "logTree", treeID: logTree.TreeId},
   593  		{desc: "frozenTree", treeID: frozenTree.TreeId},
   594  		{desc: "mapTree", treeID: mapTree.TreeId},
   595  	}
   596  	for _, test := range tests {
   597  		if err := storage.HardDeleteTree(ctx, s, test.treeID); err != nil {
   598  			t.Errorf("%v: HardDeleteTree() returned err = %v", test.desc, err)
   599  			continue
   600  		}
   601  	}
   602  }
   603  
   604  // TestHardDeleteTreeErrors tests error scenarios of HardDeleteTree.
   605  func (tester *AdminStorageTester) TestHardDeleteTreeErrors(t *testing.T) {
   606  	ctx := context.Background()
   607  	s := tester.NewAdminStorage()
   608  
   609  	activeTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree}, t.Fatalf)
   610  
   611  	tests := []struct {
   612  		desc     string
   613  		treeID   int64
   614  		wantCode codes.Code
   615  	}{
   616  		{desc: "unknownTree", treeID: 12345, wantCode: codes.NotFound},
   617  		{desc: "activeTree", treeID: activeTree.TreeId, wantCode: codes.FailedPrecondition},
   618  	}
   619  	for _, test := range tests {
   620  		if err := storage.HardDeleteTree(ctx, s, test.treeID); status.Code(err) != test.wantCode {
   621  			t.Errorf("%v: HardDeleteTree() returned err = %v, wantCode = %s", test.desc, err, test.wantCode)
   622  		}
   623  	}
   624  }
   625  
   626  // TestUndeleteTree tests success scenarios of UndeleteTree.
   627  func (tester *AdminStorageTester) TestUndeleteTree(t *testing.T) {
   628  	ctx := context.Background()
   629  	s := tester.NewAdminStorage()
   630  
   631  	activeDeleted := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true}, t.Fatalf)
   632  	frozenDeleted := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Frozen: true, Deleted: true}, t.Fatalf)
   633  
   634  	tests := []struct {
   635  		desc string
   636  		tree *trillian.Tree
   637  	}{
   638  		{desc: "activeTree", tree: activeDeleted},
   639  		{desc: "frozenTree", tree: frozenDeleted},
   640  	}
   641  	for _, test := range tests {
   642  		tree, err := storage.UndeleteTree(ctx, s, test.tree.TreeId)
   643  		if err != nil {
   644  			t.Errorf("%v: UndeleteTree() returned err = %v", test.desc, err)
   645  			continue
   646  		}
   647  
   648  		want := proto.Clone(test.tree).(*trillian.Tree)
   649  		want.Deleted = false
   650  		want.DeleteTime = nil
   651  		if got := tree; !proto.Equal(got, want) {
   652  			t.Errorf("%v: post-UndeleteTree diff (-got +want):\n%v", test.desc, pretty.Compare(got, want))
   653  		}
   654  
   655  		if err := assertStoredTree(ctx, s, tree); err != nil {
   656  			t.Errorf("%v: %v", test.desc, err)
   657  		}
   658  	}
   659  }
   660  
   661  // TestUndeleteTreeErrors tests error scenarios of UndeleteTree.
   662  func (tester *AdminStorageTester) TestUndeleteTreeErrors(t *testing.T) {
   663  	ctx := context.Background()
   664  	s := tester.NewAdminStorage()
   665  
   666  	activeTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree}, t.Fatalf)
   667  
   668  	tests := []struct {
   669  		desc     string
   670  		treeID   int64
   671  		wantCode codes.Code
   672  	}{
   673  		{desc: "unknownTree", treeID: 12345, wantCode: codes.NotFound},
   674  		{desc: "activeTree", treeID: activeTree.TreeId, wantCode: codes.FailedPrecondition},
   675  	}
   676  	for _, test := range tests {
   677  		if _, err := storage.UndeleteTree(ctx, s, test.treeID); status.Code(err) != test.wantCode {
   678  			t.Errorf("%v: UndeleteTree() returned err = %v, wantCode = %s", test.desc, err, test.wantCode)
   679  		}
   680  	}
   681  }
   682  
   683  // TestAdminTXReadWriteTransaction tests the ReadWriteTransaction method on AdminStorage.
   684  func (tester *AdminStorageTester) TestAdminTXReadWriteTransaction(t *testing.T) {
   685  	tests := []struct {
   686  		wantCommit bool
   687  	}{
   688  		{wantCommit: true},
   689  		{wantCommit: false},
   690  	}
   691  
   692  	ctx := context.Background()
   693  	s := tester.NewAdminStorage()
   694  
   695  	var tree *trillian.Tree
   696  
   697  	for i, test := range tests {
   698  		t.Run(fmt.Sprintf("%+v", test), func(t *testing.T) {
   699  			err := s.ReadWriteTransaction(ctx, func(ctx context.Context, tx storage.AdminTX) error {
   700  				var err error
   701  				tree, err = tx.CreateTree(ctx, LogTree)
   702  				if err != nil {
   703  					t.Fatalf("%v: CreateTree() = (_, %v), want = (_, nil)", i, err)
   704  				}
   705  				if !test.wantCommit {
   706  					return fmt.Errorf("No commit %d", i)
   707  				}
   708  				return nil
   709  			})
   710  			if (err != nil && test.wantCommit) ||
   711  				(err == nil && !test.wantCommit) {
   712  				t.Fatalf("%v: ReadWriteTransaction() = (_, %v), want = (_, nil)", i, err)
   713  			}
   714  
   715  			tx2, err := s.Snapshot(ctx)
   716  			if err != nil {
   717  				t.Fatalf("%v: Snapshot() = (_, %v), want = (_, nil)", i, err)
   718  			}
   719  			defer tx2.Close()
   720  			_, err = tx2.GetTree(ctx, tree.TreeId)
   721  			if hasErr := err != nil; !test.wantCommit != hasErr {
   722  				t.Errorf("%v: GetTree() = (_, %v), but wantCommit = %v", i, err, test.wantCommit)
   723  			}
   724  
   725  			// Multiple Close() calls are fine too
   726  			if err := tx2.Close(); err != nil {
   727  				t.Errorf("%v: Close() = %v, want = nil", i, err)
   728  				return
   729  			}
   730  		})
   731  	}
   732  }
   733  
   734  // assertStoredTree verifies that "want" is equal to the tree stored under its ID.
   735  func assertStoredTree(ctx context.Context, s storage.AdminStorage, want *trillian.Tree) error {
   736  	got, err := storage.GetTree(ctx, s, want.TreeId)
   737  	if err != nil {
   738  		return fmt.Errorf("GetTree() returned err = %v", err)
   739  	}
   740  	if !proto.Equal(got, want) {
   741  		return fmt.Errorf("post-GetTree() diff (-got +want):\n%v", pretty.Compare(got, want))
   742  	}
   743  	return nil
   744  }
   745  
   746  type spec struct {
   747  	Tree            *trillian.Tree
   748  	Frozen, Deleted bool
   749  }
   750  
   751  // makeTreeOrFail delegates to makeTree. If makeTree returns a non-nil error, failFn is called.
   752  func makeTreeOrFail(ctx context.Context, s storage.AdminStorage, spec spec, failFn func(string, ...interface{})) *trillian.Tree {
   753  	tree, err := makeTree(ctx, s, spec)
   754  	if err != nil {
   755  		failFn("makeTree() returned err = %v", err)
   756  		return nil
   757  	}
   758  	return tree
   759  }
   760  
   761  // makeTree creates a tree and updates it to Frozen and/or Deleted, according to "spec".
   762  func makeTree(ctx context.Context, s storage.AdminStorage, spec spec) (*trillian.Tree, error) {
   763  	tree := proto.Clone(spec.Tree).(*trillian.Tree)
   764  
   765  	var err error
   766  	tree, err = storage.CreateTree(ctx, s, tree)
   767  	if err != nil {
   768  		return nil, err
   769  	}
   770  
   771  	if spec.Frozen {
   772  		tree, err = storage.UpdateTree(ctx, s, tree.TreeId, func(t *trillian.Tree) {
   773  			t.TreeState = trillian.TreeState_FROZEN
   774  		})
   775  		if err != nil {
   776  			return nil, err
   777  		}
   778  	}
   779  
   780  	if spec.Deleted {
   781  		tree, err = storage.SoftDeleteTree(ctx, s, tree.TreeId)
   782  		if err != nil {
   783  			return nil, err
   784  		}
   785  	}
   786  
   787  	// Sanity checks
   788  	if spec.Frozen && tree.TreeState != trillian.TreeState_FROZEN {
   789  		return nil, fmt.Errorf("makeTree(): TreeState = %s, want = %s", tree.TreeState, trillian.TreeState_FROZEN)
   790  	}
   791  	if tree.Deleted != spec.Deleted {
   792  		return nil, fmt.Errorf("makeTree(): Deleted = %v, want = %v", tree.Deleted, spec.Deleted)
   793  	}
   794  
   795  	return tree, nil
   796  }