github.com/bartle-stripe/trillian@v1.2.1/integration/quota/quota_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package quota
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"hash"
    22  	"net"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/google/trillian"
    27  	"github.com/google/trillian/extension"
    28  	"github.com/google/trillian/quota/etcd/etcdqm"
    29  	"github.com/google/trillian/quota/etcd/quotaapi"
    30  	"github.com/google/trillian/quota/etcd/quotapb"
    31  	"github.com/google/trillian/quota/mysqlqm"
    32  	"github.com/google/trillian/server"
    33  	"github.com/google/trillian/server/admin"
    34  	"github.com/google/trillian/server/interceptor"
    35  	"github.com/google/trillian/storage/mysql"
    36  	"github.com/google/trillian/storage/testdb"
    37  	"github.com/google/trillian/storage/testonly"
    38  	"github.com/google/trillian/testonly/integration"
    39  	"github.com/google/trillian/testonly/integration/etcd"
    40  	"github.com/google/trillian/trees"
    41  	"github.com/google/trillian/util"
    42  	"google.golang.org/grpc"
    43  	"google.golang.org/grpc/codes"
    44  	"google.golang.org/grpc/status"
    45  
    46  	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
    47  )
    48  
    49  func TestEtcdRateLimiting(t *testing.T) {
    50  	testdb.SkipIfNoMySQL(t)
    51  	ctx := context.Background()
    52  
    53  	registry, err := integration.NewRegistryForTests(ctx)
    54  	if err != nil {
    55  		t.Fatalf("NewRegistryForTests() returned err = %v", err)
    56  	}
    57  
    58  	_, etcdClient, cleanup, err := etcd.StartEtcd()
    59  	if err != nil {
    60  		t.Fatalf("StartEtcd() returned err = %v", err)
    61  	}
    62  	defer cleanup()
    63  
    64  	registry.QuotaManager = etcdqm.New(etcdClient)
    65  
    66  	s, err := newTestServer(registry)
    67  	if err != nil {
    68  		t.Fatalf("newTestServer() returned err = %v", err)
    69  	}
    70  	defer s.close()
    71  
    72  	quotapb.RegisterQuotaServer(s.server, quotaapi.NewServer(etcdClient))
    73  	quotaClient := quotapb.NewQuotaClient(s.conn)
    74  	go s.serve()
    75  
    76  	const maxTokens = 100
    77  	if _, err := quotaClient.CreateConfig(ctx, &quotapb.CreateConfigRequest{
    78  		Name: "quotas/global/write/config",
    79  		Config: &quotapb.Config{
    80  			State:     quotapb.Config_ENABLED,
    81  			MaxTokens: maxTokens,
    82  			ReplenishmentStrategy: &quotapb.Config_TimeBased{
    83  				TimeBased: &quotapb.TimeBasedStrategy{
    84  					TokensToReplenish:        maxTokens,
    85  					ReplenishIntervalSeconds: 1000,
    86  				},
    87  			},
    88  		},
    89  	}); err != nil {
    90  		t.Fatalf("CreateConfig() returned err = %v", err)
    91  	}
    92  
    93  	if err := runRateLimitingTest(ctx, s, maxTokens); err != nil {
    94  		t.Error(err)
    95  	}
    96  }
    97  
    98  func TestMySQLRateLimiting(t *testing.T) {
    99  	testdb.SkipIfNoMySQL(t)
   100  	ctx := context.Background()
   101  	db, err := testdb.NewTrillianDB(ctx)
   102  	if err != nil {
   103  		t.Fatalf("GetTestDB() returned err = %v", err)
   104  	}
   105  	defer db.Close()
   106  
   107  	const maxUnsequenced = 20
   108  	qm := &mysqlqm.QuotaManager{DB: db, MaxUnsequencedRows: maxUnsequenced}
   109  	registry := extension.Registry{
   110  		AdminStorage: mysql.NewAdminStorage(db),
   111  		LogStorage:   mysql.NewLogStorage(db, nil),
   112  		MapStorage:   mysql.NewMapStorage(db),
   113  		QuotaManager: qm,
   114  	}
   115  
   116  	s, err := newTestServer(registry)
   117  	if err != nil {
   118  		t.Fatalf("newTestServer() returned err = %v", err)
   119  	}
   120  	defer s.close()
   121  	go s.serve()
   122  
   123  	if err := runRateLimitingTest(ctx, s, maxUnsequenced); err != nil {
   124  		t.Error(err)
   125  	}
   126  }
   127  
   128  func runRateLimitingTest(ctx context.Context, s *testServer, numTokens int) error {
   129  	tree, err := s.admin.CreateTree(ctx, &trillian.CreateTreeRequest{Tree: testonly.LogTree})
   130  	if err != nil {
   131  		return fmt.Errorf("CreateTree() returned err = %v", err)
   132  	}
   133  	// InitLog costs 1 token
   134  	numTokens--
   135  	_, err = s.log.InitLog(ctx, &trillian.InitLogRequest{LogId: tree.TreeId})
   136  	if err != nil {
   137  		return fmt.Errorf("InitLog() returned err = %v", err)
   138  	}
   139  	hasherFn, err := trees.Hash(tree)
   140  	if err != nil {
   141  		return fmt.Errorf("Hash() returned err = %v", err)
   142  	}
   143  	hasher := hasherFn.New()
   144  	lw := &leafWriter{client: s.log, hash: hasher, treeID: tree.TreeId}
   145  
   146  	// Requests where leaves < numTokens should work
   147  	for i := 0; i < numTokens; i++ {
   148  		if err := lw.queueLeaf(ctx); err != nil {
   149  			return fmt.Errorf("queueLeaf(@%d) returned err = %v", i, err)
   150  		}
   151  	}
   152  
   153  	// Some point after now requests should start to fail
   154  	stop := false
   155  	timeout := time.After(1 * time.Second)
   156  	for !stop {
   157  		select {
   158  		case <-timeout:
   159  			return errors.New("timed out before rate limiting kicked in")
   160  		default:
   161  			err := lw.queueLeaf(ctx)
   162  			if err == nil {
   163  				continue // Rate liming hasn't kicked in yet
   164  			}
   165  			if s, ok := status.FromError(err); !ok || s.Code() != codes.ResourceExhausted {
   166  				return fmt.Errorf("queueLeaf() returned err = %v", err)
   167  			}
   168  			stop = true
   169  		}
   170  	}
   171  	return nil
   172  }
   173  
   174  type leafWriter struct {
   175  	client trillian.TrillianLogClient
   176  	hash   hash.Hash
   177  	treeID int64
   178  	leafID int
   179  }
   180  
   181  func (w *leafWriter) queueLeaf(ctx context.Context) error {
   182  	value := []byte(fmt.Sprintf("leaf-%v", w.leafID))
   183  	w.leafID++
   184  
   185  	w.hash.Reset()
   186  	if _, err := w.hash.Write(value); err != nil {
   187  		return err
   188  	}
   189  	h := w.hash.Sum(nil)
   190  
   191  	_, err := w.client.QueueLeaf(ctx, &trillian.QueueLeafRequest{
   192  		LogId: w.treeID,
   193  		Leaf: &trillian.LogLeaf{
   194  			MerkleLeafHash: h,
   195  			LeafValue:      value,
   196  		}})
   197  	return err
   198  }
   199  
   200  type testServer struct {
   201  	lis    net.Listener
   202  	server *grpc.Server
   203  	conn   *grpc.ClientConn
   204  	admin  trillian.TrillianAdminClient
   205  	log    trillian.TrillianLogClient
   206  }
   207  
   208  func (s *testServer) close() {
   209  	if s.conn != nil {
   210  		s.conn.Close()
   211  	}
   212  	if s.server != nil {
   213  		s.server.GracefulStop()
   214  	}
   215  	if s.lis != nil {
   216  		s.lis.Close()
   217  	}
   218  }
   219  
   220  func (s *testServer) serve() {
   221  	s.server.Serve(s.lis)
   222  }
   223  
   224  // newTestServer returns a new testServer configured for integration tests.
   225  // Callers must defer-call s.close() to make sure resources aren't being leaked and must start the
   226  // server via s.serve().
   227  func newTestServer(registry extension.Registry) (*testServer, error) {
   228  	s := &testServer{}
   229  
   230  	ti := interceptor.New(registry.AdminStorage, registry.QuotaManager, false /* quotaDryRun */, registry.MetricFactory)
   231  	s.server = grpc.NewServer(
   232  		grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
   233  			interceptor.ErrorWrapper,
   234  			ti.UnaryInterceptor,
   235  		)),
   236  	)
   237  	trillian.RegisterTrillianAdminServer(s.server, admin.New(registry, nil /* allowedTreeTypes */))
   238  	trillian.RegisterTrillianLogServer(s.server, server.NewTrillianLogRPCServer(registry, util.SystemTimeSource{}))
   239  
   240  	var err error
   241  	s.lis, err = net.Listen("tcp", "127.0.0.1:0")
   242  	if err != nil {
   243  		s.close()
   244  		return nil, err
   245  	}
   246  
   247  	s.conn, err = grpc.Dial(s.lis.Addr().String(), grpc.WithInsecure())
   248  	if err != nil {
   249  		s.close()
   250  		return nil, err
   251  	}
   252  	s.admin = trillian.NewTrillianAdminClient(s.conn)
   253  	s.log = trillian.NewTrillianLogClient(s.conn)
   254  	return s, nil
   255  }