k8s.io/test-infra@v0.0.0-20240520184403-27c6b4c223d8/gubernator/github_auth_test.py (about)

     1  #!/usr/bin/env python
     2  
     3  # Copyright 2016 The Kubernetes Authors.
     4  #
     5  # Licensed under the Apache License, Version 2.0 (the "License");
     6  # you may not use this file except in compliance with the License.
     7  # You may obtain a copy of the License at
     8  #
     9  #     http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  
    17  import unittest
    18  import urlparse
    19  
    20  import webtest
    21  
    22  import gcs_async_test
    23  import main
    24  
    25  CLIENT_ID = '12345'
    26  CLIENT_SECRET = 'swordfish'
    27  GH_LOGIN_CODE = 'somerandomcode'
    28  
    29  main.app.config['github_client'] = {
    30      'id': CLIENT_ID,
    31      'secret': CLIENT_SECRET,
    32  }
    33  main.app.config['webapp2_extras.sessions']['secret_key'] = 'abcd'
    34  
    35  app = webtest.TestApp(main.app)
    36  
    37  VEND_URL = 'https://github.com/login/oauth/access_token'
    38  USER_URL = 'https://api.github.com/user'
    39  
    40  class TestGithubAuth(unittest.TestCase):
    41      def setUp(self):
    42          app.reset()
    43          self.testbed.init_app_identity_stub()
    44          self.testbed.init_urlfetch_stub()
    45          self.calls = []
    46          self.results = {
    47              VEND_URL: ('{"access_token": "token"}', 200),
    48              USER_URL: ('{"login": "foo"}', 200),
    49          }
    50          gcs_async_test.install_handler_dispatcher(
    51              self.testbed.get_stub('urlfetch'),
    52              (lambda url: url in self.results),
    53              self.dispatcher)
    54  
    55      def dispatcher(self, method, url, payload, headers):
    56          self.calls.append([method, url, payload, headers])
    57          return self.results[url]
    58  
    59      @staticmethod
    60      def do_phase1(arg=''):
    61          return app.get('/github_auth' + arg)
    62  
    63      @staticmethod
    64      def parse_phase1(phase1):
    65          parsed = urlparse.urlparse(phase1.location)
    66          query = urlparse.parse_qs(parsed.query)
    67          state = query.pop('state')[0]
    68          return state, query
    69  
    70      def do_phase2(self, phase1=None, status=None):
    71          if not phase1:
    72              phase1 = self.do_phase1()
    73          state, query = self.parse_phase1(phase1)
    74          code = GH_LOGIN_CODE
    75          return app.get(
    76              query['redirect_uri'][0],
    77              {'code': code, 'state': state},
    78              status=status)
    79  
    80      def test_login_works(self):
    81          "oauth login works"
    82          # 1) Redirect to github
    83          resp = self.do_phase1()
    84          self.assertEqual(resp.status_code, 302)
    85          loc = resp.location
    86          assert loc.startswith('https://github.com/login/oauth/authorize'), loc
    87          state, query = self.parse_phase1(resp)
    88          self.assertEqual(query, {
    89              'redirect_uri': ['http://localhost/github_auth/done'],
    90              'client_id': [CLIENT_ID]})
    91  
    92          # 2) Github redirects back
    93          resp = self.do_phase2(resp)
    94          self.assertIn('Welcome, foo', resp)
    95  
    96          # Test that we received the right calls to our fake API.
    97          self.assertEqual(len(self.calls), 2)
    98  
    99          vend_call = self.calls[0]
   100          user_call = self.calls[1]
   101  
   102          self.assertEqual(vend_call[:2], ['POST', VEND_URL])
   103          self.assertEqual(user_call[:3], ['GET', USER_URL, None])
   104  
   105          self.assertEqual(
   106              urlparse.parse_qs(vend_call[2]),
   107              dict(client_secret=[CLIENT_SECRET], state=[state],
   108                   code=[GH_LOGIN_CODE], client_id=[CLIENT_ID]))
   109          vend_headers = {h.key(): h.value() for h in vend_call[3]}
   110          self.assertEqual(vend_headers, {'Accept': 'application/json'})
   111  
   112      def test_redirect_pr(self):
   113          "login can redirect to another page at the end"
   114          phase1 = self.do_phase1('/pr')
   115          phase2 = self.do_phase2(phase1)
   116          self.assertEqual(phase2.status_code, 302)
   117          self.assertEqual(phase2.location, 'http://localhost/pr')
   118  
   119      def test_redirect_ignored(self):
   120          "login only redirects to allowed URLs"
   121          phase1 = self.do_phase1('/bad/redirect')
   122          phase2 = self.do_phase2(phase1)
   123          self.assertEqual(phase2.status_code, 200)
   124  
   125      def test_phase2_missing_cookie(self):
   126          "missing cookie for phase2 fails (CSRF)"
   127          phase1 = self.do_phase1()
   128          app.reset()  # clears cookies
   129          self.do_phase2(phase1, status=400)
   130  
   131      def test_phase2_mismatched_state(self):
   132          "wrong state for phase2 fails (CSRF)"
   133          phase1 = self.do_phase1()
   134          phase1.location = phase1.location.replace('state=', 'state=NOPE')
   135          self.do_phase2(phase1, status=400)
   136  
   137      def test_phase2_vend_failure(self):
   138          "GitHub API error vending tokens raises 500"
   139          self.results[VEND_URL] = ('', 403)
   140          self.do_phase2(status=500)
   141  
   142      def test_phase2_user_failure(self):
   143          "GitHub API error getting user information raises 500"
   144          self.results[USER_URL] = ('', 403)
   145          self.do_phase2(status=500)