github.com/hashicorp/cap@v0.6.0/oidc/examples/cli/main.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package main
     5  
     6  import (
     7  	"context"
     8  	"crypto/ecdsa"
     9  	"crypto/elliptic"
    10  	"crypto/rand"
    11  	"encoding/json"
    12  	"flag"
    13  	"fmt"
    14  	"net"
    15  	"net/http"
    16  	"os"
    17  	"os/signal"
    18  	"strconv"
    19  	"strings"
    20  	"time"
    21  
    22  	"github.com/hashicorp/cap/oidc"
    23  	"github.com/hashicorp/cap/oidc/callback"
    24  	"github.com/hashicorp/cap/util"
    25  	"github.com/hashicorp/go-hclog"
    26  	"golang.org/x/oauth2"
    27  )
    28  
    29  // List of required configuration environment variables
    30  const (
    31  	clientID     = "OIDC_CLIENT_ID"
    32  	clientSecret = "OIDC_CLIENT_SECRET"
    33  	issuer       = "OIDC_ISSUER"
    34  	port         = "OIDC_PORT"
    35  	attemptExp   = "attemptExp"
    36  )
    37  
    38  func envConfig(secretNotRequired bool) (map[string]interface{}, error) {
    39  	const op = "envConfig"
    40  	env := map[string]interface{}{
    41  		clientID:     os.Getenv("OIDC_CLIENT_ID"),
    42  		clientSecret: os.Getenv("OIDC_CLIENT_SECRET"),
    43  		issuer:       os.Getenv("OIDC_ISSUER"),
    44  		port:         os.Getenv("OIDC_PORT"),
    45  		attemptExp:   time.Duration(2 * time.Minute),
    46  	}
    47  	for k, v := range env {
    48  		switch t := v.(type) {
    49  		case string:
    50  			switch k {
    51  			case "OIDC_CLIENT_SECRET":
    52  				switch {
    53  				case secretNotRequired:
    54  					env[k] = "" // unsetting the secret which isn't required
    55  				case t == "":
    56  					return nil, fmt.Errorf("%s: %s is empty.\n\n   Did you intend to use -pkce or -implicit options?", op, k)
    57  				}
    58  			default:
    59  				if t == "" {
    60  					return nil, fmt.Errorf("%s: %s is empty", op, k)
    61  				}
    62  			}
    63  		case time.Duration:
    64  			if t == 0 {
    65  				return nil, fmt.Errorf("%s: %s is empty", op, k)
    66  			}
    67  		default:
    68  			return nil, fmt.Errorf("%s: %s is an unhandled type %t", op, k, t)
    69  		}
    70  	}
    71  	return env, nil
    72  }
    73  
    74  func main() {
    75  	useImplicit := flag.Bool("implicit", false, "use the implicit flow")
    76  	implicitAccessToken := flag.Bool("implicit-access-token", false, "include the access_token in the implicit flow")
    77  	usePKCE := flag.Bool("pkce", false, "use the implicit flow")
    78  	maxAge := flag.Int("max-age", -1, "max age of user authentication")
    79  	scopes := flag.String("scopes", "", "comma separated list of additional scopes to requests")
    80  	useTestProvider := flag.Bool("use-test-provider", false, "use the test oidc provider")
    81  
    82  	flag.Parse()
    83  	if *useImplicit && *usePKCE {
    84  		fmt.Fprint(os.Stderr, "you can't request both: -implicit and -pkce")
    85  		return
    86  	}
    87  
    88  	if (*useImplicit || *implicitAccessToken || *scopes != "") && *useTestProvider {
    89  		fmt.Fprint(os.Stderr, "you can't use the implicit flow, PKCE or scopes with the test provider")
    90  		return
    91  	}
    92  
    93  	optScopes := strings.Split(*scopes, ",")
    94  	for i := range optScopes {
    95  		optScopes[i] = strings.TrimSpace(optScopes[i])
    96  	}
    97  
    98  	var env map[string]interface{}
    99  	var tp *oidc.TestProvider
   100  	if *useTestProvider {
   101  		l, err := oidc.NewTestingLogger(hclog.New(nil))
   102  		if err != nil {
   103  			fmt.Fprintf(os.Stderr, "%s\n\n", err)
   104  			return
   105  		}
   106  		// Generate a key to sign JWTs with throughout most test cases
   107  		priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
   108  		if err != nil {
   109  			fmt.Fprintf(os.Stderr, "%s\n\n", err)
   110  			return
   111  		}
   112  		oidcPort := os.Getenv("OIDC_PORT")
   113  		if oidcPort == "" {
   114  			oidcPort, err = func() (string, error) {
   115  				addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
   116  				if err != nil {
   117  					return "", err
   118  				}
   119  
   120  				l, err := net.ListenTCP("tcp", addr)
   121  				if err != nil {
   122  					return "", err
   123  				}
   124  				defer l.Close()
   125  				return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil
   126  			}()
   127  			if err != nil {
   128  				fmt.Fprintf(os.Stderr, "env OIDC_PORT is empty and error finding a free port: %s", err.Error())
   129  			}
   130  			return
   131  		}
   132  
   133  		id, secret := "test-rp", "fido"
   134  		tp = oidc.StartTestProvider(l, oidc.WithNoTLS(), oidc.WithTestDefaults(&oidc.TestProviderDefaults{
   135  			CustomClaims: map[string]interface{}{},
   136  			SubjectInfo: map[string]*oidc.TestSubject{
   137  				"alice": {
   138  					Password: "fido",
   139  					UserInfo: map[string]interface{}{
   140  						"email":  "alice@example.com",
   141  						"name":   "alice smith",
   142  						"friend": "eve",
   143  					},
   144  					CustomClaims: map[string]interface{}{
   145  						"email": "alice@example.com",
   146  						"name":  "alice smith",
   147  					},
   148  				},
   149  				"eve": {
   150  					Password: "alice",
   151  					UserInfo: map[string]interface{}{
   152  						"email":  "eve@example.com",
   153  						"name":   "eve smith",
   154  						"friend": "alice",
   155  					},
   156  					CustomClaims: map[string]interface{}{
   157  						"email": "eve@example.com",
   158  						"name":  "eve smith",
   159  					},
   160  				},
   161  			},
   162  			SigningKey: &oidc.TestSigningKey{
   163  				PrivKey: priv,
   164  				PubKey:  priv.Public(),
   165  				Alg:     oidc.ES384,
   166  			},
   167  			AllowedRedirectURIs: []string{fmt.Sprintf("http://localhost:%s/callback", oidcPort)},
   168  			ClientID:            &id,
   169  			ClientSecret:        &secret,
   170  		}))
   171  		defer tp.Stop()
   172  		env = map[string]interface{}{
   173  			clientID:     id,
   174  			clientSecret: secret,
   175  			issuer:       tp.Addr(),
   176  			port:         oidcPort,
   177  			attemptExp:   time.Duration(2 * time.Minute),
   178  		}
   179  	} else {
   180  		var err error
   181  		env, err = envConfig(*useImplicit || *usePKCE)
   182  		if err != nil {
   183  			fmt.Fprintf(os.Stderr, "%s\n\n", err)
   184  			return
   185  		}
   186  	}
   187  
   188  	// handle ctrl-c while waiting for the callback
   189  	sigintCh := make(chan os.Signal, 1)
   190  	signal.Notify(sigintCh, os.Interrupt)
   191  	defer signal.Stop(sigintCh)
   192  
   193  	ctx, cancel := context.WithCancel(context.Background())
   194  	defer cancel()
   195  
   196  	issuer := env[issuer].(string)
   197  	clientID := env[clientID].(string)
   198  	clientSecret := oidc.ClientSecret(env[clientSecret].(string))
   199  	redirectURL := fmt.Sprintf("http://localhost:%s/callback", env[port].(string))
   200  	pc, err := oidc.NewConfig(issuer, clientID, clientSecret, []oidc.Alg{oidc.ES384}, []string{redirectURL})
   201  	if err != nil {
   202  		fmt.Fprint(os.Stderr, err.Error())
   203  		return
   204  	}
   205  
   206  	p, err := oidc.NewProvider(pc)
   207  	if err != nil {
   208  		fmt.Fprint(os.Stderr, err.Error())
   209  		return
   210  	}
   211  	defer p.Done()
   212  
   213  	var requestOptions []oidc.Option
   214  	switch {
   215  	case *useImplicit && !*implicitAccessToken:
   216  		requestOptions = append(requestOptions, oidc.WithImplicitFlow())
   217  	case *useImplicit && *implicitAccessToken:
   218  		requestOptions = append(requestOptions, oidc.WithImplicitFlow(true))
   219  	case *usePKCE:
   220  		v, err := oidc.NewCodeVerifier()
   221  		if err != nil {
   222  			fmt.Fprint(os.Stderr, err.Error())
   223  			return
   224  		}
   225  		requestOptions = append(requestOptions, oidc.WithPKCE(v))
   226  	}
   227  
   228  	if *maxAge >= 0 {
   229  		requestOptions = append(requestOptions, oidc.WithMaxAge(uint(*maxAge)))
   230  	}
   231  
   232  	requestOptions = append(requestOptions, oidc.WithScopes(optScopes...))
   233  
   234  	oidcRequest, err := oidc.NewRequest(env[attemptExp].(time.Duration), redirectURL, requestOptions...)
   235  	if err != nil {
   236  		fmt.Fprint(os.Stderr, err.Error())
   237  		return
   238  	}
   239  
   240  	successFn, successCh := success()
   241  	errorFn, failedCh := failed()
   242  
   243  	var handler http.HandlerFunc
   244  	if *useImplicit {
   245  		handler, err = callback.Implicit(ctx, p, &callback.SingleRequestReader{Request: oidcRequest}, successFn, errorFn)
   246  		if err != nil {
   247  			fmt.Fprintf(os.Stderr, "error creating callback handler: %s", err)
   248  			return
   249  		}
   250  	} else {
   251  		handler, err = callback.AuthCode(ctx, p, &callback.SingleRequestReader{Request: oidcRequest}, successFn, errorFn)
   252  		if err != nil {
   253  			fmt.Fprintf(os.Stderr, "error creating auth code handler: %s", err)
   254  			return
   255  		}
   256  	}
   257  
   258  	authURL, err := p.AuthURL(ctx, oidcRequest)
   259  	if err != nil {
   260  		fmt.Fprintf(os.Stderr, "error getting auth url: %s", err)
   261  		return
   262  	}
   263  
   264  	// Set up callback handler
   265  	http.HandleFunc("/callback", handler)
   266  
   267  	listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%s", env[port]))
   268  	if err != nil {
   269  		fmt.Fprint(os.Stderr, err.Error())
   270  		return
   271  	}
   272  	defer listener.Close()
   273  
   274  	// Open the default browser to the callback URL.
   275  	fmt.Fprintf(os.Stderr, "Complete the login via your OIDC provider. Launching browser to:\n\n    %s\n\n\n", authURL)
   276  	if err := util.OpenURL(authURL); err != nil {
   277  		fmt.Fprintf(os.Stderr, "Error attempting to automatically open browser: '%s'.\nPlease visit the authorization URL manually.", err)
   278  	}
   279  
   280  	srvCh := make(chan error)
   281  	// Start local server
   282  	go func() {
   283  		err := http.Serve(listener, nil)
   284  		if err != nil && err != http.ErrServerClosed {
   285  			srvCh <- err
   286  		}
   287  	}()
   288  
   289  	// Wait for either the callback to finish, SIGINT to be received or up to 2 minutes
   290  	select {
   291  	case err := <-srvCh:
   292  		fmt.Fprintf(os.Stderr, "server closed with error: %s", err.Error())
   293  		return
   294  	case resp := <-successCh:
   295  		if resp.Error != nil {
   296  			fmt.Fprintf(os.Stderr, "channel received success with error: %s", resp.Error)
   297  			return
   298  		}
   299  		printToken(resp.Token)
   300  		printClaims(resp.Token.IDToken())
   301  		printUserInfo(ctx, p, resp.Token)
   302  		return
   303  	case err := <-failedCh:
   304  		if err != nil {
   305  			fmt.Fprintf(os.Stderr, "channel received error: %s", err)
   306  			return
   307  		}
   308  		fmt.Fprint(os.Stderr, "missing error from error channel.  try again?\n")
   309  		return
   310  	case <-sigintCh:
   311  		fmt.Fprintf(os.Stderr, "Interrupted")
   312  		return
   313  	case <-time.After(env[attemptExp].(time.Duration)):
   314  		fmt.Fprintf(os.Stderr, "Timed out waiting for response from provider")
   315  		return
   316  	}
   317  }
   318  
   319  type successResp struct {
   320  	Token oidc.Token // Token is populated when the callback successfully exchanges the auth code.
   321  	Error error      // Error is populated when there's an error during the callback
   322  }
   323  
   324  func success() (callback.SuccessResponseFunc, <-chan successResp) {
   325  	const op = "success"
   326  	doneCh := make(chan successResp)
   327  	return func(state string, t oidc.Token, w http.ResponseWriter, req *http.Request) {
   328  		var responseErr error
   329  		defer func() {
   330  			doneCh <- successResp{t, responseErr}
   331  			close(doneCh)
   332  		}()
   333  		w.WriteHeader(http.StatusOK)
   334  		if _, err := w.Write([]byte(successHTML)); err != nil {
   335  			responseErr = fmt.Errorf("%s: %w", op, err)
   336  			fmt.Fprintf(os.Stderr, "error writing successful response: %s", err)
   337  		}
   338  	}, doneCh
   339  }
   340  
   341  func failed() (callback.ErrorResponseFunc, <-chan error) {
   342  	const op = "failed"
   343  	doneCh := make(chan error)
   344  	return func(state string, r *callback.AuthenErrorResponse, e error, w http.ResponseWriter, req *http.Request) {
   345  		var responseErr error
   346  		defer func() {
   347  			if _, err := w.Write([]byte(responseErr.Error())); err != nil {
   348  				fmt.Fprintf(os.Stderr, "%s: error writing failed response: %s", op, err)
   349  			}
   350  			doneCh <- responseErr
   351  			close(doneCh)
   352  		}()
   353  
   354  		if e != nil {
   355  			fmt.Fprintf(os.Stderr, "%s: callback error: %s", op, e.Error())
   356  			responseErr = e
   357  			w.WriteHeader(http.StatusInternalServerError)
   358  			return
   359  		}
   360  		if r != nil {
   361  			responseErr = fmt.Errorf("%s: callback error from oidc provider: %s", op, r)
   362  			fmt.Fprint(os.Stderr, responseErr.Error())
   363  			w.WriteHeader(http.StatusUnauthorized)
   364  			return
   365  		}
   366  		responseErr = fmt.Errorf("%s: unknown error from callback", op)
   367  	}, doneCh
   368  }
   369  
   370  type respToken struct {
   371  	IDToken      string
   372  	AccessToken  string
   373  	RefreshToken string
   374  	Expiry       time.Time
   375  }
   376  
   377  func printClaims(t oidc.IDToken) {
   378  	const op = "printClaims"
   379  	var tokenClaims map[string]interface{}
   380  	if err := t.Claims(&tokenClaims); err != nil {
   381  		fmt.Fprintf(os.Stderr, "IDToken claims: error parsing: %s", err)
   382  	} else {
   383  		if idData, err := json.MarshalIndent(tokenClaims, "", "    "); err != nil {
   384  			fmt.Fprintf(os.Stderr, "%s: %s", op, err)
   385  		} else {
   386  			fmt.Fprintf(os.Stderr, "IDToken claims:%s\n", idData)
   387  		}
   388  	}
   389  }
   390  
   391  func printUserInfo(ctx context.Context, p *oidc.Provider, t oidc.Token) {
   392  	const op = "printUserInfo"
   393  	if ts, ok := t.(interface {
   394  		StaticTokenSource() oauth2.TokenSource
   395  	}); ok {
   396  		if ts.StaticTokenSource() == nil {
   397  			fmt.Fprintf(os.Stderr, "%s: no access_token received, so we're unable to get UserInfo claims", op)
   398  			return
   399  		}
   400  		vc := struct {
   401  			Sub string
   402  		}{}
   403  		if err := t.IDToken().Claims(&vc); err != nil {
   404  			fmt.Fprintf(os.Stderr, "%s: channel received success, but error getting UserInfo claims: %s", op, err)
   405  			return
   406  		}
   407  		var infoClaims map[string]interface{}
   408  		err := p.UserInfo(ctx, ts.StaticTokenSource(), vc.Sub, &infoClaims)
   409  		if err != nil {
   410  			fmt.Fprintf(os.Stderr, "%s: channel received success, but error getting UserInfo claims: %s", op, err)
   411  			return
   412  		}
   413  		infoData, err := json.MarshalIndent(infoClaims, "", "    ")
   414  		if err != nil {
   415  			fmt.Fprintf(os.Stderr, "%s: %s", op, err)
   416  			return
   417  		}
   418  		fmt.Fprintf(os.Stderr, "UserInfo claims:%s\n", infoData)
   419  		return
   420  	}
   421  }
   422  
   423  func printToken(t oidc.Token) {
   424  	const op = "printToken"
   425  	tokenData, err := json.MarshalIndent(printableToken(t), "", "    ")
   426  	if err != nil {
   427  		fmt.Fprintf(os.Stderr, "%s: %s", op, err)
   428  		return
   429  	}
   430  	fmt.Fprintf(os.Stderr, "channel received success.\nToken:%s\n", tokenData)
   431  }
   432  
   433  // printableToken is needed because the oidc.Token redacts the IDToken,
   434  // AccessToken and RefreshToken
   435  func printableToken(t oidc.Token) respToken {
   436  	return respToken{
   437  		IDToken:      string(t.IDToken()),
   438  		AccessToken:  string(t.AccessToken()),
   439  		RefreshToken: string(t.RefreshToken()),
   440  		Expiry:       t.Expiry(),
   441  	}
   442  }