github.com/hashicorp/cap@v0.6.0/oidc/examples/spa/main.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package main
     5  
     6  import (
     7  	"context"
     8  	"flag"
     9  	"fmt"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"os/signal"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/hashicorp/cap/oidc"
    18  )
    19  
    20  // List of required configuration environment variables
    21  const (
    22  	clientID     = "OIDC_CLIENT_ID"
    23  	clientSecret = "OIDC_CLIENT_SECRET"
    24  	issuer       = "OIDC_ISSUER"
    25  	port         = "OIDC_PORT"
    26  	attemptExp   = "attemptExp"
    27  )
    28  
    29  func envConfig(secretNotRequired bool) (map[string]interface{}, error) {
    30  	const op = "envConfig"
    31  	env := map[string]interface{}{
    32  		clientID:     os.Getenv("OIDC_CLIENT_ID"),
    33  		clientSecret: os.Getenv("OIDC_CLIENT_SECRET"),
    34  		issuer:       os.Getenv("OIDC_ISSUER"),
    35  		port:         os.Getenv("OIDC_PORT"),
    36  		attemptExp:   time.Duration(2 * time.Minute),
    37  	}
    38  	for k, v := range env {
    39  		switch t := v.(type) {
    40  		case string:
    41  			switch k {
    42  			case "OIDC_CLIENT_SECRET":
    43  				switch {
    44  				case secretNotRequired:
    45  					env[k] = "" // unsetting the secret which isn't required
    46  				case t == "":
    47  					return nil, fmt.Errorf("%s: %s is empty.\n\n   Did you intend to use -pkce or -implicit options?", op, k)
    48  				}
    49  			default:
    50  				if t == "" {
    51  					return nil, fmt.Errorf("%s: %s is empty", op, k)
    52  				}
    53  			}
    54  		case time.Duration:
    55  			if t == 0 {
    56  				return nil, fmt.Errorf("%s: %s is empty", op, k)
    57  			}
    58  		default:
    59  			return nil, fmt.Errorf("%s: %s is an unhandled type %t", op, k, t)
    60  		}
    61  	}
    62  	return env, nil
    63  }
    64  
    65  func main() {
    66  	useImplicit := flag.Bool("implicit", false, "use the implicit flow")
    67  	usePKCE := flag.Bool("pkce", false, "use the implicit flow")
    68  	maxAge := flag.Int("max-age", -1, "max age of user authentication")
    69  	scopes := flag.String("scopes", "", "comma separated list of additional scopes to requests")
    70  
    71  	flag.Parse()
    72  	if *useImplicit && *usePKCE {
    73  		fmt.Fprint(os.Stderr, "you can't request both: -implicit and -pkce")
    74  		return
    75  	}
    76  
    77  	optScopes := strings.Split(*scopes, ",")
    78  	for i := range optScopes {
    79  		optScopes[i] = strings.TrimSpace(optScopes[i])
    80  	}
    81  
    82  	env, err := envConfig(*useImplicit || *usePKCE)
    83  	if err != nil {
    84  		fmt.Fprintf(os.Stderr, "%s\n\n", err)
    85  		return
    86  	}
    87  
    88  	// handle ctrl-c while waiting for the callback
    89  	sigintCh := make(chan os.Signal, 1)
    90  	signal.Notify(sigintCh, os.Interrupt)
    91  	defer signal.Stop(sigintCh)
    92  
    93  	ctx, cancel := context.WithCancel(context.Background())
    94  	defer cancel()
    95  
    96  	issuer := env[issuer].(string)
    97  	clientID := env[clientID].(string)
    98  	clientSecret := oidc.ClientSecret(env[clientSecret].(string))
    99  	redirectURL := fmt.Sprintf("http://localhost:%s/callback", env[port].(string))
   100  	timeout := env[attemptExp].(time.Duration)
   101  
   102  	rc := newRequestCache()
   103  
   104  	pc, err := oidc.NewConfig(issuer, clientID, clientSecret, []oidc.Alg{oidc.RS256}, []string{redirectURL})
   105  	if err != nil {
   106  		fmt.Fprint(os.Stderr, err.Error())
   107  		return
   108  	}
   109  
   110  	p, err := oidc.NewProvider(pc)
   111  	if err != nil {
   112  		fmt.Fprint(os.Stderr, err.Error())
   113  		return
   114  	}
   115  	defer p.Done()
   116  
   117  	if err != nil {
   118  		fmt.Fprintf(os.Stderr, "error getting auth url: %s", err)
   119  		return
   120  	}
   121  
   122  	callback, err := CallbackHandler(ctx, p, rc, *useImplicit)
   123  	if err != nil {
   124  		fmt.Fprintf(os.Stderr, "error creating callback handler: %s", err)
   125  		return
   126  	}
   127  
   128  	var requestOptions []oidc.Option
   129  	switch {
   130  	case *useImplicit:
   131  		requestOptions = append(requestOptions, oidc.WithImplicitFlow())
   132  	case *usePKCE:
   133  		v, err := oidc.NewCodeVerifier()
   134  		if err != nil {
   135  			fmt.Fprint(os.Stderr, err.Error())
   136  			return
   137  		}
   138  		requestOptions = append(requestOptions, oidc.WithPKCE(v))
   139  	}
   140  	if *maxAge >= 0 {
   141  		requestOptions = append(requestOptions, oidc.WithMaxAge(uint(*maxAge)))
   142  	}
   143  
   144  	requestOptions = append(requestOptions, oidc.WithScopes(optScopes...))
   145  
   146  	// Set up callback handler
   147  	http.HandleFunc("/callback", callback)
   148  	http.HandleFunc("/login", LoginHandler(ctx, p, rc, timeout, redirectURL, requestOptions))
   149  	http.HandleFunc("/success", SuccessHandler(ctx, rc))
   150  
   151  	listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%s", env[port]))
   152  	if err != nil {
   153  		fmt.Fprint(os.Stderr, err.Error())
   154  		return
   155  	}
   156  	defer listener.Close()
   157  
   158  	srvCh := make(chan error)
   159  	// Start local server
   160  	go func() {
   161  		fmt.Fprintf(os.Stderr, "Complete the login via your OIDC provider. Launching browser to:\n\n    http://localhost:%s/login\n\n\n", env[port])
   162  		err := http.Serve(listener, nil)
   163  		if err != nil && err != http.ErrServerClosed {
   164  			srvCh <- err
   165  		}
   166  	}()
   167  
   168  	// Wait for either the callback to finish, SIGINT to be received or up to 2 minutes
   169  	select {
   170  	case err := <-srvCh:
   171  		fmt.Fprintf(os.Stderr, "server closed with error: %s", err.Error())
   172  		return
   173  	case <-sigintCh:
   174  		fmt.Fprintf(os.Stderr, "Interrupted")
   175  		return
   176  	}
   177  }