github.com/bartle-stripe/trillian@v1.2.1/storage/mysql/storage_test.go (about) 1 // Copyright 2016 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql 16 17 import ( 18 "bytes" 19 "context" 20 "crypto" 21 "crypto/sha256" 22 "database/sql" 23 "flag" 24 "fmt" 25 "os" 26 "testing" 27 28 "github.com/golang/glog" 29 "github.com/google/trillian" 30 "github.com/google/trillian/merkle" 31 "github.com/google/trillian/merkle/rfc6962" 32 "github.com/google/trillian/storage" 33 "github.com/google/trillian/storage/testdb" 34 "github.com/google/trillian/testonly" 35 "github.com/google/trillian/types" 36 37 tcrypto "github.com/google/trillian/crypto" 38 storageto "github.com/google/trillian/storage/testonly" 39 ) 40 41 func TestNodeRoundTrip(t *testing.T) { 42 cleanTestDB(DB) 43 tree := createTreeOrPanic(DB, storageto.LogTree) 44 s := NewLogStorage(DB, nil) 45 46 const writeRevision = int64(100) 47 nodesToStore := createSomeNodes() 48 nodeIDsToRead := make([]storage.NodeID, len(nodesToStore)) 49 for i := range nodesToStore { 50 nodeIDsToRead[i] = nodesToStore[i].NodeID 51 } 52 53 { 54 runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { 55 forceWriteRevision(writeRevision, tx) 56 57 // Need to read nodes before attempting to write 58 if _, err := tx.GetMerkleNodes(ctx, 99, nodeIDsToRead); err != nil { 59 t.Fatalf("Failed to read nodes: %s", err) 60 } 61 if err := tx.SetMerkleNodes(ctx, nodesToStore); err != nil { 62 t.Fatalf("Failed to store nodes: %s", err) 63 } 64 return nil 65 }) 66 } 67 68 { 69 runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { 70 readNodes, err := tx.GetMerkleNodes(ctx, 100, nodeIDsToRead) 71 if err != nil { 72 t.Fatalf("Failed to retrieve nodes: %s", err) 73 } 74 if err := nodesAreEqual(readNodes, nodesToStore); err != nil { 75 t.Fatalf("Read back different nodes from the ones stored: %s", err) 76 } 77 return nil 78 }) 79 } 80 } 81 82 // This test ensures that node writes cross subtree boundaries so this edge case in the subtree 83 // cache gets exercised. Any tree size > 256 will do this. 84 func TestLogNodeRoundTripMultiSubtree(t *testing.T) { 85 cleanTestDB(DB) 86 tree := createTreeOrPanic(DB, storageto.LogTree) 87 s := NewLogStorage(DB, nil) 88 89 const writeRevision = int64(100) 90 nodesToStore, err := createLogNodesForTreeAtSize(871, writeRevision) 91 if err != nil { 92 t.Fatalf("failed to create test tree: %v", err) 93 } 94 nodeIDsToRead := make([]storage.NodeID, len(nodesToStore)) 95 for i := range nodesToStore { 96 nodeIDsToRead[i] = nodesToStore[i].NodeID 97 } 98 99 { 100 runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { 101 forceWriteRevision(writeRevision, tx) 102 103 // Need to read nodes before attempting to write 104 if _, err := tx.GetMerkleNodes(ctx, writeRevision-1, nodeIDsToRead); err != nil { 105 t.Fatalf("Failed to read nodes: %s", err) 106 } 107 if err := tx.SetMerkleNodes(ctx, nodesToStore); err != nil { 108 t.Fatalf("Failed to store nodes: %s", err) 109 } 110 return nil 111 }) 112 } 113 114 { 115 runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { 116 readNodes, err := tx.GetMerkleNodes(ctx, 100, nodeIDsToRead) 117 if err != nil { 118 t.Fatalf("Failed to retrieve nodes: %s", err) 119 } 120 if err := nodesAreEqual(readNodes, nodesToStore); err != nil { 121 missing, extra := diffNodes(readNodes, nodesToStore) 122 for _, n := range missing { 123 t.Errorf("Missing: %s %s", n.NodeID.String(), n.NodeID.CoordString()) 124 } 125 for _, n := range extra { 126 t.Errorf("Extra : %s %s", n.NodeID.String(), n.NodeID.CoordString()) 127 } 128 t.Fatalf("Read back different nodes from the ones stored: %s", err) 129 } 130 return nil 131 }) 132 } 133 } 134 135 func forceWriteRevision(rev int64, tx storage.TreeTX) { 136 mtx, ok := tx.(*logTreeTX) 137 if !ok { 138 panic(nil) 139 } 140 mtx.treeTX.writeRevision = rev 141 } 142 143 func createSomeNodes() []storage.Node { 144 r := make([]storage.Node, 4) 145 for i := range r { 146 r[i].NodeID = storage.NewNodeIDWithPrefix(uint64(i), 8, 8, 8) 147 h := sha256.Sum256([]byte{byte(i)}) 148 r[i].Hash = h[:] 149 glog.Infof("Node to store: %v\n", r[i].NodeID) 150 } 151 return r 152 } 153 154 func createLogNodesForTreeAtSize(ts, rev int64) ([]storage.Node, error) { 155 tree := merkle.NewCompactMerkleTree(rfc6962.New(crypto.SHA256)) 156 nodeMap := make(map[string]storage.Node) 157 for l := 0; l < int(ts); l++ { 158 // We're only interested in the side effects of adding leaves - the node updates 159 if _, _, err := tree.AddLeaf([]byte(fmt.Sprintf("Leaf %d", l)), func(depth int, index int64, hash []byte) error { 160 nID, err := storage.NewNodeIDForTreeCoords(int64(depth), index, 64) 161 if err != nil { 162 return fmt.Errorf("failed to create a nodeID for tree - should not happen d:%d i:%d", 163 depth, index) 164 } 165 166 nodeMap[nID.String()] = storage.Node{NodeID: nID, NodeRevision: rev, Hash: hash} 167 return nil 168 }); err != nil { 169 return nil, err 170 } 171 } 172 173 // Unroll the map, which has deduped the updates for us and retained the latest 174 nodes := make([]storage.Node, 0, len(nodeMap)) 175 for _, v := range nodeMap { 176 nodes = append(nodes, v) 177 } 178 179 return nodes, nil 180 } 181 182 func nodesAreEqual(lhs []storage.Node, rhs []storage.Node) error { 183 if ls, rs := len(lhs), len(rhs); ls != rs { 184 return fmt.Errorf("different number of nodes, %d vs %d", ls, rs) 185 } 186 for i := range lhs { 187 if l, r := lhs[i].NodeID.String(), rhs[i].NodeID.String(); l != r { 188 return fmt.Errorf("NodeIDs are not the same,\nlhs = %v,\nrhs = %v", l, r) 189 } 190 if l, r := lhs[i].Hash, rhs[i].Hash; !bytes.Equal(l, r) { 191 return fmt.Errorf("Hashes are not the same for %s,\nlhs = %v,\nrhs = %v", lhs[i].NodeID.CoordString(), l, r) 192 } 193 } 194 return nil 195 } 196 197 func diffNodes(got, want []storage.Node) ([]storage.Node, []storage.Node) { 198 var missing []storage.Node 199 gotMap := make(map[string]storage.Node) 200 for _, n := range got { 201 gotMap[n.NodeID.String()] = n 202 } 203 for _, n := range want { 204 _, ok := gotMap[n.NodeID.String()] 205 if !ok { 206 missing = append(missing, n) 207 } 208 delete(gotMap, n.NodeID.String()) 209 } 210 // Unpack the extra nodes to return both as slices 211 extra := make([]storage.Node, 0, len(gotMap)) 212 for _, v := range gotMap { 213 extra = append(extra, v) 214 } 215 return missing, extra 216 } 217 218 func openTestDBOrDie() *sql.DB { 219 db, err := testdb.NewTrillianDB(context.TODO()) 220 if err != nil { 221 panic(err) 222 } 223 return db 224 } 225 226 // cleanTestDB deletes all the entries in the database. 227 func cleanTestDB(db *sql.DB) { 228 for _, table := range allTables { 229 if _, err := db.ExecContext(context.TODO(), fmt.Sprintf("DELETE FROM %s", table)); err != nil { 230 panic(fmt.Sprintf("Failed to delete rows in %s: %v", table, err)) 231 } 232 } 233 } 234 235 func createFakeSignedLogRoot(db *sql.DB, tree *trillian.Tree, treeSize uint64) { 236 signer := tcrypto.NewSigner(0, testonly.NewSignerWithFixedSig(nil, []byte("notnil")), crypto.SHA256) 237 238 ctx := context.Background() 239 l := NewLogStorage(db, nil) 240 err := l.ReadWriteTransaction(ctx, tree, func(ctx context.Context, tx storage.LogTreeTX) error { 241 root, err := signer.SignLogRoot(&types.LogRootV1{TreeSize: treeSize, RootHash: []byte{0}}) 242 if err != nil { 243 return fmt.Errorf("Error creating new SignedLogRoot: %v", err) 244 } 245 if err := tx.StoreSignedLogRoot(ctx, *root); err != nil { 246 return fmt.Errorf("Error storing new SignedLogRoot: %v", err) 247 } 248 return nil 249 }) 250 if err != nil { 251 panic(fmt.Sprintf("ReadWriteTransaction() = %v", err)) 252 } 253 } 254 255 // createTree creates the specified tree using AdminStorage. 256 func createTree(db *sql.DB, tree *trillian.Tree) (*trillian.Tree, error) { 257 ctx := context.Background() 258 s := NewAdminStorage(db) 259 tree, err := storage.CreateTree(ctx, s, tree) 260 if err != nil { 261 return nil, err 262 } 263 return tree, nil 264 } 265 266 func createTreeOrPanic(db *sql.DB, create *trillian.Tree) *trillian.Tree { 267 tree, err := createTree(db, create) 268 if err != nil { 269 panic(fmt.Sprintf("Error creating tree: %v", err)) 270 } 271 return tree 272 } 273 274 // updateTree updates the specified tree using AdminStorage. 275 func updateTree(db *sql.DB, treeID int64, updateFn func(*trillian.Tree)) (*trillian.Tree, error) { 276 ctx := context.Background() 277 s := NewAdminStorage(db) 278 return storage.UpdateTree(ctx, s, treeID, updateFn) 279 } 280 281 // DB is the database used for tests. It's initialized and closed by TestMain(). 282 var DB *sql.DB 283 284 func TestMain(m *testing.M) { 285 flag.Parse() 286 if !testdb.MySQLAvailable() { 287 glog.Errorf("MySQL not available, skipping all MySQL storage tests") 288 return 289 } 290 DB = openTestDBOrDie() 291 defer DB.Close() 292 cleanTestDB(DB) 293 ec := m.Run() 294 os.Exit(ec) 295 }