github.com/bartle-stripe/trillian@v1.2.1/storage/mysql/storage_test.go (about)

     1  // Copyright 2016 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto"
    21  	"crypto/sha256"
    22  	"database/sql"
    23  	"flag"
    24  	"fmt"
    25  	"os"
    26  	"testing"
    27  
    28  	"github.com/golang/glog"
    29  	"github.com/google/trillian"
    30  	"github.com/google/trillian/merkle"
    31  	"github.com/google/trillian/merkle/rfc6962"
    32  	"github.com/google/trillian/storage"
    33  	"github.com/google/trillian/storage/testdb"
    34  	"github.com/google/trillian/testonly"
    35  	"github.com/google/trillian/types"
    36  
    37  	tcrypto "github.com/google/trillian/crypto"
    38  	storageto "github.com/google/trillian/storage/testonly"
    39  )
    40  
    41  func TestNodeRoundTrip(t *testing.T) {
    42  	cleanTestDB(DB)
    43  	tree := createTreeOrPanic(DB, storageto.LogTree)
    44  	s := NewLogStorage(DB, nil)
    45  
    46  	const writeRevision = int64(100)
    47  	nodesToStore := createSomeNodes()
    48  	nodeIDsToRead := make([]storage.NodeID, len(nodesToStore))
    49  	for i := range nodesToStore {
    50  		nodeIDsToRead[i] = nodesToStore[i].NodeID
    51  	}
    52  
    53  	{
    54  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
    55  			forceWriteRevision(writeRevision, tx)
    56  
    57  			// Need to read nodes before attempting to write
    58  			if _, err := tx.GetMerkleNodes(ctx, 99, nodeIDsToRead); err != nil {
    59  				t.Fatalf("Failed to read nodes: %s", err)
    60  			}
    61  			if err := tx.SetMerkleNodes(ctx, nodesToStore); err != nil {
    62  				t.Fatalf("Failed to store nodes: %s", err)
    63  			}
    64  			return nil
    65  		})
    66  	}
    67  
    68  	{
    69  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
    70  			readNodes, err := tx.GetMerkleNodes(ctx, 100, nodeIDsToRead)
    71  			if err != nil {
    72  				t.Fatalf("Failed to retrieve nodes: %s", err)
    73  			}
    74  			if err := nodesAreEqual(readNodes, nodesToStore); err != nil {
    75  				t.Fatalf("Read back different nodes from the ones stored: %s", err)
    76  			}
    77  			return nil
    78  		})
    79  	}
    80  }
    81  
    82  // This test ensures that node writes cross subtree boundaries so this edge case in the subtree
    83  // cache gets exercised. Any tree size > 256 will do this.
    84  func TestLogNodeRoundTripMultiSubtree(t *testing.T) {
    85  	cleanTestDB(DB)
    86  	tree := createTreeOrPanic(DB, storageto.LogTree)
    87  	s := NewLogStorage(DB, nil)
    88  
    89  	const writeRevision = int64(100)
    90  	nodesToStore, err := createLogNodesForTreeAtSize(871, writeRevision)
    91  	if err != nil {
    92  		t.Fatalf("failed to create test tree: %v", err)
    93  	}
    94  	nodeIDsToRead := make([]storage.NodeID, len(nodesToStore))
    95  	for i := range nodesToStore {
    96  		nodeIDsToRead[i] = nodesToStore[i].NodeID
    97  	}
    98  
    99  	{
   100  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   101  			forceWriteRevision(writeRevision, tx)
   102  
   103  			// Need to read nodes before attempting to write
   104  			if _, err := tx.GetMerkleNodes(ctx, writeRevision-1, nodeIDsToRead); err != nil {
   105  				t.Fatalf("Failed to read nodes: %s", err)
   106  			}
   107  			if err := tx.SetMerkleNodes(ctx, nodesToStore); err != nil {
   108  				t.Fatalf("Failed to store nodes: %s", err)
   109  			}
   110  			return nil
   111  		})
   112  	}
   113  
   114  	{
   115  		runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
   116  			readNodes, err := tx.GetMerkleNodes(ctx, 100, nodeIDsToRead)
   117  			if err != nil {
   118  				t.Fatalf("Failed to retrieve nodes: %s", err)
   119  			}
   120  			if err := nodesAreEqual(readNodes, nodesToStore); err != nil {
   121  				missing, extra := diffNodes(readNodes, nodesToStore)
   122  				for _, n := range missing {
   123  					t.Errorf("Missing: %s %s", n.NodeID.String(), n.NodeID.CoordString())
   124  				}
   125  				for _, n := range extra {
   126  					t.Errorf("Extra  : %s %s", n.NodeID.String(), n.NodeID.CoordString())
   127  				}
   128  				t.Fatalf("Read back different nodes from the ones stored: %s", err)
   129  			}
   130  			return nil
   131  		})
   132  	}
   133  }
   134  
   135  func forceWriteRevision(rev int64, tx storage.TreeTX) {
   136  	mtx, ok := tx.(*logTreeTX)
   137  	if !ok {
   138  		panic(nil)
   139  	}
   140  	mtx.treeTX.writeRevision = rev
   141  }
   142  
   143  func createSomeNodes() []storage.Node {
   144  	r := make([]storage.Node, 4)
   145  	for i := range r {
   146  		r[i].NodeID = storage.NewNodeIDWithPrefix(uint64(i), 8, 8, 8)
   147  		h := sha256.Sum256([]byte{byte(i)})
   148  		r[i].Hash = h[:]
   149  		glog.Infof("Node to store: %v\n", r[i].NodeID)
   150  	}
   151  	return r
   152  }
   153  
   154  func createLogNodesForTreeAtSize(ts, rev int64) ([]storage.Node, error) {
   155  	tree := merkle.NewCompactMerkleTree(rfc6962.New(crypto.SHA256))
   156  	nodeMap := make(map[string]storage.Node)
   157  	for l := 0; l < int(ts); l++ {
   158  		// We're only interested in the side effects of adding leaves - the node updates
   159  		if _, _, err := tree.AddLeaf([]byte(fmt.Sprintf("Leaf %d", l)), func(depth int, index int64, hash []byte) error {
   160  			nID, err := storage.NewNodeIDForTreeCoords(int64(depth), index, 64)
   161  			if err != nil {
   162  				return fmt.Errorf("failed to create a nodeID for tree - should not happen d:%d i:%d",
   163  					depth, index)
   164  			}
   165  
   166  			nodeMap[nID.String()] = storage.Node{NodeID: nID, NodeRevision: rev, Hash: hash}
   167  			return nil
   168  		}); err != nil {
   169  			return nil, err
   170  		}
   171  	}
   172  
   173  	// Unroll the map, which has deduped the updates for us and retained the latest
   174  	nodes := make([]storage.Node, 0, len(nodeMap))
   175  	for _, v := range nodeMap {
   176  		nodes = append(nodes, v)
   177  	}
   178  
   179  	return nodes, nil
   180  }
   181  
   182  func nodesAreEqual(lhs []storage.Node, rhs []storage.Node) error {
   183  	if ls, rs := len(lhs), len(rhs); ls != rs {
   184  		return fmt.Errorf("different number of nodes, %d vs %d", ls, rs)
   185  	}
   186  	for i := range lhs {
   187  		if l, r := lhs[i].NodeID.String(), rhs[i].NodeID.String(); l != r {
   188  			return fmt.Errorf("NodeIDs are not the same,\nlhs = %v,\nrhs = %v", l, r)
   189  		}
   190  		if l, r := lhs[i].Hash, rhs[i].Hash; !bytes.Equal(l, r) {
   191  			return fmt.Errorf("Hashes are not the same for %s,\nlhs = %v,\nrhs = %v", lhs[i].NodeID.CoordString(), l, r)
   192  		}
   193  	}
   194  	return nil
   195  }
   196  
   197  func diffNodes(got, want []storage.Node) ([]storage.Node, []storage.Node) {
   198  	var missing []storage.Node
   199  	gotMap := make(map[string]storage.Node)
   200  	for _, n := range got {
   201  		gotMap[n.NodeID.String()] = n
   202  	}
   203  	for _, n := range want {
   204  		_, ok := gotMap[n.NodeID.String()]
   205  		if !ok {
   206  			missing = append(missing, n)
   207  		}
   208  		delete(gotMap, n.NodeID.String())
   209  	}
   210  	// Unpack the extra nodes to return both as slices
   211  	extra := make([]storage.Node, 0, len(gotMap))
   212  	for _, v := range gotMap {
   213  		extra = append(extra, v)
   214  	}
   215  	return missing, extra
   216  }
   217  
   218  func openTestDBOrDie() *sql.DB {
   219  	db, err := testdb.NewTrillianDB(context.TODO())
   220  	if err != nil {
   221  		panic(err)
   222  	}
   223  	return db
   224  }
   225  
   226  // cleanTestDB deletes all the entries in the database.
   227  func cleanTestDB(db *sql.DB) {
   228  	for _, table := range allTables {
   229  		if _, err := db.ExecContext(context.TODO(), fmt.Sprintf("DELETE FROM %s", table)); err != nil {
   230  			panic(fmt.Sprintf("Failed to delete rows in %s: %v", table, err))
   231  		}
   232  	}
   233  }
   234  
   235  func createFakeSignedLogRoot(db *sql.DB, tree *trillian.Tree, treeSize uint64) {
   236  	signer := tcrypto.NewSigner(0, testonly.NewSignerWithFixedSig(nil, []byte("notnil")), crypto.SHA256)
   237  
   238  	ctx := context.Background()
   239  	l := NewLogStorage(db, nil)
   240  	err := l.ReadWriteTransaction(ctx, tree, func(ctx context.Context, tx storage.LogTreeTX) error {
   241  		root, err := signer.SignLogRoot(&types.LogRootV1{TreeSize: treeSize, RootHash: []byte{0}})
   242  		if err != nil {
   243  			return fmt.Errorf("Error creating new SignedLogRoot: %v", err)
   244  		}
   245  		if err := tx.StoreSignedLogRoot(ctx, *root); err != nil {
   246  			return fmt.Errorf("Error storing new SignedLogRoot: %v", err)
   247  		}
   248  		return nil
   249  	})
   250  	if err != nil {
   251  		panic(fmt.Sprintf("ReadWriteTransaction() = %v", err))
   252  	}
   253  }
   254  
   255  // createTree creates the specified tree using AdminStorage.
   256  func createTree(db *sql.DB, tree *trillian.Tree) (*trillian.Tree, error) {
   257  	ctx := context.Background()
   258  	s := NewAdminStorage(db)
   259  	tree, err := storage.CreateTree(ctx, s, tree)
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  	return tree, nil
   264  }
   265  
   266  func createTreeOrPanic(db *sql.DB, create *trillian.Tree) *trillian.Tree {
   267  	tree, err := createTree(db, create)
   268  	if err != nil {
   269  		panic(fmt.Sprintf("Error creating tree: %v", err))
   270  	}
   271  	return tree
   272  }
   273  
   274  // updateTree updates the specified tree using AdminStorage.
   275  func updateTree(db *sql.DB, treeID int64, updateFn func(*trillian.Tree)) (*trillian.Tree, error) {
   276  	ctx := context.Background()
   277  	s := NewAdminStorage(db)
   278  	return storage.UpdateTree(ctx, s, treeID, updateFn)
   279  }
   280  
   281  // DB is the database used for tests. It's initialized and closed by TestMain().
   282  var DB *sql.DB
   283  
   284  func TestMain(m *testing.M) {
   285  	flag.Parse()
   286  	if !testdb.MySQLAvailable() {
   287  		glog.Errorf("MySQL not available, skipping all MySQL storage tests")
   288  		return
   289  	}
   290  	DB = openTestDBOrDie()
   291  	defer DB.Close()
   292  	cleanTestDB(DB)
   293  	ec := m.Run()
   294  	os.Exit(ec)
   295  }