github.com/shashidharatd/test-infra@v0.0.0-20171006011030-71304e1ca560/gubernator/github_auth_test.py (about) 1 #!/usr/bin/env python 2 3 # Copyright 2016 The Kubernetes Authors. 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 import unittest 18 import urlparse 19 20 import webtest 21 22 import gcs_async_test 23 import main 24 25 CLIENT_ID = '12345' 26 CLIENT_SECRET = 'swordfish' 27 GH_LOGIN_CODE = 'somerandomcode' 28 29 main.app.config['github_client'] = { 30 'id': CLIENT_ID, 31 'secret': CLIENT_SECRET, 32 } 33 34 app = webtest.TestApp(main.app) 35 36 VEND_URL = 'https://github.com/login/oauth/access_token' 37 USER_URL = 'https://api.github.com/user' 38 39 class TestGithubAuth(unittest.TestCase): 40 def setUp(self): 41 app.reset() 42 self.testbed.init_app_identity_stub() 43 self.testbed.init_urlfetch_stub() 44 self.calls = [] 45 self.results = { 46 VEND_URL: ('{"access_token": "token"}', 200), 47 USER_URL: ('{"login": "foo"}', 200), 48 } 49 gcs_async_test.install_handler_dispatcher( 50 self.testbed.get_stub('urlfetch'), 51 (lambda url: url in self.results), 52 self.dispatcher) 53 54 def dispatcher(self, method, url, payload, headers): 55 self.calls.append([method, url, payload, headers]) 56 return self.results[url] 57 58 @staticmethod 59 def do_phase1(arg=''): 60 return app.get('/github_auth' + arg) 61 62 @staticmethod 63 def parse_phase1(phase1): 64 parsed = urlparse.urlparse(phase1.location) 65 query = urlparse.parse_qs(parsed.query) 66 state = query.pop('state')[0] 67 return state, query 68 69 def do_phase2(self, phase1=None, status=None): 70 if not phase1: 71 phase1 = self.do_phase1() 72 state, query = self.parse_phase1(phase1) 73 code = GH_LOGIN_CODE 74 return app.get( 75 query['redirect_uri'][0], 76 {'code': code, 'state': state}, 77 status=status) 78 79 def test_login_works(self): 80 "oauth login works" 81 # 1) Redirect to github 82 resp = self.do_phase1() 83 self.assertEqual(resp.status_code, 302) 84 loc = resp.location 85 assert loc.startswith('https://github.com/login/oauth/authorize'), loc 86 state, query = self.parse_phase1(resp) 87 self.assertEqual(query, { 88 'redirect_uri': ['http://localhost/github_auth/done'], 89 'client_id': [CLIENT_ID]}) 90 91 # 2) Github redirects back 92 resp = self.do_phase2(resp) 93 self.assertIn('Welcome, foo', resp) 94 95 # Test that we received the right calls to our fake API. 96 self.assertEqual(len(self.calls), 2) 97 98 vend_call = self.calls[0] 99 user_call = self.calls[1] 100 101 self.assertEqual(vend_call[:2], ['POST', VEND_URL]) 102 self.assertEqual(user_call[:3], ['GET', USER_URL, None]) 103 104 self.assertEqual( 105 urlparse.parse_qs(vend_call[2]), 106 dict(client_secret=[CLIENT_SECRET], state=[state], 107 code=[GH_LOGIN_CODE], client_id=[CLIENT_ID])) 108 vend_headers = {h.key(): h.value() for h in vend_call[3]} 109 self.assertEqual(vend_headers, {'Accept': 'application/json'}) 110 111 def test_redirect_pr(self): 112 "login can redirect to another page at the end" 113 phase1 = self.do_phase1('/pr') 114 phase2 = self.do_phase2(phase1) 115 self.assertEqual(phase2.status_code, 302) 116 self.assertEqual(phase2.location, 'http://localhost/pr') 117 118 def test_redirect_ignored(self): 119 "login only redirects to whitelisted URLs" 120 phase1 = self.do_phase1('/bad/redirect') 121 phase2 = self.do_phase2(phase1) 122 self.assertEqual(phase2.status_code, 200) 123 124 def test_phase2_missing_cookie(self): 125 "missing cookie for phase2 fails (CSRF)" 126 phase1 = self.do_phase1() 127 app.reset() # clears cookies 128 self.do_phase2(phase1, status=400) 129 130 def test_phase2_mismatched_state(self): 131 "wrong state for phase2 fails (CSRF)" 132 phase1 = self.do_phase1() 133 phase1.location = phase1.location.replace('state=', 'state=NOPE') 134 self.do_phase2(phase1, status=400) 135 136 def test_phase2_vend_failure(self): 137 "GitHub API error vending tokens raises 500" 138 self.results[VEND_URL] = ('', 403) 139 self.do_phase2(status=500) 140 141 def test_phase2_user_failure(self): 142 "GitHub API error getting user information raises 500" 143 self.results[USER_URL] = ('', 403) 144 self.do_phase2(status=500)