github.com/bartle-stripe/trillian@v1.2.1/storage/memory/log_storage.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package memory
    16  
    17  import (
    18  	"container/list"
    19  	"context"
    20  	"fmt"
    21  	"math"
    22  	"strconv"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/google/btree"
    27  	"github.com/google/trillian"
    28  	"github.com/google/trillian/merkle/hashers"
    29  	"github.com/google/trillian/monitoring"
    30  	"github.com/google/trillian/storage"
    31  	"github.com/google/trillian/storage/cache"
    32  	"github.com/google/trillian/types"
    33  	"google.golang.org/grpc/codes"
    34  	"google.golang.org/grpc/status"
    35  )
    36  
    37  const logIDLabel = "logid"
    38  
    39  var (
    40  	defaultLogStrata = []int{8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}
    41  
    42  	once            sync.Once
    43  	queuedCounter   monitoring.Counter
    44  	dequeuedCounter monitoring.Counter
    45  )
    46  
    47  func createMetrics(mf monitoring.MetricFactory) {
    48  	queuedCounter = mf.NewCounter("mem_queued_leaves", "Number of leaves queued", logIDLabel)
    49  	dequeuedCounter = mf.NewCounter("mem_dequeued_leaves", "Number of leaves dequeued", logIDLabel)
    50  }
    51  
    52  func labelForTX(t *logTreeTX) string {
    53  	return strconv.FormatInt(t.treeID, 10)
    54  }
    55  
    56  // unseqKey formats a key for use in a tree's BTree store.
    57  // The associated Item value will be a list of unsequenced entries.
    58  func unseqKey(treeID int64) btree.Item {
    59  	return &kv{k: fmt.Sprintf("/%d/unseq", treeID)}
    60  }
    61  
    62  // seqLeafKey formats a key for use in a tree's BTree store.
    63  // The associated Item value will be the leaf at the given sequence number.
    64  func seqLeafKey(treeID, seq int64) btree.Item {
    65  	return &kv{k: fmt.Sprintf("/%d/seq/%020d", treeID, seq)}
    66  }
    67  
    68  // hashToSeqKey formats a key for use in a tree's BTree store.
    69  // The associated Item value will be the sequence number for the leaf with
    70  // the given hash.
    71  func hashToSeqKey(treeID int64) btree.Item {
    72  	return &kv{k: fmt.Sprintf("/%d/h2s", treeID)}
    73  }
    74  
    75  // sthKey formats a key for use in a tree's BTree store.
    76  // The associated Item value will be the STH with the given timestamp.
    77  func sthKey(treeID int64, timestamp uint64) btree.Item {
    78  	return &kv{k: fmt.Sprintf("/%d/sth/%020d", treeID, timestamp)}
    79  }
    80  
    81  type memoryLogStorage struct {
    82  	*memoryTreeStorage
    83  	admin         storage.AdminStorage
    84  	metricFactory monitoring.MetricFactory
    85  }
    86  
    87  // NewLogStorage creates an in-memory LogStorage instance.
    88  func NewLogStorage(mf monitoring.MetricFactory) storage.LogStorage {
    89  	if mf == nil {
    90  		mf = monitoring.InertMetricFactory{}
    91  	}
    92  	ret := &memoryLogStorage{
    93  		memoryTreeStorage: newTreeStorage(),
    94  		metricFactory:     mf,
    95  	}
    96  	ret.admin = NewAdminStorage(ret)
    97  	return ret
    98  }
    99  
   100  func (m *memoryLogStorage) CheckDatabaseAccessible(ctx context.Context) error {
   101  	return nil
   102  }
   103  
   104  type readOnlyLogTX struct {
   105  	ms *memoryTreeStorage
   106  }
   107  
   108  func (m *memoryLogStorage) Snapshot(ctx context.Context) (storage.ReadOnlyLogTX, error) {
   109  	return &readOnlyLogTX{m.memoryTreeStorage}, nil
   110  }
   111  
   112  func (t *readOnlyLogTX) Commit() error {
   113  	return nil
   114  }
   115  
   116  func (t *readOnlyLogTX) Rollback() error {
   117  	return nil
   118  }
   119  
   120  func (t *readOnlyLogTX) Close() error {
   121  	return nil
   122  }
   123  
   124  func (t *readOnlyLogTX) GetActiveLogIDs(ctx context.Context) ([]int64, error) {
   125  	t.ms.mu.RLock()
   126  	defer t.ms.mu.RUnlock()
   127  
   128  	ret := make([]int64, 0, len(t.ms.trees))
   129  	for k := range t.ms.trees {
   130  		ret = append(ret, k)
   131  	}
   132  	return ret, nil
   133  }
   134  
   135  func (m *memoryLogStorage) beginInternal(ctx context.Context, tree *trillian.Tree, readonly bool) (storage.LogTreeTX, error) {
   136  	once.Do(func() {
   137  		createMetrics(m.metricFactory)
   138  	})
   139  	hasher, err := hashers.NewLogHasher(tree.HashStrategy)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  
   144  	stCache := cache.NewLogSubtreeCache(defaultLogStrata, hasher)
   145  	ttx, err := m.memoryTreeStorage.beginTreeTX(ctx, tree.TreeId, hasher.Size(), stCache, readonly)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	ltx := &logTreeTX{
   151  		treeTX: ttx,
   152  		ls:     m,
   153  	}
   154  
   155  	ltx.slr, err = ltx.fetchLatestRoot(ctx)
   156  	if err == storage.ErrTreeNeedsInit {
   157  		return ltx, err
   158  	} else if err != nil {
   159  		ttx.Rollback()
   160  		return nil, err
   161  	}
   162  
   163  	if err := ltx.root.UnmarshalBinary(ltx.slr.LogRoot); err != nil {
   164  		ttx.Rollback()
   165  		return nil, err
   166  	}
   167  
   168  	ltx.treeTX.writeRevision = int64(ltx.root.Revision) + 1
   169  
   170  	return ltx, nil
   171  }
   172  
   173  func (m *memoryLogStorage) ReadWriteTransaction(ctx context.Context, tree *trillian.Tree, f storage.LogTXFunc) error {
   174  	tx, err := m.beginInternal(ctx, tree, false /* readonly */)
   175  	if err != nil && err != storage.ErrTreeNeedsInit {
   176  		return err
   177  	}
   178  	defer tx.Close()
   179  	if err := f(ctx, tx); err != nil {
   180  		return err
   181  	}
   182  	return tx.Commit()
   183  }
   184  
   185  func (m *memoryLogStorage) AddSequencedLeaves(ctx context.Context, tree *trillian.Tree, leaves []*trillian.LogLeaf, timestamp time.Time) ([]*trillian.QueuedLogLeaf, error) {
   186  	return nil, status.Errorf(codes.Unimplemented, "AddSequencedLeaves is not implemented")
   187  }
   188  
   189  func (m *memoryLogStorage) SnapshotForTree(ctx context.Context, tree *trillian.Tree) (storage.ReadOnlyLogTreeTX, error) {
   190  	tx, err := m.beginInternal(ctx, tree, true /* readonly */)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  	return tx.(storage.ReadOnlyLogTreeTX), err
   195  }
   196  
   197  func (m *memoryLogStorage) QueueLeaves(ctx context.Context, tree *trillian.Tree, leaves []*trillian.LogLeaf, queueTimestamp time.Time) ([]*trillian.QueuedLogLeaf, error) {
   198  	tx, err := m.beginInternal(ctx, tree, false /* readonly */)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  	existing, err := tx.QueueLeaves(ctx, leaves, queueTimestamp)
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	if err := tx.Commit(); err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	ret := make([]*trillian.QueuedLogLeaf, len(leaves))
   212  	for i, e := range existing {
   213  		if e != nil {
   214  			ret[i] = &trillian.QueuedLogLeaf{
   215  				Leaf:   e,
   216  				Status: status.Newf(codes.AlreadyExists, "leaf already exists: %v", e.LeafIdentityHash).Proto(),
   217  			}
   218  			continue
   219  		}
   220  		ret[i] = &trillian.QueuedLogLeaf{Leaf: leaves[i]}
   221  	}
   222  	return ret, nil
   223  }
   224  
   225  type logTreeTX struct {
   226  	treeTX
   227  	ls   *memoryLogStorage
   228  	root types.LogRootV1
   229  	slr  trillian.SignedLogRoot
   230  }
   231  
   232  func (t *logTreeTX) ReadRevision() int64 {
   233  	return int64(t.root.Revision)
   234  }
   235  
   236  func (t *logTreeTX) WriteRevision() int64 {
   237  	return t.treeTX.writeRevision
   238  }
   239  
   240  func (t *logTreeTX) DequeueLeaves(ctx context.Context, limit int, cutoffTime time.Time) ([]*trillian.LogLeaf, error) {
   241  	leaves := make([]*trillian.LogLeaf, 0, limit)
   242  
   243  	q := t.tx.Get(unseqKey(t.treeID)).(*kv).v.(*list.List)
   244  	e := q.Front()
   245  	for i := 0; i < limit && e != nil; i++ {
   246  		// TODO(al): consider cutoffTime
   247  		leaves = append(leaves, e.Value.(*trillian.LogLeaf))
   248  		e = e.Next()
   249  	}
   250  
   251  	dequeuedCounter.Add(float64(len(leaves)), labelForTX(t))
   252  	return leaves, nil
   253  }
   254  
   255  func (t *logTreeTX) QueueLeaves(ctx context.Context, leaves []*trillian.LogLeaf, queueTimestamp time.Time) ([]*trillian.LogLeaf, error) {
   256  	// Don't accept batches if any of the leaves are invalid.
   257  	for _, leaf := range leaves {
   258  		if len(leaf.LeafIdentityHash) != t.hashSizeBytes {
   259  			return nil, fmt.Errorf("queued leaf must have a leaf ID hash of length %d", t.hashSizeBytes)
   260  		}
   261  	}
   262  	queuedCounter.Add(float64(len(leaves)), labelForTX(t))
   263  	// No deduping in this storage!
   264  	k := unseqKey(t.treeID)
   265  	q := t.tx.Get(k).(*kv).v.(*list.List)
   266  	for _, l := range leaves {
   267  		q.PushBack(l)
   268  	}
   269  	return make([]*trillian.LogLeaf, len(leaves)), nil
   270  }
   271  
   272  func (t *logTreeTX) AddSequencedLeaves(ctx context.Context, leaves []*trillian.LogLeaf, timestamp time.Time) ([]*trillian.QueuedLogLeaf, error) {
   273  	return nil, status.Errorf(codes.Unimplemented, "AddSequencedLeaves is not implemented")
   274  }
   275  
   276  func (t *logTreeTX) GetSequencedLeafCount(ctx context.Context) (int64, error) {
   277  	var sequencedLeafCount int64
   278  
   279  	t.tx.DescendRange(seqLeafKey(t.treeID, math.MaxInt64), seqLeafKey(t.treeID, 0), func(i btree.Item) bool {
   280  		sequencedLeafCount = i.(*kv).v.(*trillian.LogLeaf).LeafIndex + 1
   281  		return false
   282  	})
   283  	return sequencedLeafCount, nil
   284  }
   285  
   286  func (t *logTreeTX) GetLeavesByIndex(ctx context.Context, leaves []int64) ([]*trillian.LogLeaf, error) {
   287  	ret := make([]*trillian.LogLeaf, 0, len(leaves))
   288  	for _, seq := range leaves {
   289  		leaf := t.tx.Get(seqLeafKey(t.treeID, seq))
   290  		if leaf != nil {
   291  			ret = append(ret, leaf.(*kv).v.(*trillian.LogLeaf))
   292  		}
   293  	}
   294  	return ret, nil
   295  }
   296  
   297  func (t *logTreeTX) GetLeavesByRange(ctx context.Context, start, count int64) ([]*trillian.LogLeaf, error) {
   298  	ret := make([]*trillian.LogLeaf, 0, count)
   299  	for i := int64(0); i < count; i++ {
   300  		leaf := t.tx.Get(seqLeafKey(t.treeID, start+i))
   301  		if leaf != nil {
   302  			ret = append(ret, leaf.(*kv).v.(*trillian.LogLeaf))
   303  		}
   304  	}
   305  	return ret, nil
   306  }
   307  
   308  func (t *logTreeTX) GetLeavesByHash(ctx context.Context, leafHashes [][]byte, orderBySequence bool) ([]*trillian.LogLeaf, error) {
   309  	m := t.tx.Get(hashToSeqKey(t.treeID)).(*kv).v.(map[string][]int64)
   310  
   311  	ret := make([]*trillian.LogLeaf, 0, len(leafHashes))
   312  	for hash := range leafHashes {
   313  		seq, ok := m[string(hash)]
   314  		if !ok {
   315  			continue
   316  		}
   317  		for _, s := range seq {
   318  			l := t.tx.Get(seqLeafKey(t.treeID, s))
   319  			if l == nil {
   320  				continue
   321  			}
   322  			ret = append(ret, l.(*kv).v.(*trillian.LogLeaf))
   323  		}
   324  	}
   325  	return ret, nil
   326  }
   327  
   328  func (t *logTreeTX) LatestSignedLogRoot(ctx context.Context) (trillian.SignedLogRoot, error) {
   329  	return t.slr, nil
   330  }
   331  
   332  // fetchLatestRoot reads the latest SignedLogRoot from the DB and returns it.
   333  func (t *logTreeTX) fetchLatestRoot(ctx context.Context) (trillian.SignedLogRoot, error) {
   334  	r := t.tx.Get(sthKey(t.treeID, t.tree.currentSTH))
   335  	if r == nil {
   336  		return trillian.SignedLogRoot{}, storage.ErrTreeNeedsInit
   337  	}
   338  	return r.(*kv).v.(trillian.SignedLogRoot), nil
   339  }
   340  
   341  func (t *logTreeTX) StoreSignedLogRoot(ctx context.Context, slr trillian.SignedLogRoot) error {
   342  	var root types.LogRootV1
   343  	if err := root.UnmarshalBinary(slr.LogRoot); err != nil {
   344  		return err
   345  	}
   346  	k := sthKey(t.treeID, root.TimestampNanos)
   347  	k.(*kv).v = slr
   348  	t.tx.ReplaceOrInsert(k)
   349  
   350  	// TODO(alcutter): this breaks the transactional model
   351  	if root.TimestampNanos > t.tree.currentSTH {
   352  		t.tree.currentSTH = root.TimestampNanos
   353  	}
   354  	return nil
   355  }
   356  
   357  func (t *logTreeTX) UpdateSequencedLeaves(ctx context.Context, leaves []*trillian.LogLeaf) error {
   358  	countByMerkleHash := make(map[string]int)
   359  	for _, leaf := range leaves {
   360  		// This should fail on insert but catch it early
   361  		if got, want := len(leaf.LeafIdentityHash), t.hashSizeBytes; got != want {
   362  			return fmt.Errorf("sequenced leaf has incorrect hash size: got %v, want %v", got, want)
   363  		}
   364  		mh := string(leaf.MerkleLeafHash)
   365  		countByMerkleHash[mh]++
   366  		// insert sequenced leaf:
   367  		k := seqLeafKey(t.treeID, leaf.LeafIndex)
   368  		k.(*kv).v = leaf
   369  		t.tx.ReplaceOrInsert(k)
   370  		// update merkle-to-seq mapping:
   371  		m := t.tx.Get(hashToSeqKey(t.treeID))
   372  		l := m.(*kv).v.(map[string][]int64)[string(leaf.MerkleLeafHash)]
   373  		l = append(l, leaf.LeafIndex)
   374  		m.(*kv).v.(map[string][]int64)[string(leaf.MerkleLeafHash)] = l
   375  	}
   376  
   377  	q := t.tx.Get(unseqKey(t.treeID)).(*kv).v.(*list.List)
   378  	toRemove := make([]*list.Element, 0, q.Len())
   379  	for e := q.Front(); e != nil && len(countByMerkleHash) > 0; e = e.Next() {
   380  		h := e.Value.(*trillian.LogLeaf).MerkleLeafHash
   381  		mh := string(h)
   382  		if countByMerkleHash[mh] > 0 {
   383  			countByMerkleHash[mh]--
   384  			toRemove = append(toRemove, e)
   385  			if countByMerkleHash[mh] == 0 {
   386  				delete(countByMerkleHash, mh)
   387  			}
   388  		}
   389  	}
   390  	for _, e := range toRemove {
   391  		q.Remove(e)
   392  	}
   393  
   394  	if unknown := len(countByMerkleHash); unknown != 0 {
   395  		return fmt.Errorf("attempted to update %d unknown leaves: %x", unknown, countByMerkleHash)
   396  	}
   397  
   398  	return nil
   399  }
   400  
   401  func (t *logTreeTX) getActiveLogIDs(ctx context.Context) ([]int64, error) {
   402  	var ret []int64
   403  	for k := range t.ts.trees {
   404  		ret = append(ret, k)
   405  	}
   406  	return ret, nil
   407  }
   408  
   409  // GetActiveLogIDs returns a list of the IDs of all configured logs
   410  func (t *logTreeTX) GetActiveLogIDs(ctx context.Context) ([]int64, error) {
   411  	return t.getActiveLogIDs(ctx)
   412  }
   413  
   414  func (t *readOnlyLogTX) GetUnsequencedCounts(ctx context.Context) (storage.CountByLogID, error) {
   415  	t.ms.mu.RLock()
   416  	defer t.ms.mu.RUnlock()
   417  
   418  	ret := make(map[int64]int64)
   419  	for id, tree := range t.ms.trees {
   420  		tree.RLock()
   421  		defer tree.RUnlock() // OK to hold until method returns.
   422  
   423  		k := unseqKey(id)
   424  		queue := tree.store.Get(k).(*kv).v.(*list.List)
   425  		ret[id] = int64(queue.Len())
   426  	}
   427  	return ret, nil
   428  }