github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/auth/check_authorization_handler_test.go (about)

     1  package auth_test
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"net/url"
    10  
    11  	"github.com/pf-qiu/concourse/v6/atc/api/accessor"
    12  	"github.com/pf-qiu/concourse/v6/atc/api/accessor/accessorfakes"
    13  	"github.com/pf-qiu/concourse/v6/atc/api/auth"
    14  	"github.com/pf-qiu/concourse/v6/atc/api/auth/authfakes"
    15  	"github.com/pf-qiu/concourse/v6/atc/auditor/auditorfakes"
    16  
    17  	. "github.com/onsi/ginkgo"
    18  	. "github.com/onsi/gomega"
    19  )
    20  
    21  var _ = Describe("CheckAuthorizationHandler", func() {
    22  	var (
    23  		fakeAccessor *accessorfakes.FakeAccessFactory
    24  		fakeaccess   *accessorfakes.FakeAccess
    25  		fakeRejector *authfakes.FakeRejector
    26  
    27  		server *httptest.Server
    28  		client *http.Client
    29  	)
    30  
    31  	simpleHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    32  		buffer := bytes.NewBufferString("simple ")
    33  
    34  		io.Copy(w, buffer)
    35  		io.Copy(w, r.Body)
    36  	})
    37  
    38  	BeforeEach(func() {
    39  		fakeAccessor = new(accessorfakes.FakeAccessFactory)
    40  		fakeaccess = new(accessorfakes.FakeAccess)
    41  		fakeRejector = new(authfakes.FakeRejector)
    42  
    43  		fakeRejector.UnauthorizedStub = func(w http.ResponseWriter, r *http.Request) {
    44  			http.Error(w, "nope", http.StatusUnauthorized)
    45  		}
    46  
    47  		fakeRejector.ForbiddenStub = func(w http.ResponseWriter, r *http.Request) {
    48  			http.Error(w, "nope", http.StatusForbidden)
    49  		}
    50  
    51  		innerHandler := auth.CheckAuthorizationHandler(
    52  			simpleHandler,
    53  			fakeRejector,
    54  		)
    55  
    56  		server = httptest.NewServer(accessor.NewHandler(
    57  			logger,
    58  			"some-action",
    59  			innerHandler,
    60  			fakeAccessor,
    61  			new(auditorfakes.FakeAuditor),
    62  			map[string]string{},
    63  		))
    64  
    65  		client = &http.Client{
    66  			Transport: &http.Transport{},
    67  		}
    68  	})
    69  
    70  	JustBeforeEach(func() {
    71  		fakeAccessor.CreateReturns(fakeaccess, nil)
    72  	})
    73  
    74  	Context("when a request is made", func() {
    75  		var request *http.Request
    76  		var response *http.Response
    77  
    78  		BeforeEach(func() {
    79  			var err error
    80  			request, err = http.NewRequest("GET", server.URL+"/teams/some-team/pipelines", bytes.NewBufferString("hello"))
    81  			Expect(err).NotTo(HaveOccurred())
    82  			urlValues := url.Values{":team_name": []string{"some-team"}}
    83  			request.URL.RawQuery = urlValues.Encode()
    84  		})
    85  
    86  		JustBeforeEach(func() {
    87  			var err error
    88  
    89  			response, err = client.Do(request)
    90  			Expect(err).NotTo(HaveOccurred())
    91  		})
    92  
    93  		Context("when the request is authenticated", func() {
    94  			BeforeEach(func() {
    95  				fakeaccess.IsAuthenticatedReturns(true)
    96  			})
    97  
    98  			Context("when the bearer token's team matches the request's team", func() {
    99  				BeforeEach(func() {
   100  					fakeaccess.IsAuthorizedReturns(true)
   101  				})
   102  
   103  				It("returns 200", func() {
   104  					Expect(response.StatusCode).To(Equal(http.StatusOK))
   105  				})
   106  
   107  				It("proxies to the handler", func() {
   108  					responseBody, err := ioutil.ReadAll(response.Body)
   109  					Expect(err).NotTo(HaveOccurred())
   110  					Expect(string(responseBody)).To(Equal("simple hello"))
   111  				})
   112  			})
   113  
   114  			Context("when the bearer token's team is set to something other than the request's team", func() {
   115  				BeforeEach(func() {
   116  					fakeaccess.IsAuthorizedReturns(false)
   117  				})
   118  
   119  				It("returns 403", func() {
   120  					Expect(response.StatusCode).To(Equal(http.StatusForbidden))
   121  					responseBody, err := ioutil.ReadAll(response.Body)
   122  					Expect(err).NotTo(HaveOccurred())
   123  					Expect(string(responseBody)).To(Equal("nope\n"))
   124  				})
   125  			})
   126  		})
   127  
   128  		Context("when the request is not authenticated", func() {
   129  			BeforeEach(func() {
   130  				fakeaccess.IsAuthenticatedReturns(false)
   131  			})
   132  
   133  			It("returns 401", func() {
   134  				Expect(response.StatusCode).To(Equal(http.StatusUnauthorized))
   135  				responseBody, err := ioutil.ReadAll(response.Body)
   136  				Expect(err).NotTo(HaveOccurred())
   137  				Expect(string(responseBody)).To(Equal("nope\n"))
   138  			})
   139  		})
   140  	})
   141  })