github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/integration/quota/quota_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package quota
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"hash"
    22  	"net"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/google/trillian"
    27  	"github.com/google/trillian/extension"
    28  	"github.com/google/trillian/quota/etcd/etcdqm"
    29  	"github.com/google/trillian/quota/etcd/quotaapi"
    30  	"github.com/google/trillian/quota/etcd/quotapb"
    31  	"github.com/google/trillian/quota/mysqlqm"
    32  	"github.com/google/trillian/server"
    33  	"github.com/google/trillian/server/admin"
    34  	"github.com/google/trillian/server/interceptor"
    35  	"github.com/google/trillian/storage/mysql"
    36  	"github.com/google/trillian/storage/testdb"
    37  	"github.com/google/trillian/storage/testonly"
    38  	"github.com/google/trillian/testonly/integration"
    39  	"github.com/google/trillian/testonly/integration/etcd"
    40  	"github.com/google/trillian/trees"
    41  	"github.com/google/trillian/util"
    42  	"google.golang.org/grpc"
    43  	"google.golang.org/grpc/codes"
    44  	"google.golang.org/grpc/status"
    45  )
    46  
    47  func TestEtcdRateLimiting(t *testing.T) {
    48  	testdb.SkipIfNoMySQL(t)
    49  	ctx := context.Background()
    50  
    51  	registry, err := integration.NewRegistryForTests(ctx)
    52  	if err != nil {
    53  		t.Fatalf("NewRegistryForTests() returned err = %v", err)
    54  	}
    55  
    56  	_, etcdClient, cleanup, err := etcd.StartEtcd()
    57  	if err != nil {
    58  		t.Fatalf("StartEtcd() returned err = %v", err)
    59  	}
    60  	defer cleanup()
    61  
    62  	registry.QuotaManager = etcdqm.New(etcdClient)
    63  
    64  	s, err := newTestServer(registry)
    65  	if err != nil {
    66  		t.Fatalf("newTestServer() returned err = %v", err)
    67  	}
    68  	defer s.close()
    69  
    70  	quotapb.RegisterQuotaServer(s.server, quotaapi.NewServer(etcdClient))
    71  	quotaClient := quotapb.NewQuotaClient(s.conn)
    72  	go s.serve()
    73  
    74  	const maxTokens = 100
    75  	if _, err := quotaClient.CreateConfig(ctx, &quotapb.CreateConfigRequest{
    76  		Name: "quotas/global/write/config",
    77  		Config: &quotapb.Config{
    78  			State:     quotapb.Config_ENABLED,
    79  			MaxTokens: maxTokens,
    80  			ReplenishmentStrategy: &quotapb.Config_TimeBased{
    81  				TimeBased: &quotapb.TimeBasedStrategy{
    82  					TokensToReplenish:        maxTokens,
    83  					ReplenishIntervalSeconds: 1000,
    84  				},
    85  			},
    86  		},
    87  	}); err != nil {
    88  		t.Fatalf("CreateConfig() returned err = %v", err)
    89  	}
    90  
    91  	if err := runRateLimitingTest(ctx, s, maxTokens); err != nil {
    92  		t.Error(err)
    93  	}
    94  }
    95  
    96  func TestMySQLRateLimiting(t *testing.T) {
    97  	testdb.SkipIfNoMySQL(t)
    98  	ctx := context.Background()
    99  	db, err := testdb.NewTrillianDB(ctx)
   100  	if err != nil {
   101  		t.Fatalf("GetTestDB() returned err = %v", err)
   102  	}
   103  	defer db.Close()
   104  
   105  	const maxUnsequenced = 20
   106  	qm := &mysqlqm.QuotaManager{DB: db, MaxUnsequencedRows: maxUnsequenced}
   107  	registry := extension.Registry{
   108  		AdminStorage: mysql.NewAdminStorage(db),
   109  		LogStorage:   mysql.NewLogStorage(db, nil),
   110  		MapStorage:   mysql.NewMapStorage(db),
   111  		QuotaManager: qm,
   112  	}
   113  
   114  	s, err := newTestServer(registry)
   115  	if err != nil {
   116  		t.Fatalf("newTestServer() returned err = %v", err)
   117  	}
   118  	defer s.close()
   119  	go s.serve()
   120  
   121  	if err := runRateLimitingTest(ctx, s, maxUnsequenced); err != nil {
   122  		t.Error(err)
   123  	}
   124  }
   125  
   126  func runRateLimitingTest(ctx context.Context, s *testServer, numTokens int) error {
   127  	tree, err := s.admin.CreateTree(ctx, &trillian.CreateTreeRequest{Tree: testonly.LogTree})
   128  	if err != nil {
   129  		return fmt.Errorf("CreateTree() returned err = %v", err)
   130  	}
   131  	// InitLog costs 1 token
   132  	numTokens--
   133  	_, err = s.log.InitLog(ctx, &trillian.InitLogRequest{LogId: tree.TreeId})
   134  	if err != nil {
   135  		return fmt.Errorf("InitLog() returned err = %v", err)
   136  	}
   137  	hasherFn, err := trees.Hash(tree)
   138  	if err != nil {
   139  		return fmt.Errorf("Hash() returned err = %v", err)
   140  	}
   141  	hasher := hasherFn.New()
   142  	lw := &leafWriter{client: s.log, hash: hasher, treeID: tree.TreeId}
   143  
   144  	// Requests where leaves < numTokens should work
   145  	for i := 0; i < numTokens; i++ {
   146  		if err := lw.queueLeaf(ctx); err != nil {
   147  			return fmt.Errorf("queueLeaf(@%d) returned err = %v", i, err)
   148  		}
   149  	}
   150  
   151  	// Some point after now requests should start to fail
   152  	stop := false
   153  	timeout := time.After(1 * time.Second)
   154  	for !stop {
   155  		select {
   156  		case <-timeout:
   157  			return errors.New("timed out before rate limiting kicked in")
   158  		default:
   159  			err := lw.queueLeaf(ctx)
   160  			if err == nil {
   161  				continue // Rate liming hasn't kicked in yet
   162  			}
   163  			if s, ok := status.FromError(err); !ok || s.Code() != codes.ResourceExhausted {
   164  				return fmt.Errorf("queueLeaf() returned err = %v", err)
   165  			}
   166  			stop = true
   167  		}
   168  	}
   169  	return nil
   170  }
   171  
   172  type leafWriter struct {
   173  	client trillian.TrillianLogClient
   174  	hash   hash.Hash
   175  	treeID int64
   176  	leafID int
   177  }
   178  
   179  func (w *leafWriter) queueLeaf(ctx context.Context) error {
   180  	value := []byte(fmt.Sprintf("leaf-%v", w.leafID))
   181  	w.leafID++
   182  
   183  	w.hash.Reset()
   184  	if _, err := w.hash.Write(value); err != nil {
   185  		return err
   186  	}
   187  	h := w.hash.Sum(nil)
   188  
   189  	_, err := w.client.QueueLeaf(ctx, &trillian.QueueLeafRequest{
   190  		LogId: w.treeID,
   191  		Leaf: &trillian.LogLeaf{
   192  			MerkleLeafHash: h,
   193  			LeafValue:      value,
   194  		}})
   195  	return err
   196  }
   197  
   198  type testServer struct {
   199  	lis    net.Listener
   200  	server *grpc.Server
   201  	conn   *grpc.ClientConn
   202  	admin  trillian.TrillianAdminClient
   203  	log    trillian.TrillianLogClient
   204  }
   205  
   206  func (s *testServer) close() {
   207  	if s.conn != nil {
   208  		s.conn.Close()
   209  	}
   210  	if s.server != nil {
   211  		s.server.GracefulStop()
   212  	}
   213  	if s.lis != nil {
   214  		s.lis.Close()
   215  	}
   216  }
   217  
   218  func (s *testServer) serve() {
   219  	s.server.Serve(s.lis)
   220  }
   221  
   222  // newTestServer returns a new testServer configured for integration tests.
   223  // Callers must defer-call s.close() to make sure resources aren't being leaked and must start the
   224  // server via s.serve().
   225  func newTestServer(registry extension.Registry) (*testServer, error) {
   226  	s := &testServer{}
   227  
   228  	intercept := interceptor.New(
   229  		registry.AdminStorage, registry.QuotaManager, false /* quotaDryRun */, registry.MetricFactory)
   230  	netInterceptor := interceptor.Combine(interceptor.ErrorWrapper, intercept.UnaryInterceptor)
   231  	s.server = grpc.NewServer(grpc.UnaryInterceptor(netInterceptor))
   232  	trillian.RegisterTrillianAdminServer(s.server, admin.New(registry, nil /* allowedTreeTypes */))
   233  	trillian.RegisterTrillianLogServer(s.server, server.NewTrillianLogRPCServer(registry, util.SystemTimeSource{}))
   234  
   235  	var err error
   236  	s.lis, err = net.Listen("tcp", "127.0.0.1:0")
   237  	if err != nil {
   238  		s.close()
   239  		return nil, err
   240  	}
   241  
   242  	s.conn, err = grpc.Dial(s.lis.Addr().String(), grpc.WithInsecure())
   243  	if err != nil {
   244  		s.close()
   245  		return nil, err
   246  	}
   247  	s.admin = trillian.NewTrillianAdminClient(s.conn)
   248  	s.log = trillian.NewTrillianLogClient(s.conn)
   249  	return s, nil
   250  }