golang.org/x/oauth2@v0.18.0/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package externalaccountauthorizeduser 6 7 import ( 8 "context" 9 "encoding/json" 10 "errors" 11 "io/ioutil" 12 "net/http" 13 "net/http/httptest" 14 "testing" 15 "time" 16 17 "golang.org/x/oauth2" 18 "golang.org/x/oauth2/google/internal/stsexchange" 19 ) 20 21 const expiryDelta = 10 * time.Second 22 23 var ( 24 expiry = time.Unix(234852, 0) 25 testNow = func() time.Time { return expiry } 26 testValid = func(t oauth2.Token) bool { 27 return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow()) 28 } 29 ) 30 31 type testRefreshTokenServer struct { 32 URL string 33 Authorization string 34 ContentType string 35 Body string 36 ResponsePayload *stsexchange.Response 37 Response string 38 server *httptest.Server 39 } 40 41 func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) { 42 config := &Config{ 43 Token: "AAAAAAA", 44 Expiry: now().Add(time.Hour), 45 } 46 ts, err := config.TokenSource(context.Background()) 47 if err != nil { 48 t.Fatalf("Error getting token source: %v", err) 49 } 50 51 token, err := ts.Token() 52 if err != nil { 53 t.Fatalf("Error retrieving Token: %v", err) 54 } 55 if got, want := token.AccessToken, "AAAAAAA"; got != want { 56 t.Fatalf("Unexpected access token, got %v, want %v", got, want) 57 } 58 } 59 60 func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) { 61 server := &testRefreshTokenServer{ 62 URL: "/", 63 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=", 64 ContentType: "application/x-www-form-urlencoded", 65 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB", 66 ResponsePayload: &stsexchange.Response{ 67 ExpiresIn: 3600, 68 AccessToken: "AAAAAAA", 69 RefreshToken: "CCCCCCC", 70 }, 71 } 72 73 url, err := server.run(t) 74 if err != nil { 75 t.Fatalf("Error starting server") 76 } 77 defer server.close(t) 78 79 config := &Config{ 80 RefreshToken: "BBBBBBBBB", 81 TokenURL: url, 82 ClientID: "CLIENT_ID", 83 ClientSecret: "CLIENT_SECRET", 84 } 85 ts, err := config.TokenSource(context.Background()) 86 if err != nil { 87 t.Fatalf("Error getting token source: %v", err) 88 } 89 90 token, err := ts.Token() 91 if err != nil { 92 t.Fatalf("Error retrieving Token: %v", err) 93 } 94 if got, want := token.AccessToken, "AAAAAAA"; got != want { 95 t.Fatalf("Unexpected access token, got %v, want %v", got, want) 96 } 97 if config.RefreshToken != "CCCCCCC" { 98 t.Fatalf("Refresh token not updated") 99 } 100 } 101 102 func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) { 103 server := &testRefreshTokenServer{ 104 URL: "/", 105 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=", 106 ContentType: "application/x-www-form-urlencoded", 107 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB", 108 ResponsePayload: &stsexchange.Response{ 109 ExpiresIn: 3600, 110 AccessToken: "AAAAAAA", 111 }, 112 } 113 114 url, err := server.run(t) 115 if err != nil { 116 t.Fatalf("Error starting server") 117 } 118 defer server.close(t) 119 120 config := &Config{ 121 RefreshToken: "BBBBBBBBB", 122 TokenURL: url, 123 ClientID: "CLIENT_ID", 124 ClientSecret: "CLIENT_SECRET", 125 } 126 ts, err := config.TokenSource(context.Background()) 127 if err != nil { 128 t.Fatalf("Error getting token source: %v", err) 129 } 130 131 token, err := ts.Token() 132 if err != nil { 133 t.Fatalf("Error retrieving Token: %v", err) 134 } 135 if got, want := token.AccessToken, "AAAAAAA"; got != want { 136 t.Fatalf("Unexpected access token, got %v, want %v", got, want) 137 } 138 } 139 140 func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) { 141 server := &testRefreshTokenServer{ 142 URL: "/", 143 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=", 144 ContentType: "application/x-www-form-urlencoded", 145 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB", 146 ResponsePayload: &stsexchange.Response{ 147 ExpiresIn: 3600, 148 AccessToken: "AAAAAAA", 149 }, 150 } 151 152 url, err := server.run(t) 153 if err != nil { 154 t.Fatalf("Error starting server") 155 } 156 defer server.close(t) 157 testCases := []struct { 158 name string 159 config Config 160 }{ 161 { 162 name: "empty config", 163 config: Config{}, 164 }, 165 { 166 name: "missing refresh token", 167 config: Config{ 168 TokenURL: url, 169 ClientID: "CLIENT_ID", 170 ClientSecret: "CLIENT_SECRET", 171 }, 172 }, 173 { 174 name: "missing token url", 175 config: Config{ 176 RefreshToken: "BBBBBBBBB", 177 ClientID: "CLIENT_ID", 178 ClientSecret: "CLIENT_SECRET", 179 }, 180 }, 181 { 182 name: "missing client id", 183 config: Config{ 184 RefreshToken: "BBBBBBBBB", 185 TokenURL: url, 186 ClientSecret: "CLIENT_SECRET", 187 }, 188 }, 189 { 190 name: "missing client secrect", 191 config: Config{ 192 RefreshToken: "BBBBBBBBB", 193 TokenURL: url, 194 ClientID: "CLIENT_ID", 195 }, 196 }, 197 } 198 for _, tc := range testCases { 199 t.Run(tc.name, func(t *testing.T) { 200 201 expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)." 202 _, err := tc.config.TokenSource((context.Background())) 203 if err == nil { 204 t.Fatalf("Expected error, but received none") 205 } 206 if got := err.Error(); got != expectErrMsg { 207 t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg) 208 } 209 }) 210 } 211 } 212 213 func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) { 214 t.Helper() 215 if trts.server != nil { 216 return "", errors.New("Server is already running") 217 } 218 trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 219 if got, want := r.URL.String(), trts.URL; got != want { 220 t.Errorf("URL.String(): got %v but want %v", got, want) 221 } 222 headerAuth := r.Header.Get("Authorization") 223 if got, want := headerAuth, trts.Authorization; got != want { 224 t.Errorf("got %v but want %v", got, want) 225 } 226 headerContentType := r.Header.Get("Content-Type") 227 if got, want := headerContentType, trts.ContentType; got != want { 228 t.Errorf("got %v but want %v", got, want) 229 } 230 body, err := ioutil.ReadAll(r.Body) 231 if err != nil { 232 t.Fatalf("Failed reading request body: %s.", err) 233 } 234 if got, want := string(body), trts.Body; got != want { 235 t.Errorf("Unexpected exchange payload: got %v but want %v", got, want) 236 } 237 w.Header().Set("Content-Type", "application/json") 238 if trts.ResponsePayload != nil { 239 content, err := json.Marshal(trts.ResponsePayload) 240 if err != nil { 241 t.Fatalf("unable to marshall response JSON") 242 } 243 w.Write(content) 244 } else { 245 w.Write([]byte(trts.Response)) 246 } 247 })) 248 return trts.server.URL, nil 249 } 250 251 func (trts *testRefreshTokenServer) close(t *testing.T) error { 252 t.Helper() 253 if trts.server == nil { 254 return errors.New("No server is running") 255 } 256 trts.server.Close() 257 trts.server = nil 258 return nil 259 }