github.com/openshift-online/ocm-sdk-go@v0.1.473/internal/client_selector_test.go (about)

     1  /*
     2  Copyright (c) 2021 Red Hat, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8    http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package internal
    18  
    19  import (
    20  	"context"
    21  	"crypto/x509"
    22  	"fmt"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"strings"
    26  
    27  	. "github.com/onsi/ginkgo/v2/dsl/core" // nolint
    28  	. "github.com/onsi/gomega"             // nolint
    29  )
    30  
    31  var _ = Describe("Create client selector", func() {
    32  	It("Can't be created without a logger", func() {
    33  		selector, err := NewClientSelector().Build(context.Background())
    34  		Expect(err).To(HaveOccurred())
    35  		Expect(selector).To(BeNil())
    36  		message := err.Error()
    37  		Expect(message).To(ContainSubstring("logger"))
    38  		Expect(message).To(ContainSubstring("mandatory"))
    39  	})
    40  })
    41  
    42  var _ = Describe("Select client", func() {
    43  	var (
    44  		ctx      context.Context
    45  		selector *ClientSelector
    46  	)
    47  
    48  	BeforeEach(func() {
    49  		var err error
    50  
    51  		// Create a context:
    52  		ctx = context.Background()
    53  
    54  		// Create the selector:
    55  		selector, err = NewClientSelector().
    56  			Logger(logger).
    57  			Build(ctx)
    58  		Expect(err).ToNot(HaveOccurred())
    59  		Expect(selector).ToNot(BeNil())
    60  	})
    61  
    62  	AfterEach(func() {
    63  		// Close the selector:
    64  		err := selector.Close()
    65  		Expect(err).ToNot(HaveOccurred())
    66  	})
    67  
    68  	It("Reuses client for same TCP address", func() {
    69  		address, err := ParseServerAddress(ctx, "tcp://my.server.com")
    70  		Expect(err).ToNot(HaveOccurred())
    71  		firstClient, err := selector.Select(ctx, address)
    72  		Expect(err).ToNot(HaveOccurred())
    73  		secondClient, err := selector.Select(ctx, address)
    74  		Expect(err).ToNot(HaveOccurred())
    75  		Expect(secondClient).To(BeIdenticalTo(firstClient))
    76  	})
    77  
    78  	It("Doesn't reuse client for different TCP addresses", func() {
    79  		firstAddress, err := ParseServerAddress(ctx, "tcp://my.server.com")
    80  		Expect(err).ToNot(HaveOccurred())
    81  		secondAddress, err := ParseServerAddress(ctx, "tcp://your.server.com")
    82  		Expect(err).ToNot(HaveOccurred())
    83  		firstClient, err := selector.Select(ctx, firstAddress)
    84  		Expect(err).ToNot(HaveOccurred())
    85  		secondClient, err := selector.Select(ctx, secondAddress)
    86  		Expect(err).ToNot(HaveOccurred())
    87  		Expect(secondClient == firstClient).To(BeFalse())
    88  	})
    89  
    90  	It("Reuses client for different TCP protocols", func() {
    91  		firstAddress, err := ParseServerAddress(ctx, "http://my.server.com")
    92  		Expect(err).ToNot(HaveOccurred())
    93  		secondAddress, err := ParseServerAddress(ctx, "https://my.server.com")
    94  		Expect(err).ToNot(HaveOccurred())
    95  		firstClient, err := selector.Select(ctx, firstAddress)
    96  		Expect(err).ToNot(HaveOccurred())
    97  		secondClient, err := selector.Select(ctx, secondAddress)
    98  		Expect(err).ToNot(HaveOccurred())
    99  		Expect(secondClient == firstClient).To(BeTrue())
   100  	})
   101  
   102  	It("Doesn't resuse client for different Unix sockets", func() {
   103  		firstAddress, err := ParseServerAddress(ctx, "unix://my.server.com/my.socket")
   104  		Expect(err).ToNot(HaveOccurred())
   105  		secondAddress, err := ParseServerAddress(ctx, "unix://my.server.com/your.socket")
   106  		Expect(err).ToNot(HaveOccurred())
   107  		firstClient, err := selector.Select(ctx, firstAddress)
   108  		Expect(err).ToNot(HaveOccurred())
   109  		secondClient, err := selector.Select(ctx, secondAddress)
   110  		Expect(err).ToNot(HaveOccurred())
   111  		Expect(secondClient == firstClient).To(BeFalse())
   112  	})
   113  })
   114  
   115  var _ = Describe("Redirect Behavior", func() {
   116  	var (
   117  		ctx                  context.Context
   118  		selector             *ClientSelector
   119  		originServer         *httptest.Server
   120  		responseServer       *httptest.Server
   121  		expectedResponseBody string
   122  	)
   123  
   124  	BeforeEach(func() {
   125  		var err error
   126  
   127  		// Create a context:
   128  		ctx = context.Background()
   129  
   130  		expectedResponseBody = "myServerDotComRedirect"
   131  
   132  		responseServer = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   133  			//nolint
   134  			fmt.Fprintf(w, expectedResponseBody)
   135  		}))
   136  
   137  		// simulate a redirect to a different domain by responding with a localhost url rather than a 127.0.0.1 url
   138  		redirectURL := strings.Replace(responseServer.URL, "127.0.0.1", "localhost", 1)
   139  		originServer = httptest.NewTLSServer(http.RedirectHandler(redirectURL, http.StatusMovedPermanently))
   140  
   141  		cas := x509.NewCertPool()
   142  		cas.AddCert(responseServer.Certificate())
   143  		cas.AddCert(originServer.Certificate())
   144  
   145  		// Create the selector:
   146  		selector, err = NewClientSelector().
   147  			TrustedCAs(cas).
   148  			Insecure(true). //need insecure when using "localhost" to connect or you get TLS verification errors
   149  			Logger(logger).
   150  			Build(ctx)
   151  		Expect(err).ToNot(HaveOccurred())
   152  		Expect(selector).ToNot(BeNil())
   153  	})
   154  
   155  	AfterEach(func() {
   156  		defer responseServer.Close()
   157  		defer originServer.Close()
   158  
   159  		// Close the selector:
   160  		err := selector.Close()
   161  		Expect(err).ToNot(HaveOccurred())
   162  	})
   163  
   164  	It("Doesn't re-use origin host for redirect", func() {
   165  		address, err := ParseServerAddress(ctx, originServer.URL)
   166  		Expect(err).ToNot(HaveOccurred())
   167  
   168  		client, err := selector.Select(ctx, address)
   169  		Expect(err).ToNot(HaveOccurred())
   170  
   171  		resp, err := client.Get(originServer.URL)
   172  		Expect(err).ToNot(HaveOccurred())
   173  		Expect(resp.TLS.ServerName).To(Equal("localhost"))
   174  
   175  		body := make([]byte, len(expectedResponseBody))
   176  		_, _ = resp.Body.Read(body)
   177  		Expect(string(body)).To(Equal("myServerDotComRedirect"))
   178  	})
   179  })