github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/cmd/updatetree/main_test.go (about) 1 // Copyright 2018 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package main 16 17 import ( 18 "context" 19 "errors" 20 "flag" 21 "testing" 22 "time" 23 24 "github.com/golang/mock/gomock" 25 "github.com/google/trillian" 26 "github.com/google/trillian/testonly" 27 "github.com/google/trillian/util/flagsaver" 28 ) 29 30 type testCase struct { 31 desc string 32 setFlags func() 33 updateErr error 34 wantRPC bool 35 updateTree *trillian.Tree 36 wantErr bool 37 wantState trillian.TreeState 38 } 39 40 func TestFreezeTree(t *testing.T) { 41 runTest(t, []*testCase{ 42 { 43 // We don't set the treeID in runTest so this should fail. 44 desc: "missingTreeID", 45 wantErr: true, 46 }, 47 { 48 desc: "mandatoryOptsNotSet", 49 // Undo the flags set by runTest, so that mandatory options are no longer set. 50 setFlags: resetFlags, 51 wantErr: true, 52 }, 53 { 54 desc: "validUpdateFrozen", 55 setFlags: func() { 56 *treeID = 12345 57 *treeState = "FROZEN" 58 }, 59 wantRPC: true, 60 updateTree: &trillian.Tree{ 61 TreeId: 12345, 62 TreeState: trillian.TreeState_FROZEN, 63 }, 64 wantState: trillian.TreeState_FROZEN, 65 }, 66 { 67 desc: "updateInvalidState", 68 setFlags: func() { 69 *treeID = 12345 70 *treeState = "ITSCOLDOUTSIDE" 71 }, 72 wantErr: true, 73 }, 74 { 75 desc: "unknownTree", 76 setFlags: func() { 77 *treeID = 123456 78 *treeState = "FROZEN" 79 }, 80 wantErr: true, 81 wantRPC: true, 82 updateErr: errors.New("unknown tree id"), 83 }, 84 { 85 desc: "emptyAddr", 86 setFlags: func() { 87 *adminServerAddr = "" 88 *treeID = 12345 89 *treeState = "FROZEN" 90 }, 91 wantErr: true, 92 }, 93 { 94 desc: "updateErr", 95 setFlags: func() { 96 *treeID = 12345 97 *treeState = "FROZEN" 98 }, 99 wantRPC: true, 100 updateErr: errors.New("update tree failed"), 101 wantErr: true, 102 }, 103 }) 104 } 105 106 // runTest executes the updateTree command against a fake TrillianAdminServer 107 // for each of the provided tests, and checks that the tree in the request is 108 // as expected, or an expected error occurs. 109 // Prior to each test case, it: 110 // 1. Resets all flags to their original values. 111 // 2. Sets the adminServerAddr flag to point to the fake server. 112 // 3. Calls the test's setFlags func (if provided) to allow it to change flags specific to the test. 113 func runTest(t *testing.T, tests []*testCase) { 114 for _, tc := range tests { 115 t.Run(tc.desc, func(t *testing.T) { 116 ctrl := gomock.NewController(t) 117 defer ctrl.Finish() 118 119 s, stopFakeServer, err := testonly.NewMockServer(ctrl) 120 if err != nil { 121 t.Fatalf("Error starting fake server: %v", err) 122 } 123 defer stopFakeServer() 124 defer flagsaver.Save().Restore() 125 *adminServerAddr = s.Addr 126 if tc.setFlags != nil { 127 tc.setFlags() 128 } 129 130 // We might not get as far as updating the tree on the admin server. 131 if tc.wantRPC { 132 call := s.Admin.EXPECT().UpdateTree(gomock.Any(), gomock.Any()).Return(tc.updateTree, tc.updateErr) 133 expectCalls(call, tc.updateErr) 134 } 135 136 ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) 137 defer cancel() 138 tree, err := updateTree(ctx) 139 if hasErr := err != nil; hasErr != tc.wantErr { 140 t.Errorf("updateTree() returned err = '%v', wantErr = %v", err, tc.wantErr) 141 return 142 } 143 144 if err == nil { 145 if got, want := tree.TreeState.String(), tc.wantState.String(); got != want { 146 t.Errorf("updated state incorrect got: %v want: %v", got, want) 147 } 148 } 149 }) 150 } 151 } 152 153 // expectCalls returns the minimum number of times a function is expected to be called 154 // given the return error for the function (err), and all previous errors in the function's 155 // code path. 156 func expectCalls(call *gomock.Call, err error, prevErr ...error) *gomock.Call { 157 // If a function prior to this function errored, 158 // we do not expect this function to be called. 159 for _, e := range prevErr { 160 if e != nil { 161 return call.Times(0) 162 } 163 } 164 // If this function errors, it might be retried multiple times. 165 if err != nil { 166 return call.MinTimes(1) 167 } 168 // If this function succeeds it should only be called once. 169 return call.Times(1) 170 } 171 172 // resetFlags sets all flags to their default values. 173 func resetFlags() { 174 flag.Visit(func(f *flag.Flag) { 175 f.Value.Set(f.DefValue) 176 }) 177 }