github.com/pyroscope-io/pyroscope@v0.37.3-0.20230725203016-5f6947968bd0/pkg/server/oauth_base.go (about)

     1  package server
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"net/url"
     8  	"path/filepath"
     9  	"regexp"
    10  	"strings"
    11  
    12  	"github.com/sirupsen/logrus"
    13  	"golang.org/x/oauth2"
    14  )
    15  
    16  var errForbidden = errors.New("access forbidden")
    17  
    18  type extUserInfo struct {
    19  	Name  string
    20  	Email string
    21  }
    22  
    23  type oauthHandler interface {
    24  	userAuth(client *http.Client) (extUserInfo, error)
    25  	getOauthBase() oauthBase
    26  }
    27  
    28  type oauthBase struct {
    29  	config        *oauth2.Config
    30  	authURL       *url.URL
    31  	apiURL        string
    32  	log           *logrus.Logger
    33  	callbackRoute string
    34  	redirectRoute string
    35  	baseURL       string
    36  }
    37  
    38  func (o oauthBase) getCallbackURL(host, configCallbackURL string, hasTLS bool) (string, error) {
    39  	// I don't think this is ever true... but not super sure
    40  	if configCallbackURL != "" {
    41  		return configCallbackURL, nil
    42  	}
    43  
    44  	schema := "http"
    45  	if hasTLS {
    46  		schema = "https"
    47  	}
    48  
    49  	if o.baseURL != "" {
    50  		u, err := url.Parse(o.baseURL)
    51  		if err != nil {
    52  			return "", err
    53  		}
    54  		if u.Scheme == "" {
    55  			u.Scheme = schema
    56  		}
    57  		if u.Host == "" {
    58  			u.Host = host
    59  		}
    60  		u.Path = filepath.Join(u.Path, o.callbackRoute)
    61  		return u.String(), nil
    62  	}
    63  
    64  	if host == "" {
    65  		return "", errors.New("host is empty")
    66  	}
    67  
    68  	return fmt.Sprintf("%v://%v%v", schema, host, o.callbackRoute), nil
    69  }
    70  
    71  func (o oauthBase) buildAuthQuery(r *http.Request, w http.ResponseWriter) (redirectURL string, state string, err error) {
    72  	callbackURL, err := o.getCallbackURL(r.Host, o.config.RedirectURL, r.URL.Query().Get("tls") == "true")
    73  	if err != nil {
    74  		w.WriteHeader(http.StatusBadRequest)
    75  		return "", "", fmt.Errorf("callbackURL parsing failed: %w", err)
    76  	}
    77  
    78  	authURL := *o.authURL
    79  	parameters := url.Values{}
    80  	parameters.Add("client_id", o.config.ClientID)
    81  	parameters.Add("scope", strings.Join(o.config.Scopes, " "))
    82  	parameters.Add("redirect_uri", callbackURL)
    83  	parameters.Add("response_type", "code")
    84  
    85  	// generate state token for CSRF protection
    86  	if state, err = generateStateToken(16); err != nil {
    87  		w.WriteHeader(http.StatusInternalServerError)
    88  		return "", "", fmt.Errorf("problem generating state token: %w", err)
    89  	}
    90  
    91  	parameters.Add("state", state)
    92  	authURL.RawQuery = parameters.Encode()
    93  	return authURL.String(), state, nil
    94  }
    95  
    96  func (o oauthBase) generateOauthClient(r *http.Request) (*http.Client, error) {
    97  	code := r.FormValue("code")
    98  	if code == "" {
    99  		return nil, errors.New("code not found")
   100  	}
   101  
   102  	callbackURL, err := o.getCallbackURL(r.Host, o.config.RedirectURL, r.URL.Query().Get("tls") == "true")
   103  	if err != nil {
   104  		return nil, fmt.Errorf("callbackURL parsing failed: %w", err)
   105  	}
   106  	oauthConf := *o.config
   107  	oauthConf.RedirectURL = callbackURL
   108  	token, err := oauthConf.Exchange(r.Context(), code)
   109  	if err != nil {
   110  		return nil, fmt.Errorf("exchanging auth code for token failed: %w", err)
   111  	}
   112  
   113  	return oauthConf.Client(r.Context(), token), err
   114  }
   115  
   116  func hasMoreLinkResults(headers http.Header) (string, bool) {
   117  	value, exists := headers["Link"]
   118  	if !exists {
   119  		return "", false
   120  	}
   121  
   122  	pattern := regexp.MustCompile(`<([^>]+)>; rel="next"`)
   123  	matches := pattern.FindStringSubmatch(value[0])
   124  	if matches == nil {
   125  		return "", false
   126  	}
   127  
   128  	next := matches[1]
   129  
   130  	return next, true
   131  }