k8s.io/test-infra@v0.0.0-20240520184403-27c6b4c223d8/gubernator/github_auth_test.py (about) 1 #!/usr/bin/env python 2 3 # Copyright 2016 The Kubernetes Authors. 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 import unittest 18 import urlparse 19 20 import webtest 21 22 import gcs_async_test 23 import main 24 25 CLIENT_ID = '12345' 26 CLIENT_SECRET = 'swordfish' 27 GH_LOGIN_CODE = 'somerandomcode' 28 29 main.app.config['github_client'] = { 30 'id': CLIENT_ID, 31 'secret': CLIENT_SECRET, 32 } 33 main.app.config['webapp2_extras.sessions']['secret_key'] = 'abcd' 34 35 app = webtest.TestApp(main.app) 36 37 VEND_URL = 'https://github.com/login/oauth/access_token' 38 USER_URL = 'https://api.github.com/user' 39 40 class TestGithubAuth(unittest.TestCase): 41 def setUp(self): 42 app.reset() 43 self.testbed.init_app_identity_stub() 44 self.testbed.init_urlfetch_stub() 45 self.calls = [] 46 self.results = { 47 VEND_URL: ('{"access_token": "token"}', 200), 48 USER_URL: ('{"login": "foo"}', 200), 49 } 50 gcs_async_test.install_handler_dispatcher( 51 self.testbed.get_stub('urlfetch'), 52 (lambda url: url in self.results), 53 self.dispatcher) 54 55 def dispatcher(self, method, url, payload, headers): 56 self.calls.append([method, url, payload, headers]) 57 return self.results[url] 58 59 @staticmethod 60 def do_phase1(arg=''): 61 return app.get('/github_auth' + arg) 62 63 @staticmethod 64 def parse_phase1(phase1): 65 parsed = urlparse.urlparse(phase1.location) 66 query = urlparse.parse_qs(parsed.query) 67 state = query.pop('state')[0] 68 return state, query 69 70 def do_phase2(self, phase1=None, status=None): 71 if not phase1: 72 phase1 = self.do_phase1() 73 state, query = self.parse_phase1(phase1) 74 code = GH_LOGIN_CODE 75 return app.get( 76 query['redirect_uri'][0], 77 {'code': code, 'state': state}, 78 status=status) 79 80 def test_login_works(self): 81 "oauth login works" 82 # 1) Redirect to github 83 resp = self.do_phase1() 84 self.assertEqual(resp.status_code, 302) 85 loc = resp.location 86 assert loc.startswith('https://github.com/login/oauth/authorize'), loc 87 state, query = self.parse_phase1(resp) 88 self.assertEqual(query, { 89 'redirect_uri': ['http://localhost/github_auth/done'], 90 'client_id': [CLIENT_ID]}) 91 92 # 2) Github redirects back 93 resp = self.do_phase2(resp) 94 self.assertIn('Welcome, foo', resp) 95 96 # Test that we received the right calls to our fake API. 97 self.assertEqual(len(self.calls), 2) 98 99 vend_call = self.calls[0] 100 user_call = self.calls[1] 101 102 self.assertEqual(vend_call[:2], ['POST', VEND_URL]) 103 self.assertEqual(user_call[:3], ['GET', USER_URL, None]) 104 105 self.assertEqual( 106 urlparse.parse_qs(vend_call[2]), 107 dict(client_secret=[CLIENT_SECRET], state=[state], 108 code=[GH_LOGIN_CODE], client_id=[CLIENT_ID])) 109 vend_headers = {h.key(): h.value() for h in vend_call[3]} 110 self.assertEqual(vend_headers, {'Accept': 'application/json'}) 111 112 def test_redirect_pr(self): 113 "login can redirect to another page at the end" 114 phase1 = self.do_phase1('/pr') 115 phase2 = self.do_phase2(phase1) 116 self.assertEqual(phase2.status_code, 302) 117 self.assertEqual(phase2.location, 'http://localhost/pr') 118 119 def test_redirect_ignored(self): 120 "login only redirects to allowed URLs" 121 phase1 = self.do_phase1('/bad/redirect') 122 phase2 = self.do_phase2(phase1) 123 self.assertEqual(phase2.status_code, 200) 124 125 def test_phase2_missing_cookie(self): 126 "missing cookie for phase2 fails (CSRF)" 127 phase1 = self.do_phase1() 128 app.reset() # clears cookies 129 self.do_phase2(phase1, status=400) 130 131 def test_phase2_mismatched_state(self): 132 "wrong state for phase2 fails (CSRF)" 133 phase1 = self.do_phase1() 134 phase1.location = phase1.location.replace('state=', 'state=NOPE') 135 self.do_phase2(phase1, status=400) 136 137 def test_phase2_vend_failure(self): 138 "GitHub API error vending tokens raises 500" 139 self.results[VEND_URL] = ('', 403) 140 self.do_phase2(status=500) 141 142 def test_phase2_user_failure(self): 143 "GitHub API error getting user information raises 500" 144 self.results[USER_URL] = ('', 403) 145 self.do_phase2(status=500)