github.com/hashicorp/cap@v0.6.0/oidc/examples/spa/route_callback.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package main 5 6 import ( 7 "context" 8 "fmt" 9 "net/http" 10 "os" 11 12 "github.com/hashicorp/cap/oidc" 13 "github.com/hashicorp/cap/oidc/callback" 14 ) 15 16 func CallbackHandler(ctx context.Context, p *oidc.Provider, rc *requestCache, withImplicit bool) (http.HandlerFunc, error) { 17 if withImplicit { 18 c, err := callback.Implicit(ctx, p, rc, successFn(ctx, rc), failedFn(ctx, rc)) 19 if err != nil { 20 return nil, fmt.Errorf("CallbackHandler: %w", err) 21 } 22 return c, nil 23 } 24 c, err := callback.AuthCode(ctx, p, rc, successFn(ctx, rc), failedFn(ctx, rc)) 25 if err != nil { 26 return nil, fmt.Errorf("CallbackHandler: %w", err) 27 } 28 return c, nil 29 } 30 31 func successFn(ctx context.Context, rc *requestCache) callback.SuccessResponseFunc { 32 return func(state string, t oidc.Token, w http.ResponseWriter, req *http.Request) { 33 oidcRequest, err := rc.Read(ctx, state) 34 if err != nil { 35 fmt.Fprintf(os.Stderr, "error reading state during successful response: %s\n", err) 36 http.Error(w, err.Error(), http.StatusInternalServerError) 37 return 38 } 39 if err := rc.SetToken(oidcRequest.State(), t); err != nil { 40 fmt.Fprintf(os.Stderr, "error updating state during successful response: %s\n", err) 41 http.Error(w, err.Error(), http.StatusInternalServerError) 42 return 43 } 44 // Redirect to logged in page 45 http.Redirect(w, req, fmt.Sprintf("/success?state=%s", state), http.StatusSeeOther) 46 } 47 } 48 49 func failedFn(ctx context.Context, rc *requestCache) callback.ErrorResponseFunc { 50 const op = "failedFn" 51 return func(state string, r *callback.AuthenErrorResponse, e error, w http.ResponseWriter, req *http.Request) { 52 var responseErr error 53 defer func() { 54 if _, err := w.Write([]byte(responseErr.Error())); err != nil { 55 fmt.Fprintf(os.Stderr, "error writing failed response: %s\n", err) 56 } 57 }() 58 59 if e != nil { 60 fmt.Fprintf(os.Stderr, "callback error: %s\n", e.Error()) 61 responseErr = e 62 w.WriteHeader(http.StatusInternalServerError) 63 return 64 } 65 if r != nil { 66 fmt.Fprintf(os.Stderr, "callback error from oidc provider: %s\n", r) 67 responseErr = fmt.Errorf("%s: callback error from oidc provider: %s", op, r) 68 w.WriteHeader(http.StatusUnauthorized) 69 return 70 } 71 responseErr = fmt.Errorf("%s: unknown error from callback", op) 72 } 73 }