github.com/bartle-stripe/trillian@v1.2.1/storage/mysql/admin_storage.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"fmt"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/golang/glog"
    25  	"github.com/golang/protobuf/proto"
    26  	"github.com/golang/protobuf/ptypes"
    27  	"github.com/golang/protobuf/ptypes/any"
    28  	"github.com/google/trillian"
    29  	"github.com/google/trillian/crypto/keyspb"
    30  	spb "github.com/google/trillian/crypto/sigpb"
    31  	"github.com/google/trillian/storage"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/status"
    34  )
    35  
    36  const (
    37  	defaultSequenceIntervalSeconds = 60
    38  
    39  	nonDeletedWhere = " WHERE (Deleted IS NULL OR Deleted = 'false')"
    40  
    41  	selectTreeIDs           = "SELECT TreeId FROM Trees"
    42  	selectNonDeletedTreeIDs = selectTreeIDs + nonDeletedWhere
    43  
    44  	selectTrees = `
    45  		SELECT
    46  			TreeId,
    47  			TreeState,
    48  			TreeType,
    49  			HashStrategy,
    50  			HashAlgorithm,
    51  			SignatureAlgorithm,
    52  			DisplayName,
    53  			Description,
    54  			CreateTimeMillis,
    55  			UpdateTimeMillis,
    56  			PrivateKey,
    57  			PublicKey,
    58  			MaxRootDurationMillis,
    59  			Deleted,
    60  			DeleteTimeMillis
    61  		FROM Trees`
    62  	selectNonDeletedTrees = selectTrees + nonDeletedWhere
    63  	selectTreeByID        = selectTrees + " WHERE TreeId = ?"
    64  
    65  	updateTreeSQL = `UPDATE Trees
    66  		SET TreeState = ?, TreeType = ?, DisplayName = ?, Description = ?, UpdateTimeMillis = ?, MaxRootDurationMillis = ?, PrivateKey = ?
    67  		WHERE TreeId = ?`
    68  )
    69  
    70  // NewAdminStorage returns a MySQL storage.AdminStorage implementation backed by DB.
    71  func NewAdminStorage(db *sql.DB) storage.AdminStorage {
    72  	return &mysqlAdminStorage{db}
    73  }
    74  
    75  // mysqlAdminStorage implements storage.AdminStorage
    76  type mysqlAdminStorage struct {
    77  	db *sql.DB
    78  }
    79  
    80  func (s *mysqlAdminStorage) Snapshot(ctx context.Context) (storage.ReadOnlyAdminTX, error) {
    81  	return s.beginInternal(ctx)
    82  }
    83  
    84  func (s *mysqlAdminStorage) beginInternal(ctx context.Context) (storage.AdminTX, error) {
    85  	tx, err := s.db.BeginTx(ctx, nil /* opts */)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	return &adminTX{tx: tx}, nil
    90  }
    91  
    92  func (s *mysqlAdminStorage) ReadWriteTransaction(ctx context.Context, f storage.AdminTXFunc) error {
    93  	tx, err := s.beginInternal(ctx)
    94  	if err != nil {
    95  		return err
    96  	}
    97  	defer tx.Close()
    98  	if err := f(ctx, tx); err != nil {
    99  		return err
   100  	}
   101  	return tx.Commit()
   102  }
   103  
   104  func (s *mysqlAdminStorage) CheckDatabaseAccessible(ctx context.Context) error {
   105  	return s.db.PingContext(ctx)
   106  }
   107  
   108  type adminTX struct {
   109  	tx *sql.Tx
   110  
   111  	// mu guards *direct* reads/writes on closed, which happen only on
   112  	// Commit/Rollback/IsClosed/Close methods.
   113  	// We don't check closed on *all* methods (apart from the ones above),
   114  	// as we trust tx to keep tabs on its state (and consequently fail to do
   115  	// queries after closed).
   116  	mu     sync.RWMutex
   117  	closed bool
   118  }
   119  
   120  func (t *adminTX) Commit() error {
   121  	t.mu.Lock()
   122  	defer t.mu.Unlock()
   123  	t.closed = true
   124  	return t.tx.Commit()
   125  }
   126  
   127  func (t *adminTX) Rollback() error {
   128  	t.mu.Lock()
   129  	defer t.mu.Unlock()
   130  	t.closed = true
   131  	return t.tx.Rollback()
   132  }
   133  
   134  func (t *adminTX) IsClosed() bool {
   135  	t.mu.RLock()
   136  	defer t.mu.RUnlock()
   137  	return t.closed
   138  }
   139  
   140  func (t *adminTX) Close() error {
   141  	// Acquire and release read lock manually, without defer, as if the txn
   142  	// is not closed Rollback() will attempt to acquire the rw lock.
   143  	t.mu.RLock()
   144  	closed := t.closed
   145  	t.mu.RUnlock()
   146  	if !closed {
   147  		err := t.Rollback()
   148  		if err != nil {
   149  			glog.Warningf("Rollback error on Close(): %v", err)
   150  		}
   151  		return err
   152  	}
   153  	return nil
   154  }
   155  
   156  func (t *adminTX) GetTree(ctx context.Context, treeID int64) (*trillian.Tree, error) {
   157  	stmt, err := t.tx.PrepareContext(ctx, selectTreeByID)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  	defer stmt.Close()
   162  
   163  	// GetTree is an entry point for most RPCs, let's provide somewhat nicer error messages.
   164  	tree, err := readTree(stmt.QueryRowContext(ctx, treeID))
   165  	switch {
   166  	case err == sql.ErrNoRows:
   167  		// ErrNoRows doesn't provide useful information, so we don't forward it.
   168  		return nil, status.Errorf(codes.NotFound, "tree %v not found", treeID)
   169  	case err != nil:
   170  		return nil, fmt.Errorf("error reading tree %v: %v", treeID, err)
   171  	}
   172  	return tree, nil
   173  }
   174  
   175  // There's no common interface between sql.Row and sql.Rows(!), so we have to
   176  // define one.
   177  type row interface {
   178  	Scan(dest ...interface{}) error
   179  }
   180  
   181  func readTree(row row) (*trillian.Tree, error) {
   182  	tree := &trillian.Tree{}
   183  
   184  	// Enums and Datetimes need an extra conversion step
   185  	var treeState, treeType, hashStrategy, hashAlgorithm, signatureAlgorithm string
   186  	var createMillis, updateMillis, maxRootDurationMillis int64
   187  	var displayName, description sql.NullString
   188  	var privateKey, publicKey []byte
   189  	var deleted sql.NullBool
   190  	var deleteMillis sql.NullInt64
   191  	err := row.Scan(
   192  		&tree.TreeId,
   193  		&treeState,
   194  		&treeType,
   195  		&hashStrategy,
   196  		&hashAlgorithm,
   197  		&signatureAlgorithm,
   198  		&displayName,
   199  		&description,
   200  		&createMillis,
   201  		&updateMillis,
   202  		&privateKey,
   203  		&publicKey,
   204  		&maxRootDurationMillis,
   205  		&deleted,
   206  		&deleteMillis,
   207  	)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  
   212  	setNullStringIfValid(displayName, &tree.DisplayName)
   213  	setNullStringIfValid(description, &tree.Description)
   214  
   215  	// Convert all things!
   216  	if ts, ok := trillian.TreeState_value[treeState]; ok {
   217  		tree.TreeState = trillian.TreeState(ts)
   218  	} else {
   219  		return nil, fmt.Errorf("unknown TreeState: %v", treeState)
   220  	}
   221  	if tt, ok := trillian.TreeType_value[treeType]; ok {
   222  		tree.TreeType = trillian.TreeType(tt)
   223  	} else {
   224  		return nil, fmt.Errorf("unknown TreeType: %v", treeType)
   225  	}
   226  	if hs, ok := trillian.HashStrategy_value[hashStrategy]; ok {
   227  		tree.HashStrategy = trillian.HashStrategy(hs)
   228  	} else {
   229  		return nil, fmt.Errorf("unknown HashStrategy: %v", hashStrategy)
   230  	}
   231  	if ha, ok := spb.DigitallySigned_HashAlgorithm_value[hashAlgorithm]; ok {
   232  		tree.HashAlgorithm = spb.DigitallySigned_HashAlgorithm(ha)
   233  	} else {
   234  		return nil, fmt.Errorf("unknown HashAlgorithm: %v", hashAlgorithm)
   235  	}
   236  	if sa, ok := spb.DigitallySigned_SignatureAlgorithm_value[signatureAlgorithm]; ok {
   237  		tree.SignatureAlgorithm = spb.DigitallySigned_SignatureAlgorithm(sa)
   238  	} else {
   239  		return nil, fmt.Errorf("unknown SignatureAlgorithm: %v", signatureAlgorithm)
   240  	}
   241  
   242  	// Let's make sure we didn't mismatch any of the casts above
   243  	ok := tree.TreeState.String() == treeState
   244  	ok = ok && tree.TreeType.String() == treeType
   245  	ok = ok && tree.HashStrategy.String() == hashStrategy
   246  	ok = ok && tree.HashAlgorithm.String() == hashAlgorithm
   247  	ok = ok && tree.SignatureAlgorithm.String() == signatureAlgorithm
   248  	if !ok {
   249  		return nil, fmt.Errorf(
   250  			"mismatched enum: tree = %v, enums = [%v, %v, %v, %v, %v]",
   251  			tree,
   252  			treeState, treeType, hashStrategy, hashAlgorithm, signatureAlgorithm)
   253  	}
   254  
   255  	tree.CreateTime, err = ptypes.TimestampProto(fromMillisSinceEpoch(createMillis))
   256  	if err != nil {
   257  		return nil, fmt.Errorf("failed to parse create time: %v", err)
   258  	}
   259  	tree.UpdateTime, err = ptypes.TimestampProto(fromMillisSinceEpoch(updateMillis))
   260  	if err != nil {
   261  		return nil, fmt.Errorf("failed to parse update time: %v", err)
   262  	}
   263  	tree.MaxRootDuration = ptypes.DurationProto(time.Duration(maxRootDurationMillis * int64(time.Millisecond)))
   264  
   265  	tree.PrivateKey = &any.Any{}
   266  	if err := proto.Unmarshal(privateKey, tree.PrivateKey); err != nil {
   267  		return nil, fmt.Errorf("could not unmarshal PrivateKey: %v", err)
   268  	}
   269  	tree.PublicKey = &keyspb.PublicKey{Der: publicKey}
   270  
   271  	tree.Deleted = deleted.Valid && deleted.Bool
   272  	if tree.Deleted && deleteMillis.Valid {
   273  		tree.DeleteTime, err = ptypes.TimestampProto(fromMillisSinceEpoch(deleteMillis.Int64))
   274  		if err != nil {
   275  			return nil, fmt.Errorf("failed to parse delete time: %v", err)
   276  		}
   277  	}
   278  
   279  	return tree, nil
   280  }
   281  
   282  // setNullStringIfValid assigns src to dest if src is Valid.
   283  func setNullStringIfValid(src sql.NullString, dest *string) {
   284  	if src.Valid {
   285  		*dest = src.String
   286  	}
   287  }
   288  
   289  func (t *adminTX) ListTreeIDs(ctx context.Context, includeDeleted bool) ([]int64, error) {
   290  	var query string
   291  	if includeDeleted {
   292  		query = selectTreeIDs
   293  	} else {
   294  		query = selectNonDeletedTreeIDs
   295  	}
   296  
   297  	stmt, err := t.tx.PrepareContext(ctx, query)
   298  	if err != nil {
   299  		return nil, err
   300  	}
   301  	defer stmt.Close()
   302  
   303  	rows, err := stmt.QueryContext(ctx)
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  	defer rows.Close()
   308  
   309  	treeIDs := []int64{}
   310  	var treeID int64
   311  	for rows.Next() {
   312  		if err := rows.Scan(&treeID); err != nil {
   313  			return nil, err
   314  		}
   315  		treeIDs = append(treeIDs, treeID)
   316  	}
   317  	return treeIDs, nil
   318  }
   319  
   320  func (t *adminTX) ListTrees(ctx context.Context, includeDeleted bool) ([]*trillian.Tree, error) {
   321  	var query string
   322  	if includeDeleted {
   323  		query = selectTrees
   324  	} else {
   325  		query = selectNonDeletedTrees
   326  	}
   327  
   328  	stmt, err := t.tx.PrepareContext(ctx, query)
   329  	if err != nil {
   330  		return nil, err
   331  	}
   332  	defer stmt.Close()
   333  	rows, err := stmt.QueryContext(ctx)
   334  	if err != nil {
   335  		return nil, err
   336  	}
   337  	defer rows.Close()
   338  	trees := []*trillian.Tree{}
   339  	for rows.Next() {
   340  		tree, err := readTree(rows)
   341  		if err != nil {
   342  			return nil, err
   343  		}
   344  		trees = append(trees, tree)
   345  	}
   346  	return trees, nil
   347  }
   348  
   349  func (t *adminTX) CreateTree(ctx context.Context, tree *trillian.Tree) (*trillian.Tree, error) {
   350  	if err := storage.ValidateTreeForCreation(ctx, tree); err != nil {
   351  		return nil, err
   352  	}
   353  	if err := validateStorageSettings(tree); err != nil {
   354  		return nil, err
   355  	}
   356  
   357  	id, err := storage.NewTreeID()
   358  	if err != nil {
   359  		return nil, err
   360  	}
   361  
   362  	// Use the time truncated-to-millis throughout, as that's what's stored.
   363  	nowMillis := toMillisSinceEpoch(time.Now())
   364  	now := fromMillisSinceEpoch(nowMillis)
   365  
   366  	newTree := *tree
   367  	newTree.TreeId = id
   368  	newTree.CreateTime, err = ptypes.TimestampProto(now)
   369  	if err != nil {
   370  		return nil, fmt.Errorf("failed to build create time: %v", err)
   371  	}
   372  	newTree.UpdateTime, err = ptypes.TimestampProto(now)
   373  	if err != nil {
   374  		return nil, fmt.Errorf("failed to build update time: %v", err)
   375  	}
   376  	rootDuration, err := ptypes.Duration(newTree.MaxRootDuration)
   377  	if err != nil {
   378  		return nil, fmt.Errorf("could not parse MaxRootDuration: %v", err)
   379  	}
   380  
   381  	insertTreeStmt, err := t.tx.PrepareContext(
   382  		ctx,
   383  		`INSERT INTO Trees(
   384  			TreeId,
   385  			TreeState,
   386  			TreeType,
   387  			HashStrategy,
   388  			HashAlgorithm,
   389  			SignatureAlgorithm,
   390  			DisplayName,
   391  			Description,
   392  			CreateTimeMillis,
   393  			UpdateTimeMillis,
   394  			PrivateKey,
   395  			PublicKey,
   396  			MaxRootDurationMillis)
   397  		VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
   398  	if err != nil {
   399  		return nil, err
   400  	}
   401  	defer insertTreeStmt.Close()
   402  
   403  	privateKey, err := proto.Marshal(newTree.PrivateKey)
   404  	if err != nil {
   405  		return nil, fmt.Errorf("could not marshal PrivateKey: %v", err)
   406  	}
   407  
   408  	_, err = insertTreeStmt.ExecContext(
   409  		ctx,
   410  		newTree.TreeId,
   411  		newTree.TreeState.String(),
   412  		newTree.TreeType.String(),
   413  		newTree.HashStrategy.String(),
   414  		newTree.HashAlgorithm.String(),
   415  		newTree.SignatureAlgorithm.String(),
   416  		newTree.DisplayName,
   417  		newTree.Description,
   418  		nowMillis,
   419  		nowMillis,
   420  		privateKey,
   421  		newTree.PublicKey.GetDer(),
   422  		rootDuration/time.Millisecond,
   423  	)
   424  	if err != nil {
   425  		return nil, err
   426  	}
   427  
   428  	// MySQL silently truncates data when running in non-strict mode.
   429  	// We shouldn't be using non-strict modes, but let's guard against it
   430  	// anyway.
   431  	if _, err := t.GetTree(ctx, newTree.TreeId); err != nil {
   432  		// GetTree will fail for truncated enums (they get recorded as
   433  		// empty strings, which will not match any known value).
   434  		return nil, fmt.Errorf("enum truncated: %v", err)
   435  	}
   436  
   437  	insertControlStmt, err := t.tx.PrepareContext(
   438  		ctx,
   439  		`INSERT INTO TreeControl(
   440  			TreeId,
   441  			SigningEnabled,
   442  			SequencingEnabled,
   443  			SequenceIntervalSeconds)
   444  		VALUES(?, ?, ?, ?)`)
   445  	if err != nil {
   446  		return nil, err
   447  	}
   448  	defer insertControlStmt.Close()
   449  	_, err = insertControlStmt.ExecContext(
   450  		ctx,
   451  		newTree.TreeId,
   452  		true, /* SigningEnabled */
   453  		true, /* SequencingEnabled */
   454  		defaultSequenceIntervalSeconds,
   455  	)
   456  	if err != nil {
   457  		return nil, err
   458  	}
   459  
   460  	return &newTree, nil
   461  }
   462  
   463  func (t *adminTX) UpdateTree(ctx context.Context, treeID int64, updateFunc func(*trillian.Tree)) (*trillian.Tree, error) {
   464  	tree, err := t.GetTree(ctx, treeID)
   465  	if err != nil {
   466  		return nil, err
   467  	}
   468  
   469  	beforeUpdate := *tree
   470  	updateFunc(tree)
   471  	if err := storage.ValidateTreeForUpdate(ctx, &beforeUpdate, tree); err != nil {
   472  		return nil, err
   473  	}
   474  	if err := validateStorageSettings(tree); err != nil {
   475  		return nil, err
   476  	}
   477  
   478  	// TODO(pavelkalinnikov): When switching TreeType from PREORDERED_LOG to LOG,
   479  	// ensure all entries in SequencedLeafData are integrated.
   480  
   481  	// Use the time truncated-to-millis throughout, as that's what's stored.
   482  	nowMillis := toMillisSinceEpoch(time.Now())
   483  	now := fromMillisSinceEpoch(nowMillis)
   484  	tree.UpdateTime, err = ptypes.TimestampProto(now)
   485  	if err != nil {
   486  		return nil, fmt.Errorf("failed to build update time: %v", err)
   487  	}
   488  	rootDuration, err := ptypes.Duration(tree.MaxRootDuration)
   489  	if err != nil {
   490  		return nil, fmt.Errorf("could not parse MaxRootDuration: %v", err)
   491  	}
   492  
   493  	privateKey, err := proto.Marshal(tree.PrivateKey)
   494  	if err != nil {
   495  		return nil, fmt.Errorf("could not marshal PrivateKey: %v", err)
   496  	}
   497  
   498  	stmt, err := t.tx.PrepareContext(ctx, updateTreeSQL)
   499  	if err != nil {
   500  		return nil, err
   501  	}
   502  	defer stmt.Close()
   503  
   504  	if _, err = stmt.ExecContext(
   505  		ctx,
   506  		tree.TreeState.String(),
   507  		tree.TreeType.String(),
   508  		tree.DisplayName,
   509  		tree.Description,
   510  		nowMillis,
   511  		rootDuration/time.Millisecond,
   512  		privateKey,
   513  		tree.TreeId); err != nil {
   514  		return nil, err
   515  	}
   516  
   517  	return tree, nil
   518  }
   519  
   520  func (t *adminTX) SoftDeleteTree(ctx context.Context, treeID int64) (*trillian.Tree, error) {
   521  	return t.updateDeleted(ctx, treeID, true /* deleted */, toMillisSinceEpoch(time.Now()) /* deleteTimeMillis */)
   522  }
   523  
   524  func (t *adminTX) UndeleteTree(ctx context.Context, treeID int64) (*trillian.Tree, error) {
   525  	return t.updateDeleted(ctx, treeID, false /* deleted */, nil /* deleteTimeMillis */)
   526  }
   527  
   528  // updateDeleted updates the Deleted and DeleteTimeMillis fields of the specified tree.
   529  // deleteTimeMillis must be either an int64 (in millis since epoch) or nil.
   530  func (t *adminTX) updateDeleted(ctx context.Context, treeID int64, deleted bool, deleteTimeMillis interface{}) (*trillian.Tree, error) {
   531  	if err := validateDeleted(ctx, t.tx, treeID, !deleted); err != nil {
   532  		return nil, err
   533  	}
   534  	if _, err := t.tx.ExecContext(
   535  		ctx,
   536  		"UPDATE Trees SET Deleted = ?, DeleteTimeMillis = ? WHERE TreeId = ?",
   537  		deleted, deleteTimeMillis, treeID); err != nil {
   538  		return nil, err
   539  	}
   540  	return t.GetTree(ctx, treeID)
   541  }
   542  
   543  func (t *adminTX) HardDeleteTree(ctx context.Context, treeID int64) error {
   544  	if err := validateDeleted(ctx, t.tx, treeID, true /* wantDeleted */); err != nil {
   545  		return err
   546  	}
   547  
   548  	// TreeControl didn't have "ON DELETE CASCADE" on previous versions, so let's hit it explicitly
   549  	if _, err := t.tx.ExecContext(ctx, "DELETE FROM TreeControl WHERE TreeId = ?", treeID); err != nil {
   550  		return err
   551  	}
   552  	_, err := t.tx.ExecContext(ctx, "DELETE FROM Trees WHERE TreeId = ?", treeID)
   553  	return err
   554  }
   555  
   556  func validateDeleted(ctx context.Context, tx *sql.Tx, treeID int64, wantDeleted bool) error {
   557  	var nullDeleted sql.NullBool
   558  	switch err := tx.QueryRowContext(ctx, "SELECT Deleted FROM Trees WHERE TreeId = ?", treeID).Scan(&nullDeleted); {
   559  	case err == sql.ErrNoRows:
   560  		return status.Errorf(codes.NotFound, "tree %v not found", treeID)
   561  	case err != nil:
   562  		return err
   563  	}
   564  
   565  	switch deleted := nullDeleted.Valid && nullDeleted.Bool; {
   566  	case wantDeleted && !deleted:
   567  		return status.Errorf(codes.FailedPrecondition, "tree %v is not soft deleted", treeID)
   568  	case !wantDeleted && deleted:
   569  		return status.Errorf(codes.FailedPrecondition, "tree %v already soft deleted", treeID)
   570  	}
   571  	return nil
   572  }
   573  
   574  func toMillisSinceEpoch(t time.Time) int64 {
   575  	return t.UnixNano() / 1000000
   576  }
   577  
   578  func fromMillisSinceEpoch(ts int64) time.Time {
   579  	secs := int64(ts / 1000)
   580  	msecs := int64(ts % 1000)
   581  	return time.Unix(secs, msecs*1000000)
   582  }
   583  
   584  func validateStorageSettings(tree *trillian.Tree) error {
   585  	if tree.StorageSettings != nil {
   586  		return fmt.Errorf("storage_settings not supported, but got %v", tree.StorageSettings)
   587  	}
   588  	return nil
   589  }