github.com/google/cloudprober@v0.11.3/common/oauth/bearer_test.go (about) 1 // Copyright 2019 The Cloudprober Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package oauth 16 17 import ( 18 "sync" 19 "testing" 20 "time" 21 22 "github.com/golang/protobuf/proto" 23 configpb "github.com/google/cloudprober/common/oauth/proto" 24 ) 25 26 var global struct { 27 callCounter int 28 mu sync.RWMutex 29 } 30 31 func incCallCounter() { 32 global.mu.Lock() 33 defer global.mu.Unlock() 34 global.callCounter++ 35 } 36 37 func callCounter() int { 38 global.mu.RLock() 39 defer global.mu.RUnlock() 40 return global.callCounter 41 } 42 43 func testTokenFromFile(c *configpb.BearerToken) (string, error) { 44 incCallCounter() 45 return c.GetFile() + "_file_token", nil 46 } 47 48 func testTokenFromCmd(c *configpb.BearerToken) (string, error) { 49 incCallCounter() 50 return c.GetCmd() + "_cmd_token", nil 51 } 52 53 func testTokenFromGCEMetadata(c *configpb.BearerToken) (string, error) { 54 incCallCounter() 55 return c.GetGceServiceAccount() + "_gce_token", nil 56 } 57 58 func TestNewBearerToken(t *testing.T) { 59 getTokenFromFile = testTokenFromFile 60 getTokenFromCmd = testTokenFromCmd 61 getTokenFromGCEMetadata = testTokenFromGCEMetadata 62 63 testConfig := "file: \"f\"" 64 verifyBearerTokenSource(t, true, testConfig, "f_file_token") 65 66 // Disable caching by setting refresh_interval_sec to 0. 67 testConfig = "file: \"f\"\nrefresh_interval_sec: 0" 68 verifyBearerTokenSource(t, false, testConfig, "f_file_token") 69 70 testConfig = "cmd: \"c\"" 71 verifyBearerTokenSource(t, true, testConfig, "c_cmd_token") 72 73 testConfig = "gce_service_account: \"default\"" 74 verifyBearerTokenSource(t, true, testConfig, "default_gce_token") 75 } 76 77 func verifyBearerTokenSource(t *testing.T, cacheEnabled bool, testConfig, expectedToken string) { 78 t.Helper() 79 80 testC := &configpb.BearerToken{} 81 err := proto.UnmarshalText(testConfig, testC) 82 if err != nil { 83 t.Fatalf("error parsing test config (%s): %v", testConfig, err) 84 } 85 86 // Call counter should always increase during token source creation. 87 expectedC := callCounter() + 1 88 89 cts, err := newBearerTokenSource(testC, nil) 90 if err != nil { 91 t.Errorf("got unexpected error: %v", err) 92 } 93 94 cc := callCounter() 95 if cc != expectedC { 96 t.Errorf("unexpected call counter: got=%d, expected=%d", cc, expectedC) 97 } 98 99 tok, err := cts.Token() 100 if err != nil { 101 t.Errorf("unexpected error while retrieving token from config (%s): %v", testConfig, err) 102 } 103 104 if tok.AccessToken != expectedToken { 105 t.Errorf("Got token: %s, expected: %s", tok.AccessToken, expectedToken) 106 } 107 108 // Call counter will increase after Token call only if caching is disabled. 109 if !cacheEnabled { 110 expectedC++ 111 } 112 cc = callCounter() 113 114 if cc != expectedC { 115 t.Errorf("unexpected call counter: got=%d, expected=%d", cc, expectedC) 116 } 117 } 118 119 var ( 120 calledTestTokenOnce bool 121 calledTestTokenOnceMu sync.Mutex 122 ) 123 124 func testTokenRefresh(c *configpb.BearerToken) (string, error) { 125 calledTestTokenOnceMu.Lock() 126 defer calledTestTokenOnceMu.Unlock() 127 if calledTestTokenOnce { 128 return "new-token", nil 129 } 130 calledTestTokenOnce = true 131 return "old-token", nil 132 } 133 134 // TestRefreshCycle verifies that token gets refreshed after the refresh 135 // cycle. 136 func TestRefreshCycle(t *testing.T) { 137 getTokenFromCmd = testTokenRefresh 138 // Disable caching by setting refresh_interval_sec to 0. 139 testConfig := "cmd: \"c\"\nrefresh_interval_sec: 1" 140 141 testC := &configpb.BearerToken{} 142 err := proto.UnmarshalText(testConfig, testC) 143 if err != nil { 144 t.Fatalf("error parsing test config (%s): %v", testConfig, err) 145 } 146 147 ts, err := newBearerTokenSource(testC, nil) 148 if err != nil { 149 t.Errorf("got unexpected error: %v", err) 150 } 151 152 tok, err := ts.Token() 153 if err != nil { 154 t.Errorf("unexpected error while retrieving token from config (%s): %v", testConfig, err) 155 } 156 157 oldToken := "old-token" 158 newToken := "new-token" 159 160 if tok.AccessToken != oldToken { 161 t.Errorf("ts.Token(): got=%s, expected=%s", tok, oldToken) 162 } 163 164 time.Sleep(5 * time.Second) 165 166 tok, err = ts.Token() 167 if err != nil { 168 t.Errorf("unexpected error while retrieving token from config (%s): %v", testConfig, err) 169 } 170 171 if tok.AccessToken != newToken { 172 t.Errorf("ts.Token(): got=%s, expected=%s", tok, newToken) 173 } 174 }