github.com/kubeshop/testkube@v1.17.23/pkg/oauth/oauth.go (about)

     1  package oauth
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net/http"
     8  	"strconv"
     9  	"time"
    10  
    11  	"github.com/skratchdot/open-golang/open"
    12  	"golang.org/x/oauth2"
    13  
    14  	"github.com/kubeshop/testkube/pkg/rand"
    15  	"github.com/kubeshop/testkube/pkg/ui"
    16  )
    17  
    18  // key is context key
    19  type key int
    20  
    21  const (
    22  	// localIP is ip uo open local website
    23  	localIP = "127.0.0.1"
    24  	// localPort is port to open local website
    25  	localPort = 13254
    26  	// authTimeout is time to wait for authentication completed
    27  	authTimeout = 60
    28  	// oauthStateStringContextKey is a context key for oauth strategy
    29  	oauthStateStringContextKey key = 987
    30  	// callbackPath is a path to callback handler
    31  	callbackPath = "/oauth/callback"
    32  	// errorPath is a path to error handler
    33  	errorPath = "/oauth/error"
    34  	// redirectDelay is redirect delay
    35  	redirectDelay = 10 * time.Second
    36  	// shutdownTimeout is shutdown timeout
    37  	shutdownTimeout = 5 * time.Second
    38  	// randomLength is a length of a random string
    39  	randomLength = 8
    40  	// successPage is a page to show for success authentication
    41  	successPage = `<html><body><h2>Success!</h2>
    42  		<p>You are authenticated, you can now return to the program.</p></body></html>`
    43  	// errorPage is a page to show for failed authentication
    44  	errorPage = `<html><body><h2>Error!</h2>
    45  		<p>Authentication was failed, please check the program logs.</p></body</html>`
    46  	// AuthorizationPrefix is authorization prefix
    47  	AuthorizationPrefix = "Bearer"
    48  )
    49  
    50  // NewProvider returns new provider
    51  func NewProvider(clientID, clientSecret string, scopes []string) Provider {
    52  	// add transport for self-signed certificate to context
    53  	tr := &http.Transport{
    54  		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
    55  		Proxy:           http.ProxyFromEnvironment,
    56  	}
    57  
    58  	client := &http.Client{Transport: tr}
    59  	provider := Provider{
    60  		clientID:     clientID,
    61  		clientSecret: clientSecret,
    62  		scopes:       scopes,
    63  		client:       client,
    64  		port:         localPort,
    65  		validators:   map[ProviderType]Validator{},
    66  	}
    67  
    68  	provider.AddValidator(GithubProviderType, NewGithubValidator(client, clientID, clientSecret, scopes))
    69  	return provider
    70  }
    71  
    72  // Provider contains oauth provider config
    73  type Provider struct {
    74  	clientID     string
    75  	clientSecret string
    76  	scopes       []string
    77  	client       *http.Client
    78  	port         int
    79  	validators   map[ProviderType]Validator
    80  }
    81  
    82  // AuthorizedClient is authorized client and token
    83  type AuthorizedClient struct {
    84  	Client *http.Client
    85  	Token  *oauth2.Token
    86  }
    87  
    88  func (p Provider) getOAuthConfig(providerType ProviderType) (*oauth2.Config, error) {
    89  	validator, err := p.GetValidator(providerType)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	redirectURL := fmt.Sprintf("http://%s:%d%s", localIP, localPort, callbackPath)
    95  	return &oauth2.Config{
    96  		ClientID:     p.clientID,
    97  		ClientSecret: p.clientSecret,
    98  		Endpoint:     validator.GetEndpoint(),
    99  		RedirectURL:  redirectURL,
   100  		Scopes:       p.scopes,
   101  	}, nil
   102  }
   103  
   104  // AddValidator adds validator
   105  func (p Provider) AddValidator(providerType ProviderType, validator Validator) {
   106  	p.validators[providerType] = validator
   107  }
   108  
   109  // GetValidator returns validator
   110  func (p Provider) GetValidator(providerType ProviderType) (Validator, error) {
   111  	validator, ok := p.validators[providerType]
   112  	if !ok {
   113  		return nil, fmt.Errorf("unknown oauth provider %s", providerType)
   114  	}
   115  
   116  	return validator, nil
   117  }
   118  
   119  // ValidateToken validates token
   120  func (p Provider) ValidateToken(providerType ProviderType, token *oauth2.Token) (*oauth2.Token, error) {
   121  	config, err := p.getOAuthConfig(providerType)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	tokenSource := config.TokenSource(context.Background(), token)
   127  	return tokenSource.Token()
   128  }
   129  
   130  // ValidateAccessToken validates access token
   131  func (p Provider) ValidateAccessToken(providerType ProviderType, accessToken string) error {
   132  	validator, err := p.GetValidator(providerType)
   133  	if err != nil {
   134  		return err
   135  	}
   136  
   137  	return validator.Validate(accessToken)
   138  }
   139  
   140  // AuthenticateUser starts the login process
   141  func (p Provider) AuthenticateUser(providerType ProviderType) (client *AuthorizedClient, err error) {
   142  	oauthStateString := rand.String(randomLength)
   143  	ctx := context.WithValue(context.WithValue(context.Background(), oauth2.HTTPClient, p.client),
   144  		oauthStateStringContextKey, oauthStateString)
   145  
   146  	config, err := p.getOAuthConfig(providerType)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	authURL := config.AuthCodeURL(oauthStateString, oauth2.AccessTypeOffline)
   152  
   153  	clientChan := make(chan *AuthorizedClient)
   154  	shutdownChan := make(chan struct{})
   155  	cancelChan := make(chan struct{})
   156  
   157  	p.startHTTPServer(ctx, clientChan, shutdownChan, providerType)
   158  
   159  	ui.Info("You will be redirected to your browser for authentication or you can open the url below manually")
   160  	ui.Info(authURL)
   161  
   162  	time.Sleep(redirectDelay)
   163  
   164  	if err = open.Run(authURL); err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	// shutdown the server after timeout
   169  	go func() {
   170  		ui.Info(fmt.Sprintf("Authentication will be cancelled in %d seconds", authTimeout))
   171  		time.Sleep(authTimeout * time.Second)
   172  
   173  		cancelChan <- struct{}{}
   174  	}()
   175  
   176  	// wait for an authenticated client or cancel authentication
   177  	select {
   178  	case client = <-clientChan:
   179  	case <-cancelChan:
   180  		err = fmt.Errorf("authentication timed out and was cancelled")
   181  	}
   182  
   183  	shutdownChan <- struct{}{}
   184  	return client, err
   185  }
   186  
   187  // startHTTPServer starts http server
   188  func (p Provider) startHTTPServer(ctx context.Context, clientChan chan *AuthorizedClient,
   189  	shutdownChan chan struct{}, providerType ProviderType) {
   190  	http.HandleFunc(callbackPath, p.CallbackHandler(ctx, clientChan, providerType))
   191  	http.HandleFunc(errorPath, p.ErrorHandler())
   192  	srv := &http.Server{Addr: ":" + strconv.Itoa(p.port)}
   193  
   194  	// handle server shutdown signal
   195  	go func() {
   196  		<-shutdownChan
   197  
   198  		ui.Info("Shutting down server...")
   199  
   200  		ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(shutdownTimeout))
   201  		defer cancel()
   202  
   203  		if err := srv.Shutdown(ctx); err != nil {
   204  			ui.Errf("stopping http server: %v", err)
   205  		}
   206  	}()
   207  
   208  	// handle callback request
   209  	go func() {
   210  		if err := srv.ListenAndServe(); err != http.ErrServerClosed {
   211  			ui.ExitOnError("starting http server", err)
   212  		}
   213  
   214  		ui.Success("Server gracefully stopped")
   215  	}()
   216  }
   217  
   218  // CallbackHandler is oauth callback handler
   219  func (p Provider) CallbackHandler(ctx context.Context, clientChan chan *AuthorizedClient,
   220  	providerType ProviderType) func(w http.ResponseWriter, r *http.Request) {
   221  	return func(w http.ResponseWriter, r *http.Request) {
   222  		requestState, ok := ctx.Value(oauthStateStringContextKey).(string)
   223  		if !ok {
   224  			ui.Errf("unknown oauth state: %v", ctx.Value(oauthStateStringContextKey))
   225  			http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect)
   226  			return
   227  		}
   228  
   229  		responseState := r.FormValue("state")
   230  		if responseState != requestState {
   231  			ui.Errf("invalid oauth state, expected %s, got %s", requestState, responseState)
   232  			http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect)
   233  			return
   234  		}
   235  
   236  		config, err := p.getOAuthConfig(providerType)
   237  		if err != nil {
   238  			ui.Errf("getting oauth config: %v", err)
   239  			http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect)
   240  			return
   241  		}
   242  
   243  		code := r.FormValue("code")
   244  		token, err := config.Exchange(ctx, code)
   245  		if err != nil {
   246  			ui.Errf("exchanging oauth code: %v", err)
   247  			http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect)
   248  			return
   249  		}
   250  
   251  		if _, err = fmt.Fprint(w, successPage); err != nil {
   252  			ui.Errf("showing success page: %v", err)
   253  			http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect)
   254  			return
   255  		}
   256  
   257  		clientChan <- &AuthorizedClient{
   258  			Client: config.Client(ctx, token),
   259  			Token:  token,
   260  		}
   261  	}
   262  }
   263  
   264  // ErrorHandler is oauth error handler
   265  func (p Provider) ErrorHandler() func(w http.ResponseWriter, r *http.Request) {
   266  	return func(w http.ResponseWriter, r *http.Request) {
   267  		if _, err := fmt.Fprint(w, errorPage); err != nil {
   268  			ui.Errf("showing success page: %v", err)
   269  			return
   270  		}
   271  	}
   272  }