github.com/bartle-stripe/trillian@v1.2.1/storage/mysql/tree_storage.go (about)

     1  // Copyright 2016 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package mysql provides a MySQL-based storage layer implementation.
    16  package mysql
    17  
    18  import (
    19  	"context"
    20  	"database/sql"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"runtime/debug"
    24  	"strings"
    25  	"sync"
    26  
    27  	"github.com/golang/glog"
    28  	"github.com/golang/protobuf/proto"
    29  	"github.com/google/trillian"
    30  	"github.com/google/trillian/storage"
    31  	"github.com/google/trillian/storage/cache"
    32  	"github.com/google/trillian/storage/storagepb"
    33  )
    34  
    35  // These statements are fixed
    36  const (
    37  	insertSubtreeMultiSQL = `INSERT INTO Subtree(TreeId, SubtreeId, Nodes, SubtreeRevision) ` + placeholderSQL
    38  	insertTreeHeadSQL     = `INSERT INTO TreeHead(TreeId,TreeHeadTimestamp,TreeSize,RootHash,TreeRevision,RootSignature)
    39  		 VALUES(?,?,?,?,?,?)`
    40  	selectTreeRevisionAtSizeOrLargerSQL = "SELECT TreeRevision,TreeSize FROM TreeHead WHERE TreeId=? AND TreeSize>=? ORDER BY TreeRevision LIMIT 1"
    41  
    42  	selectSubtreeSQL = `
    43   SELECT x.SubtreeId, x.MaxRevision, Subtree.Nodes
    44   FROM (
    45   	SELECT n.SubtreeId, max(n.SubtreeRevision) AS MaxRevision
    46  	FROM Subtree n
    47  	WHERE n.SubtreeId IN (` + placeholderSQL + `) AND
    48  	 n.TreeId = ? AND n.SubtreeRevision <= ?
    49  	GROUP BY n.SubtreeId
    50   ) AS x
    51   INNER JOIN Subtree 
    52   ON Subtree.SubtreeId = x.SubtreeId 
    53   AND Subtree.SubtreeRevision = x.MaxRevision 
    54   AND Subtree.TreeId = ?`
    55  	placeholderSQL = "<placeholder>"
    56  )
    57  
    58  // mySQLTreeStorage is shared between the mySQLLog- and (forthcoming) mySQLMap-
    59  // Storage implementations, and contains functionality which is common to both,
    60  type mySQLTreeStorage struct {
    61  	db *sql.DB
    62  
    63  	// Must hold the mutex before manipulating the statement map. Sharing a lock because
    64  	// it only needs to be held while the statements are built, not while they execute and
    65  	// this will be a short time. These maps are from the number of placeholder '?'
    66  	// in the query to the statement that should be used.
    67  	statementMutex sync.Mutex
    68  	statements     map[string]map[int]*sql.Stmt
    69  }
    70  
    71  // OpenDB opens a database connection for all MySQL-based storage implementations.
    72  func OpenDB(dbURL string) (*sql.DB, error) {
    73  	db, err := sql.Open("mysql", dbURL)
    74  	if err != nil {
    75  		// Don't log uri as it could contain credentials
    76  		glog.Warningf("Could not open MySQL database, check config: %s", err)
    77  		return nil, err
    78  	}
    79  
    80  	if _, err := db.ExecContext(context.TODO(), "SET sql_mode = 'STRICT_ALL_TABLES'"); err != nil {
    81  		glog.Warningf("Failed to set strict mode on mysql db: %s", err)
    82  		return nil, err
    83  	}
    84  
    85  	return db, nil
    86  }
    87  
    88  func newTreeStorage(db *sql.DB) *mySQLTreeStorage {
    89  	return &mySQLTreeStorage{
    90  		db:         db,
    91  		statements: make(map[string]map[int]*sql.Stmt),
    92  	}
    93  }
    94  
    95  // expandPlaceholderSQL expands an sql statement by adding a specified number of '?'
    96  // placeholder slots. At most one placeholder will be expanded.
    97  func expandPlaceholderSQL(sql string, num int, first, rest string) string {
    98  	if num <= 0 {
    99  		panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", sql))
   100  	}
   101  
   102  	parameters := first + strings.Repeat(","+rest, num-1)
   103  
   104  	return strings.Replace(sql, placeholderSQL, parameters, 1)
   105  }
   106  
   107  // getStmt creates and caches sql.Stmt structs based on the passed in statement
   108  // and number of bound arguments.
   109  // TODO(al,martin): consider pulling this all out as a separate unit for reuse
   110  // elsewhere.
   111  func (m *mySQLTreeStorage) getStmt(ctx context.Context, statement string, num int, first, rest string) (*sql.Stmt, error) {
   112  	m.statementMutex.Lock()
   113  	defer m.statementMutex.Unlock()
   114  
   115  	if m.statements[statement] != nil {
   116  		if m.statements[statement][num] != nil {
   117  			// TODO(al,martin): we'll possibly need to expire Stmts from the cache,
   118  			// e.g. when DB connections break etc.
   119  			return m.statements[statement][num], nil
   120  		}
   121  	} else {
   122  		m.statements[statement] = make(map[int]*sql.Stmt)
   123  	}
   124  
   125  	s, err := m.db.PrepareContext(ctx, expandPlaceholderSQL(statement, num, first, rest))
   126  
   127  	if err != nil {
   128  		glog.Warningf("Failed to prepare statement %d: %s", num, err)
   129  		return nil, err
   130  	}
   131  
   132  	m.statements[statement][num] = s
   133  
   134  	return s, nil
   135  }
   136  
   137  func (m *mySQLTreeStorage) getSubtreeStmt(ctx context.Context, num int) (*sql.Stmt, error) {
   138  	return m.getStmt(ctx, selectSubtreeSQL, num, "?", "?")
   139  }
   140  
   141  func (m *mySQLTreeStorage) setSubtreeStmt(ctx context.Context, num int) (*sql.Stmt, error) {
   142  	return m.getStmt(ctx, insertSubtreeMultiSQL, num, "VALUES(?, ?, ?, ?)", "(?, ?, ?, ?)")
   143  }
   144  
   145  func (m *mySQLTreeStorage) beginTreeTx(ctx context.Context, tree *trillian.Tree, hashSizeBytes int, subtreeCache cache.SubtreeCache) (treeTX, error) {
   146  	t, err := m.db.BeginTx(ctx, nil /* opts */)
   147  	if err != nil {
   148  		glog.Warningf("Could not start tree TX: %s", err)
   149  		return treeTX{}, err
   150  	}
   151  	return treeTX{
   152  		tx:            t,
   153  		ts:            m,
   154  		treeID:        tree.TreeId,
   155  		treeType:      tree.TreeType,
   156  		hashSizeBytes: hashSizeBytes,
   157  		subtreeCache:  subtreeCache,
   158  		writeRevision: -1,
   159  	}, nil
   160  }
   161  
   162  type treeTX struct {
   163  	closed        bool
   164  	tx            *sql.Tx
   165  	ts            *mySQLTreeStorage
   166  	treeID        int64
   167  	treeType      trillian.TreeType
   168  	hashSizeBytes int
   169  	subtreeCache  cache.SubtreeCache
   170  	writeRevision int64
   171  }
   172  
   173  func (t *treeTX) getSubtree(ctx context.Context, treeRevision int64, nodeID storage.NodeID) (*storagepb.SubtreeProto, error) {
   174  	s, err := t.getSubtrees(ctx, treeRevision, []storage.NodeID{nodeID})
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	switch len(s) {
   179  	case 0:
   180  		return nil, nil
   181  	case 1:
   182  		return s[0], nil
   183  	default:
   184  		return nil, fmt.Errorf("got %d subtrees, but expected 1", len(s))
   185  	}
   186  }
   187  
   188  func (t *treeTX) getSubtrees(ctx context.Context, treeRevision int64, nodeIDs []storage.NodeID) ([]*storagepb.SubtreeProto, error) {
   189  	glog.V(4).Infof("getSubtrees(")
   190  	if len(nodeIDs) == 0 {
   191  		return nil, nil
   192  	}
   193  
   194  	tmpl, err := t.ts.getSubtreeStmt(ctx, len(nodeIDs))
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  	stx := t.tx.StmtContext(ctx, tmpl)
   199  	defer stx.Close()
   200  
   201  	args := make([]interface{}, 0, len(nodeIDs)+3)
   202  
   203  	// populate args with nodeIDs
   204  	for _, nodeID := range nodeIDs {
   205  		if nodeID.PrefixLenBits%8 != 0 {
   206  			return nil, fmt.Errorf("invalid subtree ID - not multiple of 8: %d", nodeID.PrefixLenBits)
   207  		}
   208  
   209  		nodeIDBytes := nodeID.Path[:nodeID.PrefixLenBits/8]
   210  		glog.V(4).Infof("  nodeID: %x", nodeIDBytes)
   211  
   212  		args = append(args, interface{}(nodeIDBytes))
   213  	}
   214  
   215  	args = append(args, interface{}(t.treeID))
   216  	args = append(args, interface{}(treeRevision))
   217  	args = append(args, interface{}(t.treeID))
   218  
   219  	rows, err := stx.QueryContext(ctx, args...)
   220  	if err != nil {
   221  		glog.Warningf("Failed to get merkle subtrees: %s", err)
   222  		return nil, err
   223  	}
   224  	defer rows.Close()
   225  
   226  	if rows.Err() != nil {
   227  		// Nothing from the DB
   228  		glog.Warningf("Nothing from DB: %s", rows.Err())
   229  		return nil, rows.Err()
   230  	}
   231  
   232  	ret := make([]*storagepb.SubtreeProto, 0, len(nodeIDs))
   233  
   234  	for rows.Next() {
   235  
   236  		var subtreeIDBytes []byte
   237  		var subtreeRev int64
   238  		var nodesRaw []byte
   239  		if err := rows.Scan(&subtreeIDBytes, &subtreeRev, &nodesRaw); err != nil {
   240  			glog.Warningf("Failed to scan merkle subtree: %s", err)
   241  			return nil, err
   242  		}
   243  		var subtree storagepb.SubtreeProto
   244  		if err := proto.Unmarshal(nodesRaw, &subtree); err != nil {
   245  			glog.Warningf("Failed to unmarshal SubtreeProto: %s", err)
   246  			return nil, err
   247  		}
   248  		if subtree.Prefix == nil {
   249  			subtree.Prefix = []byte{}
   250  		}
   251  		ret = append(ret, &subtree)
   252  
   253  		if glog.V(4) {
   254  			glog.Infof("  subtree: NID: %x, prefix: %x, depth: %d",
   255  				subtreeIDBytes, subtree.Prefix, subtree.Depth)
   256  			for k, v := range subtree.Leaves {
   257  				b, err := base64.StdEncoding.DecodeString(k)
   258  				if err != nil {
   259  					glog.Errorf("base64.DecodeString(%v): %v", k, err)
   260  				}
   261  				glog.Infof("     %x: %x", b, v)
   262  			}
   263  		}
   264  	}
   265  
   266  	// The InternalNodes cache is possibly nil here, but the SubtreeCache (which called
   267  	// this method) will re-populate it.
   268  	return ret, nil
   269  }
   270  
   271  func (t *treeTX) storeSubtrees(ctx context.Context, subtrees []*storagepb.SubtreeProto) error {
   272  	if glog.V(4) {
   273  		glog.Infof("storeSubtrees(")
   274  		for _, s := range subtrees {
   275  			glog.Infof("  prefix: %x, depth: %d", s.Prefix, s.Depth)
   276  			for k, v := range s.Leaves {
   277  				b, err := base64.StdEncoding.DecodeString(k)
   278  				if err != nil {
   279  					glog.Errorf("base64.DecodeString(%v): %v", k, err)
   280  				}
   281  				glog.Infof("     %x: %x", b, v)
   282  			}
   283  		}
   284  	}
   285  	if len(subtrees) == 0 {
   286  		glog.Warning("attempted to store 0 subtrees...")
   287  		return nil
   288  	}
   289  
   290  	// TODO(al): probably need to be able to batch this in the case where we have
   291  	// a really large number of subtrees to store.
   292  	args := make([]interface{}, 0, len(subtrees))
   293  
   294  	for _, s := range subtrees {
   295  		s := s
   296  		if s.Prefix == nil {
   297  			panic(fmt.Errorf("nil prefix on %v", s))
   298  		}
   299  		subtreeBytes, err := proto.Marshal(s)
   300  		if err != nil {
   301  			return err
   302  		}
   303  		args = append(args, t.treeID)
   304  		args = append(args, s.Prefix)
   305  		args = append(args, subtreeBytes)
   306  		args = append(args, t.writeRevision)
   307  	}
   308  
   309  	tmpl, err := t.ts.setSubtreeStmt(ctx, len(subtrees))
   310  	if err != nil {
   311  		return err
   312  	}
   313  	stx := t.tx.StmtContext(ctx, tmpl)
   314  	defer stx.Close()
   315  
   316  	r, err := stx.ExecContext(ctx, args...)
   317  	if err != nil {
   318  		glog.Warningf("Failed to set merkle subtrees: %s", err)
   319  		return err
   320  	}
   321  	_, _ = r.RowsAffected()
   322  	return nil
   323  }
   324  
   325  func checkResultOkAndRowCountIs(res sql.Result, err error, count int64) error {
   326  	// The Exec() might have just failed
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	// Otherwise we have to look at the result of the operation
   332  	rowsAffected, rowsError := res.RowsAffected()
   333  
   334  	if rowsError != nil {
   335  		return rowsError
   336  	}
   337  
   338  	if rowsAffected != count {
   339  		return fmt.Errorf("Expected %d row(s) to be affected but saw: %d", count,
   340  			rowsAffected)
   341  	}
   342  
   343  	return nil
   344  }
   345  
   346  // GetTreeRevisionAtSize returns the max node version for a tree at a particular size.
   347  // It is an error to request tree sizes larger than the currently published tree size.
   348  // For an inexact tree size this implementation always returns the next largest revision if an
   349  // exact one does not exist but it isn't required to do so.
   350  func (t *treeTX) GetTreeRevisionIncludingSize(ctx context.Context, treeSize int64) (int64, int64, error) {
   351  	// Negative size is not sensible and a zero sized tree has no nodes so no revisions
   352  	if treeSize <= 0 {
   353  		return 0, 0, fmt.Errorf("invalid tree size: %d", treeSize)
   354  	}
   355  
   356  	var treeRevision, actualTreeSize int64
   357  	err := t.tx.QueryRowContext(ctx, selectTreeRevisionAtSizeOrLargerSQL, t.treeID, treeSize).Scan(&treeRevision, &actualTreeSize)
   358  
   359  	return treeRevision, actualTreeSize, err
   360  }
   361  
   362  // getSubtreesAtRev returns a GetSubtreesFunc which reads at the passed in rev.
   363  func (t *treeTX) getSubtreesAtRev(ctx context.Context, rev int64) cache.GetSubtreesFunc {
   364  	return func(ids []storage.NodeID) ([]*storagepb.SubtreeProto, error) {
   365  		return t.getSubtrees(ctx, rev, ids)
   366  	}
   367  }
   368  
   369  // GetMerkleNodes returns the requests nodes at (or below) the passed in treeRevision.
   370  func (t *treeTX) GetMerkleNodes(ctx context.Context, treeRevision int64, nodeIDs []storage.NodeID) ([]storage.Node, error) {
   371  	return t.subtreeCache.GetNodes(nodeIDs, t.getSubtreesAtRev(ctx, treeRevision))
   372  }
   373  
   374  func (t *treeTX) SetMerkleNodes(ctx context.Context, nodes []storage.Node) error {
   375  	for _, n := range nodes {
   376  		err := t.subtreeCache.SetNodeHash(n.NodeID, n.Hash,
   377  			func(nID storage.NodeID) (*storagepb.SubtreeProto, error) {
   378  				return t.getSubtree(ctx, t.writeRevision, nID)
   379  			})
   380  		if err != nil {
   381  			return err
   382  		}
   383  	}
   384  	return nil
   385  }
   386  
   387  func (t *treeTX) Commit() error {
   388  	if t.writeRevision > -1 {
   389  		if err := t.subtreeCache.Flush(func(st []*storagepb.SubtreeProto) error {
   390  			return t.storeSubtrees(context.TODO(), st)
   391  		}); err != nil {
   392  			glog.Warningf("TX commit flush error: %v", err)
   393  			return err
   394  		}
   395  	}
   396  	t.closed = true
   397  	if err := t.tx.Commit(); err != nil {
   398  		glog.Warningf("TX commit error: %s, stack:\n%s", err, string(debug.Stack()))
   399  		return err
   400  	}
   401  	return nil
   402  }
   403  
   404  func (t *treeTX) Rollback() error {
   405  	t.closed = true
   406  	if err := t.tx.Rollback(); err != nil {
   407  		glog.Warningf("TX rollback error: %s, stack:\n%s", err, string(debug.Stack()))
   408  		return err
   409  	}
   410  	return nil
   411  }
   412  
   413  func (t *treeTX) Close() error {
   414  	if !t.closed {
   415  		err := t.Rollback()
   416  		if err != nil {
   417  			glog.Warningf("Rollback error on Close(): %v", err)
   418  		}
   419  		return err
   420  	}
   421  	return nil
   422  }
   423  
   424  func (t *treeTX) IsOpen() bool {
   425  	return !t.closed
   426  }