github.com/louisevanderlith/droxolite@v1.20.2/open/hybridprotector.go (about) 1 package open 2 3 import ( 4 "encoding/base64" 5 "encoding/json" 6 "fmt" 7 "github.com/coreos/go-oidc" 8 "github.com/louisevanderlith/droxolite/mix" 9 "golang.org/x/net/context" 10 "golang.org/x/oauth2" 11 "golang.org/x/oauth2/clientcredentials" 12 "io/ioutil" 13 "log" 14 "net/http" 15 "strings" 16 ) 17 18 func NewHybridLock(p *oidc.Provider, clntCfg *clientcredentials.Config, usrConfig *oauth2.Config) hybridprotector { 19 return hybridprotector{ 20 provider: p, 21 clntConfig: clntCfg, 22 usrConfig: usrConfig, 23 } 24 } 25 26 type hybridprotector struct { 27 provider *oidc.Provider 28 clntConfig *clientcredentials.Config 29 usrConfig *oauth2.Config 30 } 31 32 func (p hybridprotector) Refresh(w http.ResponseWriter, r *http.Request) { 33 jtoken, _ := r.Cookie("acctoken") 34 35 if jtoken == nil { 36 http.Error(w, "", http.StatusUnauthorized) 37 return 38 } 39 40 tkn64, err := base64.URLEncoding.DecodeString(jtoken.Value) 41 42 if err != nil { 43 http.Error(w, err.Error(), http.StatusInternalServerError) 44 return 45 } 46 47 tknVal := oauth2.Token{} 48 err = json.Unmarshal(tkn64, &tknVal) 49 50 if err != nil { 51 http.Error(w, err.Error(), http.StatusInternalServerError) 52 return 53 } 54 55 if tknVal.Valid() { 56 mix.Write(w, mix.JSON(jtoken.Value)) 57 return 58 } 59 60 params := "grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s" 61 payload := strings.NewReader(fmt.Sprintf(params, p.usrConfig.ClientID, p.usrConfig.ClientSecret, tknVal.RefreshToken)) 62 63 req, _ := http.NewRequest("POST", p.provider.Endpoint().TokenURL, payload) 64 65 req.Header.Add("content-type", "application/x-www-form-urlencoded") 66 67 res, err := http.DefaultClient.Do(req) 68 69 if err != nil { 70 http.Error(w, err.Error(), http.StatusInternalServerError) 71 return 72 } 73 74 defer res.Body.Close() 75 76 body, err := ioutil.ReadAll(res.Body) 77 78 if err != nil { 79 log.Println("ReadAll Error", err) 80 http.Error(w, err.Error(), http.StatusInternalServerError) 81 return 82 } 83 84 err = mix.Write(w, mix.JSON(body)) 85 86 if err != nil { 87 log.Println("Serve Error", err) 88 } 89 } 90 91 func (p hybridprotector) Lock(handler http.Handler) http.Handler { 92 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 93 idtkn := r.Context().Value("IDToken") 94 95 if idtkn == nil { 96 p.Login(w, r) 97 return 98 } 99 100 handler.ServeHTTP(w, r) 101 }) 102 } 103 104 func (p hybridprotector) Login(w http.ResponseWriter, r *http.Request) { 105 state := generateStateOauthCookie(w) 106 http.Redirect(w, r, p.usrConfig.AuthCodeURL(state), http.StatusTemporaryRedirect) 107 } 108 109 func (p hybridprotector) Callback(w http.ResponseWriter, r *http.Request) { 110 state, err := r.Cookie("oauthstate") 111 112 if err != nil { 113 http.Error(w, "state not found", http.StatusInternalServerError) 114 return 115 } 116 117 if r.URL.Query().Get("state") != state.Value { 118 http.Error(w, "state did not match", http.StatusBadRequest) 119 return 120 } 121 122 oauth2Token, err := p.usrConfig.Exchange(r.Context(), r.URL.Query().Get("code")) 123 if err != nil { 124 http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) 125 return 126 } 127 128 rawIDToken, ok := oauth2Token.Extra("id_token").(string) 129 130 if !ok { 131 http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) 132 return 133 } 134 135 jtoken, err := json.Marshal(oauth2Token) 136 if !ok { 137 http.Error(w, err.Error(), http.StatusInternalServerError) 138 return 139 } 140 141 tkn64 := base64.URLEncoding.EncodeToString(jtoken) 142 tokencookie := http.Cookie{ 143 Name: "acctoken", 144 Value: tkn64, 145 MaxAge: 0, 146 Path: "/", 147 HttpOnly: true, 148 } 149 http.SetCookie(w, &tokencookie) 150 151 idcookie := http.Cookie{ 152 Name: "idtoken", 153 Value: rawIDToken, 154 MaxAge: 0, 155 Path: "/", 156 HttpOnly: true, 157 } 158 http.SetCookie(w, &idcookie) 159 160 state.MaxAge = -1 161 state.Value = "" 162 http.SetCookie(w, state) 163 164 RedirectToLastLocation(w, r) 165 } 166 167 func (p hybridprotector) Logout(w http.ResponseWriter, r *http.Request) { 168 acc, err := r.Cookie("acctoken") 169 170 if err != nil { 171 http.Error(w, "acctoken not found", http.StatusInternalServerError) 172 return 173 } 174 175 acc.MaxAge = -1 176 acc.Value = "" 177 http.SetCookie(w, acc) 178 179 http.Redirect(w, r, "/", http.StatusTemporaryRedirect) 180 } 181 182 func (p hybridprotector) Protect(next http.Handler) http.Handler { 183 oidcConfig := &oidc.Config{ 184 ClientID: p.usrConfig.ClientID, 185 } 186 187 v := p.provider.Verifier(oidcConfig) 188 189 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 190 setLastLocationCookie(w, r.URL.EscapedPath()) 191 192 jtoken, _ := r.Cookie("acctoken") 193 194 if jtoken == nil { 195 tkn, err := p.clntConfig.Token(r.Context()) 196 197 if err != nil { 198 http.Error(w, err.Error(), http.StatusInternalServerError) 199 return 200 } 201 202 acc := context.WithValue(r.Context(), "Token", *tkn) 203 next.ServeHTTP(w, r.WithContext(acc)) 204 return 205 } 206 207 tkn64, err := base64.URLEncoding.DecodeString(jtoken.Value) 208 209 if err != nil { 210 http.Error(w, err.Error(), http.StatusInternalServerError) 211 return 212 } 213 214 accToken := oauth2.Token{} 215 err = json.Unmarshal(tkn64, &accToken) 216 217 if err != nil { 218 http.Error(w, err.Error(), http.StatusInternalServerError) 219 return 220 } 221 222 xidn := context.WithValue(r.Context(), "Token", accToken) 223 224 rawIDToken, err := r.Cookie("idtoken") 225 226 if err != nil { 227 log.Println("Cookie Error", err) 228 next.ServeHTTP(w, r.WithContext(xidn)) 229 return 230 } 231 232 idToken, err := v.Verify(r.Context(), rawIDToken.Value) 233 if err != nil { 234 log.Println("ID Verify Error", err) 235 next.ServeHTTP(w, r.WithContext(xidn)) 236 return 237 } 238 239 err = idToken.VerifyAccessToken(accToken.AccessToken) 240 241 if err != nil { 242 http.Error(w, err.Error(), http.StatusInternalServerError) 243 return 244 } 245 246 idn := context.WithValue(xidn, "IDToken", idToken) 247 //TODO: Replace IDToken with Claims (User) 248 249 next.ServeHTTP(w, r.WithContext(idn)) 250 }) 251 }