github.com/blockchain-gm/fabric-ca@v0.0.0-20200423072702-b2c40c7ac69c/lib/server_whitebox_test.go (about) 1 /* 2 Copyright IBM Corp. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package lib 8 9 import ( 10 "context" 11 "net/http" 12 "net/http/httptest" 13 "os" 14 "testing" 15 16 "github.com/cloudflare/cfssl/log" 17 "github.com/gorilla/mux" 18 cadb "github.com/hyperledger/fabric-ca/lib/server/db" 19 "github.com/hyperledger/fabric-ca/lib/server/metrics" 20 "github.com/hyperledger/fabric-ca/util" 21 "github.com/hyperledger/fabric/common/metrics/metricsfakes" 22 "github.com/jmoiron/sqlx" 23 . "github.com/onsi/gomega" 24 "github.com/stretchr/testify/assert" 25 ) 26 27 const ( 28 serverPort = 7060 29 affiliationName = "org1" 30 ) 31 32 // TestGetAffliation checks if there is one record for the 33 // affilition 'org1' in the database after starting the server 34 // two times. This test is to make sure server does not create 35 // duplicate affiliations in the database every time it is 36 // started. 37 func TestGetAffliation(t *testing.T) { 38 defer func() { 39 err := os.RemoveAll("../testdata/ca-cert.pem") 40 if err != nil { 41 t.Errorf("RemoveAll failed: %s", err) 42 } 43 err = os.RemoveAll("../testdata/fabric-ca-server.db") 44 if err != nil { 45 t.Errorf("RemoveAll failed: %s", err) 46 } 47 err = os.RemoveAll("../testdata/msp") 48 if err != nil { 49 t.Errorf("RemoveAll failed: %s", err) 50 } 51 }() 52 // Start the server at an available port (using port 0 will make OS to 53 // pick an available port) 54 srv := getServer(serverPort, testdataDir, "", -1, t) 55 56 err := srv.Start() 57 if err != nil { 58 t.Fatalf("Server start failed: %v", err) 59 } 60 err = srv.Stop() 61 if err != nil { 62 t.Fatalf("Server stop failed: %v", err) 63 } 64 65 err = srv.Start() 66 if err != nil { 67 t.Fatalf("Server start failed: %v", err) 68 } 69 defer func() { 70 err = srv.Stop() 71 if err != nil { 72 t.Errorf("Failed to stop server: %s", err) 73 } 74 }() 75 76 name := "org1.department1" 77 rows, err := srv.CA.registry.GetAllAffiliations(name) 78 if err != nil { 79 t.Fatalf("Failed to get affiliation %s: %v", affiliationName, err) 80 } 81 var count int 82 for rows.Next() { 83 count++ 84 } 85 if count != 1 { 86 t.Fatalf("Found 0 or more than one record for the affiliation %s in the database, expected 1 record", affiliationName) 87 } 88 } 89 90 func TestServerLogLevel(t *testing.T) { 91 var err error 92 93 srv := TestGetRootServer(t) 94 srv.Config.Debug = false 95 srv.Config.LogLevel = "info" 96 err = srv.Init(false) 97 util.FatalError(t, err, "Failed to init server with 'info' log level") 98 assert.Equal(t, log.Level, log.LevelInfo) 99 100 srv.Config.LogLevel = "Debug" 101 err = srv.Init(false) 102 util.FatalError(t, err, "Failed to init server 'debug' log level") 103 assert.Equal(t, log.Level, log.LevelDebug) 104 105 srv.Config.LogLevel = "warning" 106 err = srv.Init(false) 107 util.FatalError(t, err, "Failed to init server with 'warning' log level") 108 assert.Equal(t, log.Level, log.LevelWarning) 109 110 srv.Config.LogLevel = "critical" 111 err = srv.Init(false) 112 util.FatalError(t, err, "Failed to init server with 'critical' log level") 113 assert.Equal(t, log.Level, log.LevelCritical) 114 115 srv.Config.LogLevel = "fatal" 116 err = srv.Init(false) 117 util.FatalError(t, err, "Failed to init server with 'fatal' log level") 118 assert.Equal(t, log.Level, log.LevelFatal) 119 120 srv.Config.Debug = true 121 err = srv.Init(false) 122 assert.Error(t, err, "Should fail, can't specify a log level and set debug true at same time") 123 } 124 125 func TestServerMetrics(t *testing.T) { 126 gt := NewGomegaWithT(t) 127 128 se := &serverEndpoint{ 129 Path: "/test", 130 } 131 132 router := mux.NewRouter() 133 router.Handle(se.Path, se).Name(se.Path) 134 135 fakeCounter := &metricsfakes.Counter{} 136 fakeCounter.WithReturns(fakeCounter) 137 fakeHist := &metricsfakes.Histogram{} 138 fakeHist.WithReturns(fakeHist) 139 server := &Server{ 140 CA: CA{ 141 Config: &CAConfig{ 142 CA: CAInfo{ 143 Name: "ca1", 144 }, 145 }, 146 }, 147 Metrics: metrics.Metrics{ 148 APICounter: fakeCounter, 149 APIDuration: fakeHist, 150 }, 151 mux: router, 152 } 153 154 server.mux.Use(server.middleware) 155 se.Server = server 156 157 req, err := http.NewRequest("GET", "/test", nil) 158 gt.Expect(err).NotTo(HaveOccurred()) 159 160 rr := httptest.NewRecorder() 161 router.ServeHTTP(rr, req) 162 gt.Expect(fakeCounter.AddCallCount()).To(Equal(1)) 163 gt.Expect(fakeCounter.WithArgsForCall(0)).NotTo(BeZero()) 164 gt.Expect(fakeCounter.WithArgsForCall(0)).To(Equal([]string{"ca_name", "ca1", "api_name", "/test", "status_code", "405"})) 165 166 gt.Expect(fakeHist.ObserveCallCount()).To(Equal(1)) 167 gt.Expect(fakeHist.WithArgsForCall(0)).NotTo(BeZero()) 168 gt.Expect(fakeHist.WithArgsForCall(0)).To(Equal([]string{"ca_name", "ca1", "api_name", "/test", "status_code", "405"})) 169 } 170 171 func TestServerHealthCheck(t *testing.T) { 172 srv := TestGetRootServer(t) 173 174 os.Mkdir("./.tmpDir", 0755) 175 176 dataSource := "./.tmpDir/sqlite.db" 177 srv.CA.Config.DB.Datasource = dataSource 178 defer os.RemoveAll("./.tmpDir") 179 180 db, err := sqlx.Open("sqlite3", dataSource) 181 assert.NoError(t, err) 182 183 srv.CA.db = &cadb.DB{DB: db, IsDBInitialized: false} 184 185 err = srv.HealthCheck(context.Background()) 186 assert.NoError(t, err) 187 188 err = srv.db.Close() 189 assert.NoError(t, err) 190 191 err = srv.HealthCheck(context.Background()) 192 assert.EqualError(t, err, "sql: database is closed") 193 } 194 195 func TestCORS(t *testing.T) { 196 tests := []struct { 197 cors CORS 198 origin string 199 expectHeader bool 200 }{ 201 { 202 cors: CORS{ 203 Enabled: false, 204 }, 205 origin: "badorigin.com", 206 expectHeader: false, 207 }, 208 { 209 cors: CORS{ 210 Enabled: true, 211 Origins: []string{"goodorigin.com"}, 212 }, 213 origin: "goodorigin.com", 214 expectHeader: true, 215 }, 216 { 217 cors: CORS{ 218 Enabled: true, 219 Origins: []string{"goodorigin.com"}, 220 }, 221 origin: "badorigin.com", 222 expectHeader: false, 223 }, 224 } 225 226 for _, test := range tests { 227 _test := test 228 t.Run("", func(t *testing.T) { 229 s := &Server{ 230 Config: &ServerConfig{ 231 CORS: _test.cors, 232 }, 233 } 234 handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 235 rw.WriteHeader(http.StatusOK) 236 }) 237 req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) 238 req.Header.Set("Origin", _test.origin) 239 rw := httptest.NewRecorder() 240 s.cors(handler).ServeHTTP(rw, req) 241 res := rw.Result() 242 for k, v := range res.Header { 243 t.Logf("%s : %s", k, v) 244 } 245 _, ok := res.Header["Access-Control-Allow-Origin"] 246 assert.Equal(t, _test.expectHeader, ok) 247 }) 248 } 249 250 }