github.com/zorawar87/trillian@v1.2.1/quota/mysqlqm/mysql_quota_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysqlqm_test
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"database/sql"
    21  	"fmt"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/google/trillian"
    26  	"github.com/google/trillian/quota"
    27  	"github.com/google/trillian/quota/mysqlqm"
    28  	"github.com/google/trillian/storage"
    29  	"github.com/google/trillian/storage/mysql"
    30  	"github.com/google/trillian/storage/testdb"
    31  	"github.com/google/trillian/testonly"
    32  	"github.com/google/trillian/trees"
    33  	"github.com/google/trillian/types"
    34  	"github.com/kylelemons/godebug/pretty"
    35  
    36  	tcrypto "github.com/google/trillian/crypto"
    37  	stestonly "github.com/google/trillian/storage/testonly"
    38  )
    39  
    40  func TestQuotaManager_GetTokens(t *testing.T) {
    41  	testdb.SkipIfNoMySQL(t)
    42  	ctx := context.Background()
    43  
    44  	db, err := testdb.NewTrillianDB(ctx)
    45  	if err != nil {
    46  		t.Fatalf("GetTestDB() returned err = %v", err)
    47  	}
    48  	defer db.Close()
    49  
    50  	tree, err := createTree(ctx, db)
    51  	if err != nil {
    52  		t.Fatalf("createTree() returned err = %v", err)
    53  	}
    54  
    55  	tests := []struct {
    56  		desc                                           string
    57  		unsequencedRows, maxUnsequencedRows, numTokens int
    58  		specs                                          []quota.Spec
    59  		wantErr                                        bool
    60  	}{
    61  		{
    62  			desc:               "globalWriteSingleToken",
    63  			unsequencedRows:    10,
    64  			maxUnsequencedRows: 20,
    65  			numTokens:          1,
    66  			specs:              []quota.Spec{{Group: quota.Global, Kind: quota.Write}},
    67  		},
    68  		{
    69  			desc:               "globalWriteMultiToken",
    70  			unsequencedRows:    10,
    71  			maxUnsequencedRows: 20,
    72  			numTokens:          5,
    73  			specs:              []quota.Spec{{Group: quota.Global, Kind: quota.Write}},
    74  		},
    75  		{
    76  			desc:               "globalWriteOverQuota1",
    77  			unsequencedRows:    20,
    78  			maxUnsequencedRows: 20,
    79  			numTokens:          1,
    80  			specs:              []quota.Spec{{Group: quota.Global, Kind: quota.Write}},
    81  			wantErr:            true,
    82  		},
    83  		{
    84  			desc:               "globalWriteOverQuota2",
    85  			unsequencedRows:    15,
    86  			maxUnsequencedRows: 20,
    87  			numTokens:          10,
    88  			specs:              []quota.Spec{{Group: quota.Global, Kind: quota.Write}},
    89  			wantErr:            true,
    90  		},
    91  		{
    92  			desc:      "unlimitedQuotas",
    93  			numTokens: 10,
    94  			specs: []quota.Spec{
    95  				{Group: quota.User, Kind: quota.Read, User: "dylan"},
    96  				{Group: quota.Tree, Kind: quota.Read, TreeID: tree.TreeId},
    97  				{Group: quota.Global, Kind: quota.Read},
    98  				{Group: quota.User, Kind: quota.Write, User: "dylan"},
    99  				{Group: quota.Tree, Kind: quota.Write, TreeID: tree.TreeId},
   100  			},
   101  		},
   102  	}
   103  
   104  	for _, test := range tests {
   105  		if err := setUnsequencedRows(ctx, db, tree, test.unsequencedRows); err != nil {
   106  			t.Errorf("setUnsequencedRows() returned err = %v", err)
   107  			continue
   108  		}
   109  
   110  		// Test general cases using select count(*) to avoid flakiness / allow for more
   111  		// precise assertions.
   112  		// See TestQuotaManager_GetTokens_InformationSchema for information schema tests.
   113  		qm := &mysqlqm.QuotaManager{DB: db, MaxUnsequencedRows: test.maxUnsequencedRows, UseSelectCount: true}
   114  		err := qm.GetTokens(ctx, test.numTokens, test.specs)
   115  		if hasErr := err == mysqlqm.ErrTooManyUnsequencedRows; hasErr != test.wantErr {
   116  			t.Errorf("%v: GetTokens() returned err = %q, wantErr = %v", test.desc, err, test.wantErr)
   117  		}
   118  	}
   119  }
   120  
   121  func TestQuotaManager_GetTokens_InformationSchema(t *testing.T) {
   122  	testdb.SkipIfNoMySQL(t)
   123  	ctx := context.Background()
   124  
   125  	maxUnsequenced := 20
   126  	globalWriteSpec := []quota.Spec{{Group: quota.Global, Kind: quota.Write}}
   127  
   128  	// Make both variants go through the test.
   129  	tests := []struct {
   130  		useSelectCount bool
   131  	}{
   132  		{useSelectCount: true},
   133  		{useSelectCount: false},
   134  	}
   135  	for _, test := range tests {
   136  		desc := fmt.Sprintf("useSelectCount = %v", test.useSelectCount)
   137  		t.Run(desc, func(t *testing.T) {
   138  			db, err := testdb.NewTrillianDB(ctx)
   139  			if err != nil {
   140  				t.Fatalf("NewTrillianDB() returned err = %v", err)
   141  			}
   142  			defer db.Close()
   143  
   144  			tree, err := createTree(ctx, db)
   145  			if err != nil {
   146  				t.Fatalf("createTree() returned err = %v", err)
   147  			}
   148  
   149  			qm := &mysqlqm.QuotaManager{DB: db, MaxUnsequencedRows: maxUnsequenced, UseSelectCount: test.useSelectCount}
   150  
   151  			// All GetTokens() calls where leaves < maxUnsequenced should succeed:
   152  			// information_schema may be outdated, but it should refer to a valid point in the
   153  			// past.
   154  			for i := 0; i < maxUnsequenced-1; i++ {
   155  				if err := queueLeaves(ctx, db, tree, i /* firstID */, 1 /* num */); err != nil {
   156  					t.Fatalf("queueLeaves() returned err = %v", err)
   157  				}
   158  				if err := qm.GetTokens(ctx, 1 /* numTokens */, globalWriteSpec); err != nil {
   159  					t.Errorf("GetTokens() returned err = %v (%v leaves)", err, i+1)
   160  				}
   161  			}
   162  
   163  			// Make leaves = maxUnsequenced
   164  			if err := queueLeaves(ctx, db, tree, maxUnsequenced-1 /* firstID */, 1 /* num */); err != nil {
   165  				t.Fatalf("queueLeaves() returned err = %v", err)
   166  			}
   167  
   168  			// Allow some time for information_schema to "catch up".
   169  			stop := false
   170  			timeout := time.After(1 * time.Second)
   171  			for !stop {
   172  				select {
   173  				case <-timeout:
   174  					t.Errorf("timed out")
   175  					stop = true
   176  				default:
   177  					// An error means that GetTokens is working correctly
   178  					stop = qm.GetTokens(ctx, 1 /* numTokens */, globalWriteSpec) == mysqlqm.ErrTooManyUnsequencedRows
   179  				}
   180  			}
   181  		})
   182  	}
   183  }
   184  
   185  func TestQuotaManager_PeekTokens(t *testing.T) {
   186  	testdb.SkipIfNoMySQL(t)
   187  	ctx := context.Background()
   188  
   189  	db, err := testdb.NewTrillianDB(ctx)
   190  	if err != nil {
   191  		t.Fatalf("GetTestDB() returned err = %v", err)
   192  	}
   193  	defer db.Close()
   194  
   195  	tree, err := createTree(ctx, db)
   196  	if err != nil {
   197  		t.Fatalf("createTree() returned err = %v", err)
   198  	}
   199  
   200  	unsequencedRows := 10
   201  	maxUnsequencedRows := 1000
   202  	wantRows := maxUnsequencedRows - unsequencedRows
   203  	if err := setUnsequencedRows(ctx, db, tree, unsequencedRows); err != nil {
   204  		t.Fatalf("setUnsequencedRows() returned err = %v", err)
   205  	}
   206  
   207  	// Test using select count(*) to allow for precise assertions without flakiness.
   208  	qm := &mysqlqm.QuotaManager{DB: db, MaxUnsequencedRows: maxUnsequencedRows, UseSelectCount: true}
   209  	specs := allSpecs(ctx, qm, tree.TreeId)
   210  	tokens, err := qm.PeekTokens(ctx, specs)
   211  	if err != nil {
   212  		t.Fatalf("PeekTokens() returned err = %v", err)
   213  	}
   214  
   215  	// All specs but Global/Write are infinite
   216  	wantTokens := make(map[quota.Spec]int)
   217  	for _, spec := range specs {
   218  		wantTokens[spec] = quota.MaxTokens
   219  	}
   220  	wantTokens[quota.Spec{Group: quota.Global, Kind: quota.Write}] = wantRows
   221  
   222  	if diff := pretty.Compare(tokens, wantTokens); diff != "" {
   223  		t.Errorf("post-PeekTokens() diff:\n%v", diff)
   224  	}
   225  }
   226  
   227  func TestQuotaManager_Noops(t *testing.T) {
   228  	testdb.SkipIfNoMySQL(t)
   229  	ctx := context.Background()
   230  
   231  	db, err := testdb.NewTrillianDB(ctx)
   232  	if err != nil {
   233  		t.Fatalf("GetTestDB() returned err = %v", err)
   234  	}
   235  	defer db.Close()
   236  
   237  	qm := &mysqlqm.QuotaManager{DB: db, MaxUnsequencedRows: 1000}
   238  	specs := allSpecs(ctx, qm, 10 /* treeID */)
   239  
   240  	tests := []struct {
   241  		desc string
   242  		fn   func() error
   243  	}{
   244  		{
   245  			desc: "PutTokens",
   246  			fn: func() error {
   247  				return qm.PutTokens(ctx, 10 /* numTokens */, specs)
   248  			},
   249  		},
   250  		{
   251  			desc: "ResetQuota",
   252  			fn: func() error {
   253  				return qm.ResetQuota(ctx, specs)
   254  			},
   255  		},
   256  	}
   257  	for _, test := range tests {
   258  		if err := test.fn(); err != nil {
   259  			t.Errorf("%v: got err = %v", test.desc, err)
   260  		}
   261  	}
   262  }
   263  
   264  func allSpecs(ctx context.Context, qm quota.Manager, treeID int64) []quota.Spec {
   265  	return []quota.Spec{
   266  		{Group: quota.User, Kind: quota.Read, User: "florence"},
   267  		{Group: quota.Tree, Kind: quota.Read, TreeID: treeID},
   268  		{Group: quota.Global, Kind: quota.Read},
   269  		{Group: quota.User, Kind: quota.Write, User: "florence"},
   270  		{Group: quota.Tree, Kind: quota.Write, TreeID: treeID},
   271  		{Group: quota.Global, Kind: quota.Write},
   272  	}
   273  }
   274  
   275  func countUnsequenced(ctx context.Context, db *sql.DB) (int, error) {
   276  	var count int
   277  	if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM Unsequenced").Scan(&count); err != nil {
   278  		return 0, err
   279  	}
   280  	return count, nil
   281  }
   282  
   283  func createTree(ctx context.Context, db *sql.DB) (*trillian.Tree, error) {
   284  	var tree *trillian.Tree
   285  
   286  	{
   287  		as := mysql.NewAdminStorage(db)
   288  		err := as.ReadWriteTransaction(ctx, func(ctx context.Context, tx storage.AdminTX) error {
   289  			var err error
   290  			tree, err = tx.CreateTree(ctx, stestonly.LogTree)
   291  			return err
   292  		})
   293  		if err != nil {
   294  			return nil, err
   295  		}
   296  	}
   297  
   298  	{
   299  		ls := mysql.NewLogStorage(db, nil)
   300  		err := ls.ReadWriteTransaction(ctx, tree, func(ctx context.Context, tx storage.LogTreeTX) error {
   301  			signer := tcrypto.NewSigner(0, testonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256)
   302  			slr, err := signer.SignLogRoot(&types.LogRootV1{RootHash: []byte{0}})
   303  			if err != nil {
   304  				return err
   305  			}
   306  			return tx.StoreSignedLogRoot(ctx, *slr)
   307  		})
   308  		if err != nil {
   309  			return nil, err
   310  		}
   311  	}
   312  
   313  	return tree, nil
   314  }
   315  
   316  func queueLeaves(ctx context.Context, db *sql.DB, tree *trillian.Tree, firstID, num int) error {
   317  	hasherFn, err := trees.Hash(tree)
   318  	if err != nil {
   319  		return err
   320  	}
   321  	hasher := hasherFn.New()
   322  
   323  	leaves := []*trillian.LogLeaf{}
   324  	for i := 0; i < num; i++ {
   325  		value := []byte(fmt.Sprintf("leaf-%v", firstID+i))
   326  		hasher.Reset()
   327  		if _, err := hasher.Write(value); err != nil {
   328  			return err
   329  		}
   330  		hash := hasher.Sum(nil)
   331  		leaves = append(leaves, &trillian.LogLeaf{
   332  			MerkleLeafHash:   hash,
   333  			LeafValue:        value,
   334  			ExtraData:        []byte("extra data"),
   335  			LeafIdentityHash: hash,
   336  		})
   337  	}
   338  
   339  	ls := mysql.NewLogStorage(db, nil)
   340  	return ls.ReadWriteTransaction(ctx, tree, func(ctx context.Context, tx storage.LogTreeTX) error {
   341  		_, err := tx.QueueLeaves(ctx, leaves, time.Now())
   342  		return err
   343  	})
   344  }
   345  
   346  func setUnsequencedRows(ctx context.Context, db *sql.DB, tree *trillian.Tree, wantRows int) error {
   347  	count, err := countUnsequenced(ctx, db)
   348  	if err != nil {
   349  		return err
   350  	}
   351  	if count == wantRows {
   352  		return nil
   353  	}
   354  
   355  	// Clear the tables and re-create leaves from scratch. It's easier than having to reason
   356  	// about duplicate entries.
   357  	if _, err := db.ExecContext(ctx, "DELETE FROM LeafData"); err != nil {
   358  		return err
   359  	}
   360  	if _, err := db.ExecContext(ctx, "DELETE FROM Unsequenced"); err != nil {
   361  		return err
   362  	}
   363  	if err := queueLeaves(ctx, db, tree, 0 /* firstID */, wantRows); err != nil {
   364  		return err
   365  	}
   366  
   367  	// Sanity check the final count
   368  	count, err = countUnsequenced(ctx, db)
   369  	if err != nil {
   370  		return err
   371  	}
   372  	if count != wantRows {
   373  		return fmt.Errorf("got %v unsequenced rows, want = %v", count, wantRows)
   374  	}
   375  
   376  	return nil
   377  }