github.com/kubeshop/testkube@v1.17.23/pkg/oauth/oauth.go (about) 1 package oauth 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "net/http" 8 "strconv" 9 "time" 10 11 "github.com/skratchdot/open-golang/open" 12 "golang.org/x/oauth2" 13 14 "github.com/kubeshop/testkube/pkg/rand" 15 "github.com/kubeshop/testkube/pkg/ui" 16 ) 17 18 // key is context key 19 type key int 20 21 const ( 22 // localIP is ip uo open local website 23 localIP = "127.0.0.1" 24 // localPort is port to open local website 25 localPort = 13254 26 // authTimeout is time to wait for authentication completed 27 authTimeout = 60 28 // oauthStateStringContextKey is a context key for oauth strategy 29 oauthStateStringContextKey key = 987 30 // callbackPath is a path to callback handler 31 callbackPath = "/oauth/callback" 32 // errorPath is a path to error handler 33 errorPath = "/oauth/error" 34 // redirectDelay is redirect delay 35 redirectDelay = 10 * time.Second 36 // shutdownTimeout is shutdown timeout 37 shutdownTimeout = 5 * time.Second 38 // randomLength is a length of a random string 39 randomLength = 8 40 // successPage is a page to show for success authentication 41 successPage = `<html><body><h2>Success!</h2> 42 <p>You are authenticated, you can now return to the program.</p></body></html>` 43 // errorPage is a page to show for failed authentication 44 errorPage = `<html><body><h2>Error!</h2> 45 <p>Authentication was failed, please check the program logs.</p></body</html>` 46 // AuthorizationPrefix is authorization prefix 47 AuthorizationPrefix = "Bearer" 48 ) 49 50 // NewProvider returns new provider 51 func NewProvider(clientID, clientSecret string, scopes []string) Provider { 52 // add transport for self-signed certificate to context 53 tr := &http.Transport{ 54 TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 55 Proxy: http.ProxyFromEnvironment, 56 } 57 58 client := &http.Client{Transport: tr} 59 provider := Provider{ 60 clientID: clientID, 61 clientSecret: clientSecret, 62 scopes: scopes, 63 client: client, 64 port: localPort, 65 validators: map[ProviderType]Validator{}, 66 } 67 68 provider.AddValidator(GithubProviderType, NewGithubValidator(client, clientID, clientSecret, scopes)) 69 return provider 70 } 71 72 // Provider contains oauth provider config 73 type Provider struct { 74 clientID string 75 clientSecret string 76 scopes []string 77 client *http.Client 78 port int 79 validators map[ProviderType]Validator 80 } 81 82 // AuthorizedClient is authorized client and token 83 type AuthorizedClient struct { 84 Client *http.Client 85 Token *oauth2.Token 86 } 87 88 func (p Provider) getOAuthConfig(providerType ProviderType) (*oauth2.Config, error) { 89 validator, err := p.GetValidator(providerType) 90 if err != nil { 91 return nil, err 92 } 93 94 redirectURL := fmt.Sprintf("http://%s:%d%s", localIP, localPort, callbackPath) 95 return &oauth2.Config{ 96 ClientID: p.clientID, 97 ClientSecret: p.clientSecret, 98 Endpoint: validator.GetEndpoint(), 99 RedirectURL: redirectURL, 100 Scopes: p.scopes, 101 }, nil 102 } 103 104 // AddValidator adds validator 105 func (p Provider) AddValidator(providerType ProviderType, validator Validator) { 106 p.validators[providerType] = validator 107 } 108 109 // GetValidator returns validator 110 func (p Provider) GetValidator(providerType ProviderType) (Validator, error) { 111 validator, ok := p.validators[providerType] 112 if !ok { 113 return nil, fmt.Errorf("unknown oauth provider %s", providerType) 114 } 115 116 return validator, nil 117 } 118 119 // ValidateToken validates token 120 func (p Provider) ValidateToken(providerType ProviderType, token *oauth2.Token) (*oauth2.Token, error) { 121 config, err := p.getOAuthConfig(providerType) 122 if err != nil { 123 return nil, err 124 } 125 126 tokenSource := config.TokenSource(context.Background(), token) 127 return tokenSource.Token() 128 } 129 130 // ValidateAccessToken validates access token 131 func (p Provider) ValidateAccessToken(providerType ProviderType, accessToken string) error { 132 validator, err := p.GetValidator(providerType) 133 if err != nil { 134 return err 135 } 136 137 return validator.Validate(accessToken) 138 } 139 140 // AuthenticateUser starts the login process 141 func (p Provider) AuthenticateUser(providerType ProviderType) (client *AuthorizedClient, err error) { 142 oauthStateString := rand.String(randomLength) 143 ctx := context.WithValue(context.WithValue(context.Background(), oauth2.HTTPClient, p.client), 144 oauthStateStringContextKey, oauthStateString) 145 146 config, err := p.getOAuthConfig(providerType) 147 if err != nil { 148 return nil, err 149 } 150 151 authURL := config.AuthCodeURL(oauthStateString, oauth2.AccessTypeOffline) 152 153 clientChan := make(chan *AuthorizedClient) 154 shutdownChan := make(chan struct{}) 155 cancelChan := make(chan struct{}) 156 157 p.startHTTPServer(ctx, clientChan, shutdownChan, providerType) 158 159 ui.Info("You will be redirected to your browser for authentication or you can open the url below manually") 160 ui.Info(authURL) 161 162 time.Sleep(redirectDelay) 163 164 if err = open.Run(authURL); err != nil { 165 return nil, err 166 } 167 168 // shutdown the server after timeout 169 go func() { 170 ui.Info(fmt.Sprintf("Authentication will be cancelled in %d seconds", authTimeout)) 171 time.Sleep(authTimeout * time.Second) 172 173 cancelChan <- struct{}{} 174 }() 175 176 // wait for an authenticated client or cancel authentication 177 select { 178 case client = <-clientChan: 179 case <-cancelChan: 180 err = fmt.Errorf("authentication timed out and was cancelled") 181 } 182 183 shutdownChan <- struct{}{} 184 return client, err 185 } 186 187 // startHTTPServer starts http server 188 func (p Provider) startHTTPServer(ctx context.Context, clientChan chan *AuthorizedClient, 189 shutdownChan chan struct{}, providerType ProviderType) { 190 http.HandleFunc(callbackPath, p.CallbackHandler(ctx, clientChan, providerType)) 191 http.HandleFunc(errorPath, p.ErrorHandler()) 192 srv := &http.Server{Addr: ":" + strconv.Itoa(p.port)} 193 194 // handle server shutdown signal 195 go func() { 196 <-shutdownChan 197 198 ui.Info("Shutting down server...") 199 200 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(shutdownTimeout)) 201 defer cancel() 202 203 if err := srv.Shutdown(ctx); err != nil { 204 ui.Errf("stopping http server: %v", err) 205 } 206 }() 207 208 // handle callback request 209 go func() { 210 if err := srv.ListenAndServe(); err != http.ErrServerClosed { 211 ui.ExitOnError("starting http server", err) 212 } 213 214 ui.Success("Server gracefully stopped") 215 }() 216 } 217 218 // CallbackHandler is oauth callback handler 219 func (p Provider) CallbackHandler(ctx context.Context, clientChan chan *AuthorizedClient, 220 providerType ProviderType) func(w http.ResponseWriter, r *http.Request) { 221 return func(w http.ResponseWriter, r *http.Request) { 222 requestState, ok := ctx.Value(oauthStateStringContextKey).(string) 223 if !ok { 224 ui.Errf("unknown oauth state: %v", ctx.Value(oauthStateStringContextKey)) 225 http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect) 226 return 227 } 228 229 responseState := r.FormValue("state") 230 if responseState != requestState { 231 ui.Errf("invalid oauth state, expected %s, got %s", requestState, responseState) 232 http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect) 233 return 234 } 235 236 config, err := p.getOAuthConfig(providerType) 237 if err != nil { 238 ui.Errf("getting oauth config: %v", err) 239 http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect) 240 return 241 } 242 243 code := r.FormValue("code") 244 token, err := config.Exchange(ctx, code) 245 if err != nil { 246 ui.Errf("exchanging oauth code: %v", err) 247 http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect) 248 return 249 } 250 251 if _, err = fmt.Fprint(w, successPage); err != nil { 252 ui.Errf("showing success page: %v", err) 253 http.Redirect(w, r, errorPath, http.StatusTemporaryRedirect) 254 return 255 } 256 257 clientChan <- &AuthorizedClient{ 258 Client: config.Client(ctx, token), 259 Token: token, 260 } 261 } 262 } 263 264 // ErrorHandler is oauth error handler 265 func (p Provider) ErrorHandler() func(w http.ResponseWriter, r *http.Request) { 266 return func(w http.ResponseWriter, r *http.Request) { 267 if _, err := fmt.Fprint(w, errorPage); err != nil { 268 ui.Errf("showing success page: %v", err) 269 return 270 } 271 } 272 }