github.com/bartle-stripe/trillian@v1.2.1/storage/testonly/admin_storage_tester.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testonly
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"fmt"
    21  	"reflect"
    22  	"sort"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/golang/protobuf/proto"
    27  	"github.com/golang/protobuf/ptypes"
    28  	"github.com/golang/protobuf/ptypes/any"
    29  	"github.com/golang/protobuf/ptypes/empty"
    30  	"github.com/google/trillian"
    31  	"github.com/google/trillian/crypto/keys"
    32  	"github.com/google/trillian/crypto/keys/pem"
    33  	"github.com/google/trillian/crypto/keyspb"
    34  	"github.com/google/trillian/crypto/sigpb"
    35  	"github.com/google/trillian/merkle/maphasher"
    36  	"github.com/google/trillian/storage"
    37  	"github.com/google/trillian/testonly"
    38  	"github.com/kylelemons/godebug/pretty"
    39  	"google.golang.org/grpc/codes"
    40  	"google.golang.org/grpc/status"
    41  
    42  	ktestonly "github.com/google/trillian/crypto/keys/testonly"
    43  	spb "github.com/google/trillian/crypto/sigpb"
    44  
    45  	_ "github.com/google/trillian/crypto/keys/der/proto" // PrivateKey proto handler
    46  	_ "github.com/google/trillian/crypto/keys/pem/proto" // PEMKeyFile proto handler
    47  	_ "github.com/google/trillian/merkle/maphasher"      // TEST_MAP_HASHER
    48  )
    49  
    50  const (
    51  	privateKeyPass = "towel"
    52  	privateKeyPEM  = `
    53  -----BEGIN EC PRIVATE KEY-----
    54  Proc-Type: 4,ENCRYPTED
    55  DEK-Info: DES-CBC,D95ECC664FF4BDEC
    56  
    57  Xy3zzHFwlFwjE8L1NCngJAFbu3zFf4IbBOCsz6Fa790utVNdulZncNCl2FMK3U2T
    58  sdoiTW8ymO+qgwcNrqvPVmjFRBtkN0Pn5lgbWhN/aK3TlS9IYJ/EShbMUzjgVzie
    59  S9+/31whWcH/FLeLJx4cBzvhgCtfquwA+s5ojeLYYsk=
    60  -----END EC PRIVATE KEY-----`
    61  	// PublicKeyPEM is the public key for: LogTree and PreorderedLogTree
    62  	PublicKeyPEM = `
    63  -----BEGIN PUBLIC KEY-----
    64  MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEywnWicNEQ8bn3GXcGpA+tiU4VL70
    65  Ws9xezgQPrg96YGsFrF6KYG68iqyHDlQ+4FWuKfGKXHn3ooVtB/pfawb5Q==
    66  -----END PUBLIC KEY-----`
    67  )
    68  
    69  // mustMarshalAny panics if ptypes.MarshalAny fails.
    70  func mustMarshalAny(pb proto.Message) *any.Any {
    71  	value, err := ptypes.MarshalAny(pb)
    72  	if err != nil {
    73  		panic(err)
    74  	}
    75  	return value
    76  }
    77  
    78  // TODO(phad): consider how to better break the import loop between trees and
    79  // trees/testonly (which is due to trees.Hash) than this.
    80  
    81  // hash returns the crypto.Hash configured by the tree.
    82  func hash(tree *trillian.Tree) (crypto.Hash, error) {
    83  	switch tree.HashAlgorithm {
    84  	case sigpb.DigitallySigned_SHA256:
    85  		return crypto.SHA256, nil
    86  	}
    87  	// There's no nil-like value for crypto.Hash, something has to be returned.
    88  	return crypto.SHA256, fmt.Errorf("unexpected hash algorithm: %s", tree.HashAlgorithm)
    89  }
    90  
    91  var (
    92  	// LogTree is a valid, LOG-type trillian.Tree for tests.
    93  	LogTree = &trillian.Tree{
    94  		TreeState:          trillian.TreeState_ACTIVE,
    95  		TreeType:           trillian.TreeType_LOG,
    96  		HashStrategy:       trillian.HashStrategy_RFC6962_SHA256,
    97  		HashAlgorithm:      spb.DigitallySigned_SHA256,
    98  		SignatureAlgorithm: spb.DigitallySigned_ECDSA,
    99  		DisplayName:        "Llamas Log",
   100  		Description:        "Registry of publicly-owned llamas",
   101  		PrivateKey: mustMarshalAny(&keyspb.PrivateKey{
   102  			Der: ktestonly.MustMarshalPrivatePEMToDER(privateKeyPEM, privateKeyPass),
   103  		}),
   104  		PublicKey: &keyspb.PublicKey{
   105  			Der: ktestonly.MustMarshalPublicPEMToDER(PublicKeyPEM),
   106  		},
   107  		MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond),
   108  	}
   109  	// LogTreeEmptyRootHash is the root hash of LogTree when empty.
   110  	LogTreeEmptyRootHash = func() []byte {
   111  		hasher, err := hash(LogTree)
   112  		if err != nil {
   113  			panic(err)
   114  		}
   115  		return hasher.New().Sum(nil)
   116  	}()
   117  
   118  	// PreorderedLogTree is a valid, PREORDERED_LOG-type trillian.Tree for tests.
   119  	PreorderedLogTree = &trillian.Tree{
   120  		TreeState:          trillian.TreeState_ACTIVE,
   121  		TreeType:           trillian.TreeType_PREORDERED_LOG,
   122  		HashStrategy:       trillian.HashStrategy_RFC6962_SHA256,
   123  		HashAlgorithm:      spb.DigitallySigned_SHA256,
   124  		SignatureAlgorithm: spb.DigitallySigned_ECDSA,
   125  		DisplayName:        "Pre-ordered Log",
   126  		Description:        "Mirror registry of publicly-owned llamas",
   127  		PrivateKey: mustMarshalAny(&keyspb.PrivateKey{
   128  			Der: ktestonly.MustMarshalPrivatePEMToDER(privateKeyPEM, privateKeyPass),
   129  		}),
   130  		PublicKey: &keyspb.PublicKey{
   131  			Der: ktestonly.MustMarshalPublicPEMToDER(PublicKeyPEM),
   132  		},
   133  		MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond),
   134  	}
   135  
   136  	// MapTree is a valid, MAP-type trillian.Tree for tests.
   137  	MapTree = &trillian.Tree{
   138  		TreeState:          trillian.TreeState_ACTIVE,
   139  		TreeType:           trillian.TreeType_MAP,
   140  		HashStrategy:       trillian.HashStrategy_TEST_MAP_HASHER,
   141  		HashAlgorithm:      spb.DigitallySigned_SHA256,
   142  		SignatureAlgorithm: spb.DigitallySigned_ECDSA,
   143  		DisplayName:        "Llamas Map",
   144  		Description:        "Key Transparency map for all your digital llama needs.",
   145  		PrivateKey: mustMarshalAny(&keyspb.PrivateKey{
   146  			Der: ktestonly.MustMarshalPrivatePEMToDER(testonly.DemoPrivateKey, testonly.DemoPrivateKeyPass),
   147  		}),
   148  		PublicKey: &keyspb.PublicKey{
   149  			Der: ktestonly.MustMarshalPublicPEMToDER(testonly.DemoPublicKey),
   150  		},
   151  		MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond),
   152  	}
   153  
   154  	// MapTreeEmptyRootHash is the root hash of MapTree when 'empty' (i.e. no leaves are set).
   155  	MapTreeEmptyRootHash = func() []byte {
   156  		hasher, err := hash(MapTree)
   157  		if err != nil {
   158  			panic(err)
   159  		}
   160  		mh := maphasher.New(hasher)
   161  		return mh.HashEmpty(0 /*treeID - unused*/, nil /*index - unused*/, mh.BitLen())
   162  	}()
   163  )
   164  
   165  // AdminStorageTester runs a suite of tests against AdminStorage implementations.
   166  type AdminStorageTester struct {
   167  	// NewAdminStorage returns an AdminStorage instance pointing to a clean
   168  	// test database.
   169  	NewAdminStorage func() storage.AdminStorage
   170  }
   171  
   172  // RunAllTests runs all AdminStorage tests.
   173  func (tester *AdminStorageTester) RunAllTests(t *testing.T) {
   174  	t.Run("TestCreateTree", tester.TestCreateTree)
   175  	t.Run("TestUpdateTree", tester.TestUpdateTree)
   176  	t.Run("TestListTrees", tester.TestListTrees)
   177  	t.Run("TestSoftDeleteTree", tester.TestSoftDeleteTree)
   178  	t.Run("TestSoftDeleteTreeErrors", tester.TestSoftDeleteTreeErrors)
   179  	t.Run("TestHardDeleteTree", tester.TestHardDeleteTree)
   180  	t.Run("TestHardDeleteTreeErrors", tester.TestHardDeleteTreeErrors)
   181  	t.Run("TestUndeleteTree", tester.TestUndeleteTree)
   182  	t.Run("TestUndeleteTreeErrors", tester.TestUndeleteTreeErrors)
   183  	t.Run("TestAdminTXReadWriteTransaction", tester.TestAdminTXReadWriteTransaction)
   184  }
   185  
   186  // TestCreateTree tests AdminStorage Tree creation.
   187  func (tester *AdminStorageTester) TestCreateTree(t *testing.T) {
   188  	// Check that validation runs, but leave details to the validation
   189  	// tests.
   190  	invalidTree := *LogTree
   191  	invalidTree.TreeType = trillian.TreeType_UNKNOWN_TREE_TYPE
   192  
   193  	validTree1 := *LogTree
   194  	validTree2 := *MapTree
   195  	validTree3 := *PreorderedLogTree
   196  
   197  	validTreeWithoutOptionals := *LogTree
   198  	validTreeWithoutOptionals.DisplayName = ""
   199  	validTreeWithoutOptionals.Description = ""
   200  
   201  	tests := []struct {
   202  		desc    string
   203  		tree    *trillian.Tree
   204  		wantErr bool
   205  	}{
   206  		{
   207  			desc:    "invalidTree",
   208  			tree:    &invalidTree,
   209  			wantErr: true,
   210  		},
   211  		{
   212  			desc: "validTree1",
   213  			tree: &validTree1,
   214  		},
   215  		{
   216  			desc: "validTree2",
   217  			tree: &validTree2,
   218  		},
   219  		{
   220  			desc: "validTree3",
   221  			tree: &validTree3,
   222  		},
   223  		{
   224  			desc: "validTreeWithoutOptionals",
   225  			tree: &validTreeWithoutOptionals,
   226  		},
   227  	}
   228  
   229  	ctx := context.Background()
   230  	s := tester.NewAdminStorage()
   231  	for _, test := range tests {
   232  		func() {
   233  			// Test CreateTree up to the tx commit
   234  			newTree, err := storage.CreateTree(ctx, s, test.tree)
   235  			if hasErr := err != nil; hasErr != test.wantErr {
   236  				t.Errorf("%v: CreateTree() = (_, %v), wantErr = %v", test.desc, err, test.wantErr)
   237  				return
   238  			} else if hasErr {
   239  				// Tested above
   240  				return
   241  			}
   242  
   243  			createTime := newTree.CreateTime
   244  			updateTime := newTree.UpdateTime
   245  			if _, err := ptypes.Timestamp(createTime); err != nil {
   246  				t.Errorf("%v: CreateTime malformed after creation: %v", test.desc, newTree)
   247  				return
   248  			}
   249  
   250  			switch {
   251  			case newTree.TreeId == 0:
   252  				t.Errorf("%v: TreeID not returned from creation: %v", test.desc, newTree)
   253  				return
   254  			case !reflect.DeepEqual(createTime, updateTime):
   255  				t.Errorf("%v: CreateTime != UpdateTime: %v", test.desc, newTree)
   256  				return
   257  			}
   258  
   259  			wantTree := *test.tree
   260  			wantTree.TreeId = newTree.TreeId
   261  			wantTree.CreateTime = createTime
   262  			wantTree.UpdateTime = updateTime
   263  			// Ignore storage_settings changes (OK to vary between implementations)
   264  			wantTree.StorageSettings = newTree.StorageSettings
   265  			if !proto.Equal(newTree, &wantTree) {
   266  				diff := pretty.Compare(newTree, &wantTree)
   267  				t.Errorf("%v: post-CreateTree diff:\n%v", test.desc, diff)
   268  				return
   269  			}
   270  
   271  			if err := assertStoredTree(ctx, s, newTree); err != nil {
   272  				t.Errorf("%v: %v", test.desc, err)
   273  			}
   274  		}()
   275  	}
   276  }
   277  
   278  // TestUpdateTree tests AdminStorage Tree updates.
   279  func (tester *AdminStorageTester) TestUpdateTree(t *testing.T) {
   280  	ctx := context.Background()
   281  	s := tester.NewAdminStorage()
   282  
   283  	unrelatedTree := makeTreeOrFail(ctx, s, spec{Tree: MapTree}, t.Fatalf)
   284  
   285  	referenceLog := *LogTree
   286  	validLog := referenceLog
   287  	validLog.TreeState = trillian.TreeState_FROZEN
   288  	validLog.DisplayName = "Frozen Tree"
   289  	validLog.Description = "A Frozen Tree"
   290  	validLogFunc := func(tree *trillian.Tree) {
   291  		tree.TreeState = validLog.TreeState
   292  		tree.DisplayName = validLog.DisplayName
   293  		tree.Description = validLog.Description
   294  	}
   295  
   296  	validLogWithoutOptionalsFunc := func(tree *trillian.Tree) {
   297  		tree.DisplayName = ""
   298  		tree.Description = ""
   299  	}
   300  	validLogWithoutOptionals := referenceLog
   301  	validLogWithoutOptionalsFunc(&validLogWithoutOptionals)
   302  
   303  	invalidLogFunc := func(tree *trillian.Tree) {
   304  		tree.TreeState = trillian.TreeState_UNKNOWN_TREE_STATE
   305  	}
   306  
   307  	readonlyChangedFunc := func(tree *trillian.Tree) {
   308  		tree.TreeType = trillian.TreeType_MAP
   309  	}
   310  
   311  	referenceMap := *MapTree
   312  	validMap := referenceMap
   313  	validMap.DisplayName = "Updated Map"
   314  	validMapFunc := func(tree *trillian.Tree) {
   315  		tree.DisplayName = validMap.DisplayName
   316  	}
   317  
   318  	newPrivateKey := &empty.Empty{}
   319  	privateKeyChangedButKeyMaterialSameTree := *LogTree
   320  	privateKeyChangedButKeyMaterialSameTree.PrivateKey = testonly.MustMarshalAny(t, newPrivateKey)
   321  	keys.RegisterHandler(newPrivateKey, func(ctx context.Context, pb proto.Message) (crypto.Signer, error) {
   322  		return pem.UnmarshalPrivateKey(privateKeyPEM, privateKeyPass)
   323  	})
   324  	defer keys.UnregisterHandler(newPrivateKey)
   325  
   326  	privateKeyChangedButKeyMaterialSameFunc := func(tree *trillian.Tree) {
   327  		tree.PrivateKey = privateKeyChangedButKeyMaterialSameTree.PrivateKey
   328  	}
   329  
   330  	privateKeyChangedAndKeyMaterialDifferentFunc := func(tree *trillian.Tree) {
   331  		tree.PrivateKey = testonly.MustMarshalAny(t, &keyspb.PrivateKey{
   332  			Der: ktestonly.MustMarshalPrivatePEMToDER(testonly.DemoPrivateKey, testonly.DemoPrivateKeyPass),
   333  		})
   334  	}
   335  
   336  	// Test for an unknown tree outside the loop: it makes the test logic simpler
   337  	if _, err := storage.UpdateTree(ctx, s, -1, func(tree *trillian.Tree) {}); err == nil {
   338  		t.Error("UpdateTree() for treeID -1 returned nil err")
   339  	}
   340  
   341  	tests := []struct {
   342  		desc         string
   343  		create, want *trillian.Tree
   344  		updateFunc   func(*trillian.Tree)
   345  		wantErr      bool
   346  	}{
   347  		{
   348  			desc:       "validLog",
   349  			create:     &referenceLog,
   350  			updateFunc: validLogFunc,
   351  			want:       &validLog,
   352  		},
   353  		{
   354  			desc:       "validLogWithoutOptionals",
   355  			create:     &referenceLog,
   356  			updateFunc: validLogWithoutOptionalsFunc,
   357  			want:       &validLogWithoutOptionals,
   358  		},
   359  		{
   360  			desc:       "invalidLog",
   361  			create:     &referenceLog,
   362  			updateFunc: invalidLogFunc,
   363  			wantErr:    true,
   364  		},
   365  		{
   366  			desc:       "readonlyChanged",
   367  			create:     &referenceLog,
   368  			updateFunc: readonlyChangedFunc,
   369  			wantErr:    true,
   370  		},
   371  		{
   372  			desc:       "validMap",
   373  			create:     &referenceMap,
   374  			updateFunc: validMapFunc,
   375  			want:       &validMap,
   376  		},
   377  		{
   378  			desc:       "privateKeyChangedButKeyMaterialSame",
   379  			create:     &referenceLog,
   380  			updateFunc: privateKeyChangedButKeyMaterialSameFunc,
   381  			want:       &privateKeyChangedButKeyMaterialSameTree,
   382  		},
   383  		{
   384  			desc:       "privateKeyChangedAndKeyMaterialDifferent",
   385  			create:     &referenceLog,
   386  			updateFunc: privateKeyChangedAndKeyMaterialDifferentFunc,
   387  			wantErr:    true,
   388  		},
   389  	}
   390  	for _, test := range tests {
   391  		createdTree, err := storage.CreateTree(ctx, s, test.create)
   392  		if err != nil {
   393  			t.Errorf("CreateTree() = (_, %v), want = (_, nil)", err)
   394  			continue
   395  		}
   396  
   397  		updatedTree, err := storage.UpdateTree(ctx, s, createdTree.TreeId, test.updateFunc)
   398  		if hasErr := err != nil; hasErr != test.wantErr {
   399  			t.Errorf("%v: UpdateTree() = (_, %v), wantErr = %v", test.desc, err, test.wantErr)
   400  			continue
   401  		} else if hasErr {
   402  			continue
   403  		}
   404  
   405  		if createdTree.TreeId != updatedTree.TreeId {
   406  			t.Errorf("%v: TreeId = %v, want = %v", test.desc, updatedTree.TreeId, createdTree.TreeId)
   407  		}
   408  		if !reflect.DeepEqual(createdTree.CreateTime, updatedTree.CreateTime) {
   409  			t.Errorf("%v: CreateTime = %v, want = %v", test.desc, updatedTree.CreateTime, createdTree.CreateTime)
   410  		}
   411  		createUpdateTime, err := ptypes.Timestamp(createdTree.UpdateTime)
   412  		if err != nil {
   413  			t.Errorf("%v: createdTree.UpdateTime malformed: %v", test.desc, err)
   414  		}
   415  		updatedUpdateTime, err := ptypes.Timestamp(updatedTree.UpdateTime)
   416  		if err != nil {
   417  			t.Errorf("%v: updatedTree.UpdateTime malformed: %v", test.desc, err)
   418  		}
   419  		if createUpdateTime.After(updatedUpdateTime) {
   420  			t.Errorf("%v: UpdateTime = %v, want >= %v", test.desc, updatedTree.UpdateTime, createdTree.UpdateTime)
   421  		}
   422  		// Copy storage-generated values to want before comparing
   423  		wantTree := *test.want
   424  		wantTree.TreeId = updatedTree.TreeId
   425  		wantTree.CreateTime = updatedTree.CreateTime
   426  		wantTree.UpdateTime = updatedTree.UpdateTime
   427  		// Ignore storage_settings changes (OK to vary between implementations)
   428  		wantTree.StorageSettings = updatedTree.StorageSettings
   429  		if !proto.Equal(updatedTree, &wantTree) {
   430  			diff := pretty.Compare(updatedTree, &wantTree)
   431  			t.Errorf("%v: updatedTree doesn't match wantTree:\n%s", test.desc, diff)
   432  		}
   433  
   434  		if err := assertStoredTree(ctx, s, updatedTree); err != nil {
   435  			t.Errorf("%v: %v", test.desc, err)
   436  		}
   437  
   438  		if err := assertStoredTree(ctx, s, unrelatedTree); err != nil {
   439  			t.Errorf("%v: %v", test.desc, err)
   440  		}
   441  	}
   442  }
   443  
   444  // TestListTrees tests both ListTreeIDs and ListTrees.
   445  func (tester *AdminStorageTester) TestListTrees(t *testing.T) {
   446  	ctx := context.Background()
   447  	s := tester.NewAdminStorage()
   448  
   449  	run := func(desc string, includeDeleted bool, wantTrees []*trillian.Tree) {
   450  		if err := storage.RunInAdminSnapshot(ctx, s, func(tx storage.ReadOnlyAdminTX) error {
   451  			if err := runListTreeIDsTest(ctx, tx, includeDeleted, wantTrees); err != nil {
   452  				t.Errorf("%v: %v", desc, err)
   453  			}
   454  			if err := runListTreesTest(ctx, tx, includeDeleted, wantTrees); err != nil {
   455  				t.Errorf("%v: %v", desc, err)
   456  			}
   457  			// Always return nil, as we're reporting errors independently above.
   458  			return nil
   459  		}); err != nil {
   460  			// Capture Begin() / Commit() errors
   461  			t.Errorf("%v: RunInAdminSnapshot() returned err = %v", desc, err)
   462  		}
   463  	}
   464  
   465  	// Do a first pass with an empty DB
   466  	run("empty", false /* includeDeleted */, nil /* wantTrees */)
   467  	run("emptyDeleted", true /* includeDeleted */, nil /* wantTrees */)
   468  
   469  	// Add some trees and do another pass
   470  	activeLog := makeTreeOrFail(ctx, s, spec{Tree: LogTree}, t.Fatalf)
   471  	frozenLog := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Frozen: true}, t.Fatalf)
   472  	deletedLog := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true}, t.Fatalf)
   473  	activeMap := makeTreeOrFail(ctx, s, spec{Tree: MapTree}, t.Fatalf)
   474  	run("multipleTrees", false /* includeDeleted */, []*trillian.Tree{activeLog, frozenLog, activeMap})
   475  	run("multipleTreesDeleted", true /* includeDeleted */, []*trillian.Tree{activeLog, frozenLog, deletedLog, activeMap})
   476  }
   477  
   478  func runListTreeIDsTest(ctx context.Context, tx storage.ReadOnlyAdminTX, includeDeleted bool, wantTrees []*trillian.Tree) error {
   479  	got, err := tx.ListTreeIDs(ctx, includeDeleted)
   480  	if err != nil {
   481  		return fmt.Errorf("ListTreeIDs() returned err = %v", err)
   482  	}
   483  
   484  	want := make([]int64, 0, len(wantTrees))
   485  	for _, tree := range wantTrees {
   486  		want = append(want, tree.TreeId)
   487  	}
   488  
   489  	sort.Slice(got, func(i, j int) bool { return got[i] < got[j] })
   490  	sort.Slice(want, func(i, j int) bool { return want[i] < want[j] })
   491  	if diff := pretty.Compare(got, want); diff != "" {
   492  		return fmt.Errorf("post-ListTreeIDs() diff (-got +want):\n%v", diff)
   493  	}
   494  	return nil
   495  }
   496  
   497  func runListTreesTest(ctx context.Context, tx storage.ReadOnlyAdminTX, includeDeleted bool, wantTrees []*trillian.Tree) error {
   498  	got, err := tx.ListTrees(ctx, includeDeleted)
   499  	if err != nil {
   500  		return fmt.Errorf("ListTrees() returned err = %v", err)
   501  	}
   502  
   503  	if len(got) != len(wantTrees) {
   504  		return fmt.Errorf("ListTrees() returned %v trees, want = %v", len(got), len(wantTrees))
   505  	}
   506  
   507  	want := wantTrees
   508  	sort.Slice(got, func(i, j int) bool { return got[i].TreeId < got[j].TreeId })
   509  	sort.Slice(want, func(i, j int) bool { return want[i].TreeId < want[j].TreeId })
   510  
   511  	for i, wantTree := range want {
   512  		if !proto.Equal(got[i], wantTree) {
   513  			return fmt.Errorf("post-ListTrees() diff (-got +want):\n%v", pretty.Compare(got, want))
   514  		}
   515  	}
   516  	return nil
   517  }
   518  
   519  // TestSoftDeleteTree tests success scenarios of SoftDeleteTree.
   520  func (tester *AdminStorageTester) TestSoftDeleteTree(t *testing.T) {
   521  	ctx := context.Background()
   522  	s := tester.NewAdminStorage()
   523  
   524  	logTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree}, t.Fatalf)
   525  	mapTree := makeTreeOrFail(ctx, s, spec{Tree: MapTree}, t.Fatalf)
   526  
   527  	tests := []struct {
   528  		desc string
   529  		tree *trillian.Tree
   530  	}{
   531  		{desc: "logTree", tree: logTree},
   532  		{desc: "mapTree", tree: mapTree},
   533  	}
   534  	for _, test := range tests {
   535  		deletedTree, err := storage.SoftDeleteTree(ctx, s, test.tree.TreeId)
   536  		if err != nil {
   537  			t.Errorf("%v: SoftDeleteTree() returned err = %v", test.desc, err)
   538  			continue
   539  		}
   540  
   541  		if deletedTree.GetDeleteTime().GetSeconds() == 0 {
   542  			t.Errorf("%v: tree.DeleteTime = %v, want > 0", test.desc, deletedTree.DeleteTime)
   543  		}
   544  
   545  		wantTree := proto.Clone(test.tree).(*trillian.Tree)
   546  		wantTree.Deleted = true
   547  		wantTree.DeleteTime = deletedTree.DeleteTime
   548  		if got, want := deletedTree, wantTree; !proto.Equal(got, want) {
   549  			t.Errorf("%v: post-SoftDeleteTree diff (-got +want):\n%v", test.desc, pretty.Compare(got, want))
   550  		}
   551  
   552  		if err := assertStoredTree(ctx, s, deletedTree); err != nil {
   553  			t.Errorf("%v: %v", test.desc, err)
   554  		}
   555  	}
   556  }
   557  
   558  // TestSoftDeleteTreeErrors tests error scenarios of SoftDeleteTree.
   559  func (tester *AdminStorageTester) TestSoftDeleteTreeErrors(t *testing.T) {
   560  	ctx := context.Background()
   561  	s := tester.NewAdminStorage()
   562  
   563  	softDeleted := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true}, t.Fatalf)
   564  
   565  	tests := []struct {
   566  		desc     string
   567  		treeID   int64
   568  		wantCode codes.Code
   569  	}{
   570  		{desc: "unknownTree", treeID: 12345, wantCode: codes.NotFound},
   571  		{desc: "alreadyDeleted", treeID: softDeleted.TreeId, wantCode: codes.FailedPrecondition},
   572  	}
   573  	for _, test := range tests {
   574  		if _, err := storage.SoftDeleteTree(ctx, s, test.treeID); status.Code(err) != test.wantCode {
   575  			t.Errorf("%v: SoftDeleteTree() returned err = %v, wantCode = %s", test.desc, err, test.wantCode)
   576  		}
   577  	}
   578  }
   579  
   580  // TestHardDeleteTree tests success scenarios of HardDeleteTree.
   581  func (tester *AdminStorageTester) TestHardDeleteTree(t *testing.T) {
   582  	ctx := context.Background()
   583  	s := tester.NewAdminStorage()
   584  
   585  	logTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true}, t.Fatalf)
   586  	frozenTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true, Frozen: true}, t.Fatalf)
   587  	mapTree := makeTreeOrFail(ctx, s, spec{Tree: MapTree, Deleted: true}, t.Fatalf)
   588  
   589  	tests := []struct {
   590  		desc   string
   591  		treeID int64
   592  	}{
   593  		{desc: "logTree", treeID: logTree.TreeId},
   594  		{desc: "frozenTree", treeID: frozenTree.TreeId},
   595  		{desc: "mapTree", treeID: mapTree.TreeId},
   596  	}
   597  	for _, test := range tests {
   598  		if err := storage.HardDeleteTree(ctx, s, test.treeID); err != nil {
   599  			t.Errorf("%v: HardDeleteTree() returned err = %v", test.desc, err)
   600  			continue
   601  		}
   602  	}
   603  }
   604  
   605  // TestHardDeleteTreeErrors tests error scenarios of HardDeleteTree.
   606  func (tester *AdminStorageTester) TestHardDeleteTreeErrors(t *testing.T) {
   607  	ctx := context.Background()
   608  	s := tester.NewAdminStorage()
   609  
   610  	activeTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree}, t.Fatalf)
   611  
   612  	tests := []struct {
   613  		desc     string
   614  		treeID   int64
   615  		wantCode codes.Code
   616  	}{
   617  		{desc: "unknownTree", treeID: 12345, wantCode: codes.NotFound},
   618  		{desc: "activeTree", treeID: activeTree.TreeId, wantCode: codes.FailedPrecondition},
   619  	}
   620  	for _, test := range tests {
   621  		if err := storage.HardDeleteTree(ctx, s, test.treeID); status.Code(err) != test.wantCode {
   622  			t.Errorf("%v: HardDeleteTree() returned err = %v, wantCode = %s", test.desc, err, test.wantCode)
   623  		}
   624  	}
   625  }
   626  
   627  // TestUndeleteTree tests success scenarios of UndeleteTree.
   628  func (tester *AdminStorageTester) TestUndeleteTree(t *testing.T) {
   629  	ctx := context.Background()
   630  	s := tester.NewAdminStorage()
   631  
   632  	activeDeleted := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Deleted: true}, t.Fatalf)
   633  	frozenDeleted := makeTreeOrFail(ctx, s, spec{Tree: LogTree, Frozen: true, Deleted: true}, t.Fatalf)
   634  
   635  	tests := []struct {
   636  		desc string
   637  		tree *trillian.Tree
   638  	}{
   639  		{desc: "activeTree", tree: activeDeleted},
   640  		{desc: "frozenTree", tree: frozenDeleted},
   641  	}
   642  	for _, test := range tests {
   643  		tree, err := storage.UndeleteTree(ctx, s, test.tree.TreeId)
   644  		if err != nil {
   645  			t.Errorf("%v: UndeleteTree() returned err = %v", test.desc, err)
   646  			continue
   647  		}
   648  
   649  		want := proto.Clone(test.tree).(*trillian.Tree)
   650  		want.Deleted = false
   651  		want.DeleteTime = nil
   652  		if got := tree; !proto.Equal(got, want) {
   653  			t.Errorf("%v: post-UndeleteTree diff (-got +want):\n%v", test.desc, pretty.Compare(got, want))
   654  		}
   655  
   656  		if err := assertStoredTree(ctx, s, tree); err != nil {
   657  			t.Errorf("%v: %v", test.desc, err)
   658  		}
   659  	}
   660  }
   661  
   662  // TestUndeleteTreeErrors tests error scenarios of UndeleteTree.
   663  func (tester *AdminStorageTester) TestUndeleteTreeErrors(t *testing.T) {
   664  	ctx := context.Background()
   665  	s := tester.NewAdminStorage()
   666  
   667  	activeTree := makeTreeOrFail(ctx, s, spec{Tree: LogTree}, t.Fatalf)
   668  
   669  	tests := []struct {
   670  		desc     string
   671  		treeID   int64
   672  		wantCode codes.Code
   673  	}{
   674  		{desc: "unknownTree", treeID: 12345, wantCode: codes.NotFound},
   675  		{desc: "activeTree", treeID: activeTree.TreeId, wantCode: codes.FailedPrecondition},
   676  	}
   677  	for _, test := range tests {
   678  		if _, err := storage.UndeleteTree(ctx, s, test.treeID); status.Code(err) != test.wantCode {
   679  			t.Errorf("%v: UndeleteTree() returned err = %v, wantCode = %s", test.desc, err, test.wantCode)
   680  		}
   681  	}
   682  }
   683  
   684  // TestAdminTXReadWriteTransaction tests the ReadWriteTransaction method on AdminStorage.
   685  func (tester *AdminStorageTester) TestAdminTXReadWriteTransaction(t *testing.T) {
   686  	tests := []struct {
   687  		wantCommit bool
   688  	}{
   689  		{wantCommit: true},
   690  		{wantCommit: false},
   691  	}
   692  
   693  	ctx := context.Background()
   694  	s := tester.NewAdminStorage()
   695  
   696  	var tree *trillian.Tree
   697  
   698  	for i, test := range tests {
   699  		t.Run(fmt.Sprintf("%+v", test), func(t *testing.T) {
   700  			err := s.ReadWriteTransaction(ctx, func(ctx context.Context, tx storage.AdminTX) error {
   701  				var err error
   702  				tree, err = tx.CreateTree(ctx, LogTree)
   703  				if err != nil {
   704  					t.Fatalf("%v: CreateTree() = (_, %v), want = (_, nil)", i, err)
   705  				}
   706  				if !test.wantCommit {
   707  					return fmt.Errorf("No commit %d", i)
   708  				}
   709  				return nil
   710  			})
   711  			if (err != nil && test.wantCommit) ||
   712  				(err == nil && !test.wantCommit) {
   713  				t.Fatalf("%v: ReadWriteTransaction() = (_, %v), want = (_, nil)", i, err)
   714  			}
   715  
   716  			tx2, err := s.Snapshot(ctx)
   717  			if err != nil {
   718  				t.Fatalf("%v: Snapshot() = (_, %v), want = (_, nil)", i, err)
   719  			}
   720  			defer tx2.Close()
   721  			_, err = tx2.GetTree(ctx, tree.TreeId)
   722  			if hasErr := err != nil; !test.wantCommit != hasErr {
   723  				t.Errorf("%v: GetTree() = (_, %v), but wantCommit = %v", i, err, test.wantCommit)
   724  			}
   725  
   726  			// Multiple Close() calls are fine too
   727  			if err := tx2.Close(); err != nil {
   728  				t.Errorf("%v: Close() = %v, want = nil", i, err)
   729  				return
   730  			}
   731  		})
   732  	}
   733  }
   734  
   735  // assertStoredTree verifies that "want" is equal to the tree stored under its ID.
   736  func assertStoredTree(ctx context.Context, s storage.AdminStorage, want *trillian.Tree) error {
   737  	got, err := storage.GetTree(ctx, s, want.TreeId)
   738  	if err != nil {
   739  		return fmt.Errorf("GetTree() returned err = %v", err)
   740  	}
   741  	if !proto.Equal(got, want) {
   742  		return fmt.Errorf("post-GetTree() diff (-got +want):\n%v", pretty.Compare(got, want))
   743  	}
   744  	return nil
   745  }
   746  
   747  type spec struct {
   748  	Tree            *trillian.Tree
   749  	Frozen, Deleted bool
   750  }
   751  
   752  // makeTreeOrFail delegates to makeTree. If makeTree returns a non-nil error, failFn is called.
   753  func makeTreeOrFail(ctx context.Context, s storage.AdminStorage, spec spec, failFn func(string, ...interface{})) *trillian.Tree {
   754  	tree, err := makeTree(ctx, s, spec)
   755  	if err != nil {
   756  		failFn("makeTree() returned err = %v", err)
   757  		return nil
   758  	}
   759  	return tree
   760  }
   761  
   762  // makeTree creates a tree and updates it to Frozen and/or Deleted, according to "spec".
   763  func makeTree(ctx context.Context, s storage.AdminStorage, spec spec) (*trillian.Tree, error) {
   764  	tree := proto.Clone(spec.Tree).(*trillian.Tree)
   765  
   766  	var err error
   767  	tree, err = storage.CreateTree(ctx, s, tree)
   768  	if err != nil {
   769  		return nil, err
   770  	}
   771  
   772  	if spec.Frozen {
   773  		tree, err = storage.UpdateTree(ctx, s, tree.TreeId, func(t *trillian.Tree) {
   774  			t.TreeState = trillian.TreeState_FROZEN
   775  		})
   776  		if err != nil {
   777  			return nil, err
   778  		}
   779  	}
   780  
   781  	if spec.Deleted {
   782  		tree, err = storage.SoftDeleteTree(ctx, s, tree.TreeId)
   783  		if err != nil {
   784  			return nil, err
   785  		}
   786  	}
   787  
   788  	// Sanity checks
   789  	if spec.Frozen && tree.TreeState != trillian.TreeState_FROZEN {
   790  		return nil, fmt.Errorf("makeTree(): TreeState = %s, want = %s", tree.TreeState, trillian.TreeState_FROZEN)
   791  	}
   792  	if tree.Deleted != spec.Deleted {
   793  		return nil, fmt.Errorf("makeTree(): Deleted = %v, want = %v", tree.Deleted, spec.Deleted)
   794  	}
   795  
   796  	return tree, nil
   797  }