github.com/shashidharatd/test-infra@v0.0.0-20171006011030-71304e1ca560/gubernator/github_auth_test.py (about)

     1  #!/usr/bin/env python
     2  
     3  # Copyright 2016 The Kubernetes Authors.
     4  #
     5  # Licensed under the Apache License, Version 2.0 (the "License");
     6  # you may not use this file except in compliance with the License.
     7  # You may obtain a copy of the License at
     8  #
     9  #     http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  
    17  import unittest
    18  import urlparse
    19  
    20  import webtest
    21  
    22  import gcs_async_test
    23  import main
    24  
    25  CLIENT_ID = '12345'
    26  CLIENT_SECRET = 'swordfish'
    27  GH_LOGIN_CODE = 'somerandomcode'
    28  
    29  main.app.config['github_client'] = {
    30      'id': CLIENT_ID,
    31      'secret': CLIENT_SECRET,
    32  }
    33  
    34  app = webtest.TestApp(main.app)
    35  
    36  VEND_URL = 'https://github.com/login/oauth/access_token'
    37  USER_URL = 'https://api.github.com/user'
    38  
    39  class TestGithubAuth(unittest.TestCase):
    40      def setUp(self):
    41          app.reset()
    42          self.testbed.init_app_identity_stub()
    43          self.testbed.init_urlfetch_stub()
    44          self.calls = []
    45          self.results = {
    46              VEND_URL: ('{"access_token": "token"}', 200),
    47              USER_URL: ('{"login": "foo"}', 200),
    48          }
    49          gcs_async_test.install_handler_dispatcher(
    50              self.testbed.get_stub('urlfetch'),
    51              (lambda url: url in self.results),
    52              self.dispatcher)
    53  
    54      def dispatcher(self, method, url, payload, headers):
    55          self.calls.append([method, url, payload, headers])
    56          return self.results[url]
    57  
    58      @staticmethod
    59      def do_phase1(arg=''):
    60          return app.get('/github_auth' + arg)
    61  
    62      @staticmethod
    63      def parse_phase1(phase1):
    64          parsed = urlparse.urlparse(phase1.location)
    65          query = urlparse.parse_qs(parsed.query)
    66          state = query.pop('state')[0]
    67          return state, query
    68  
    69      def do_phase2(self, phase1=None, status=None):
    70          if not phase1:
    71              phase1 = self.do_phase1()
    72          state, query = self.parse_phase1(phase1)
    73          code = GH_LOGIN_CODE
    74          return app.get(
    75              query['redirect_uri'][0],
    76              {'code': code, 'state': state},
    77              status=status)
    78  
    79      def test_login_works(self):
    80          "oauth login works"
    81          # 1) Redirect to github
    82          resp = self.do_phase1()
    83          self.assertEqual(resp.status_code, 302)
    84          loc = resp.location
    85          assert loc.startswith('https://github.com/login/oauth/authorize'), loc
    86          state, query = self.parse_phase1(resp)
    87          self.assertEqual(query, {
    88              'redirect_uri': ['http://localhost/github_auth/done'],
    89              'client_id': [CLIENT_ID]})
    90  
    91          # 2) Github redirects back
    92          resp = self.do_phase2(resp)
    93          self.assertIn('Welcome, foo', resp)
    94  
    95          # Test that we received the right calls to our fake API.
    96          self.assertEqual(len(self.calls), 2)
    97  
    98          vend_call = self.calls[0]
    99          user_call = self.calls[1]
   100  
   101          self.assertEqual(vend_call[:2], ['POST', VEND_URL])
   102          self.assertEqual(user_call[:3], ['GET', USER_URL, None])
   103  
   104          self.assertEqual(
   105              urlparse.parse_qs(vend_call[2]),
   106              dict(client_secret=[CLIENT_SECRET], state=[state],
   107                   code=[GH_LOGIN_CODE], client_id=[CLIENT_ID]))
   108          vend_headers = {h.key(): h.value() for h in vend_call[3]}
   109          self.assertEqual(vend_headers, {'Accept': 'application/json'})
   110  
   111      def test_redirect_pr(self):
   112          "login can redirect to another page at the end"
   113          phase1 = self.do_phase1('/pr')
   114          phase2 = self.do_phase2(phase1)
   115          self.assertEqual(phase2.status_code, 302)
   116          self.assertEqual(phase2.location, 'http://localhost/pr')
   117  
   118      def test_redirect_ignored(self):
   119          "login only redirects to whitelisted URLs"
   120          phase1 = self.do_phase1('/bad/redirect')
   121          phase2 = self.do_phase2(phase1)
   122          self.assertEqual(phase2.status_code, 200)
   123  
   124      def test_phase2_missing_cookie(self):
   125          "missing cookie for phase2 fails (CSRF)"
   126          phase1 = self.do_phase1()
   127          app.reset()  # clears cookies
   128          self.do_phase2(phase1, status=400)
   129  
   130      def test_phase2_mismatched_state(self):
   131          "wrong state for phase2 fails (CSRF)"
   132          phase1 = self.do_phase1()
   133          phase1.location = phase1.location.replace('state=', 'state=NOPE')
   134          self.do_phase2(phase1, status=400)
   135  
   136      def test_phase2_vend_failure(self):
   137          "GitHub API error vending tokens raises 500"
   138          self.results[VEND_URL] = ('', 403)
   139          self.do_phase2(status=500)
   140  
   141      def test_phase2_user_failure(self):
   142          "GitHub API error getting user information raises 500"
   143          self.results[USER_URL] = ('', 403)
   144          self.do_phase2(status=500)