github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/fabhttp/server_test.go (about)

     1  /*
     2  Copyright hechain All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package fabhttp_test
     8  
     9  import (
    10  	"crypto/tls"
    11  	"fmt"
    12  	"io/ioutil"
    13  	"net"
    14  	"net/http"
    15  	"os"
    16  	"path/filepath"
    17  	"syscall"
    18  
    19  	"github.com/hechain20/hechain/common/fabhttp"
    20  	"github.com/hechain20/hechain/core/operations/fakes"
    21  	. "github.com/onsi/ginkgo"
    22  	. "github.com/onsi/gomega"
    23  	"github.com/tedsuo/ifrit"
    24  )
    25  
    26  var _ = Describe("Server", func() {
    27  	const AdditionalTestApiPath = "/some-additional-test-api"
    28  
    29  	var (
    30  		fakeLogger *fakes.Logger
    31  		tempDir    string
    32  
    33  		client       *http.Client
    34  		unauthClient *http.Client
    35  		options      fabhttp.Options
    36  		server       *fabhttp.Server
    37  	)
    38  
    39  	BeforeEach(func() {
    40  		var err error
    41  		tempDir, err = ioutil.TempDir("", "fabhttp-test")
    42  		Expect(err).NotTo(HaveOccurred())
    43  
    44  		generateCertificates(tempDir)
    45  		client = newHTTPClient(tempDir, true)
    46  		unauthClient = newHTTPClient(tempDir, false)
    47  
    48  		fakeLogger = &fakes.Logger{}
    49  		options = fabhttp.Options{
    50  			Logger:        fakeLogger,
    51  			ListenAddress: "127.0.0.1:0",
    52  			TLS: fabhttp.TLS{
    53  				Enabled:            true,
    54  				CertFile:           filepath.Join(tempDir, "server-cert.pem"),
    55  				KeyFile:            filepath.Join(tempDir, "server-key.pem"),
    56  				ClientCertRequired: false,
    57  				ClientCACertFiles:  []string{filepath.Join(tempDir, "client-ca.pem")},
    58  			},
    59  		}
    60  
    61  		server = fabhttp.NewServer(options)
    62  	})
    63  
    64  	AfterEach(func() {
    65  		os.RemoveAll(tempDir)
    66  		if server != nil {
    67  			server.Stop()
    68  		}
    69  	})
    70  
    71  	When("trying to connect with an old TLS version", func() {
    72  		BeforeEach(func() {
    73  			tlsOpts := []func(config *tls.Config){func(config *tls.Config) {
    74  				config.MaxVersion = tls.VersionTLS11
    75  				config.ClientAuth = tls.RequireAndVerifyClientCert
    76  			}}
    77  
    78  			client = newHTTPClient(tempDir, true, tlsOpts...)
    79  		})
    80  
    81  		It("does not answer clients using an older TLS version than 1.2", func() {
    82  			server.RegisterHandler(AdditionalTestApiPath, &fakes.Handler{Code: http.StatusOK, Text: "secure"}, options.TLS.Enabled)
    83  			err := server.Start()
    84  			Expect(err).NotTo(HaveOccurred())
    85  
    86  			addApiURL := fmt.Sprintf("https://%s%s", server.Addr(), AdditionalTestApiPath)
    87  			_, err = client.Get(addApiURL)
    88  			Expect(err.Error()).To(ContainSubstring("tls: protocol version not supported"))
    89  		})
    90  	})
    91  
    92  	It("does not host a secure endpoint for additional APIs by default", func() {
    93  		err := server.Start()
    94  		Expect(err).NotTo(HaveOccurred())
    95  
    96  		addApiURL := fmt.Sprintf("https://%s%s", server.Addr(), AdditionalTestApiPath)
    97  		resp, err := client.Get(addApiURL)
    98  		Expect(err).NotTo(HaveOccurred())
    99  		Expect(resp.StatusCode).To(Equal(http.StatusNotFound)) // service is not handled by default, i.e. in peer
   100  		resp.Body.Close()
   101  
   102  		resp, err = unauthClient.Get(addApiURL)
   103  		Expect(err).NotTo(HaveOccurred())
   104  		Expect(resp.StatusCode).To(Equal(http.StatusNotFound))
   105  	})
   106  
   107  	It("hosts a secure endpoint for additional APIs when added", func() {
   108  		server.RegisterHandler(AdditionalTestApiPath, &fakes.Handler{Code: http.StatusOK, Text: "secure"}, options.TLS.Enabled)
   109  		err := server.Start()
   110  		Expect(err).NotTo(HaveOccurred())
   111  
   112  		addApiURL := fmt.Sprintf("https://%s%s", server.Addr(), AdditionalTestApiPath)
   113  		resp, err := client.Get(addApiURL)
   114  		Expect(err).NotTo(HaveOccurred())
   115  		Expect(resp.StatusCode).To(Equal(http.StatusOK))
   116  		Expect(resp.Header.Get("Content-Type")).To(Equal("text/plain; charset=utf-8"))
   117  		buff, err := ioutil.ReadAll(resp.Body)
   118  		Expect(err).NotTo(HaveOccurred())
   119  		Expect(string(buff)).To(Equal("secure"))
   120  		resp.Body.Close()
   121  
   122  		resp, err = unauthClient.Get(addApiURL)
   123  		Expect(err).NotTo(HaveOccurred())
   124  		Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
   125  	})
   126  
   127  	Context("when TLS is disabled", func() {
   128  		BeforeEach(func() {
   129  			options.TLS.Enabled = false
   130  			server = fabhttp.NewServer(options)
   131  		})
   132  
   133  		It("does not host an insecure endpoint for additional APIs by default", func() {
   134  			err := server.Start()
   135  			Expect(err).NotTo(HaveOccurred())
   136  
   137  			addApiURL := fmt.Sprintf("http://%s%s", server.Addr(), AdditionalTestApiPath)
   138  			resp, err := client.Get(addApiURL)
   139  			Expect(err).NotTo(HaveOccurred())
   140  			Expect(resp.StatusCode).To(Equal(http.StatusNotFound)) // service is not handled by default, i.e. in peer
   141  			resp.Body.Close()
   142  		})
   143  
   144  		It("hosts an insecure endpoint for additional APIs when added", func() {
   145  			server.RegisterHandler(AdditionalTestApiPath, &fakes.Handler{Code: http.StatusOK, Text: "insecure"}, options.TLS.Enabled)
   146  			err := server.Start()
   147  			Expect(err).NotTo(HaveOccurred())
   148  
   149  			addApiURL := fmt.Sprintf("http://%s%s", server.Addr(), AdditionalTestApiPath)
   150  			resp, err := client.Get(addApiURL)
   151  			Expect(err).NotTo(HaveOccurred())
   152  			Expect(resp.StatusCode).To(Equal(http.StatusOK))
   153  			Expect(resp.Header.Get("Content-Type")).To(Equal("text/plain; charset=utf-8"))
   154  			buff, err := ioutil.ReadAll(resp.Body)
   155  			Expect(err).NotTo(HaveOccurred())
   156  			Expect(string(buff)).To(Equal("insecure"))
   157  			resp.Body.Close()
   158  		})
   159  	})
   160  
   161  	Context("when ClientCertRequired is true", func() {
   162  		BeforeEach(func() {
   163  			options.TLS.ClientCertRequired = true
   164  			server = fabhttp.NewServer(options)
   165  		})
   166  
   167  		It("requires a client cert to connect", func() {
   168  			err := server.Start()
   169  			Expect(err).NotTo(HaveOccurred())
   170  
   171  			_, err = unauthClient.Get(fmt.Sprintf("https://%s/healthz", server.Addr()))
   172  			Expect(err).To(MatchError(ContainSubstring("remote error: tls: bad certificate")))
   173  		})
   174  	})
   175  
   176  	Context("when listen fails", func() {
   177  		var listener net.Listener
   178  
   179  		BeforeEach(func() {
   180  			var err error
   181  			listener, err = net.Listen("tcp", "127.0.0.1:0")
   182  			Expect(err).NotTo(HaveOccurred())
   183  
   184  			options.ListenAddress = listener.Addr().String()
   185  			server = fabhttp.NewServer(options)
   186  		})
   187  
   188  		AfterEach(func() {
   189  			listener.Close()
   190  		})
   191  
   192  		It("returns an error", func() {
   193  			err := server.Start()
   194  			Expect(err).To(MatchError(ContainSubstring("bind: address already in use")))
   195  		})
   196  	})
   197  
   198  	Context("when a bad TLS configuration is provided", func() {
   199  		BeforeEach(func() {
   200  			options.TLS.CertFile = "cert-file-does-not-exist"
   201  			server = fabhttp.NewServer(options)
   202  		})
   203  
   204  		It("returns an error", func() {
   205  			err := server.Start()
   206  			Expect(err).To(MatchError("open cert-file-does-not-exist: no such file or directory"))
   207  		})
   208  	})
   209  
   210  	It("proxies Log to the provided logger", func() {
   211  		err := server.Log("key", "value")
   212  		Expect(err).NotTo(HaveOccurred())
   213  
   214  		Expect(fakeLogger.WarnCallCount()).To(Equal(1))
   215  		Expect(fakeLogger.WarnArgsForCall(0)).To(Equal([]interface{}{"key", "value"}))
   216  	})
   217  
   218  	Context("when a logger is not provided", func() {
   219  		BeforeEach(func() {
   220  			options.Logger = nil
   221  			server = fabhttp.NewServer(options)
   222  		})
   223  
   224  		It("does not panic when logging", func() {
   225  			Expect(func() { server.Log("key", "value") }).NotTo(Panic())
   226  		})
   227  
   228  		It("returns nil from Log", func() {
   229  			err := server.Log("key", "value")
   230  			Expect(err).NotTo(HaveOccurred())
   231  		})
   232  	})
   233  
   234  	It("supports ifrit", func() {
   235  		process := ifrit.Invoke(server)
   236  		Eventually(process.Ready()).Should(BeClosed())
   237  
   238  		process.Signal(syscall.SIGTERM)
   239  		Eventually(process.Wait()).Should(Receive(BeNil()))
   240  	})
   241  
   242  	Context("when start fails and ifrit is used", func() {
   243  		BeforeEach(func() {
   244  			options.TLS.CertFile = "non-existent-file"
   245  			server = fabhttp.NewServer(options)
   246  		})
   247  
   248  		It("does not close the ready chan", func() {
   249  			process := ifrit.Invoke(server)
   250  			Consistently(process.Ready()).ShouldNot(BeClosed())
   251  			Eventually(process.Wait()).Should(Receive(MatchError("open non-existent-file: no such file or directory")))
   252  		})
   253  	})
   254  })