github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/policychecker/checker_test.go (about)

     1  package policychecker_test
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  
    10  	"github.com/pf-qiu/concourse/v6/atc/api/accessor"
    11  	"github.com/pf-qiu/concourse/v6/atc/api/accessor/accessorfakes"
    12  	"github.com/pf-qiu/concourse/v6/atc/api/policychecker"
    13  	"github.com/pf-qiu/concourse/v6/atc/policy"
    14  	"github.com/pf-qiu/concourse/v6/atc/policy/policyfakes"
    15  
    16  	. "github.com/onsi/ginkgo"
    17  	. "github.com/onsi/gomega"
    18  )
    19  
    20  var _ = Describe("PolicyChecker", func() {
    21  	var (
    22  		policyFilter policy.Filter
    23  		fakeAccess   *accessorfakes.FakeAccess
    24  		fakeRequest  *http.Request
    25  		result       policy.PolicyCheckOutput
    26  		checkErr     error
    27  	)
    28  
    29  	BeforeEach(func() {
    30  		fakeAccess = new(accessorfakes.FakeAccess)
    31  		fakePolicyAgent = new(policyfakes.FakeAgent)
    32  		fakePolicyAgentFactory.NewAgentReturns(fakePolicyAgent, nil)
    33  
    34  		policyFilter = policy.Filter{
    35  			ActionsToSkip: []string{},
    36  			Actions:       []string{},
    37  			HttpMethods:   []string{},
    38  		}
    39  	})
    40  
    41  	JustBeforeEach(func() {
    42  		policyCheck, err := policy.Initialize(testLogger, "some-cluster", "some-version", policyFilter)
    43  		Expect(err).ToNot(HaveOccurred())
    44  		Expect(policyCheck).ToNot(BeNil())
    45  		result, checkErr = policychecker.NewApiPolicyChecker(policyCheck).Check("some-action", fakeAccess, fakeRequest)
    46  	})
    47  
    48  	Context("when system action", func() {
    49  		BeforeEach(func() {
    50  			fakeAccess.IsSystemReturns(true)
    51  		})
    52  		It("should pass", func() {
    53  			Expect(checkErr).ToNot(HaveOccurred())
    54  			Expect(result.Allowed).To(BeTrue())
    55  		})
    56  		It("Agent should not be called", func() {
    57  			Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
    58  		})
    59  	})
    60  
    61  	Context("when not system action", func() {
    62  		BeforeEach(func() {
    63  			fakeAccess.IsSystemReturns(false)
    64  		})
    65  
    66  		Context("when the action should be skipped", func() {
    67  			BeforeEach(func() {
    68  				policyFilter.ActionsToSkip = []string{"some-action"}
    69  			})
    70  			It("should pass", func() {
    71  				Expect(checkErr).ToNot(HaveOccurred())
    72  				Expect(result.Allowed).To(BeTrue())
    73  			})
    74  			It("Agent should not be called", func() {
    75  				Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
    76  			})
    77  		})
    78  
    79  		Context("when the http method no need to check", func() {
    80  			BeforeEach(func() {
    81  				fakeRequest = httptest.NewRequest("GET", "/something", nil)
    82  				policyFilter.HttpMethods = []string{"PUT"}
    83  			})
    84  			It("should pass", func() {
    85  				Expect(checkErr).ToNot(HaveOccurred())
    86  				Expect(result.Allowed).To(BeTrue())
    87  			})
    88  			It("Agent should not be called", func() {
    89  				Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
    90  			})
    91  		})
    92  
    93  		Context("when not in action list", func() {
    94  			BeforeEach(func() {
    95  				fakeRequest = httptest.NewRequest("PUT", "/something", nil)
    96  				policyFilter.HttpMethods = []string{}
    97  				policyFilter.Actions = []string{}
    98  			})
    99  			It("should pass", func() {
   100  				Expect(checkErr).ToNot(HaveOccurred())
   101  				Expect(result.Allowed).To(BeTrue())
   102  			})
   103  			It("Agent should not be called", func() {
   104  				Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
   105  			})
   106  		})
   107  
   108  		Context("when the http method needs to check", func() {
   109  			BeforeEach(func() {
   110  				fakeRequest = httptest.NewRequest("PUT", "/something", nil)
   111  				policyFilter.HttpMethods = []string{"PUT"}
   112  			})
   113  
   114  			Context("when request body is a bad json", func() {
   115  				BeforeEach(func() {
   116  					body := bytes.NewBuffer([]byte("hello"))
   117  					fakeRequest = httptest.NewRequest("PUT", "/something", body)
   118  					fakeRequest.Header.Add("Content-type", "application/json")
   119  				})
   120  
   121  				It("should error", func() {
   122  					Expect(checkErr).To(HaveOccurred())
   123  					Expect(checkErr.Error()).To(Equal(`invalid character 'h' looking for beginning of value`))
   124  					Expect(result.Allowed).To(BeFalse())
   125  				})
   126  				It("Agent should not be called", func() {
   127  					Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
   128  				})
   129  			})
   130  
   131  			Context("when request body is a bad yaml", func() {
   132  				BeforeEach(func() {
   133  					body := bytes.NewBuffer([]byte("a:\nb"))
   134  					fakeRequest = httptest.NewRequest("PUT", "/something", body)
   135  					fakeRequest.Header.Add("Content-type", "application/x-yaml")
   136  				})
   137  
   138  				It("should error", func() {
   139  					Expect(checkErr).To(HaveOccurred())
   140  					Expect(checkErr.Error()).To(Equal(`error converting YAML to JSON: yaml: line 3: could not find expected ':'`))
   141  					Expect(result.Allowed).To(BeFalse())
   142  				})
   143  				It("Agent should not be called", func() {
   144  					Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
   145  				})
   146  			})
   147  
   148  			Context("when every is ok", func() {
   149  				BeforeEach(func() {
   150  					fakeAccess.TeamRolesReturns(map[string][]string{
   151  						"some-team": []string{"some-role"},
   152  					})
   153  					fakeAccess.ClaimsReturns(accessor.Claims{UserName: "some-user"})
   154  					body := bytes.NewBuffer([]byte("a: b"))
   155  					fakeRequest = httptest.NewRequest("PUT", "/something?:team_name=some-team&:pipeline_name=some-pipeline", body)
   156  					fakeRequest.Header.Add("Content-type", "application/x-yaml")
   157  					fakeRequest.ParseForm()
   158  				})
   159  
   160  				It("should not error", func() {
   161  					Expect(checkErr).ToNot(HaveOccurred())
   162  				})
   163  				It("Agent should be called", func() {
   164  					Expect(fakePolicyAgent.CheckCallCount()).To(Equal(1))
   165  				})
   166  				It("Agent should take correct input", func() {
   167  					Expect(fakePolicyAgent.CheckArgsForCall(0)).To(Equal(policy.PolicyCheckInput{
   168  						Service:        "concourse",
   169  						ClusterName:    "some-cluster",
   170  						ClusterVersion: "some-version",
   171  						HttpMethod:     "PUT",
   172  						Action:         "some-action",
   173  						User:           "some-user",
   174  						Team:           "some-team",
   175  						Roles:          []string{"some-role"},
   176  						Pipeline:       "some-pipeline",
   177  						Data:           map[string]interface{}{"a": "b"},
   178  					}))
   179  				})
   180  
   181  				It("request body should still be readable", func() {
   182  					body, err := ioutil.ReadAll(fakeRequest.Body)
   183  					Expect(err).ToNot(HaveOccurred())
   184  					Expect(body).To(Equal([]byte("a: b")))
   185  				})
   186  
   187  				Context("when Agent says pass", func() {
   188  					BeforeEach(func() {
   189  						fakePolicyAgent.CheckReturns(policy.PassedPolicyCheck(), nil)
   190  					})
   191  
   192  					It("it should pass", func() {
   193  						Expect(checkErr).ToNot(HaveOccurred())
   194  						Expect(result.Allowed).To(BeTrue())
   195  					})
   196  				})
   197  
   198  				Context("when Agent says not-pass", func() {
   199  					BeforeEach(func() {
   200  						fakePolicyAgent.CheckReturns(policy.PolicyCheckOutput{
   201  							Allowed: false,
   202  							Reasons: []string{"a policy says you can't do that"},
   203  						}, nil)
   204  					})
   205  
   206  					It("should not pass", func() {
   207  						Expect(checkErr).ToNot(HaveOccurred())
   208  						Expect(result.Allowed).To(BeFalse())
   209  						Expect(result.Reasons).To(ConsistOf("a policy says you can't do that"))
   210  					})
   211  				})
   212  
   213  				Context("when Agent says error", func() {
   214  					BeforeEach(func() {
   215  						fakePolicyAgent.CheckReturns(policy.FailedPolicyCheck(), errors.New("some-error"))
   216  					})
   217  
   218  					It("should not pass", func() {
   219  						Expect(checkErr).To(HaveOccurred())
   220  						Expect(checkErr.Error()).To(Equal("some-error"))
   221  						Expect(result.Allowed).To(BeFalse())
   222  					})
   223  				})
   224  			})
   225  		})
   226  	})
   227  })