github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/auth/check_admin_handler_test.go (about)

     1  package auth_test
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  
    10  	"github.com/pf-qiu/concourse/v6/atc/api/accessor"
    11  	"github.com/pf-qiu/concourse/v6/atc/api/accessor/accessorfakes"
    12  	"github.com/pf-qiu/concourse/v6/atc/api/auth"
    13  	"github.com/pf-qiu/concourse/v6/atc/api/auth/authfakes"
    14  	"github.com/pf-qiu/concourse/v6/atc/auditor/auditorfakes"
    15  
    16  	. "github.com/onsi/ginkgo"
    17  	. "github.com/onsi/gomega"
    18  )
    19  
    20  var _ = Describe("CheckAdminHandler", func() {
    21  	var (
    22  		fakeRejector *authfakes.FakeRejector
    23  		fakeAccessor *accessorfakes.FakeAccessFactory
    24  		fakeaccess   *accessorfakes.FakeAccess
    25  		server       *httptest.Server
    26  		client       *http.Client
    27  	)
    28  
    29  	simpleHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    30  		buffer := bytes.NewBufferString("simple ")
    31  
    32  		_, err := io.Copy(w, buffer)
    33  		Expect(err).ToNot(HaveOccurred())
    34  		_, err = io.Copy(w, r.Body)
    35  		Expect(err).ToNot(HaveOccurred())
    36  	})
    37  
    38  	BeforeEach(func() {
    39  		fakeRejector = new(authfakes.FakeRejector)
    40  		fakeAccessor = new(accessorfakes.FakeAccessFactory)
    41  		fakeaccess = new(accessorfakes.FakeAccess)
    42  
    43  		fakeRejector.UnauthorizedStub = func(w http.ResponseWriter, r *http.Request) {
    44  			http.Error(w, "nope", http.StatusUnauthorized)
    45  		}
    46  
    47  		fakeRejector.ForbiddenStub = func(w http.ResponseWriter, r *http.Request) {
    48  			http.Error(w, "still nope", http.StatusForbidden)
    49  		}
    50  
    51  		innerHandler := auth.CheckAdminHandler(
    52  			simpleHandler,
    53  			fakeRejector,
    54  		)
    55  
    56  		server = httptest.NewServer(accessor.NewHandler(
    57  			logger,
    58  			"some-action",
    59  			innerHandler,
    60  			fakeAccessor,
    61  			new(auditorfakes.FakeAuditor),
    62  			map[string]string{},
    63  		))
    64  
    65  		client = &http.Client{
    66  			Transport: &http.Transport{},
    67  		}
    68  	})
    69  
    70  	JustBeforeEach(func() {
    71  		fakeAccessor.CreateReturns(fakeaccess, nil)
    72  	})
    73  
    74  	Context("when a request is made", func() {
    75  		var request *http.Request
    76  		var response *http.Response
    77  
    78  		BeforeEach(func() {
    79  			var err error
    80  
    81  			request, err = http.NewRequest("GET", server.URL, bytes.NewBufferString("hello"))
    82  			Expect(err).NotTo(HaveOccurred())
    83  		})
    84  
    85  		JustBeforeEach(func() {
    86  			var err error
    87  
    88  			response, err = client.Do(request)
    89  			Expect(err).NotTo(HaveOccurred())
    90  		})
    91  
    92  		Context("when the validator returns true", func() {
    93  			BeforeEach(func() {
    94  				fakeaccess.IsAuthenticatedReturns(true)
    95  			})
    96  
    97  			Context("when is admin", func() {
    98  				BeforeEach(func() {
    99  					fakeaccess.IsAdminReturns(true)
   100  				})
   101  
   102  				It("returns 200 OK", func() {
   103  					Expect(response.StatusCode).To(Equal(http.StatusOK))
   104  				})
   105  
   106  				It("proxies to the handler", func() {
   107  					responseBody, err := ioutil.ReadAll(response.Body)
   108  					Expect(err).NotTo(HaveOccurred())
   109  					Expect(string(responseBody)).To(Equal("simple hello"))
   110  				})
   111  			})
   112  
   113  			Context("when is not admin", func() {
   114  				It("returns 403 Forbidden", func() {
   115  					Expect(response.StatusCode).To(Equal(http.StatusForbidden))
   116  				})
   117  			})
   118  		})
   119  
   120  		Context("when the validator returns false", func() {
   121  			BeforeEach(func() {
   122  				fakeaccess.IsAuthenticatedReturns(false)
   123  			})
   124  
   125  			It("rejects the request", func() {
   126  				Expect(response.StatusCode).To(Equal(http.StatusUnauthorized))
   127  				responseBody, err := ioutil.ReadAll(response.Body)
   128  				Expect(err).NotTo(HaveOccurred())
   129  				Expect(string(responseBody)).To(Equal("nope\n"))
   130  			})
   131  		})
   132  	})
   133  })