github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/worker/transport/baggageclaim_round_tripper_test.go (about)

     1  package transport_test
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"net/url"
     7  	"time"
     8  
     9  	"github.com/pf-qiu/concourse/v6/atc/db/dbfakes"
    10  	"github.com/pf-qiu/concourse/v6/atc/worker/transport"
    11  	"github.com/pf-qiu/concourse/v6/atc/worker/transport/transportfakes"
    12  	"github.com/concourse/retryhttp/retryhttpfakes"
    13  
    14  	"github.com/pf-qiu/concourse/v6/atc/db"
    15  	. "github.com/onsi/ginkgo"
    16  	. "github.com/onsi/gomega"
    17  )
    18  
    19  var _ = Describe("BaggageclaimRoundTripper #RoundTrip", func() {
    20  	var (
    21  		request          http.Request
    22  		fakeDB           *transportfakes.FakeTransportDB
    23  		fakeRoundTripper *retryhttpfakes.FakeRoundTripper
    24  		roundTripper     http.RoundTripper
    25  		response         *http.Response
    26  		err              error
    27  	)
    28  
    29  	BeforeEach(func() {
    30  		fakeDB = new(transportfakes.FakeTransportDB)
    31  		fakeRoundTripper = new(retryhttpfakes.FakeRoundTripper)
    32  		workerBaggageClaimURL := "http://1.2.3.4:7878"
    33  		roundTripper = transport.NewBaggageclaimRoundTripper("some-worker", &workerBaggageClaimURL, fakeDB, fakeRoundTripper)
    34  		requestUrl, err := url.Parse("/something")
    35  		Expect(err).NotTo(HaveOccurred())
    36  
    37  		request = http.Request{
    38  			URL: requestUrl,
    39  		}
    40  
    41  		fakeRoundTripper.RoundTripReturns(&http.Response{StatusCode: http.StatusTeapot}, nil)
    42  	})
    43  
    44  	JustBeforeEach(func() {
    45  		response, err = roundTripper.RoundTrip(&request)
    46  	})
    47  
    48  	It("returns the response", func() {
    49  		Expect(err).NotTo(HaveOccurred())
    50  		Expect(response).To(Equal(&http.Response{StatusCode: http.StatusTeapot}))
    51  	})
    52  
    53  	It("sends the request with worker's baggageclaim url", func() {
    54  		Expect(fakeRoundTripper.RoundTripCallCount()).To(Equal(1))
    55  		actualRequest := fakeRoundTripper.RoundTripArgsForCall(0)
    56  		Expect(actualRequest.URL.Scheme).To(Equal("http"))
    57  		Expect(actualRequest.URL.Host).To(Equal("1.2.3.4:7878"))
    58  		Expect(actualRequest.URL.Path).To(Equal("/something"))
    59  	})
    60  
    61  	It("reuses the request cached host on subsequent calls", func() {
    62  		Expect(fakeDB.GetWorkerCallCount()).To(Equal(0))
    63  		_, err := roundTripper.RoundTrip(&request)
    64  		Expect(err).NotTo(HaveOccurred())
    65  		Expect(fakeDB.GetWorkerCallCount()).To(Equal(0))
    66  	})
    67  
    68  	Context("when inner roundtrip fails", func() {
    69  		BeforeEach(func() {
    70  			fakeRoundTripper.RoundTripReturns(nil, errors.New("some-error"))
    71  
    72  			bcURL := "http://5.6.7.8:7878"
    73  			savedWorker := new(dbfakes.FakeWorker)
    74  			savedWorker.BaggageclaimURLReturns(&bcURL)
    75  			savedWorker.ExpiresAtReturns(time.Now().Add(123 * time.Minute))
    76  			savedWorker.StateReturns(db.WorkerStateRunning)
    77  
    78  			fakeDB.GetWorkerReturns(savedWorker, true, nil)
    79  		})
    80  
    81  		It("updates cached request host on subsequent call", func() {
    82  			Expect(err).To(HaveOccurred())
    83  			Expect(err.Error()).To(ContainSubstring("some-error"))
    84  
    85  			Expect(fakeRoundTripper.RoundTripCallCount()).To(Equal(1))
    86  			actualRequest := fakeRoundTripper.RoundTripArgsForCall(0)
    87  			Expect(actualRequest.URL.Host).To(Equal("1.2.3.4:7878"))
    88  			Expect(fakeDB.GetWorkerCallCount()).To(Equal(0))
    89  
    90  			_, err := roundTripper.RoundTrip(&request)
    91  			Expect(err).To(HaveOccurred())
    92  
    93  			Expect(fakeDB.GetWorkerCallCount()).To(Equal(1))
    94  			Expect(fakeRoundTripper.RoundTripCallCount()).To(Equal(2))
    95  			actualRequest = fakeRoundTripper.RoundTripArgsForCall(1)
    96  			Expect(actualRequest.URL.Host).To(Equal("5.6.7.8:7878"))
    97  		})
    98  
    99  		Context("when the lookup of the worker in the db errors", func() {
   100  			var expectedErr error
   101  			BeforeEach(func() {
   102  				expectedErr = errors.New("some-db-error")
   103  				fakeDB.GetWorkerReturns(nil, true, expectedErr)
   104  			})
   105  
   106  			It("throws an error", func() {
   107  				_, err := roundTripper.RoundTrip(&request)
   108  				Expect(err).To(HaveOccurred())
   109  				Expect(err.Error()).To(ContainSubstring(expectedErr.Error()))
   110  			})
   111  		})
   112  
   113  		Context("when the worker is in the DB and the baggageclaim URL is empty", func() {
   114  			BeforeEach(func() {
   115  
   116  				runningWorker := new(dbfakes.FakeWorker)
   117  				runningWorker.StateReturns(db.WorkerStateStalled)
   118  				runningWorker.BaggageclaimURLReturns(nil)
   119  
   120  				fakeDB.GetWorkerReturns(runningWorker, true, nil)
   121  			})
   122  
   123  			It("throws a descriptive error", func() {
   124  				_, err := roundTripper.RoundTrip(&request)
   125  				Expect(err).To(MatchError("worker 'some-worker' is unreachable (state is 'stalled')"))
   126  			})
   127  		})
   128  
   129  		Context("when the worker is not found in the db", func() {
   130  			BeforeEach(func() {
   131  				fakeDB.GetWorkerReturns(nil, false, nil)
   132  			})
   133  
   134  			It("throws an error", func() {
   135  				_, err := roundTripper.RoundTrip(&request)
   136  				Expect(err).To(HaveOccurred())
   137  				Expect(err).To(Equal(transport.WorkerMissingError{WorkerName: "some-worker"}))
   138  			})
   139  		})
   140  	})
   141  })