github.com/pyroscope-io/pyroscope@v0.37.3-0.20230725203016-5f6947968bd0/pkg/server/oauth_base.go (about) 1 package server 2 3 import ( 4 "errors" 5 "fmt" 6 "net/http" 7 "net/url" 8 "path/filepath" 9 "regexp" 10 "strings" 11 12 "github.com/sirupsen/logrus" 13 "golang.org/x/oauth2" 14 ) 15 16 var errForbidden = errors.New("access forbidden") 17 18 type extUserInfo struct { 19 Name string 20 Email string 21 } 22 23 type oauthHandler interface { 24 userAuth(client *http.Client) (extUserInfo, error) 25 getOauthBase() oauthBase 26 } 27 28 type oauthBase struct { 29 config *oauth2.Config 30 authURL *url.URL 31 apiURL string 32 log *logrus.Logger 33 callbackRoute string 34 redirectRoute string 35 baseURL string 36 } 37 38 func (o oauthBase) getCallbackURL(host, configCallbackURL string, hasTLS bool) (string, error) { 39 // I don't think this is ever true... but not super sure 40 if configCallbackURL != "" { 41 return configCallbackURL, nil 42 } 43 44 schema := "http" 45 if hasTLS { 46 schema = "https" 47 } 48 49 if o.baseURL != "" { 50 u, err := url.Parse(o.baseURL) 51 if err != nil { 52 return "", err 53 } 54 if u.Scheme == "" { 55 u.Scheme = schema 56 } 57 if u.Host == "" { 58 u.Host = host 59 } 60 u.Path = filepath.Join(u.Path, o.callbackRoute) 61 return u.String(), nil 62 } 63 64 if host == "" { 65 return "", errors.New("host is empty") 66 } 67 68 return fmt.Sprintf("%v://%v%v", schema, host, o.callbackRoute), nil 69 } 70 71 func (o oauthBase) buildAuthQuery(r *http.Request, w http.ResponseWriter) (redirectURL string, state string, err error) { 72 callbackURL, err := o.getCallbackURL(r.Host, o.config.RedirectURL, r.URL.Query().Get("tls") == "true") 73 if err != nil { 74 w.WriteHeader(http.StatusBadRequest) 75 return "", "", fmt.Errorf("callbackURL parsing failed: %w", err) 76 } 77 78 authURL := *o.authURL 79 parameters := url.Values{} 80 parameters.Add("client_id", o.config.ClientID) 81 parameters.Add("scope", strings.Join(o.config.Scopes, " ")) 82 parameters.Add("redirect_uri", callbackURL) 83 parameters.Add("response_type", "code") 84 85 // generate state token for CSRF protection 86 if state, err = generateStateToken(16); err != nil { 87 w.WriteHeader(http.StatusInternalServerError) 88 return "", "", fmt.Errorf("problem generating state token: %w", err) 89 } 90 91 parameters.Add("state", state) 92 authURL.RawQuery = parameters.Encode() 93 return authURL.String(), state, nil 94 } 95 96 func (o oauthBase) generateOauthClient(r *http.Request) (*http.Client, error) { 97 code := r.FormValue("code") 98 if code == "" { 99 return nil, errors.New("code not found") 100 } 101 102 callbackURL, err := o.getCallbackURL(r.Host, o.config.RedirectURL, r.URL.Query().Get("tls") == "true") 103 if err != nil { 104 return nil, fmt.Errorf("callbackURL parsing failed: %w", err) 105 } 106 oauthConf := *o.config 107 oauthConf.RedirectURL = callbackURL 108 token, err := oauthConf.Exchange(r.Context(), code) 109 if err != nil { 110 return nil, fmt.Errorf("exchanging auth code for token failed: %w", err) 111 } 112 113 return oauthConf.Client(r.Context(), token), err 114 } 115 116 func hasMoreLinkResults(headers http.Header) (string, bool) { 117 value, exists := headers["Link"] 118 if !exists { 119 return "", false 120 } 121 122 pattern := regexp.MustCompile(`<([^>]+)>; rel="next"`) 123 matches := pattern.FindStringSubmatch(value[0]) 124 if matches == nil { 125 return "", false 126 } 127 128 next := matches[1] 129 130 return next, true 131 }