github.com/bartle-stripe/trillian@v1.2.1/storage/mysql/log_storage_test.go (about)

     1  // Copyright 2016 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto"
    21  	"crypto/sha256"
    22  	"database/sql"
    23  	"fmt"
    24  	"reflect"
    25  	"sort"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/golang/protobuf/proto"
    30  	"github.com/golang/protobuf/ptypes"
    31  	"github.com/google/trillian"
    32  	"github.com/google/trillian/storage"
    33  	"github.com/google/trillian/storage/testonly"
    34  	"github.com/google/trillian/types"
    35  	"github.com/kylelemons/godebug/pretty"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/status"
    38  
    39  	tcrypto "github.com/google/trillian/crypto"
    40  	ttestonly "github.com/google/trillian/testonly"
    41  
    42  	_ "github.com/go-sql-driver/mysql"
    43  )
    44  
    45  var allTables = []string{"Unsequenced", "TreeHead", "SequencedLeafData", "LeafData", "Subtree", "TreeControl", "Trees", "MapLeaf", "MapHead"}
    46  
    47  // Must be 32 bytes to match sha256 length if it was a real hash
    48  var dummyHash = []byte("hashxxxxhashxxxxhashxxxxhashxxxx")
    49  var dummyRawHash = []byte("xxxxhashxxxxhashxxxxhashxxxxhash")
    50  var dummyRawHash2 = []byte("yyyyhashyyyyhashyyyyhashyyyyhash")
    51  var dummyHash2 = []byte("HASHxxxxhashxxxxhashxxxxhashxxxx")
    52  var dummyHash3 = []byte("hashxxxxhashxxxxhashxxxxHASHxxxx")
    53  
    54  // Time we will queue all leaves at
    55  var fakeQueueTime = time.Date(2016, 11, 10, 15, 16, 27, 0, time.UTC)
    56  
    57  // Time we will integrate all leaves at
    58  var fakeIntegrateTime = time.Date(2016, 11, 10, 15, 16, 30, 0, time.UTC)
    59  
    60  // Time we'll request for guard cutoff in tests that don't test this (should include all above)
    61  var fakeDequeueCutoffTime = time.Date(2016, 11, 10, 15, 16, 30, 0, time.UTC)
    62  
    63  // Used for tests involving extra data
    64  var someExtraData = []byte("Some extra data")
    65  var someExtraData2 = []byte("Some even more extra data")
    66  
    67  const leavesToInsert = 5
    68  const sequenceNumber int64 = 237
    69  
    70  // Tests that access the db should each use a distinct log ID to prevent lock contention when
    71  // run in parallel or race conditions / unexpected interactions. Tests that pass should hold
    72  // no locks afterwards.
    73  
    74  func createFakeLeaf(ctx context.Context, db *sql.DB, logID int64, rawHash, hash, data, extraData []byte, seq int64, t *testing.T) *trillian.LogLeaf {
    75  	t.Helper()
    76  	queuedAtNanos := fakeQueueTime.UnixNano()
    77  	integratedAtNanos := fakeIntegrateTime.UnixNano()
    78  	_, err := db.ExecContext(ctx, "INSERT INTO LeafData(TreeId, LeafIdentityHash, LeafValue, ExtraData, QueueTimestampNanos) VALUES(?,?,?,?,?)", logID, rawHash, data, extraData, queuedAtNanos)
    79  	_, err2 := db.ExecContext(ctx, "INSERT INTO SequencedLeafData(TreeId, SequenceNumber, LeafIdentityHash, MerkleLeafHash, IntegrateTimestampNanos) VALUES(?,?,?,?,?)", logID, seq, rawHash, hash, integratedAtNanos)
    80  
    81  	if err != nil || err2 != nil {
    82  		t.Fatalf("Failed to create test leaves: %v %v", err, err2)
    83  	}
    84  	queueTimestamp, err := ptypes.TimestampProto(fakeQueueTime)
    85  	if err != nil {
    86  		panic(err)
    87  	}
    88  	integrateTimestamp, err := ptypes.TimestampProto(fakeIntegrateTime)
    89  	if err != nil {
    90  		panic(err)
    91  	}
    92  	return &trillian.LogLeaf{
    93  		MerkleLeafHash:     hash,
    94  		LeafValue:          data,
    95  		ExtraData:          extraData,
    96  		LeafIndex:          seq,
    97  		LeafIdentityHash:   rawHash,
    98  		QueueTimestamp:     queueTimestamp,
    99  		IntegrateTimestamp: integrateTimestamp,
   100  	}
   101  }
   102  
   103  func checkLeafContents(leaf *trillian.LogLeaf, seq int64, rawHash, hash, data, extraData []byte, t *testing.T) {
   104  	t.Helper()
   105  	if got, want := leaf.MerkleLeafHash, hash; !bytes.Equal(got, want) {
   106  		t.Fatalf("Wrong leaf hash in returned leaf got\n%v\nwant:\n%v", got, want)
   107  	}
   108  
   109  	if got, want := leaf.LeafIdentityHash, rawHash; !bytes.Equal(got, want) {
   110  		t.Fatalf("Wrong raw leaf hash in returned leaf got\n%v\nwant:\n%v", got, want)
   111  	}
   112  
   113  	if got, want := seq, leaf.LeafIndex; got != want {
   114  		t.Fatalf("Bad sequence number in returned leaf got: %d, want:%d", got, want)
   115  	}
   116  
   117  	if got, want := leaf.LeafValue, data; !bytes.Equal(got, want) {
   118  		t.Fatalf("Unxpected data in returned leaf. got:\n%v\nwant:\n%v", got, want)
   119  	}
   120  
   121  	if got, want := leaf.ExtraData, extraData; !bytes.Equal(got, want) {
   122  		t.Fatalf("Unxpected data in returned leaf. got:\n%v\nwant:\n%v", got, want)
   123  	}
   124  
   125  	iTime, err := ptypes.Timestamp(leaf.IntegrateTimestamp)
   126  	if err != nil {
   127  		t.Fatalf("Got invalid integrate timestamp: %v", err)
   128  	}
   129  	if got, want := iTime.UnixNano(), fakeIntegrateTime.UnixNano(); got != want {
   130  		t.Errorf("Wrong IntegrateTimestamp: got %v, want %v", got, want)
   131  	}
   132  }
   133  
   134  func TestMySQLLogStorage_CheckDatabaseAccessible(t *testing.T) {
   135  	cleanTestDB(DB)
   136  	s := NewLogStorage(DB, nil)
   137  	if err := s.CheckDatabaseAccessible(context.Background()); err != nil {
   138  		t.Errorf("CheckDatabaseAccessible() = %v, want = nil", err)
   139  	}
   140  }
   141  
   142  func TestSnapshot(t *testing.T) {
   143  	cleanTestDB(DB)
   144  
   145  	frozenLog := createTreeOrPanic(DB, testonly.LogTree)
   146  	createFakeSignedLogRoot(DB, frozenLog, 0)
   147  	if _, err := updateTree(DB, frozenLog.TreeId, func(tree *trillian.Tree) {
   148  		tree.TreeState = trillian.TreeState_FROZEN
   149  	}); err != nil {
   150  		t.Fatalf("Error updating frozen tree: %v", err)
   151  	}
   152  
   153  	activeLog := createTreeOrPanic(DB, testonly.LogTree)
   154  	createFakeSignedLogRoot(DB, activeLog, 0)
   155  	mapTreeID := createTreeOrPanic(DB, testonly.MapTree).TreeId
   156  
   157  	tests := []struct {
   158  		desc    string
   159  		tree    *trillian.Tree
   160  		wantErr bool
   161  	}{
   162  		{
   163  			desc:    "unknownSnapshot",
   164  			tree:    logTree(-1),
   165  			wantErr: true,
   166  		},
   167  		{
   168  			desc: "activeLogSnapshot",
   169  			tree: activeLog,
   170  		},
   171  		{
   172  			desc: "frozenSnapshot",
   173  			tree: frozenLog,
   174  		},
   175  		{
   176  			desc:    "mapSnapshot",
   177  			tree:    logTree(mapTreeID),
   178  			wantErr: true,
   179  		},
   180  	}
   181  
   182  	ctx := context.Background()
   183  	s := NewLogStorage(DB, nil)
   184  	for _, test := range tests {
   185  		t.Run(test.desc, func(t *testing.T) {
   186  			tx, err := s.SnapshotForTree(ctx, test.tree)
   187  
   188  			if err == storage.ErrTreeNeedsInit {
   189  				defer tx.Close()
   190  			}
   191  
   192  			if hasErr := err != nil; hasErr != test.wantErr {
   193  				t.Fatalf("err = %q, wantErr = %v", err, test.wantErr)
   194  			} else if hasErr {
   195  				return
   196  			}
   197  			defer tx.Close()
   198  
   199  			_, err = tx.LatestSignedLogRoot(ctx)
   200  			if err != nil {
   201  				t.Errorf("LatestSignedLogRoot() returned err = %v", err)
   202  			}
   203  			if err := tx.Commit(); err != nil {
   204  				t.Errorf("Commit() returned err = %v", err)
   205  			}
   206  		})
   207  	}
   208  }
   209  
   210  func TestReadWriteTransaction(t *testing.T) {
   211  	cleanTestDB(DB)
   212  	activeLog := createTreeOrPanic(DB, testonly.LogTree)
   213  	createFakeSignedLogRoot(DB, activeLog, 0)
   214  
   215  	tests := []struct {
   216  		desc        string
   217  		tree        *trillian.Tree
   218  		wantErr     bool
   219  		wantLogRoot []byte
   220  		wantTXRev   int64
   221  	}{
   222  		{
   223  			// Unknown logs IDs are now handled outside storage.
   224  			desc:        "unknownBegin",
   225  			tree:        logTree(-1),
   226  			wantLogRoot: nil,
   227  			wantTXRev:   -1,
   228  		},
   229  		{
   230  			desc: "activeLogBegin",
   231  			tree: activeLog,
   232  			wantLogRoot: func() []byte {
   233  				b, err := (&types.LogRootV1{RootHash: []byte{0}}).MarshalBinary()
   234  				if err != nil {
   235  					panic(err)
   236  				}
   237  				return b
   238  			}(),
   239  			wantTXRev: 1,
   240  		},
   241  	}
   242  
   243  	ctx := context.Background()
   244  	s := NewLogStorage(DB, nil)
   245  	for _, test := range tests {
   246  		t.Run(test.desc, func(t *testing.T) {
   247  			err := s.ReadWriteTransaction(ctx, test.tree, func(ctx context.Context, tx storage.LogTreeTX) error {
   248  				root, err := tx.LatestSignedLogRoot(ctx)
   249  				if err != nil {
   250  					t.Fatalf("%v: LatestSignedLogRoot() returned err = %v", test.desc, err)
   251  				}
   252  				if got, want := tx.WriteRevision(), test.wantTXRev; got != want {
   253  					t.Errorf("%v: WriteRevision() = %v, want = %v", test.desc, got, want)
   254  				}
   255  				if got, want := root.LogRoot, test.wantLogRoot; !bytes.Equal(got, want) {
   256  					t.Errorf("%v: LogRoot: \n%x, want \n%x", test.desc, got, want)
   257  				}
   258  				return nil
   259  			})
   260  			if hasErr := err != nil; hasErr != test.wantErr {
   261  				t.Fatalf("%v: err = %q, wantErr = %v", test.desc, err, test.wantErr)
   262  			} else if hasErr {
   263  				return
   264  			}
   265  		})
   266  	}
   267  }
   268  
   269  func TestQueueDuplicateLeaf(t *testing.T) {
   270  	cleanTestDB(DB)
   271  	tree := createTreeOrPanic(DB, testonly.LogTree)
   272  	s := NewLogStorage(DB, nil)
   273  	count := 15
   274  	leaves := createTestLeaves(int64(count), 10)
   275  	leaves2 := createTestLeaves(int64(count), 12)
   276  	leaves3 := createTestLeaves(3, 100)
   277  
   278  	// Note that tests accumulate queued leaves on top of each other.
   279  	var tests = []struct {
   280  		desc   string
   281  		leaves []*trillian.LogLeaf
   282  		want   []*trillian.LogLeaf
   283  	}{
   284  		{
   285  			desc:   "[10, 11, 12, ...]",
   286  			leaves: leaves,
   287  			want:   make([]*trillian.LogLeaf, count),
   288  		},
   289  		{
   290  			desc:   "[12, 13, 14, ...] so first (count-2) are duplicates",
   291  			leaves: leaves2,
   292  			want:   append(leaves[2:], nil, nil),
   293  		},
   294  		{
   295  			desc:   "[10, 100, 11, 101, 102] so [dup, new, dup, new, dup]",
   296  			leaves: []*trillian.LogLeaf{leaves[0], leaves3[0], leaves[1], leaves3[1], leaves[2]},
   297  			want:   []*trillian.LogLeaf{leaves[0], nil, leaves[1], nil, leaves[2]},
   298  		},
   299  	}
   300  
   301  	for _, test := range tests {
   302  		t.Run(test.desc, func(t *testing.T) {
   303  			runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   304  				existing, err := tx.QueueLeaves(ctx, test.leaves, fakeQueueTime)
   305  				if err != nil {
   306  					t.Errorf("Failed to queue leaves: %v", err)
   307  					return err
   308  				}
   309  
   310  				if len(existing) != len(test.want) {
   311  					t.Fatalf("|QueueLeaves()|=%d; want %d", len(existing), len(test.want))
   312  				}
   313  				for i, want := range test.want {
   314  					got := existing[i]
   315  					if want == nil {
   316  						if got != nil {
   317  							t.Fatalf("QueueLeaves()[%d]=%v; want nil", i, got)
   318  						}
   319  						return nil
   320  					}
   321  					if got == nil {
   322  						t.Fatalf("QueueLeaves()[%d]=nil; want non-nil", i)
   323  					} else if !bytes.Equal(got.LeafIdentityHash, want.LeafIdentityHash) {
   324  						t.Fatalf("QueueLeaves()[%d].LeafIdentityHash=%x; want %x", i, got.LeafIdentityHash, want.LeafIdentityHash)
   325  					}
   326  				}
   327  				return nil
   328  			})
   329  		})
   330  	}
   331  }
   332  
   333  func TestQueueLeaves(t *testing.T) {
   334  	ctx := context.Background()
   335  
   336  	cleanTestDB(DB)
   337  	tree := createTreeOrPanic(DB, testonly.LogTree)
   338  	s := NewLogStorage(DB, nil)
   339  
   340  	runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   341  		leaves := createTestLeaves(leavesToInsert, 20)
   342  		if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil {
   343  			t.Fatalf("Failed to queue leaves: %v", err)
   344  		}
   345  		return nil
   346  	})
   347  
   348  	// Should see the leaves in the database. There is no API to read from the unsequenced data.
   349  	var count int
   350  	if err := DB.QueryRowContext(ctx, "SELECT COUNT(*) FROM Unsequenced WHERE TreeID=?", tree.TreeId).Scan(&count); err != nil {
   351  		t.Fatalf("Could not query row count: %v", err)
   352  	}
   353  	if leavesToInsert != count {
   354  		t.Fatalf("Expected %d unsequenced rows but got: %d", leavesToInsert, count)
   355  	}
   356  
   357  	// Additional check on timestamp being set correctly in the database
   358  	var queueTimestamp int64
   359  	if err := DB.QueryRowContext(ctx, "SELECT DISTINCT QueueTimestampNanos FROM Unsequenced WHERE TreeID=?", tree.TreeId).Scan(&queueTimestamp); err != nil {
   360  		t.Fatalf("Could not query timestamp: %v", err)
   361  	}
   362  	if got, want := queueTimestamp, fakeQueueTime.UnixNano(); got != want {
   363  		t.Fatalf("Incorrect queue timestamp got: %d want: %d", got, want)
   364  	}
   365  }
   366  
   367  // AddSequencedLeaves tests. ---------------------------------------------------
   368  
   369  type addSequencedLeavesTest struct {
   370  	t    *testing.T
   371  	s    storage.LogStorage
   372  	tree *trillian.Tree
   373  }
   374  
   375  func initAddSequencedLeavesTest(t *testing.T) addSequencedLeavesTest {
   376  	cleanTestDB(DB)
   377  	s := NewLogStorage(DB, nil)
   378  	tree := createTreeOrPanic(DB, testonly.PreorderedLogTree)
   379  	return addSequencedLeavesTest{t, s, tree}
   380  }
   381  
   382  func (t *addSequencedLeavesTest) addSequencedLeaves(leaves []*trillian.LogLeaf) {
   383  	runLogTX(t.s, t.tree, t.t, func(ctx context.Context, tx storage.LogTreeTX) error {
   384  		if _, err := tx.AddSequencedLeaves(ctx, leaves, fakeQueueTime); err != nil {
   385  			t.t.Fatalf("Failed to add sequenced leaves: %v", err)
   386  		}
   387  		// TODO(pavelkalinnikov): Verify returned status for each leaf.
   388  		return nil
   389  	})
   390  }
   391  
   392  func (t *addSequencedLeavesTest) verifySequencedLeaves(start, count int64, exp []*trillian.LogLeaf) {
   393  	var stored []*trillian.LogLeaf
   394  	runLogTX(t.s, t.tree, t.t, func(ctx context.Context, tx storage.LogTreeTX) error {
   395  		var err error
   396  		stored, err = tx.GetLeavesByRange(ctx, start, count)
   397  		if err != nil {
   398  			t.t.Fatalf("Failed to read sequenced leaves: %v", err)
   399  		}
   400  		return nil
   401  	})
   402  	if got, want := len(stored), len(exp); got != want {
   403  		t.t.Fatalf("Unexpected number of leaves: got %d, want %d", got, want)
   404  	}
   405  
   406  	for i, leaf := range stored {
   407  		if got, want := leaf.LeafIndex, exp[i].LeafIndex; got != want {
   408  			t.t.Fatalf("Leaf #%d: LeafIndex=%v, want %v", i, got, want)
   409  		}
   410  		if got, want := leaf.LeafIdentityHash, exp[i].LeafIdentityHash; !bytes.Equal(got, want) {
   411  			t.t.Fatalf("Leaf #%d: LeafIdentityHash=%v, want %v", i, got, want)
   412  		}
   413  	}
   414  }
   415  
   416  func TestAddSequencedLeavesUnordered(t *testing.T) {
   417  	const chunk = leavesToInsert
   418  	const count = chunk * 5
   419  	const extraCount = 16
   420  	leaves := createTestLeaves(count, 0)
   421  
   422  	aslt := initAddSequencedLeavesTest(t)
   423  	for _, idx := range []int{1, 0, 4, 2} {
   424  		aslt.addSequencedLeaves(leaves[chunk*idx : chunk*(idx+1)])
   425  	}
   426  	aslt.verifySequencedLeaves(0, count+extraCount, leaves[:chunk*3])
   427  	aslt.verifySequencedLeaves(chunk*4, chunk+extraCount, leaves[chunk*4:count])
   428  	aslt.addSequencedLeaves(leaves[chunk*3 : chunk*4])
   429  	aslt.verifySequencedLeaves(0, count+extraCount, leaves)
   430  }
   431  
   432  func TestAddSequencedLeavesWithDuplicates(t *testing.T) {
   433  	leaves := createTestLeaves(6, 0)
   434  
   435  	aslt := initAddSequencedLeavesTest(t)
   436  	aslt.addSequencedLeaves(leaves[:3])
   437  	aslt.verifySequencedLeaves(0, 3, leaves[:3])
   438  	aslt.addSequencedLeaves(leaves[2:]) // Full dup.
   439  	aslt.verifySequencedLeaves(0, 6, leaves)
   440  
   441  	dupLeaves := createTestLeaves(4, 6)
   442  	dupLeaves[0].LeafIdentityHash = leaves[0].LeafIdentityHash // Hash dup.
   443  	dupLeaves[2].LeafIndex = 2                                 // Index dup.
   444  	aslt.addSequencedLeaves(dupLeaves)
   445  	aslt.verifySequencedLeaves(6, 4, nil)
   446  	aslt.verifySequencedLeaves(7, 4, dupLeaves[1:2])
   447  	aslt.verifySequencedLeaves(8, 4, nil)
   448  	aslt.verifySequencedLeaves(9, 4, dupLeaves[3:4])
   449  
   450  	dupLeaves = createTestLeaves(4, 6)
   451  	aslt.addSequencedLeaves(dupLeaves)
   452  	aslt.verifySequencedLeaves(6, 4, dupLeaves)
   453  }
   454  
   455  // -----------------------------------------------------------------------------
   456  
   457  func TestDequeueLeavesNoneQueued(t *testing.T) {
   458  	cleanTestDB(DB)
   459  	tree := createTreeOrPanic(DB, testonly.LogTree)
   460  	s := NewLogStorage(DB, nil)
   461  
   462  	runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   463  		leaves, err := tx.DequeueLeaves(ctx, 999, fakeDequeueCutoffTime)
   464  		if err != nil {
   465  			t.Fatalf("Didn't expect an error on dequeue with no work to be done: %v", err)
   466  		}
   467  		if len(leaves) > 0 {
   468  			t.Fatalf("Expected nothing to be dequeued but we got %d leaves", len(leaves))
   469  		}
   470  		return nil
   471  	})
   472  }
   473  
   474  func TestDequeueLeaves(t *testing.T) {
   475  	cleanTestDB(DB)
   476  	tree := createTreeOrPanic(DB, testonly.LogTree)
   477  	s := NewLogStorage(DB, nil)
   478  
   479  	{
   480  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   481  			leaves := createTestLeaves(leavesToInsert, 20)
   482  			if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil {
   483  				t.Fatalf("Failed to queue leaves: %v", err)
   484  			}
   485  			return nil
   486  		})
   487  	}
   488  
   489  	{
   490  		// Now try to dequeue them
   491  		runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
   492  			leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime)
   493  			if err != nil {
   494  				t.Fatalf("Failed to dequeue leaves: %v", err)
   495  			}
   496  			if len(leaves2) != leavesToInsert {
   497  				t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert)
   498  			}
   499  			ensureAllLeavesDistinct(leaves2, t)
   500  			return nil
   501  		})
   502  	}
   503  
   504  	{
   505  		// If we dequeue again then we should now get nothing
   506  		runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error {
   507  			leaves3, err := tx3.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime)
   508  			if err != nil {
   509  				t.Fatalf("Failed to dequeue leaves (second time): %v", err)
   510  			}
   511  			if len(leaves3) != 0 {
   512  				t.Fatalf("Dequeued %d leaves but expected to get none", len(leaves3))
   513  			}
   514  			return nil
   515  		})
   516  	}
   517  }
   518  
   519  func TestDequeueLeavesHaveQueueTimestamp(t *testing.T) {
   520  	cleanTestDB(DB)
   521  	tree := createTreeOrPanic(DB, testonly.LogTree)
   522  	s := NewLogStorage(DB, nil)
   523  
   524  	{
   525  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   526  			leaves := createTestLeaves(leavesToInsert, 20)
   527  			if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil {
   528  				t.Fatalf("Failed to queue leaves: %v", err)
   529  			}
   530  			return nil
   531  		})
   532  	}
   533  
   534  	{
   535  		// Now try to dequeue them
   536  		runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
   537  			leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime)
   538  			if err != nil {
   539  				t.Fatalf("Failed to dequeue leaves: %v", err)
   540  			}
   541  			if len(leaves2) != leavesToInsert {
   542  				t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert)
   543  			}
   544  			ensureLeavesHaveQueueTimestamp(t, leaves2, fakeDequeueCutoffTime)
   545  			return nil
   546  		})
   547  	}
   548  }
   549  
   550  func TestDequeueLeavesTwoBatches(t *testing.T) {
   551  	cleanTestDB(DB)
   552  	tree := createTreeOrPanic(DB, testonly.LogTree)
   553  	s := NewLogStorage(DB, nil)
   554  
   555  	leavesToDequeue1 := 3
   556  	leavesToDequeue2 := 2
   557  
   558  	{
   559  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   560  			leaves := createTestLeaves(leavesToInsert, 20)
   561  			if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil {
   562  				t.Fatalf("Failed to queue leaves: %v", err)
   563  			}
   564  			return nil
   565  		})
   566  	}
   567  
   568  	var err error
   569  	var leaves2, leaves3, leaves4 []*trillian.LogLeaf
   570  	{
   571  		// Now try to dequeue some of them
   572  		runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
   573  			leaves2, err = tx2.DequeueLeaves(ctx, leavesToDequeue1, fakeDequeueCutoffTime)
   574  			if err != nil {
   575  				t.Fatalf("Failed to dequeue leaves: %v", err)
   576  			}
   577  			if len(leaves2) != leavesToDequeue1 {
   578  				t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert)
   579  			}
   580  			ensureAllLeavesDistinct(leaves2, t)
   581  			ensureLeavesHaveQueueTimestamp(t, leaves2, fakeDequeueCutoffTime)
   582  			return nil
   583  		})
   584  
   585  		// Now try to dequeue the rest of them
   586  		runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error {
   587  			leaves3, err = tx3.DequeueLeaves(ctx, leavesToDequeue2, fakeDequeueCutoffTime)
   588  			if err != nil {
   589  				t.Fatalf("Failed to dequeue leaves: %v", err)
   590  			}
   591  			if len(leaves3) != leavesToDequeue2 {
   592  				t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves3), leavesToDequeue2)
   593  			}
   594  			ensureAllLeavesDistinct(leaves3, t)
   595  			ensureLeavesHaveQueueTimestamp(t, leaves3, fakeDequeueCutoffTime)
   596  
   597  			// Plus the union of the leaf batches should all have distinct hashes
   598  			leaves4 = append(leaves2, leaves3...)
   599  			ensureAllLeavesDistinct(leaves4, t)
   600  			return nil
   601  		})
   602  	}
   603  
   604  	{
   605  		// If we dequeue again then we should now get nothing
   606  		runLogTX(s, tree, t, func(ctx context.Context, tx4 storage.LogTreeTX) error {
   607  			leaves5, err := tx4.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime)
   608  			if err != nil {
   609  				t.Fatalf("Failed to dequeue leaves (second time): %v", err)
   610  			}
   611  			if len(leaves5) != 0 {
   612  				t.Fatalf("Dequeued %d leaves but expected to get none", len(leaves5))
   613  			}
   614  			return nil
   615  		})
   616  	}
   617  }
   618  
   619  // Queues leaves and attempts to dequeue before the guard cutoff allows it. This should
   620  // return nothing. Then retry with an inclusive guard cutoff and ensure the leaves
   621  // are returned.
   622  func TestDequeueLeavesGuardInterval(t *testing.T) {
   623  	cleanTestDB(DB)
   624  	tree := createTreeOrPanic(DB, testonly.LogTree)
   625  	s := NewLogStorage(DB, nil)
   626  
   627  	{
   628  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   629  			leaves := createTestLeaves(leavesToInsert, 20)
   630  			if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil {
   631  				t.Fatalf("Failed to queue leaves: %v", err)
   632  			}
   633  			return nil
   634  		})
   635  	}
   636  
   637  	{
   638  		// Now try to dequeue them using a cutoff that means we should get none
   639  		runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
   640  			leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeQueueTime.Add(-time.Second))
   641  			if err != nil {
   642  				t.Fatalf("Failed to dequeue leaves: %v", err)
   643  			}
   644  			if len(leaves2) != 0 {
   645  				t.Fatalf("Dequeued %d leaves when they all should be in guard interval", len(leaves2))
   646  			}
   647  
   648  			// Try to dequeue again using a cutoff that should include them
   649  			leaves2, err = tx2.DequeueLeaves(ctx, 99, fakeQueueTime.Add(time.Second))
   650  			if err != nil {
   651  				t.Fatalf("Failed to dequeue leaves: %v", err)
   652  			}
   653  			if len(leaves2) != leavesToInsert {
   654  				t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert)
   655  			}
   656  			ensureAllLeavesDistinct(leaves2, t)
   657  			return nil
   658  		})
   659  	}
   660  }
   661  
   662  func TestDequeueLeavesTimeOrdering(t *testing.T) {
   663  	// Queue two small batches of leaves at different timestamps. Do two separate dequeue
   664  	// transactions and make sure the returned leaves are respecting the time ordering of the
   665  	// queue.
   666  	cleanTestDB(DB)
   667  	tree := createTreeOrPanic(DB, testonly.LogTree)
   668  	s := NewLogStorage(DB, nil)
   669  
   670  	batchSize := 2
   671  	leaves := createTestLeaves(int64(batchSize), 0)
   672  	leaves2 := createTestLeaves(int64(batchSize), int64(batchSize))
   673  
   674  	{
   675  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   676  			if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil {
   677  				t.Fatalf("QueueLeaves(1st batch) = %v", err)
   678  			}
   679  			// These are one second earlier so should be dequeued first
   680  			if _, err := tx.QueueLeaves(ctx, leaves2, fakeQueueTime.Add(-time.Second)); err != nil {
   681  				t.Fatalf("QueueLeaves(2nd batch) = %v", err)
   682  			}
   683  			return nil
   684  		})
   685  	}
   686  
   687  	{
   688  		// Now try to dequeue two leaves and we should get the second batch
   689  		runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
   690  			dequeue1, err := tx2.DequeueLeaves(ctx, batchSize, fakeQueueTime)
   691  			if err != nil {
   692  				t.Fatalf("DequeueLeaves(1st) = %v", err)
   693  			}
   694  			if got, want := len(dequeue1), batchSize; got != want {
   695  				t.Fatalf("Dequeue count mismatch (1st) got: %d, want: %d", got, want)
   696  			}
   697  			ensureAllLeavesDistinct(dequeue1, t)
   698  
   699  			// Ensure this is the second batch queued by comparing leaf hashes (must be distinct as
   700  			// the leaf data was).
   701  			if !leafInBatch(dequeue1[0], leaves2) || !leafInBatch(dequeue1[1], leaves2) {
   702  				t.Fatalf("Got leaf from wrong batch (1st dequeue): %v", dequeue1)
   703  			}
   704  			return nil
   705  		})
   706  
   707  		// Try to dequeue again and we should get the batch that was queued first, though at a later time
   708  		runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error {
   709  			dequeue2, err := tx3.DequeueLeaves(ctx, batchSize, fakeQueueTime)
   710  			if err != nil {
   711  				t.Fatalf("DequeueLeaves(2nd) = %v", err)
   712  			}
   713  			if got, want := len(dequeue2), batchSize; got != want {
   714  				t.Fatalf("Dequeue count mismatch (2nd) got: %d, want: %d", got, want)
   715  			}
   716  			ensureAllLeavesDistinct(dequeue2, t)
   717  
   718  			// Ensure this is the first batch by comparing leaf hashes.
   719  			if !leafInBatch(dequeue2[0], leaves) || !leafInBatch(dequeue2[1], leaves) {
   720  				t.Fatalf("Got leaf from wrong batch (2nd dequeue): %v", dequeue2)
   721  			}
   722  			return nil
   723  		})
   724  	}
   725  }
   726  
   727  func TestGetLeavesByHashNotPresent(t *testing.T) {
   728  	cleanTestDB(DB)
   729  	tree := createTreeOrPanic(DB, testonly.LogTree)
   730  	s := NewLogStorage(DB, nil)
   731  
   732  	runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   733  		hashes := [][]byte{[]byte("thisdoesn'texist")}
   734  		leaves, err := tx.GetLeavesByHash(ctx, hashes, false)
   735  		if err != nil {
   736  			t.Fatalf("Error getting leaves by hash: %v", err)
   737  		}
   738  		if len(leaves) != 0 {
   739  			t.Fatalf("Expected no leaves returned but got %d", len(leaves))
   740  		}
   741  		return nil
   742  	})
   743  }
   744  
   745  func TestGetLeavesByHash(t *testing.T) {
   746  	ctx := context.Background()
   747  
   748  	// Create fake leaf as if it had been sequenced
   749  	cleanTestDB(DB)
   750  	tree := createTreeOrPanic(DB, testonly.LogTree)
   751  	s := NewLogStorage(DB, nil)
   752  
   753  	data := []byte("some data")
   754  	createFakeLeaf(ctx, DB, tree.TreeId, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t)
   755  
   756  	runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   757  		hashes := [][]byte{dummyHash}
   758  		leaves, err := tx.GetLeavesByHash(ctx, hashes, false)
   759  		if err != nil {
   760  			t.Fatalf("Unexpected error getting leaf by hash: %v", err)
   761  		}
   762  		if len(leaves) != 1 {
   763  			t.Fatalf("Got %d leaves but expected one", len(leaves))
   764  		}
   765  		checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t)
   766  		return nil
   767  	})
   768  }
   769  
   770  func TestGetLeafDataByIdentityHash(t *testing.T) {
   771  	ctx := context.Background()
   772  
   773  	// Create fake leaf as if it had been sequenced
   774  	cleanTestDB(DB)
   775  	tree := createTreeOrPanic(DB, testonly.LogTree)
   776  	s := NewLogStorage(DB, nil)
   777  	data := []byte("some data")
   778  	leaf := createFakeLeaf(ctx, DB, tree.TreeId, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t)
   779  	leaf.LeafIndex = -1
   780  	leaf.MerkleLeafHash = []byte(dummyMerkleLeafHash)
   781  	leaf2 := createFakeLeaf(ctx, DB, tree.TreeId, dummyHash2, dummyHash2, data, someExtraData, sequenceNumber+1, t)
   782  	leaf2.LeafIndex = -1
   783  	leaf2.MerkleLeafHash = []byte(dummyMerkleLeafHash)
   784  
   785  	var tests = []struct {
   786  		hashes [][]byte
   787  		want   []*trillian.LogLeaf
   788  	}{
   789  		{
   790  			hashes: [][]byte{dummyRawHash},
   791  			want:   []*trillian.LogLeaf{leaf},
   792  		},
   793  		{
   794  			hashes: [][]byte{{0x01, 0x02}},
   795  		},
   796  		{
   797  			hashes: [][]byte{
   798  				dummyRawHash,
   799  				{0x01, 0x02},
   800  				dummyHash2,
   801  				{0x01, 0x02},
   802  			},
   803  			// Note: leaves not necessarily returned in order requested.
   804  			want: []*trillian.LogLeaf{leaf2, leaf},
   805  		},
   806  	}
   807  	for i, test := range tests {
   808  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   809  			runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   810  				leaves, err := tx.(*logTreeTX).getLeafDataByIdentityHash(ctx, test.hashes)
   811  				if err != nil {
   812  					t.Fatalf("getLeavesByIdentityHash(_) = (_,%v); want (_,nil)", err)
   813  				}
   814  
   815  				if len(leaves) != len(test.want) {
   816  					t.Fatalf("getLeavesByIdentityHash(_) = (|%d|,nil); want (|%d|,nil)", len(leaves), len(test.want))
   817  				}
   818  				leavesEquivalent(t, leaves, test.want)
   819  				return nil
   820  			})
   821  		})
   822  	}
   823  }
   824  
   825  func leavesEquivalent(t *testing.T, gotLeaves, wantLeaves []*trillian.LogLeaf) {
   826  	t.Helper()
   827  	want := make(map[string]*trillian.LogLeaf)
   828  	for _, w := range wantLeaves {
   829  		k := sha256.Sum256([]byte(w.String()))
   830  		want[string(k[:])] = w
   831  	}
   832  	got := make(map[string]*trillian.LogLeaf)
   833  	for _, g := range gotLeaves {
   834  		k := sha256.Sum256([]byte(g.String()))
   835  		got[string(k[:])] = g
   836  	}
   837  	if diff := pretty.Compare(want, got); diff != "" {
   838  		t.Errorf("leaves not equivalent: diff -want,+got:\n%v", diff)
   839  	}
   840  }
   841  
   842  func TestGetLeavesByIndex(t *testing.T) {
   843  	ctx := context.Background()
   844  
   845  	// Create fake leaf as if it had been sequenced, read it back and check contents
   846  	cleanTestDB(DB)
   847  	tree := createTreeOrPanic(DB, testonly.LogTree)
   848  	s := NewLogStorage(DB, nil)
   849  
   850  	// The leaf indices are checked against the tree size so we need a root.
   851  	createFakeSignedLogRoot(DB, tree, uint64(sequenceNumber+1))
   852  
   853  	data := []byte("some data")
   854  	data2 := []byte("some other data")
   855  	createFakeLeaf(ctx, DB, tree.TreeId, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t)
   856  	createFakeLeaf(ctx, DB, tree.TreeId, dummyRawHash2, dummyHash2, data2, someExtraData2, sequenceNumber-1, t)
   857  
   858  	var tests = []struct {
   859  		desc     string
   860  		indices  []int64
   861  		wantErr  bool
   862  		wantCode codes.Code
   863  		checkFn  func([]*trillian.LogLeaf, *testing.T)
   864  	}{
   865  		{
   866  			desc:    "InTree",
   867  			indices: []int64{sequenceNumber},
   868  			checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) {
   869  				checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t)
   870  			},
   871  		},
   872  		{
   873  			desc:    "InTree2",
   874  			indices: []int64{sequenceNumber - 1},
   875  			wantErr: false,
   876  			checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) {
   877  				checkLeafContents(leaves[0], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t)
   878  			},
   879  		},
   880  		{
   881  			desc:    "InTreeMultiple",
   882  			indices: []int64{sequenceNumber - 1, sequenceNumber},
   883  			checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) {
   884  				checkLeafContents(leaves[1], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t)
   885  				checkLeafContents(leaves[0], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t)
   886  			},
   887  		},
   888  		{
   889  			desc:    "InTreeMultipleReverse",
   890  			indices: []int64{sequenceNumber, sequenceNumber - 1},
   891  			checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) {
   892  				checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t)
   893  				checkLeafContents(leaves[1], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t)
   894  			},
   895  		}, {
   896  			desc:     "OutsideTree",
   897  			indices:  []int64{sequenceNumber + 1},
   898  			wantErr:  true,
   899  			wantCode: codes.OutOfRange,
   900  		},
   901  		{
   902  			desc:     "LongWayOutsideTree",
   903  			indices:  []int64{9999},
   904  			wantErr:  true,
   905  			wantCode: codes.OutOfRange,
   906  		},
   907  		{
   908  			desc:     "MixedInOutTree",
   909  			indices:  []int64{sequenceNumber, sequenceNumber + 1},
   910  			wantErr:  true,
   911  			wantCode: codes.OutOfRange,
   912  		},
   913  		{
   914  			desc:     "MixedInOutTree2",
   915  			indices:  []int64{sequenceNumber - 1, sequenceNumber + 1},
   916  			wantErr:  true,
   917  			wantCode: codes.OutOfRange,
   918  		},
   919  	}
   920  
   921  	for _, test := range tests {
   922  		t.Run(test.desc, func(t *testing.T) {
   923  			runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   924  				got, err := tx.GetLeavesByIndex(ctx, test.indices)
   925  				if test.wantErr {
   926  					if err == nil || status.Code(err) != test.wantCode {
   927  						t.Errorf("GetLeavesByIndex(%v)=%v,%v; want: nil, err with code %v", test.indices, got, err, test.wantCode)
   928  					}
   929  				} else {
   930  					if err != nil {
   931  						t.Errorf("GetLeavesByIndex(%v)=%v,%v; want: got, nil", test.indices, got, err)
   932  					}
   933  				}
   934  				return nil
   935  			})
   936  		})
   937  	}
   938  }
   939  
   940  // GetLeavesByRange tests. -----------------------------------------------------
   941  
   942  type getLeavesByRangeTest struct {
   943  	start, count int64
   944  	want         []int64
   945  	wantErr      bool
   946  }
   947  
   948  func testGetLeavesByRangeImpl(t *testing.T, create *trillian.Tree, tests []getLeavesByRangeTest) {
   949  	cleanTestDB(DB)
   950  
   951  	ctx := context.Background()
   952  	tree, err := createTree(DB, create)
   953  	if err != nil {
   954  		t.Fatalf("Error creating log: %v", err)
   955  	}
   956  	// Note: GetLeavesByRange loads the root internally to get the tree size.
   957  	createFakeSignedLogRoot(DB, tree, 14)
   958  	s := NewLogStorage(DB, nil)
   959  
   960  	// Create leaves [0]..[19] but drop leaf [5] and set the tree size to 14.
   961  	for i := int64(0); i < 20; i++ {
   962  		if i == 5 {
   963  			continue
   964  		}
   965  		data := []byte{byte(i)}
   966  		identityHash := sha256.Sum256(data)
   967  		createFakeLeaf(ctx, DB, tree.TreeId, identityHash[:], identityHash[:], data, someExtraData, i, t)
   968  	}
   969  
   970  	for _, test := range tests {
   971  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   972  			leaves, err := tx.GetLeavesByRange(ctx, test.start, test.count)
   973  			if err != nil {
   974  				if !test.wantErr {
   975  					t.Errorf("GetLeavesByRange(%d, +%d)=_,%v; want _,nil", test.start, test.count, err)
   976  				}
   977  				return nil
   978  			}
   979  			if test.wantErr {
   980  				t.Errorf("GetLeavesByRange(%d, +%d)=_,nil; want _,non-nil", test.start, test.count)
   981  			}
   982  			got := make([]int64, len(leaves))
   983  			for i, leaf := range leaves {
   984  				got[i] = leaf.LeafIndex
   985  			}
   986  			if !reflect.DeepEqual(got, test.want) {
   987  				t.Errorf("GetLeavesByRange(%d, +%d)=%+v; want %+v", test.start, test.count, got, test.want)
   988  			}
   989  			return nil
   990  		})
   991  	}
   992  }
   993  
   994  func TestGetLeavesByRangeFromLog(t *testing.T) {
   995  	var tests = []getLeavesByRangeTest{
   996  		{start: 0, count: 1, want: []int64{0}},
   997  		{start: 0, count: 2, want: []int64{0, 1}},
   998  		{start: 1, count: 3, want: []int64{1, 2, 3}},
   999  		{start: 10, count: 7, want: []int64{10, 11, 12, 13}},
  1000  		{start: 13, count: 1, want: []int64{13}},
  1001  		{start: 14, count: 4, wantErr: true},   // Starts right after tree size.
  1002  		{start: 19, count: 2, wantErr: true},   // Starts further away.
  1003  		{start: 3, count: 5, wantErr: true},    // Hits non-contiguous leaves.
  1004  		{start: 5, count: 5, wantErr: true},    // Starts from a missing leaf.
  1005  		{start: 1, count: 0, wantErr: true},    // Empty range.
  1006  		{start: -1, count: 1, wantErr: true},   // Negative start.
  1007  		{start: 1, count: -1, wantErr: true},   // Negative count.
  1008  		{start: 100, count: 30, wantErr: true}, // Starts after all stored leaves.
  1009  	}
  1010  	testGetLeavesByRangeImpl(t, testonly.LogTree, tests)
  1011  }
  1012  
  1013  func TestGetLeavesByRangeFromPreorderedLog(t *testing.T) {
  1014  	var tests = []getLeavesByRangeTest{
  1015  		{start: 0, count: 1, want: []int64{0}},
  1016  		{start: 0, count: 2, want: []int64{0, 1}},
  1017  		{start: 1, count: 3, want: []int64{1, 2, 3}},
  1018  		{start: 10, count: 7, want: []int64{10, 11, 12, 13, 14, 15, 16}},
  1019  		{start: 13, count: 1, want: []int64{13}},
  1020  		// Starts right after tree size.
  1021  		{start: 14, count: 4, want: []int64{14, 15, 16, 17}},
  1022  		{start: 19, count: 2, want: []int64{19}}, // Starts further away.
  1023  		{start: 3, count: 5, wantErr: true},      // Hits non-contiguous leaves.
  1024  		{start: 5, count: 5, wantErr: true},      // Starts from a missing leaf.
  1025  		{start: 1, count: 0, wantErr: true},      // Empty range.
  1026  		{start: -1, count: 1, wantErr: true},     // Negative start.
  1027  		{start: 1, count: -1, wantErr: true},     // Negative count.
  1028  		{start: 100, count: 30, want: []int64{}}, // Starts after all stored leaves.
  1029  	}
  1030  	testGetLeavesByRangeImpl(t, testonly.PreorderedLogTree, tests)
  1031  }
  1032  
  1033  // -----------------------------------------------------------------------------
  1034  
  1035  func TestLatestSignedRootNoneWritten(t *testing.T) {
  1036  	ctx := context.Background()
  1037  
  1038  	cleanTestDB(DB)
  1039  	tree, err := createTree(DB, testonly.LogTree)
  1040  	if err != nil {
  1041  		t.Fatalf("createTree: %v", err)
  1042  	}
  1043  	s := NewLogStorage(DB, nil)
  1044  
  1045  	tx, err := s.SnapshotForTree(ctx, tree)
  1046  	if err != storage.ErrTreeNeedsInit {
  1047  		t.Fatalf("SnapshotForTree gave %v, want %v", err, storage.ErrTreeNeedsInit)
  1048  	}
  1049  	commit(tx, t)
  1050  }
  1051  
  1052  func TestLatestSignedLogRoot(t *testing.T) {
  1053  	cleanTestDB(DB)
  1054  	tree := createTreeOrPanic(DB, testonly.LogTree)
  1055  	s := NewLogStorage(DB, nil)
  1056  
  1057  	signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256)
  1058  	root, err := signer.SignLogRoot(&types.LogRootV1{
  1059  		TimestampNanos: 98765,
  1060  		TreeSize:       16,
  1061  		Revision:       5,
  1062  		RootHash:       []byte(dummyHash),
  1063  	})
  1064  	if err != nil {
  1065  		t.Fatalf("SignLogRoot(): %v", err)
  1066  	}
  1067  
  1068  	runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
  1069  		if err := tx.StoreSignedLogRoot(ctx, *root); err != nil {
  1070  			t.Fatalf("Failed to store signed root: %v", err)
  1071  		}
  1072  		return nil
  1073  	})
  1074  
  1075  	{
  1076  		runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
  1077  			root2, err := tx2.LatestSignedLogRoot(ctx)
  1078  			if err != nil {
  1079  				t.Fatalf("Failed to read back new log root: %v", err)
  1080  			}
  1081  			if !proto.Equal(root, &root2) {
  1082  				t.Fatalf("Root round trip failed: <%v> and: <%v>", root, root2)
  1083  			}
  1084  			return nil
  1085  		})
  1086  	}
  1087  }
  1088  
  1089  func TestDuplicateSignedLogRoot(t *testing.T) {
  1090  	cleanTestDB(DB)
  1091  	tree := createTreeOrPanic(DB, testonly.LogTree)
  1092  	s := NewLogStorage(DB, nil)
  1093  
  1094  	signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256)
  1095  	root, err := signer.SignLogRoot(&types.LogRootV1{
  1096  		TimestampNanos: 98765,
  1097  		TreeSize:       16,
  1098  		Revision:       5,
  1099  		RootHash:       []byte(dummyHash),
  1100  	})
  1101  	if err != nil {
  1102  		t.Fatalf("SignLogRoot(): %v", err)
  1103  	}
  1104  
  1105  	runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
  1106  		if err := tx.StoreSignedLogRoot(ctx, *root); err != nil {
  1107  			t.Fatalf("Failed to store signed root: %v", err)
  1108  		}
  1109  		// Shouldn't be able to do it again
  1110  		if err := tx.StoreSignedLogRoot(ctx, *root); err == nil {
  1111  			t.Fatal("Allowed duplicate signed root")
  1112  		}
  1113  		return nil
  1114  	})
  1115  }
  1116  
  1117  func TestLogRootUpdate(t *testing.T) {
  1118  	// Write two roots for a log and make sure the one with the newest timestamp supersedes
  1119  	cleanTestDB(DB)
  1120  	tree := createTreeOrPanic(DB, testonly.LogTree)
  1121  	s := NewLogStorage(DB, nil)
  1122  
  1123  	signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256)
  1124  	root, err := signer.SignLogRoot(&types.LogRootV1{
  1125  		TimestampNanos: 98765,
  1126  		TreeSize:       16,
  1127  		Revision:       5,
  1128  		RootHash:       []byte(dummyHash),
  1129  	})
  1130  	if err != nil {
  1131  		t.Fatalf("SignLogRoot(): %v", err)
  1132  	}
  1133  	root2, err := signer.SignLogRoot(&types.LogRootV1{
  1134  		TimestampNanos: 98766,
  1135  		TreeSize:       16,
  1136  		Revision:       6,
  1137  		RootHash:       []byte(dummyHash),
  1138  	})
  1139  	if err != nil {
  1140  		t.Fatalf("SignLogRoot(): %v", err)
  1141  	}
  1142  
  1143  	runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
  1144  		if err := tx.StoreSignedLogRoot(ctx, *root); err != nil {
  1145  			t.Fatalf("Failed to store signed root: %v", err)
  1146  		}
  1147  		if err := tx.StoreSignedLogRoot(ctx, *root2); err != nil {
  1148  			t.Fatalf("Failed to store signed root: %v", err)
  1149  		}
  1150  		return nil
  1151  	})
  1152  
  1153  	runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
  1154  		root3, err := tx2.LatestSignedLogRoot(ctx)
  1155  		if err != nil {
  1156  			t.Fatalf("Failed to read back new log root: %v", err)
  1157  		}
  1158  		if !proto.Equal(root2, &root3) {
  1159  			t.Fatalf("Root round trip failed: <%v> and: <%v>", root, root2)
  1160  		}
  1161  		return nil
  1162  	})
  1163  }
  1164  
  1165  func TestGetActiveLogIDs(t *testing.T) {
  1166  	ctx := context.Background()
  1167  
  1168  	cleanTestDB(DB)
  1169  	admin := NewAdminStorage(DB)
  1170  
  1171  	// Create a few test trees
  1172  	log1 := proto.Clone(testonly.LogTree).(*trillian.Tree)
  1173  	log2 := proto.Clone(testonly.LogTree).(*trillian.Tree)
  1174  	log3 := proto.Clone(testonly.PreorderedLogTree).(*trillian.Tree)
  1175  	drainingLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
  1176  	frozenLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
  1177  	deletedLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
  1178  	map1 := proto.Clone(testonly.MapTree).(*trillian.Tree)
  1179  	map2 := proto.Clone(testonly.MapTree).(*trillian.Tree)
  1180  	deletedMap := proto.Clone(testonly.MapTree).(*trillian.Tree)
  1181  	for _, tree := range []*trillian.Tree{log1, log2, log3, drainingLog, frozenLog, deletedLog, map1, map2, deletedMap} {
  1182  		newTree, err := storage.CreateTree(ctx, admin, tree)
  1183  		if err != nil {
  1184  			t.Fatalf("CreateTree(%+v) returned err = %v", tree, err)
  1185  		}
  1186  		*tree = *newTree
  1187  	}
  1188  
  1189  	// FROZEN is not a valid initial state, so we have to update it separately.
  1190  	if _, err := storage.UpdateTree(ctx, admin, frozenLog.TreeId, func(t *trillian.Tree) {
  1191  		t.TreeState = trillian.TreeState_FROZEN
  1192  	}); err != nil {
  1193  		t.Fatalf("UpdateTree() returned err = %v", err)
  1194  	}
  1195  	// DRAINING is not a valid initial state, so we have to update it separately.
  1196  	if _, err := storage.UpdateTree(ctx, admin, drainingLog.TreeId, func(t *trillian.Tree) {
  1197  		t.TreeState = trillian.TreeState_DRAINING
  1198  	}); err != nil {
  1199  		t.Fatalf("UpdateTree() returned err = %v", err)
  1200  	}
  1201  
  1202  	// Update deleted trees accordingly
  1203  	updateDeletedStmt, err := DB.PrepareContext(ctx, "UPDATE Trees SET Deleted = ? WHERE TreeId = ?")
  1204  	if err != nil {
  1205  		t.Fatalf("PrepareContext() returned err = %v", err)
  1206  	}
  1207  	defer updateDeletedStmt.Close()
  1208  	for _, treeID := range []int64{deletedLog.TreeId, deletedMap.TreeId} {
  1209  		if _, err := updateDeletedStmt.ExecContext(ctx, true, treeID); err != nil {
  1210  			t.Fatalf("ExecContext(%v) returned err = %v", treeID, err)
  1211  		}
  1212  	}
  1213  
  1214  	s := NewLogStorage(DB, nil)
  1215  	tx, err := s.Snapshot(ctx)
  1216  	if err != nil {
  1217  		t.Fatalf("Snapshot() returns err = %v", err)
  1218  	}
  1219  	defer tx.Close()
  1220  	got, err := tx.GetActiveLogIDs(ctx)
  1221  	if err != nil {
  1222  		t.Fatalf("GetActiveLogIDs() returns err = %v", err)
  1223  	}
  1224  	if err := tx.Commit(); err != nil {
  1225  		t.Errorf("Commit() returned err = %v", err)
  1226  	}
  1227  
  1228  	want := []int64{log1.TreeId, log2.TreeId, log3.TreeId, drainingLog.TreeId}
  1229  	sort.Slice(got, func(i, j int) bool { return got[i] < got[j] })
  1230  	sort.Slice(want, func(i, j int) bool { return want[i] < want[j] })
  1231  	if diff := pretty.Compare(got, want); diff != "" {
  1232  		t.Errorf("post-GetActiveLogIDs diff (-got +want):\n%v", diff)
  1233  	}
  1234  }
  1235  
  1236  func TestGetActiveLogIDsEmpty(t *testing.T) {
  1237  	ctx := context.Background()
  1238  
  1239  	cleanTestDB(DB)
  1240  	s := NewLogStorage(DB, nil)
  1241  
  1242  	tx, err := s.Snapshot(context.Background())
  1243  	if err != nil {
  1244  		t.Fatalf("Snapshot() = (_, %v), want = (_, nil)", err)
  1245  	}
  1246  	defer tx.Close()
  1247  	ids, err := tx.GetActiveLogIDs(ctx)
  1248  	if err != nil {
  1249  		t.Fatalf("GetActiveLogIDs() = (_, %v), want = (_, nil)", err)
  1250  	}
  1251  	if err := tx.Commit(); err != nil {
  1252  		t.Errorf("Commit() = %v, want = nil", err)
  1253  	}
  1254  
  1255  	if got, want := len(ids), 0; got != want {
  1256  		t.Errorf("GetActiveLogIDs(): got %v IDs, want = %v", got, want)
  1257  	}
  1258  }
  1259  
  1260  func TestGetUnsequencedCounts(t *testing.T) {
  1261  	numLogs := 4
  1262  	cleanTestDB(DB)
  1263  	trees := make([]*trillian.Tree, 0, numLogs)
  1264  	for i := 0; i < numLogs; i++ {
  1265  		trees = append(trees, createTreeOrPanic(DB, testonly.LogTree))
  1266  	}
  1267  	s := NewLogStorage(DB, nil)
  1268  
  1269  	ctx := context.Background()
  1270  	expectedCount := make(map[int64]int64)
  1271  
  1272  	for i := int64(1); i < 10; i++ {
  1273  		// Put some leaves in the queue of each of the logs
  1274  		for j, tree := range trees {
  1275  			numToAdd := i + int64(j)
  1276  			runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
  1277  				leaves := createTestLeaves(numToAdd, expectedCount[tree.TreeId])
  1278  				if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil {
  1279  					t.Fatalf("Failed to queue leaves: %v", err)
  1280  				}
  1281  				return nil
  1282  			})
  1283  
  1284  			expectedCount[tree.TreeId] += numToAdd
  1285  		}
  1286  
  1287  		// Now check what we get back from GetUnsequencedCounts matches
  1288  		tx, err := s.Snapshot(ctx)
  1289  		if err != nil {
  1290  			t.Fatalf("Snapshot() = (_, %v), want no error", err)
  1291  		}
  1292  		// tx explicitly closed in all branches
  1293  
  1294  		got, err := tx.GetUnsequencedCounts(ctx)
  1295  		if err != nil {
  1296  			tx.Close()
  1297  			t.Errorf("GetUnsequencedCounts() = %v, want no error", err)
  1298  		}
  1299  		if err := tx.Commit(); err != nil {
  1300  			t.Errorf("Commit() = %v, want no error", err)
  1301  			return
  1302  		}
  1303  		if diff := pretty.Compare(expectedCount, got); diff != "" {
  1304  			t.Errorf("GetUnsequencedCounts() = diff -want +got:\n%s", diff)
  1305  		}
  1306  	}
  1307  }
  1308  
  1309  func TestReadOnlyLogTX_Rollback(t *testing.T) {
  1310  	ctx := context.Background()
  1311  	cleanTestDB(DB)
  1312  	s := NewLogStorage(DB, nil)
  1313  	tx, err := s.Snapshot(ctx)
  1314  	if err != nil {
  1315  		t.Fatalf("Snapshot() = (_, %v), want = (_, nil)", err)
  1316  	}
  1317  	defer tx.Close()
  1318  	if _, err := tx.GetActiveLogIDs(ctx); err != nil {
  1319  		t.Fatalf("GetActiveLogIDs() = (_, %v), want = (_, nil)", err)
  1320  	}
  1321  	// It's a bit hard to have a more meaningful test. This should suffice.
  1322  	if err := tx.Rollback(); err != nil {
  1323  		t.Errorf("Rollback() = (_, %v), want = (_, nil)", err)
  1324  	}
  1325  }
  1326  
  1327  func TestGetSequencedLeafCount(t *testing.T) {
  1328  	ctx := context.Background()
  1329  
  1330  	// We'll create leaves for two different trees
  1331  	cleanTestDB(DB)
  1332  	log1 := createTreeOrPanic(DB, testonly.LogTree)
  1333  	log2 := createTreeOrPanic(DB, testonly.LogTree)
  1334  	s := NewLogStorage(DB, nil)
  1335  
  1336  	{
  1337  		// Create fake leaf as if it had been sequenced
  1338  		data := []byte("some data")
  1339  		createFakeLeaf(ctx, DB, log1.TreeId, dummyHash, dummyRawHash, data, someExtraData, sequenceNumber, t)
  1340  
  1341  		// Create fake leaves for second tree as if they had been sequenced
  1342  		data2 := []byte("some data 2")
  1343  		data3 := []byte("some data 3")
  1344  		createFakeLeaf(ctx, DB, log2.TreeId, dummyHash2, dummyRawHash, data2, someExtraData, sequenceNumber, t)
  1345  		createFakeLeaf(ctx, DB, log2.TreeId, dummyHash3, dummyRawHash, data3, someExtraData, sequenceNumber+1, t)
  1346  	}
  1347  
  1348  	// Read back the leaf counts from both trees
  1349  	runLogTX(s, log1, t, func(ctx context.Context, tx storage.LogTreeTX) error {
  1350  		count1, err := tx.GetSequencedLeafCount(ctx)
  1351  		if err != nil {
  1352  			t.Fatalf("unexpected error getting leaf count: %v", err)
  1353  		}
  1354  		if want, got := int64(1), count1; want != got {
  1355  			t.Fatalf("expected %d sequenced for logId but got %d", want, got)
  1356  		}
  1357  		return nil
  1358  	})
  1359  
  1360  	runLogTX(s, log2, t, func(ctx context.Context, tx storage.LogTreeTX) error {
  1361  		count2, err := tx.GetSequencedLeafCount(ctx)
  1362  		if err != nil {
  1363  			t.Fatalf("unexpected error getting leaf count2: %v", err)
  1364  		}
  1365  		if want, got := int64(2), count2; want != got {
  1366  			t.Fatalf("expected %d sequenced for logId2 but got %d", want, got)
  1367  		}
  1368  		return nil
  1369  	})
  1370  }
  1371  
  1372  func TestSortByLeafIdentityHash(t *testing.T) {
  1373  	l := make([]*trillian.LogLeaf, 30)
  1374  	for i := range l {
  1375  		hash := sha256.Sum256([]byte{byte(i)})
  1376  		leaf := trillian.LogLeaf{
  1377  			LeafIdentityHash: hash[:],
  1378  			LeafValue:        []byte(fmt.Sprintf("Value %d", i)),
  1379  			ExtraData:        []byte(fmt.Sprintf("Extra %d", i)),
  1380  			LeafIndex:        int64(i),
  1381  		}
  1382  		l[i] = &leaf
  1383  	}
  1384  	sort.Sort(byLeafIdentityHash(l))
  1385  	for i := range l {
  1386  		if i == 0 {
  1387  			continue
  1388  		}
  1389  		if bytes.Compare(l[i-1].LeafIdentityHash, l[i].LeafIdentityHash) != -1 {
  1390  			t.Errorf("sorted leaves not in order, [%d] = %x, [%d] = %x", i-1, l[i-1].LeafIdentityHash, i, l[i].LeafIdentityHash)
  1391  		}
  1392  	}
  1393  
  1394  }
  1395  
  1396  func ensureAllLeavesDistinct(leaves []*trillian.LogLeaf, t *testing.T) {
  1397  	t.Helper()
  1398  	// All the leaf value hashes should be distinct because the leaves were created with distinct
  1399  	// leaf data. If only we had maps with slices as keys or sets or pretty much any kind of usable
  1400  	// data structures we could do this properly.
  1401  	for i := range leaves {
  1402  		for j := range leaves {
  1403  			if i != j && bytes.Equal(leaves[i].LeafIdentityHash, leaves[j].LeafIdentityHash) {
  1404  				t.Fatalf("Unexpectedly got a duplicate leaf hash: %v %v",
  1405  					leaves[i].LeafIdentityHash, leaves[j].LeafIdentityHash)
  1406  			}
  1407  		}
  1408  	}
  1409  }
  1410  
  1411  func ensureLeavesHaveQueueTimestamp(t *testing.T, leaves []*trillian.LogLeaf, want time.Time) {
  1412  	t.Helper()
  1413  	for _, leaf := range leaves {
  1414  		gotQTimestamp, err := ptypes.Timestamp(leaf.QueueTimestamp)
  1415  		if err != nil {
  1416  			t.Fatalf("Got invalid queue timestamp: %v", err)
  1417  		}
  1418  		if got, want := gotQTimestamp.UnixNano(), want.UnixNano(); got != want {
  1419  			t.Errorf("Got leaf with QueueTimestampNanos = %v, want %v: %v", got, want, leaf)
  1420  		}
  1421  	}
  1422  }
  1423  
  1424  // Creates some test leaves with predictable data
  1425  func createTestLeaves(n, startSeq int64) []*trillian.LogLeaf {
  1426  	var leaves []*trillian.LogLeaf
  1427  	for l := int64(0); l < n; l++ {
  1428  		lv := fmt.Sprintf("Leaf %d", l+startSeq)
  1429  		h := sha256.New()
  1430  		h.Write([]byte(lv))
  1431  		leafHash := h.Sum(nil)
  1432  		leaf := &trillian.LogLeaf{
  1433  			LeafIdentityHash: leafHash,
  1434  			MerkleLeafHash:   leafHash,
  1435  			LeafValue:        []byte(lv),
  1436  			ExtraData:        []byte(fmt.Sprintf("Extra %d", l)),
  1437  			LeafIndex:        int64(startSeq + l),
  1438  		}
  1439  		leaves = append(leaves, leaf)
  1440  	}
  1441  
  1442  	return leaves
  1443  }
  1444  
  1445  // Convenience methods to avoid copying out "if err != nil { blah }" all over the place
  1446  func runLogTX(s storage.LogStorage, tree *trillian.Tree, t *testing.T, f storage.LogTXFunc) {
  1447  	t.Helper()
  1448  	if err := s.ReadWriteTransaction(context.Background(), tree, f); err != nil {
  1449  		t.Fatalf("Failed to run log tx: %v", err)
  1450  	}
  1451  }
  1452  
  1453  type committableTX interface {
  1454  	Commit() error
  1455  }
  1456  
  1457  func commit(tx committableTX, t *testing.T) {
  1458  	t.Helper()
  1459  	if err := tx.Commit(); err != nil {
  1460  		t.Errorf("Failed to commit tx: %v", err)
  1461  	}
  1462  }
  1463  
  1464  func leafInBatch(leaf *trillian.LogLeaf, batch []*trillian.LogLeaf) bool {
  1465  	for _, bl := range batch {
  1466  		if bytes.Equal(bl.LeafIdentityHash, leaf.LeafIdentityHash) {
  1467  			return true
  1468  		}
  1469  	}
  1470  
  1471  	return false
  1472  }
  1473  
  1474  // byLeafIdentityHash allows sorting of leaves by their identity hash, so DB
  1475  // operations always happen in a consistent order.
  1476  type byLeafIdentityHash []*trillian.LogLeaf
  1477  
  1478  func (l byLeafIdentityHash) Len() int      { return len(l) }
  1479  func (l byLeafIdentityHash) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
  1480  func (l byLeafIdentityHash) Less(i, j int) bool {
  1481  	return bytes.Compare(l[i].LeafIdentityHash, l[j].LeafIdentityHash) == -1
  1482  }
  1483  
  1484  func logTree(logID int64) *trillian.Tree {
  1485  	return &trillian.Tree{
  1486  		TreeId:       logID,
  1487  		TreeType:     trillian.TreeType_LOG,
  1488  		HashStrategy: trillian.HashStrategy_RFC6962_SHA256,
  1489  	}
  1490  }