k8s.io/client-go@v0.31.1/transport/token_source_test.go (about) 1 /* 2 Copyright 2018 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package transport 18 19 import ( 20 "fmt" 21 "net/http" 22 "reflect" 23 "sync" 24 "testing" 25 "time" 26 27 "golang.org/x/oauth2" 28 ) 29 30 type testTokenSource struct { 31 calls int 32 tok *oauth2.Token 33 err error 34 } 35 36 func (ts *testTokenSource) Token() (*oauth2.Token, error) { 37 ts.calls++ 38 return ts.tok, ts.err 39 } 40 41 func TestCachingTokenSource(t *testing.T) { 42 start := time.Now() 43 tokA := &oauth2.Token{ 44 AccessToken: "a", 45 Expiry: start.Add(10 * time.Minute), 46 } 47 tokB := &oauth2.Token{ 48 AccessToken: "b", 49 Expiry: start.Add(20 * time.Minute), 50 } 51 tests := []struct { 52 name string 53 54 tok *oauth2.Token 55 tsTok *oauth2.Token 56 tsErr error 57 wait time.Duration 58 59 wantTok *oauth2.Token 60 wantErr bool 61 wantTSCalls int 62 }{ 63 { 64 name: "valid token returned from cache", 65 tok: tokA, 66 wantTok: tokA, 67 }, 68 { 69 name: "valid token returned from cache 1 minute before scheduled refresh", 70 tok: tokA, 71 wait: 8 * time.Minute, 72 wantTok: tokA, 73 }, 74 { 75 name: "new token created when cache is empty", 76 tsTok: tokA, 77 wantTok: tokA, 78 wantTSCalls: 1, 79 }, 80 { 81 name: "new token created 1 minute after scheduled refresh", 82 tok: tokA, 83 tsTok: tokB, 84 wait: 10 * time.Minute, 85 wantTok: tokB, 86 wantTSCalls: 1, 87 }, 88 { 89 name: "error on create token returns error", 90 tsErr: fmt.Errorf("error"), 91 wantErr: true, 92 wantTSCalls: 1, 93 }, 94 } 95 for _, c := range tests { 96 t.Run(c.name, func(t *testing.T) { 97 tts := &testTokenSource{ 98 tok: c.tsTok, 99 err: c.tsErr, 100 } 101 102 ts := &cachingTokenSource{ 103 base: tts, 104 tok: c.tok, 105 leeway: 1 * time.Minute, 106 now: func() time.Time { return start.Add(c.wait) }, 107 } 108 109 gotTok, gotErr := ts.Token() 110 if got, want := gotTok, c.wantTok; !reflect.DeepEqual(got, want) { 111 t.Errorf("unexpected token:\n\tgot:\t%#v\n\twant:\t%#v", got, want) 112 } 113 if got, want := tts.calls, c.wantTSCalls; got != want { 114 t.Errorf("unexpected number of Token() calls: got %d, want %d", got, want) 115 } 116 if gotErr == nil && c.wantErr { 117 t.Errorf("wanted error but got none") 118 } 119 if gotErr != nil && !c.wantErr { 120 t.Errorf("unexpected error: %v", gotErr) 121 } 122 }) 123 } 124 } 125 126 func TestCachingTokenSourceRace(t *testing.T) { 127 for i := 0; i < 100; i++ { 128 tts := &testTokenSource{ 129 tok: &oauth2.Token{ 130 AccessToken: "a", 131 Expiry: time.Now().Add(1000 * time.Hour), 132 }, 133 } 134 135 ts := &cachingTokenSource{ 136 now: time.Now, 137 base: tts, 138 leeway: 1 * time.Minute, 139 } 140 141 var wg sync.WaitGroup 142 wg.Add(100) 143 errc := make(chan error, 100) 144 145 for i := 0; i < 100; i++ { 146 go func() { 147 defer wg.Done() 148 if _, err := ts.Token(); err != nil { 149 errc <- err 150 } 151 }() 152 } 153 go func() { 154 wg.Wait() 155 close(errc) 156 }() 157 if err, ok := <-errc; ok { 158 t.Fatalf("err: %v", err) 159 } 160 if tts.calls != 1 { 161 t.Errorf("expected one call to Token() but saw: %d", tts.calls) 162 } 163 } 164 } 165 166 func TestTokenSourceTransportRoundTrip(t *testing.T) { 167 goodToken := &oauth2.Token{ 168 AccessToken: "good", 169 Expiry: time.Now().Add(1000 * time.Hour), 170 } 171 badToken := &oauth2.Token{ 172 AccessToken: "bad", 173 Expiry: time.Now().Add(1000 * time.Hour), 174 } 175 tests := []struct { 176 name string 177 header http.Header 178 token *oauth2.Token 179 cachedToken *oauth2.Token 180 wantCalls int 181 wantCaching bool 182 }{ 183 { 184 name: "skip oauth rt if has authorization header", 185 header: map[string][]string{"Authorization": {"Bearer TOKEN"}}, 186 token: goodToken, 187 }, 188 { 189 name: "authorized on newly acquired good token", 190 token: goodToken, 191 wantCalls: 1, 192 wantCaching: true, 193 }, 194 { 195 name: "authorized on cached good token", 196 token: goodToken, 197 cachedToken: goodToken, 198 wantCalls: 0, 199 wantCaching: true, 200 }, 201 { 202 name: "unauthorized on newly acquired bad token", 203 token: badToken, 204 wantCalls: 1, 205 wantCaching: true, 206 }, 207 { 208 name: "unauthorized on cached bad token", 209 token: badToken, 210 cachedToken: badToken, 211 wantCalls: 0, 212 }, 213 } 214 for _, test := range tests { 215 t.Run(test.name, func(t *testing.T) { 216 tts := &testTokenSource{ 217 tok: test.token, 218 } 219 cachedTokenSource := NewCachedTokenSource(tts) 220 cachedTokenSource.tok = test.cachedToken 221 222 rt := ResettableTokenSourceWrapTransport(cachedTokenSource)(&testTransport{}) 223 224 rt.RoundTrip(&http.Request{Header: test.header}) 225 if tts.calls != test.wantCalls { 226 t.Errorf("RoundTrip() called Token() = %d times, want %d", tts.calls, test.wantCalls) 227 } 228 229 if (cachedTokenSource.tok != nil) != test.wantCaching { 230 t.Errorf("Got caching %v, want caching %v", cachedTokenSource != nil, test.wantCaching) 231 } 232 }) 233 } 234 } 235 236 type uncancellableRT struct { 237 rt http.RoundTripper 238 } 239 240 func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error) { 241 return urt.rt.RoundTrip(req) 242 } 243 244 func TestTokenSourceTransportCancelRequest(t *testing.T) { 245 tests := []struct { 246 name string 247 header http.Header 248 wrapTransport func(http.RoundTripper) http.RoundTripper 249 expectCancel bool 250 }{ 251 { 252 name: "cancel req with bearer token skips oauth rt", 253 header: map[string][]string{"Authorization": {"Bearer TOKEN"}}, 254 expectCancel: true, 255 }, 256 { 257 name: "can't cancel request with rts that doesn't implent unwrap or cancel", 258 wrapTransport: func(rt http.RoundTripper) http.RoundTripper { 259 return &uncancellableRT{rt: rt} 260 }, 261 expectCancel: false, 262 }, 263 } 264 for _, test := range tests { 265 t.Run(test.name, func(t *testing.T) { 266 baseRecorder := &testTransport{} 267 268 var base http.RoundTripper = baseRecorder 269 if test.wrapTransport != nil { 270 base = test.wrapTransport(base) 271 } 272 273 rt := &tokenSourceTransport{ 274 base: base, 275 ort: &oauth2.Transport{ 276 Base: base, 277 }, 278 } 279 280 rt.CancelRequest(&http.Request{ 281 Header: test.header, 282 }) 283 284 if baseRecorder.canceled != test.expectCancel { 285 t.Errorf("unexpected cancel: got=%v, want=%v", baseRecorder.canceled, test.expectCancel) 286 } 287 }) 288 } 289 } 290 291 type testTransport struct { 292 canceled bool 293 base http.RoundTripper 294 } 295 296 func (rt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { 297 if req.Header["Authorization"][0] == "Bearer bad" { 298 return &http.Response{StatusCode: 401}, nil 299 } 300 return nil, nil 301 } 302 303 func (rt *testTransport) CancelRequest(req *http.Request) { 304 rt.canceled = true 305 if rt.base != nil { 306 tryCancelRequest(rt.base, req) 307 } 308 }