github.com/louisevanderlith/droxolite@v1.20.2/open/hybridprotector.go (about)

     1  package open
     2  
     3  import (
     4  	"encoding/base64"
     5  	"encoding/json"
     6  	"fmt"
     7  	"github.com/coreos/go-oidc"
     8  	"github.com/louisevanderlith/droxolite/mix"
     9  	"golang.org/x/net/context"
    10  	"golang.org/x/oauth2"
    11  	"golang.org/x/oauth2/clientcredentials"
    12  	"io/ioutil"
    13  	"log"
    14  	"net/http"
    15  	"strings"
    16  )
    17  
    18  func NewHybridLock(p *oidc.Provider, clntCfg *clientcredentials.Config, usrConfig *oauth2.Config) hybridprotector {
    19  	return hybridprotector{
    20  		provider:   p,
    21  		clntConfig: clntCfg,
    22  		usrConfig:  usrConfig,
    23  	}
    24  }
    25  
    26  type hybridprotector struct {
    27  	provider   *oidc.Provider
    28  	clntConfig *clientcredentials.Config
    29  	usrConfig  *oauth2.Config
    30  }
    31  
    32  func (p hybridprotector) Refresh(w http.ResponseWriter, r *http.Request) {
    33  	jtoken, _ := r.Cookie("acctoken")
    34  
    35  	if jtoken == nil {
    36  		http.Error(w, "", http.StatusUnauthorized)
    37  		return
    38  	}
    39  
    40  	tkn64, err := base64.URLEncoding.DecodeString(jtoken.Value)
    41  
    42  	if err != nil {
    43  		http.Error(w, err.Error(), http.StatusInternalServerError)
    44  		return
    45  	}
    46  
    47  	tknVal := oauth2.Token{}
    48  	err = json.Unmarshal(tkn64, &tknVal)
    49  
    50  	if err != nil {
    51  		http.Error(w, err.Error(), http.StatusInternalServerError)
    52  		return
    53  	}
    54  
    55  	if tknVal.Valid() {
    56  		mix.Write(w, mix.JSON(jtoken.Value))
    57  		return
    58  	}
    59  
    60  	params := "grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s"
    61  	payload := strings.NewReader(fmt.Sprintf(params, p.usrConfig.ClientID, p.usrConfig.ClientSecret, tknVal.RefreshToken))
    62  
    63  	req, _ := http.NewRequest("POST", p.provider.Endpoint().TokenURL, payload)
    64  
    65  	req.Header.Add("content-type", "application/x-www-form-urlencoded")
    66  
    67  	res, err := http.DefaultClient.Do(req)
    68  
    69  	if err != nil {
    70  		http.Error(w, err.Error(), http.StatusInternalServerError)
    71  		return
    72  	}
    73  
    74  	defer res.Body.Close()
    75  
    76  	body, err := ioutil.ReadAll(res.Body)
    77  
    78  	if err != nil {
    79  		log.Println("ReadAll Error", err)
    80  		http.Error(w, err.Error(), http.StatusInternalServerError)
    81  		return
    82  	}
    83  
    84  	err = mix.Write(w, mix.JSON(body))
    85  
    86  	if err != nil {
    87  		log.Println("Serve Error", err)
    88  	}
    89  }
    90  
    91  func (p hybridprotector) Lock(handler http.Handler) http.Handler {
    92  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    93  		idtkn := r.Context().Value("IDToken")
    94  
    95  		if idtkn == nil {
    96  			p.Login(w, r)
    97  			return
    98  		}
    99  
   100  		handler.ServeHTTP(w, r)
   101  	})
   102  }
   103  
   104  func (p hybridprotector) Login(w http.ResponseWriter, r *http.Request) {
   105  	state := generateStateOauthCookie(w)
   106  	http.Redirect(w, r, p.usrConfig.AuthCodeURL(state), http.StatusTemporaryRedirect)
   107  }
   108  
   109  func (p hybridprotector) Callback(w http.ResponseWriter, r *http.Request) {
   110  	state, err := r.Cookie("oauthstate")
   111  
   112  	if err != nil {
   113  		http.Error(w, "state not found", http.StatusInternalServerError)
   114  		return
   115  	}
   116  
   117  	if r.URL.Query().Get("state") != state.Value {
   118  		http.Error(w, "state did not match", http.StatusBadRequest)
   119  		return
   120  	}
   121  
   122  	oauth2Token, err := p.usrConfig.Exchange(r.Context(), r.URL.Query().Get("code"))
   123  	if err != nil {
   124  		http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
   125  		return
   126  	}
   127  
   128  	rawIDToken, ok := oauth2Token.Extra("id_token").(string)
   129  
   130  	if !ok {
   131  		http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
   132  		return
   133  	}
   134  
   135  	jtoken, err := json.Marshal(oauth2Token)
   136  	if !ok {
   137  		http.Error(w, err.Error(), http.StatusInternalServerError)
   138  		return
   139  	}
   140  
   141  	tkn64 := base64.URLEncoding.EncodeToString(jtoken)
   142  	tokencookie := http.Cookie{
   143  		Name:     "acctoken",
   144  		Value:    tkn64,
   145  		MaxAge:   0,
   146  		Path:     "/",
   147  		HttpOnly: true,
   148  	}
   149  	http.SetCookie(w, &tokencookie)
   150  
   151  	idcookie := http.Cookie{
   152  		Name:     "idtoken",
   153  		Value:    rawIDToken,
   154  		MaxAge:   0,
   155  		Path:     "/",
   156  		HttpOnly: true,
   157  	}
   158  	http.SetCookie(w, &idcookie)
   159  
   160  	state.MaxAge = -1
   161  	state.Value = ""
   162  	http.SetCookie(w, state)
   163  
   164  	RedirectToLastLocation(w, r)
   165  }
   166  
   167  func (p hybridprotector) Logout(w http.ResponseWriter, r *http.Request) {
   168  	acc, err := r.Cookie("acctoken")
   169  
   170  	if err != nil {
   171  		http.Error(w, "acctoken not found", http.StatusInternalServerError)
   172  		return
   173  	}
   174  
   175  	acc.MaxAge = -1
   176  	acc.Value = ""
   177  	http.SetCookie(w, acc)
   178  
   179  	http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
   180  }
   181  
   182  func (p hybridprotector) Protect(next http.Handler) http.Handler {
   183  	oidcConfig := &oidc.Config{
   184  		ClientID: p.usrConfig.ClientID,
   185  	}
   186  
   187  	v := p.provider.Verifier(oidcConfig)
   188  
   189  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   190  		setLastLocationCookie(w, r.URL.EscapedPath())
   191  
   192  		jtoken, _ := r.Cookie("acctoken")
   193  
   194  		if jtoken == nil {
   195  			tkn, err := p.clntConfig.Token(r.Context())
   196  
   197  			if err != nil {
   198  				http.Error(w, err.Error(), http.StatusInternalServerError)
   199  				return
   200  			}
   201  
   202  			acc := context.WithValue(r.Context(), "Token", *tkn)
   203  			next.ServeHTTP(w, r.WithContext(acc))
   204  			return
   205  		}
   206  
   207  		tkn64, err := base64.URLEncoding.DecodeString(jtoken.Value)
   208  
   209  		if err != nil {
   210  			http.Error(w, err.Error(), http.StatusInternalServerError)
   211  			return
   212  		}
   213  
   214  		accToken := oauth2.Token{}
   215  		err = json.Unmarshal(tkn64, &accToken)
   216  
   217  		if err != nil {
   218  			http.Error(w, err.Error(), http.StatusInternalServerError)
   219  			return
   220  		}
   221  
   222  		xidn := context.WithValue(r.Context(), "Token", accToken)
   223  
   224  		rawIDToken, err := r.Cookie("idtoken")
   225  
   226  		if err != nil {
   227  			log.Println("Cookie Error", err)
   228  			next.ServeHTTP(w, r.WithContext(xidn))
   229  			return
   230  		}
   231  
   232  		idToken, err := v.Verify(r.Context(), rawIDToken.Value)
   233  		if err != nil {
   234  			log.Println("ID Verify Error", err)
   235  			next.ServeHTTP(w, r.WithContext(xidn))
   236  			return
   237  		}
   238  
   239  		err = idToken.VerifyAccessToken(accToken.AccessToken)
   240  
   241  		if err != nil {
   242  			http.Error(w, err.Error(), http.StatusInternalServerError)
   243  			return
   244  		}
   245  
   246  		idn := context.WithValue(xidn, "IDToken", idToken)
   247  		//TODO: Replace IDToken with Claims (User)
   248  
   249  		next.ServeHTTP(w, r.WithContext(idn))
   250  	})
   251  }