github.com/bartle-stripe/trillian@v1.2.1/cmd/createtree/main_test.go (about) 1 // Copyright 2017 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package main 16 17 import ( 18 "context" 19 "errors" 20 "flag" 21 "testing" 22 "time" 23 24 "github.com/golang/mock/gomock" 25 "github.com/golang/protobuf/proto" 26 "github.com/golang/protobuf/ptypes" 27 "github.com/golang/protobuf/ptypes/any" 28 "github.com/golang/protobuf/ptypes/empty" 29 "github.com/google/trillian" 30 "github.com/google/trillian/crypto/sigpb" 31 "github.com/google/trillian/testonly" 32 "github.com/google/trillian/util/flagsaver" 33 ) 34 35 // defaultTree reflects all flag defaults with the addition of a valid private key. 36 var defaultTree = &trillian.Tree{ 37 TreeState: trillian.TreeState_ACTIVE, 38 TreeType: trillian.TreeType_LOG, 39 HashStrategy: trillian.HashStrategy_RFC6962_SHA256, 40 HashAlgorithm: sigpb.DigitallySigned_SHA256, 41 SignatureAlgorithm: sigpb.DigitallySigned_ECDSA, 42 PrivateKey: mustMarshalAny(&empty.Empty{}), 43 MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond), 44 } 45 46 type testCase struct { 47 desc string 48 setFlags func() 49 validateErr error 50 createErr error 51 initErr error 52 wantErr bool 53 wantTree *trillian.Tree 54 } 55 56 func mustMarshalAny(p proto.Message) *any.Any { 57 anyKey, err := ptypes.MarshalAny(p) 58 if err != nil { 59 panic(err) 60 } 61 return anyKey 62 } 63 64 func TestCreateTree(t *testing.T) { 65 nonDefaultTree := *defaultTree 66 nonDefaultTree.TreeType = trillian.TreeType_MAP 67 nonDefaultTree.SignatureAlgorithm = sigpb.DigitallySigned_RSA 68 nonDefaultTree.DisplayName = "Llamas Map" 69 nonDefaultTree.Description = "For all your digital llama needs!" 70 71 runTest(t, []*testCase{ 72 { 73 desc: "validOpts", 74 // runTest sets mandatory options, so no need to provide a setFlags func. 75 wantTree: defaultTree, 76 }, 77 { 78 desc: "nonDefaultOpts", 79 setFlags: func() { 80 *treeType = nonDefaultTree.TreeType.String() 81 *signatureAlgorithm = nonDefaultTree.SignatureAlgorithm.String() 82 *displayName = nonDefaultTree.DisplayName 83 *description = nonDefaultTree.Description 84 }, 85 wantTree: &nonDefaultTree, 86 }, 87 { 88 desc: "mandatoryOptsNotSet", 89 // Undo the flags set by runTest, so that mandatory options are no longer set. 90 setFlags: resetFlags, 91 validateErr: errAdminAddrNotSet, 92 wantErr: true, 93 }, 94 { 95 desc: "emptyAddr", 96 setFlags: func() { *adminServerAddr = "" }, 97 validateErr: errAdminAddrNotSet, 98 wantErr: true, 99 }, 100 { 101 desc: "invalidEnumOpts", 102 setFlags: func() { *treeType = "LLAMA!" }, 103 validateErr: errors.New("unknown TreeType"), 104 wantErr: true, 105 }, 106 { 107 desc: "invalidKeyTypeOpts", 108 setFlags: func() { *privateKeyFormat = "LLAMA!!" }, 109 validateErr: errors.New("key protobuf must be one of"), 110 wantErr: true, 111 }, 112 { 113 desc: "createErr", 114 createErr: errors.New("create tree failed"), 115 wantErr: true, 116 }, 117 { 118 desc: "logInitErr", 119 setFlags: func() { 120 nonDefaultTree.TreeType = trillian.TreeType_LOG 121 *treeType = nonDefaultTree.TreeType.String() 122 }, 123 wantTree: defaultTree, 124 initErr: errors.New("log init failed"), 125 wantErr: true, 126 }, 127 { 128 desc: "mapInitErr", 129 setFlags: func() { 130 nonDefaultTree.TreeType = trillian.TreeType_MAP 131 *treeType = nonDefaultTree.TreeType.String() 132 }, 133 wantTree: &nonDefaultTree, 134 initErr: errors.New("map init failed"), 135 wantErr: true, 136 }, 137 }) 138 } 139 140 // runTest executes the createtree command against a fake TrillianAdminServer 141 // for each of the provided tests, and checks that the tree in the request is 142 // as expected, or an expected error occurs. 143 // Prior to each test case, it: 144 // 1. Resets all flags to their original values. 145 // 2. Sets the adminServerAddr flag to point to the fake server. 146 // 3. Calls the test's setFlags func (if provided) to allow it to change flags specific to the test. 147 func runTest(t *testing.T, tests []*testCase) { 148 for _, tc := range tests { 149 t.Run(tc.desc, func(t *testing.T) { 150 ctrl := gomock.NewController(t) 151 defer ctrl.Finish() 152 153 s, stopFakeServer, err := testonly.NewMockServer(ctrl) 154 if err != nil { 155 t.Fatalf("Error starting fake server: %v", err) 156 } 157 defer stopFakeServer() 158 defer flagsaver.Save().Restore() 159 *adminServerAddr = s.Addr 160 if tc.setFlags != nil { 161 tc.setFlags() 162 } 163 164 call := s.Admin.EXPECT().CreateTree(gomock.Any(), gomock.Any()).Return(tc.wantTree, tc.createErr) 165 expectCalls(call, tc.createErr, tc.validateErr) 166 switch *treeType { 167 case "LOG": 168 call := s.Log.EXPECT().InitLog(gomock.Any(), gomock.Any()).Return(&trillian.InitLogResponse{}, tc.initErr) 169 expectCalls(call, tc.initErr, tc.validateErr, tc.createErr) 170 call = s.Log.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(&trillian.GetLatestSignedLogRootResponse{}, nil) 171 expectCalls(call, nil, tc.validateErr, tc.createErr, tc.initErr) 172 case "MAP": 173 call := s.Map.EXPECT().InitMap(gomock.Any(), gomock.Any()).Return(&trillian.InitMapResponse{}, tc.initErr) 174 expectCalls(call, tc.initErr, tc.validateErr, tc.createErr) 175 call = s.Map.EXPECT().GetSignedMapRootByRevision(gomock.Any(), gomock.Any()).Return(&trillian.GetSignedMapRootResponse{}, nil) 176 expectCalls(call, nil, tc.validateErr, tc.createErr, tc.initErr) 177 } 178 179 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 180 defer cancel() 181 _, err = createTree(ctx) 182 if hasErr := err != nil; hasErr != tc.wantErr { 183 t.Errorf("createTree() '%v', wantErr = %v", err, tc.wantErr) 184 } 185 }) 186 } 187 } 188 189 // expectCalls returns the minimum number of times a function is expected to be called 190 // given the return error for the function (err), and all previous errors in the function's 191 // code path. 192 func expectCalls(call *gomock.Call, err error, prevErr ...error) *gomock.Call { 193 // If a function prior to this function errored, 194 // we do not expect this function to be called. 195 for _, e := range prevErr { 196 if e != nil { 197 return call.Times(0) 198 } 199 } 200 // If this function errors, it will be retried multiple times. 201 if err != nil { 202 return call.MinTimes(2) 203 } 204 // If this function succeeds it should only be called once. 205 return call.Times(1) 206 } 207 208 // resetFlags sets all flags to their default values. 209 func resetFlags() { 210 flag.Visit(func(f *flag.Flag) { 211 f.Value.Set(f.DefValue) 212 }) 213 }