github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/cmd/updatetree/main_test.go (about)

     1  // Copyright 2018 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"flag"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/golang/mock/gomock"
    25  	"github.com/google/trillian"
    26  	"github.com/google/trillian/testonly"
    27  	"github.com/google/trillian/util/flagsaver"
    28  )
    29  
    30  type testCase struct {
    31  	desc       string
    32  	setFlags   func()
    33  	updateErr  error
    34  	wantRPC    bool
    35  	updateTree *trillian.Tree
    36  	wantErr    bool
    37  	wantState  trillian.TreeState
    38  }
    39  
    40  func TestFreezeTree(t *testing.T) {
    41  	runTest(t, []*testCase{
    42  		{
    43  			// We don't set the treeID in runTest so this should fail.
    44  			desc:    "missingTreeID",
    45  			wantErr: true,
    46  		},
    47  		{
    48  			desc: "mandatoryOptsNotSet",
    49  			// Undo the flags set by runTest, so that mandatory options are no longer set.
    50  			setFlags: resetFlags,
    51  			wantErr:  true,
    52  		},
    53  		{
    54  			desc: "validUpdateFrozen",
    55  			setFlags: func() {
    56  				*treeID = 12345
    57  				*treeState = "FROZEN"
    58  			},
    59  			wantRPC: true,
    60  			updateTree: &trillian.Tree{
    61  				TreeId:    12345,
    62  				TreeState: trillian.TreeState_FROZEN,
    63  			},
    64  			wantState: trillian.TreeState_FROZEN,
    65  		},
    66  		{
    67  			desc: "updateInvalidState",
    68  			setFlags: func() {
    69  				*treeID = 12345
    70  				*treeState = "ITSCOLDOUTSIDE"
    71  			},
    72  			wantErr: true,
    73  		},
    74  		{
    75  			desc: "unknownTree",
    76  			setFlags: func() {
    77  				*treeID = 123456
    78  				*treeState = "FROZEN"
    79  			},
    80  			wantErr:   true,
    81  			wantRPC:   true,
    82  			updateErr: errors.New("unknown tree id"),
    83  		},
    84  		{
    85  			desc: "emptyAddr",
    86  			setFlags: func() {
    87  				*adminServerAddr = ""
    88  				*treeID = 12345
    89  				*treeState = "FROZEN"
    90  			},
    91  			wantErr: true,
    92  		},
    93  		{
    94  			desc: "updateErr",
    95  			setFlags: func() {
    96  				*treeID = 12345
    97  				*treeState = "FROZEN"
    98  			},
    99  			wantRPC:   true,
   100  			updateErr: errors.New("update tree failed"),
   101  			wantErr:   true,
   102  		},
   103  	})
   104  }
   105  
   106  // runTest executes the updateTree command against a fake TrillianAdminServer
   107  // for each of the provided tests, and checks that the tree in the request is
   108  // as expected, or an expected error occurs.
   109  // Prior to each test case, it:
   110  // 1. Resets all flags to their original values.
   111  // 2. Sets the adminServerAddr flag to point to the fake server.
   112  // 3. Calls the test's setFlags func (if provided) to allow it to change flags specific to the test.
   113  func runTest(t *testing.T, tests []*testCase) {
   114  	for _, tc := range tests {
   115  		t.Run(tc.desc, func(t *testing.T) {
   116  			ctrl := gomock.NewController(t)
   117  			defer ctrl.Finish()
   118  
   119  			s, stopFakeServer, err := testonly.NewMockServer(ctrl)
   120  			if err != nil {
   121  				t.Fatalf("Error starting fake server: %v", err)
   122  			}
   123  			defer stopFakeServer()
   124  			defer flagsaver.Save().Restore()
   125  			*adminServerAddr = s.Addr
   126  			if tc.setFlags != nil {
   127  				tc.setFlags()
   128  			}
   129  
   130  			// We might not get as far as updating the tree on the admin server.
   131  			if tc.wantRPC {
   132  				call := s.Admin.EXPECT().UpdateTree(gomock.Any(), gomock.Any()).Return(tc.updateTree, tc.updateErr)
   133  				expectCalls(call, tc.updateErr)
   134  			}
   135  
   136  			ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
   137  			defer cancel()
   138  			tree, err := updateTree(ctx)
   139  			if hasErr := err != nil; hasErr != tc.wantErr {
   140  				t.Errorf("updateTree() returned err = '%v', wantErr = %v", err, tc.wantErr)
   141  				return
   142  			}
   143  
   144  			if err == nil {
   145  				if got, want := tree.TreeState.String(), tc.wantState.String(); got != want {
   146  					t.Errorf("updated state incorrect got: %v want: %v", got, want)
   147  				}
   148  			}
   149  		})
   150  	}
   151  }
   152  
   153  // expectCalls returns the minimum number of times a function is expected to be called
   154  // given the return error for the function (err), and all previous errors in the function's
   155  // code path.
   156  func expectCalls(call *gomock.Call, err error, prevErr ...error) *gomock.Call {
   157  	// If a function prior to this function errored,
   158  	// we do not expect this function to be called.
   159  	for _, e := range prevErr {
   160  		if e != nil {
   161  			return call.Times(0)
   162  		}
   163  	}
   164  	// If this function errors, it might be retried multiple times.
   165  	if err != nil {
   166  		return call.MinTimes(1)
   167  	}
   168  	// If this function succeeds it should only be called once.
   169  	return call.Times(1)
   170  }
   171  
   172  // resetFlags sets all flags to their default values.
   173  func resetFlags() {
   174  	flag.Visit(func(f *flag.Flag) {
   175  		f.Value.Set(f.DefValue)
   176  	})
   177  }