github.com/zorawar87/trillian@v1.2.1/cmd/createtree/main_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"flag"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/golang/mock/gomock"
    25  	"github.com/golang/protobuf/proto"
    26  	"github.com/golang/protobuf/ptypes"
    27  	"github.com/golang/protobuf/ptypes/any"
    28  	"github.com/golang/protobuf/ptypes/empty"
    29  	"github.com/google/trillian"
    30  	"github.com/google/trillian/crypto/sigpb"
    31  	"github.com/google/trillian/testonly"
    32  	"github.com/google/trillian/util/flagsaver"
    33  )
    34  
    35  // defaultTree reflects all flag defaults with the addition of a valid private key.
    36  var defaultTree = &trillian.Tree{
    37  	TreeState:          trillian.TreeState_ACTIVE,
    38  	TreeType:           trillian.TreeType_LOG,
    39  	HashStrategy:       trillian.HashStrategy_RFC6962_SHA256,
    40  	HashAlgorithm:      sigpb.DigitallySigned_SHA256,
    41  	SignatureAlgorithm: sigpb.DigitallySigned_ECDSA,
    42  	PrivateKey:         mustMarshalAny(&empty.Empty{}),
    43  	MaxRootDuration:    ptypes.DurationProto(0 * time.Millisecond),
    44  }
    45  
    46  type testCase struct {
    47  	desc        string
    48  	setFlags    func()
    49  	validateErr error
    50  	createErr   error
    51  	initErr     error
    52  	wantErr     bool
    53  	wantTree    *trillian.Tree
    54  }
    55  
    56  func mustMarshalAny(p proto.Message) *any.Any {
    57  	anyKey, err := ptypes.MarshalAny(p)
    58  	if err != nil {
    59  		panic(err)
    60  	}
    61  	return anyKey
    62  }
    63  
    64  func TestCreateTree(t *testing.T) {
    65  	nonDefaultTree := *defaultTree
    66  	nonDefaultTree.TreeType = trillian.TreeType_MAP
    67  	nonDefaultTree.SignatureAlgorithm = sigpb.DigitallySigned_RSA
    68  	nonDefaultTree.DisplayName = "Llamas Map"
    69  	nonDefaultTree.Description = "For all your digital llama needs!"
    70  
    71  	runTest(t, []*testCase{
    72  		{
    73  			desc: "validOpts",
    74  			// runTest sets mandatory options, so no need to provide a setFlags func.
    75  			wantTree: defaultTree,
    76  		},
    77  		{
    78  			desc: "nonDefaultOpts",
    79  			setFlags: func() {
    80  				*treeType = nonDefaultTree.TreeType.String()
    81  				*signatureAlgorithm = nonDefaultTree.SignatureAlgorithm.String()
    82  				*displayName = nonDefaultTree.DisplayName
    83  				*description = nonDefaultTree.Description
    84  			},
    85  			wantTree: &nonDefaultTree,
    86  		},
    87  		{
    88  			desc: "mandatoryOptsNotSet",
    89  			// Undo the flags set by runTest, so that mandatory options are no longer set.
    90  			setFlags:    resetFlags,
    91  			validateErr: errAdminAddrNotSet,
    92  			wantErr:     true,
    93  		},
    94  		{
    95  			desc:        "emptyAddr",
    96  			setFlags:    func() { *adminServerAddr = "" },
    97  			validateErr: errAdminAddrNotSet,
    98  			wantErr:     true,
    99  		},
   100  		{
   101  			desc:        "invalidEnumOpts",
   102  			setFlags:    func() { *treeType = "LLAMA!" },
   103  			validateErr: errors.New("unknown TreeType"),
   104  			wantErr:     true,
   105  		},
   106  		{
   107  			desc:        "invalidKeyTypeOpts",
   108  			setFlags:    func() { *privateKeyFormat = "LLAMA!!" },
   109  			validateErr: errors.New("key protobuf must be one of"),
   110  			wantErr:     true,
   111  		},
   112  		{
   113  			desc:      "createErr",
   114  			createErr: errors.New("create tree failed"),
   115  			wantErr:   true,
   116  		},
   117  		{
   118  			desc: "logInitErr",
   119  			setFlags: func() {
   120  				nonDefaultTree.TreeType = trillian.TreeType_LOG
   121  				*treeType = nonDefaultTree.TreeType.String()
   122  			},
   123  			wantTree: defaultTree,
   124  			initErr:  errors.New("log init failed"),
   125  			wantErr:  true,
   126  		},
   127  		{
   128  			desc: "mapInitErr",
   129  			setFlags: func() {
   130  				nonDefaultTree.TreeType = trillian.TreeType_MAP
   131  				*treeType = nonDefaultTree.TreeType.String()
   132  			},
   133  			wantTree: &nonDefaultTree,
   134  			initErr:  errors.New("map init failed"),
   135  			wantErr:  true,
   136  		},
   137  	})
   138  }
   139  
   140  // runTest executes the createtree command against a fake TrillianAdminServer
   141  // for each of the provided tests, and checks that the tree in the request is
   142  // as expected, or an expected error occurs.
   143  // Prior to each test case, it:
   144  // 1. Resets all flags to their original values.
   145  // 2. Sets the adminServerAddr flag to point to the fake server.
   146  // 3. Calls the test's setFlags func (if provided) to allow it to change flags specific to the test.
   147  func runTest(t *testing.T, tests []*testCase) {
   148  	for _, tc := range tests {
   149  		t.Run(tc.desc, func(t *testing.T) {
   150  			ctrl := gomock.NewController(t)
   151  			defer ctrl.Finish()
   152  
   153  			s, stopFakeServer, err := testonly.NewMockServer(ctrl)
   154  			if err != nil {
   155  				t.Fatalf("Error starting fake server: %v", err)
   156  			}
   157  			defer stopFakeServer()
   158  			defer flagsaver.Save().Restore()
   159  			*adminServerAddr = s.Addr
   160  			if tc.setFlags != nil {
   161  				tc.setFlags()
   162  			}
   163  
   164  			call := s.Admin.EXPECT().CreateTree(gomock.Any(), gomock.Any()).Return(tc.wantTree, tc.createErr)
   165  			expectCalls(call, tc.createErr, tc.validateErr)
   166  			switch *treeType {
   167  			case "LOG":
   168  				call := s.Log.EXPECT().InitLog(gomock.Any(), gomock.Any()).Return(&trillian.InitLogResponse{}, tc.initErr)
   169  				expectCalls(call, tc.initErr, tc.validateErr, tc.createErr)
   170  				call = s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{}, nil)
   171  				expectCalls(call, nil, tc.validateErr, tc.createErr, tc.initErr)
   172  			case "MAP":
   173  				call := s.Map.EXPECT().InitMap(gomock.Any(), gomock.Any()).Return(&trillian.InitMapResponse{}, tc.initErr)
   174  				expectCalls(call, tc.initErr, tc.validateErr, tc.createErr)
   175  				call = s.Map.EXPECT().GetSignedMapRootByRevision(gomock.Any(), gomock.Any()).Return(&trillian.GetSignedMapRootResponse{}, nil)
   176  				expectCalls(call, nil, tc.validateErr, tc.createErr, tc.initErr)
   177  			}
   178  
   179  			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   180  			defer cancel()
   181  			_, err = createTree(ctx)
   182  			if hasErr := err != nil; hasErr != tc.wantErr {
   183  				t.Errorf("createTree() '%v', wantErr = %v", err, tc.wantErr)
   184  			}
   185  		})
   186  	}
   187  }
   188  
   189  // expectCalls returns the minimum number of times a function is expected to be called
   190  // given the return error for the function (err), and all previous errors in the function's
   191  // code path.
   192  func expectCalls(call *gomock.Call, err error, prevErr ...error) *gomock.Call {
   193  	// If a function prior to this function errored,
   194  	// we do not expect this function to be called.
   195  	for _, e := range prevErr {
   196  		if e != nil {
   197  			return call.Times(0)
   198  		}
   199  	}
   200  	// If this function errors, it will be retried multiple times.
   201  	if err != nil {
   202  		return call.MinTimes(2)
   203  	}
   204  	// If this function succeeds it should only be called once.
   205  	return call.Times(1)
   206  }
   207  
   208  // resetFlags sets all flags to their default values.
   209  func resetFlags() {
   210  	flag.Visit(func(f *flag.Flag) {
   211  		f.Value.Set(f.DefValue)
   212  	})
   213  }