github.com/hashicorp/cap@v0.6.0/oidc/examples/cli/main.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package main 5 6 import ( 7 "context" 8 "crypto/ecdsa" 9 "crypto/elliptic" 10 "crypto/rand" 11 "encoding/json" 12 "flag" 13 "fmt" 14 "net" 15 "net/http" 16 "os" 17 "os/signal" 18 "strconv" 19 "strings" 20 "time" 21 22 "github.com/hashicorp/cap/oidc" 23 "github.com/hashicorp/cap/oidc/callback" 24 "github.com/hashicorp/cap/util" 25 "github.com/hashicorp/go-hclog" 26 "golang.org/x/oauth2" 27 ) 28 29 // List of required configuration environment variables 30 const ( 31 clientID = "OIDC_CLIENT_ID" 32 clientSecret = "OIDC_CLIENT_SECRET" 33 issuer = "OIDC_ISSUER" 34 port = "OIDC_PORT" 35 attemptExp = "attemptExp" 36 ) 37 38 func envConfig(secretNotRequired bool) (map[string]interface{}, error) { 39 const op = "envConfig" 40 env := map[string]interface{}{ 41 clientID: os.Getenv("OIDC_CLIENT_ID"), 42 clientSecret: os.Getenv("OIDC_CLIENT_SECRET"), 43 issuer: os.Getenv("OIDC_ISSUER"), 44 port: os.Getenv("OIDC_PORT"), 45 attemptExp: time.Duration(2 * time.Minute), 46 } 47 for k, v := range env { 48 switch t := v.(type) { 49 case string: 50 switch k { 51 case "OIDC_CLIENT_SECRET": 52 switch { 53 case secretNotRequired: 54 env[k] = "" // unsetting the secret which isn't required 55 case t == "": 56 return nil, fmt.Errorf("%s: %s is empty.\n\n Did you intend to use -pkce or -implicit options?", op, k) 57 } 58 default: 59 if t == "" { 60 return nil, fmt.Errorf("%s: %s is empty", op, k) 61 } 62 } 63 case time.Duration: 64 if t == 0 { 65 return nil, fmt.Errorf("%s: %s is empty", op, k) 66 } 67 default: 68 return nil, fmt.Errorf("%s: %s is an unhandled type %t", op, k, t) 69 } 70 } 71 return env, nil 72 } 73 74 func main() { 75 useImplicit := flag.Bool("implicit", false, "use the implicit flow") 76 implicitAccessToken := flag.Bool("implicit-access-token", false, "include the access_token in the implicit flow") 77 usePKCE := flag.Bool("pkce", false, "use the implicit flow") 78 maxAge := flag.Int("max-age", -1, "max age of user authentication") 79 scopes := flag.String("scopes", "", "comma separated list of additional scopes to requests") 80 useTestProvider := flag.Bool("use-test-provider", false, "use the test oidc provider") 81 82 flag.Parse() 83 if *useImplicit && *usePKCE { 84 fmt.Fprint(os.Stderr, "you can't request both: -implicit and -pkce") 85 return 86 } 87 88 if (*useImplicit || *implicitAccessToken || *scopes != "") && *useTestProvider { 89 fmt.Fprint(os.Stderr, "you can't use the implicit flow, PKCE or scopes with the test provider") 90 return 91 } 92 93 optScopes := strings.Split(*scopes, ",") 94 for i := range optScopes { 95 optScopes[i] = strings.TrimSpace(optScopes[i]) 96 } 97 98 var env map[string]interface{} 99 var tp *oidc.TestProvider 100 if *useTestProvider { 101 l, err := oidc.NewTestingLogger(hclog.New(nil)) 102 if err != nil { 103 fmt.Fprintf(os.Stderr, "%s\n\n", err) 104 return 105 } 106 // Generate a key to sign JWTs with throughout most test cases 107 priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) 108 if err != nil { 109 fmt.Fprintf(os.Stderr, "%s\n\n", err) 110 return 111 } 112 oidcPort := os.Getenv("OIDC_PORT") 113 if oidcPort == "" { 114 oidcPort, err = func() (string, error) { 115 addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 116 if err != nil { 117 return "", err 118 } 119 120 l, err := net.ListenTCP("tcp", addr) 121 if err != nil { 122 return "", err 123 } 124 defer l.Close() 125 return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil 126 }() 127 if err != nil { 128 fmt.Fprintf(os.Stderr, "env OIDC_PORT is empty and error finding a free port: %s", err.Error()) 129 } 130 return 131 } 132 133 id, secret := "test-rp", "fido" 134 tp = oidc.StartTestProvider(l, oidc.WithNoTLS(), oidc.WithTestDefaults(&oidc.TestProviderDefaults{ 135 CustomClaims: map[string]interface{}{}, 136 SubjectInfo: map[string]*oidc.TestSubject{ 137 "alice": { 138 Password: "fido", 139 UserInfo: map[string]interface{}{ 140 "email": "alice@example.com", 141 "name": "alice smith", 142 "friend": "eve", 143 }, 144 CustomClaims: map[string]interface{}{ 145 "email": "alice@example.com", 146 "name": "alice smith", 147 }, 148 }, 149 "eve": { 150 Password: "alice", 151 UserInfo: map[string]interface{}{ 152 "email": "eve@example.com", 153 "name": "eve smith", 154 "friend": "alice", 155 }, 156 CustomClaims: map[string]interface{}{ 157 "email": "eve@example.com", 158 "name": "eve smith", 159 }, 160 }, 161 }, 162 SigningKey: &oidc.TestSigningKey{ 163 PrivKey: priv, 164 PubKey: priv.Public(), 165 Alg: oidc.ES384, 166 }, 167 AllowedRedirectURIs: []string{fmt.Sprintf("http://localhost:%s/callback", oidcPort)}, 168 ClientID: &id, 169 ClientSecret: &secret, 170 })) 171 defer tp.Stop() 172 env = map[string]interface{}{ 173 clientID: id, 174 clientSecret: secret, 175 issuer: tp.Addr(), 176 port: oidcPort, 177 attemptExp: time.Duration(2 * time.Minute), 178 } 179 } else { 180 var err error 181 env, err = envConfig(*useImplicit || *usePKCE) 182 if err != nil { 183 fmt.Fprintf(os.Stderr, "%s\n\n", err) 184 return 185 } 186 } 187 188 // handle ctrl-c while waiting for the callback 189 sigintCh := make(chan os.Signal, 1) 190 signal.Notify(sigintCh, os.Interrupt) 191 defer signal.Stop(sigintCh) 192 193 ctx, cancel := context.WithCancel(context.Background()) 194 defer cancel() 195 196 issuer := env[issuer].(string) 197 clientID := env[clientID].(string) 198 clientSecret := oidc.ClientSecret(env[clientSecret].(string)) 199 redirectURL := fmt.Sprintf("http://localhost:%s/callback", env[port].(string)) 200 pc, err := oidc.NewConfig(issuer, clientID, clientSecret, []oidc.Alg{oidc.ES384}, []string{redirectURL}) 201 if err != nil { 202 fmt.Fprint(os.Stderr, err.Error()) 203 return 204 } 205 206 p, err := oidc.NewProvider(pc) 207 if err != nil { 208 fmt.Fprint(os.Stderr, err.Error()) 209 return 210 } 211 defer p.Done() 212 213 var requestOptions []oidc.Option 214 switch { 215 case *useImplicit && !*implicitAccessToken: 216 requestOptions = append(requestOptions, oidc.WithImplicitFlow()) 217 case *useImplicit && *implicitAccessToken: 218 requestOptions = append(requestOptions, oidc.WithImplicitFlow(true)) 219 case *usePKCE: 220 v, err := oidc.NewCodeVerifier() 221 if err != nil { 222 fmt.Fprint(os.Stderr, err.Error()) 223 return 224 } 225 requestOptions = append(requestOptions, oidc.WithPKCE(v)) 226 } 227 228 if *maxAge >= 0 { 229 requestOptions = append(requestOptions, oidc.WithMaxAge(uint(*maxAge))) 230 } 231 232 requestOptions = append(requestOptions, oidc.WithScopes(optScopes...)) 233 234 oidcRequest, err := oidc.NewRequest(env[attemptExp].(time.Duration), redirectURL, requestOptions...) 235 if err != nil { 236 fmt.Fprint(os.Stderr, err.Error()) 237 return 238 } 239 240 successFn, successCh := success() 241 errorFn, failedCh := failed() 242 243 var handler http.HandlerFunc 244 if *useImplicit { 245 handler, err = callback.Implicit(ctx, p, &callback.SingleRequestReader{Request: oidcRequest}, successFn, errorFn) 246 if err != nil { 247 fmt.Fprintf(os.Stderr, "error creating callback handler: %s", err) 248 return 249 } 250 } else { 251 handler, err = callback.AuthCode(ctx, p, &callback.SingleRequestReader{Request: oidcRequest}, successFn, errorFn) 252 if err != nil { 253 fmt.Fprintf(os.Stderr, "error creating auth code handler: %s", err) 254 return 255 } 256 } 257 258 authURL, err := p.AuthURL(ctx, oidcRequest) 259 if err != nil { 260 fmt.Fprintf(os.Stderr, "error getting auth url: %s", err) 261 return 262 } 263 264 // Set up callback handler 265 http.HandleFunc("/callback", handler) 266 267 listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%s", env[port])) 268 if err != nil { 269 fmt.Fprint(os.Stderr, err.Error()) 270 return 271 } 272 defer listener.Close() 273 274 // Open the default browser to the callback URL. 275 fmt.Fprintf(os.Stderr, "Complete the login via your OIDC provider. Launching browser to:\n\n %s\n\n\n", authURL) 276 if err := util.OpenURL(authURL); err != nil { 277 fmt.Fprintf(os.Stderr, "Error attempting to automatically open browser: '%s'.\nPlease visit the authorization URL manually.", err) 278 } 279 280 srvCh := make(chan error) 281 // Start local server 282 go func() { 283 err := http.Serve(listener, nil) 284 if err != nil && err != http.ErrServerClosed { 285 srvCh <- err 286 } 287 }() 288 289 // Wait for either the callback to finish, SIGINT to be received or up to 2 minutes 290 select { 291 case err := <-srvCh: 292 fmt.Fprintf(os.Stderr, "server closed with error: %s", err.Error()) 293 return 294 case resp := <-successCh: 295 if resp.Error != nil { 296 fmt.Fprintf(os.Stderr, "channel received success with error: %s", resp.Error) 297 return 298 } 299 printToken(resp.Token) 300 printClaims(resp.Token.IDToken()) 301 printUserInfo(ctx, p, resp.Token) 302 return 303 case err := <-failedCh: 304 if err != nil { 305 fmt.Fprintf(os.Stderr, "channel received error: %s", err) 306 return 307 } 308 fmt.Fprint(os.Stderr, "missing error from error channel. try again?\n") 309 return 310 case <-sigintCh: 311 fmt.Fprintf(os.Stderr, "Interrupted") 312 return 313 case <-time.After(env[attemptExp].(time.Duration)): 314 fmt.Fprintf(os.Stderr, "Timed out waiting for response from provider") 315 return 316 } 317 } 318 319 type successResp struct { 320 Token oidc.Token // Token is populated when the callback successfully exchanges the auth code. 321 Error error // Error is populated when there's an error during the callback 322 } 323 324 func success() (callback.SuccessResponseFunc, <-chan successResp) { 325 const op = "success" 326 doneCh := make(chan successResp) 327 return func(state string, t oidc.Token, w http.ResponseWriter, req *http.Request) { 328 var responseErr error 329 defer func() { 330 doneCh <- successResp{t, responseErr} 331 close(doneCh) 332 }() 333 w.WriteHeader(http.StatusOK) 334 if _, err := w.Write([]byte(successHTML)); err != nil { 335 responseErr = fmt.Errorf("%s: %w", op, err) 336 fmt.Fprintf(os.Stderr, "error writing successful response: %s", err) 337 } 338 }, doneCh 339 } 340 341 func failed() (callback.ErrorResponseFunc, <-chan error) { 342 const op = "failed" 343 doneCh := make(chan error) 344 return func(state string, r *callback.AuthenErrorResponse, e error, w http.ResponseWriter, req *http.Request) { 345 var responseErr error 346 defer func() { 347 if _, err := w.Write([]byte(responseErr.Error())); err != nil { 348 fmt.Fprintf(os.Stderr, "%s: error writing failed response: %s", op, err) 349 } 350 doneCh <- responseErr 351 close(doneCh) 352 }() 353 354 if e != nil { 355 fmt.Fprintf(os.Stderr, "%s: callback error: %s", op, e.Error()) 356 responseErr = e 357 w.WriteHeader(http.StatusInternalServerError) 358 return 359 } 360 if r != nil { 361 responseErr = fmt.Errorf("%s: callback error from oidc provider: %s", op, r) 362 fmt.Fprint(os.Stderr, responseErr.Error()) 363 w.WriteHeader(http.StatusUnauthorized) 364 return 365 } 366 responseErr = fmt.Errorf("%s: unknown error from callback", op) 367 }, doneCh 368 } 369 370 type respToken struct { 371 IDToken string 372 AccessToken string 373 RefreshToken string 374 Expiry time.Time 375 } 376 377 func printClaims(t oidc.IDToken) { 378 const op = "printClaims" 379 var tokenClaims map[string]interface{} 380 if err := t.Claims(&tokenClaims); err != nil { 381 fmt.Fprintf(os.Stderr, "IDToken claims: error parsing: %s", err) 382 } else { 383 if idData, err := json.MarshalIndent(tokenClaims, "", " "); err != nil { 384 fmt.Fprintf(os.Stderr, "%s: %s", op, err) 385 } else { 386 fmt.Fprintf(os.Stderr, "IDToken claims:%s\n", idData) 387 } 388 } 389 } 390 391 func printUserInfo(ctx context.Context, p *oidc.Provider, t oidc.Token) { 392 const op = "printUserInfo" 393 if ts, ok := t.(interface { 394 StaticTokenSource() oauth2.TokenSource 395 }); ok { 396 if ts.StaticTokenSource() == nil { 397 fmt.Fprintf(os.Stderr, "%s: no access_token received, so we're unable to get UserInfo claims", op) 398 return 399 } 400 vc := struct { 401 Sub string 402 }{} 403 if err := t.IDToken().Claims(&vc); err != nil { 404 fmt.Fprintf(os.Stderr, "%s: channel received success, but error getting UserInfo claims: %s", op, err) 405 return 406 } 407 var infoClaims map[string]interface{} 408 err := p.UserInfo(ctx, ts.StaticTokenSource(), vc.Sub, &infoClaims) 409 if err != nil { 410 fmt.Fprintf(os.Stderr, "%s: channel received success, but error getting UserInfo claims: %s", op, err) 411 return 412 } 413 infoData, err := json.MarshalIndent(infoClaims, "", " ") 414 if err != nil { 415 fmt.Fprintf(os.Stderr, "%s: %s", op, err) 416 return 417 } 418 fmt.Fprintf(os.Stderr, "UserInfo claims:%s\n", infoData) 419 return 420 } 421 } 422 423 func printToken(t oidc.Token) { 424 const op = "printToken" 425 tokenData, err := json.MarshalIndent(printableToken(t), "", " ") 426 if err != nil { 427 fmt.Fprintf(os.Stderr, "%s: %s", op, err) 428 return 429 } 430 fmt.Fprintf(os.Stderr, "channel received success.\nToken:%s\n", tokenData) 431 } 432 433 // printableToken is needed because the oidc.Token redacts the IDToken, 434 // AccessToken and RefreshToken 435 func printableToken(t oidc.Token) respToken { 436 return respToken{ 437 IDToken: string(t.IDToken()), 438 AccessToken: string(t.AccessToken()), 439 RefreshToken: string(t.RefreshToken()), 440 Expiry: t.Expiry(), 441 } 442 }