github.com/hashicorp/cap@v0.6.0/oidc/examples/spa/route_callback.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package main
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"net/http"
    10  	"os"
    11  
    12  	"github.com/hashicorp/cap/oidc"
    13  	"github.com/hashicorp/cap/oidc/callback"
    14  )
    15  
    16  func CallbackHandler(ctx context.Context, p *oidc.Provider, rc *requestCache, withImplicit bool) (http.HandlerFunc, error) {
    17  	if withImplicit {
    18  		c, err := callback.Implicit(ctx, p, rc, successFn(ctx, rc), failedFn(ctx, rc))
    19  		if err != nil {
    20  			return nil, fmt.Errorf("CallbackHandler: %w", err)
    21  		}
    22  		return c, nil
    23  	}
    24  	c, err := callback.AuthCode(ctx, p, rc, successFn(ctx, rc), failedFn(ctx, rc))
    25  	if err != nil {
    26  		return nil, fmt.Errorf("CallbackHandler: %w", err)
    27  	}
    28  	return c, nil
    29  }
    30  
    31  func successFn(ctx context.Context, rc *requestCache) callback.SuccessResponseFunc {
    32  	return func(state string, t oidc.Token, w http.ResponseWriter, req *http.Request) {
    33  		oidcRequest, err := rc.Read(ctx, state)
    34  		if err != nil {
    35  			fmt.Fprintf(os.Stderr, "error reading state during successful response: %s\n", err)
    36  			http.Error(w, err.Error(), http.StatusInternalServerError)
    37  			return
    38  		}
    39  		if err := rc.SetToken(oidcRequest.State(), t); err != nil {
    40  			fmt.Fprintf(os.Stderr, "error updating state during successful response: %s\n", err)
    41  			http.Error(w, err.Error(), http.StatusInternalServerError)
    42  			return
    43  		}
    44  		// Redirect to logged in page
    45  		http.Redirect(w, req, fmt.Sprintf("/success?state=%s", state), http.StatusSeeOther)
    46  	}
    47  }
    48  
    49  func failedFn(ctx context.Context, rc *requestCache) callback.ErrorResponseFunc {
    50  	const op = "failedFn"
    51  	return func(state string, r *callback.AuthenErrorResponse, e error, w http.ResponseWriter, req *http.Request) {
    52  		var responseErr error
    53  		defer func() {
    54  			if _, err := w.Write([]byte(responseErr.Error())); err != nil {
    55  				fmt.Fprintf(os.Stderr, "error writing failed response: %s\n", err)
    56  			}
    57  		}()
    58  
    59  		if e != nil {
    60  			fmt.Fprintf(os.Stderr, "callback error: %s\n", e.Error())
    61  			responseErr = e
    62  			w.WriteHeader(http.StatusInternalServerError)
    63  			return
    64  		}
    65  		if r != nil {
    66  			fmt.Fprintf(os.Stderr, "callback error from oidc provider: %s\n", r)
    67  			responseErr = fmt.Errorf("%s: callback error from oidc provider: %s", op, r)
    68  			w.WriteHeader(http.StatusUnauthorized)
    69  			return
    70  		}
    71  		responseErr = fmt.Errorf("%s: unknown error from callback", op)
    72  	}
    73  }