github.com/NBISweden/sda-cli@v0.1.2-0.20240506070033-4c8af88918df/login/login.go (about) 1 package login 2 3 import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/hex" 8 "encoding/json" 9 "errors" 10 "flag" 11 "fmt" 12 "io" 13 "net/http" 14 "os" 15 "os/exec" 16 "runtime" 17 "strings" 18 "time" 19 20 "gopkg.in/ini.v1" 21 ) 22 23 // Help text and command line flags. 24 25 // Usage text that will be displayed as command line help text when using the 26 // `help login` command 27 var Usage = ` 28 29 USAGE: %s login <login-target> 30 31 login: 32 logs in to the SDA using the provided login target. 33 ` 34 35 // ArgHelp is the suffix text that will be displayed after the argument list in 36 // the module help 37 var ArgHelp = ` 38 [login-target] 39 The login target is the base URL of the service.` 40 41 // Args is a flagset that needs to be exported so that it can be written to the 42 // main program help 43 var Args = flag.NewFlagSet("login", flag.ExitOnError) 44 45 type S3Config struct { 46 AccessKey string `ini:"access_key"` 47 SecretKey string `ini:"secret_key"` 48 AccessToken string `ini:"access_token"` 49 HostBucket string `ini:"host_bucket"` 50 HostBase string `ini:"host_base"` 51 MultipartChunkSizeMb int64 `ini:"multipart_chunk_size_mb"` 52 GuessMimeType bool `ini:"guess_mime_type"` 53 Encoding string `ini:"encoding"` 54 CheckSslCertificate bool `ini:"check_ssl_certificate"` 55 CheckSslHostname bool `ini:"check_ssl_hostname"` 56 UseHTTPS bool `ini:"use_https"` 57 SocketTimeout int `ini:"socket_timeout"` 58 HumanReadableSizes bool `ini:"human_readable_sizes"` 59 PublicKey string `ini:"public_key"` 60 } 61 62 type OIDCWellKnown struct { 63 TokenEndpoint string `json:"token_endpoint"` 64 DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` 65 } 66 67 type DeviceLoginResponse struct { 68 VerificationURL string `json:"verification_uri_complete"` 69 DeviceCode string `json:"device_code"` 70 ExpiresIn int `json:"expires_in"` 71 } 72 73 type Result struct { 74 AccessToken string `json:"access_token"` 75 IDToken string `json:"id_token"` 76 Scope string `json:"scope"` 77 TokenType string `json:"token_type"` 78 ExpiresIn int `json:"expires_in"` 79 Error string `json:"error"` 80 ErrorDescription string `json:"error_description"` 81 } 82 83 type UserInfo struct { 84 Sub string `json:"sub"` 85 Name string `json:"name"` 86 PreferredUsername string `json:"preferred_username"` 87 GivenName string `json:"given_name"` 88 FamilyName string `json:"family_name"` 89 Email string `json:"email"` 90 EmailVerified bool `json:"email_verified"` 91 Ga4ghPassportV1 []string `json:"ga4gh_passport_v1"` 92 } 93 94 type DeviceLogin struct { 95 BaseURL string 96 ClientID string 97 S3Target string 98 PublicKey string 99 PollingInterval int 100 LoginResult *Result 101 UserInfo *UserInfo 102 wellKnown *OIDCWellKnown 103 deviceLogin *DeviceLoginResponse 104 CodeVerifier string 105 } 106 107 type AuthInfo struct { 108 ClientID string `json:"client_id"` 109 OidcURI string `json:"oidc_uri"` 110 PublicKey string `json:"public_key"` 111 InboxURI string `json:"inbox_uri"` 112 } 113 114 // requests the /info endpoint to fetch the parameters needed for login 115 func GetAuthInfo(baseURL string) (*AuthInfo, error) { 116 url := strings.TrimSuffix(baseURL, "/") + "/info" 117 resp, err := http.Get(url) 118 if err != nil { 119 return nil, err 120 } 121 defer resp.Body.Close() 122 var result AuthInfo 123 body, err := io.ReadAll(resp.Body) 124 if err != nil { 125 return nil, err 126 } 127 err = json.Unmarshal(body, &result) 128 if err != nil { 129 return nil, err 130 } 131 132 return &result, nil 133 } 134 135 // creates a .sda-cli-session file and updates its values 136 func (login *DeviceLogin) UpdateConfigFile() error { 137 138 out, err := os.Create(".sda-cli-session") 139 if err != nil { 140 return err 141 } 142 143 cfg, err := ini.Load(".sda-cli-session") 144 if err != nil { 145 return err 146 } 147 148 s3Config, err := login.GetS3Config() 149 if err != nil { 150 return err 151 } 152 153 err = ini.ReflectFrom(cfg, s3Config) 154 if err != nil { 155 return err 156 } 157 err = cfg.SaveTo(".sda-cli-session") 158 if err != nil { 159 return err 160 } 161 defer out.Close() 162 163 return nil 164 } 165 166 func NewLogin(args []string) error { 167 deviceLogin, err := NewDeviceLogin(args) 168 if err != nil { 169 return fmt.Errorf("failed to contact authentication service: %v", err) 170 } 171 err = deviceLogin.Login() 172 if err != nil { 173 return err 174 } 175 fmt.Printf("Logged in as %v\n", deviceLogin.UserInfo.Name) 176 177 return err 178 } 179 180 // NewDeviceLogin() returns a new `DeviceLogin` with the given `url` and 181 // `clientID` set. 182 func NewDeviceLogin(args []string) (DeviceLogin, error) { 183 184 var loginURL string 185 err := Args.Parse(args[1:]) 186 if err != nil { 187 return DeviceLogin{}, fmt.Errorf("failed parsing arguments: %v", err) 188 } 189 if len(Args.Args()) == 1 { 190 loginURL = Args.Args()[0] 191 } 192 info, err := GetAuthInfo(loginURL) 193 if err != nil { 194 return DeviceLogin{}, fmt.Errorf("failed to get auth Info: %v", err) 195 } 196 197 return DeviceLogin{BaseURL: info.OidcURI, ClientID: info.ClientID, PollingInterval: 2, S3Target: info.InboxURI, PublicKey: info.PublicKey}, nil 198 } 199 200 // open opens the specified URL in the default browser of the user. 201 func open(url string) error { 202 var cmd string 203 var args []string 204 205 switch runtime.GOOS { 206 case "windows": 207 cmd = "cmd" 208 args = []string{"/c", "start"} 209 case "darwin": 210 cmd = "open" 211 default: // "linux", "freebsd", "openbsd", "netbsd" 212 cmd = "xdg-open" 213 } 214 args = append(args, url) 215 216 return exec.Command(cmd, args...).Start() 217 } 218 219 // Login() does a full login by fetching the remote configuration, starting the 220 // login procedure, and then waiting for the user to complete login. 221 func (login *DeviceLogin) Login() error { 222 223 var err error 224 login.wellKnown, err = login.getWellKnown() 225 if err != nil { 226 return fmt.Errorf("failed to fetch .well-known configuration: %v", err) 227 } 228 229 login.deviceLogin, err = login.startDeviceLogin() 230 if err != nil { 231 return fmt.Errorf("failed to start device login: %v", err) 232 } 233 expires := time.Duration(login.deviceLogin.ExpiresIn * int(time.Second)) 234 fmt.Printf("Login started (expires in %v minutes)\n", expires.Minutes()) 235 236 err = open(login.deviceLogin.VerificationURL) 237 if err != nil { 238 return fmt.Errorf("failed to open login URL: %v", err) 239 } 240 241 loginResult, err := login.waitForLogin() 242 if err != nil { 243 return err 244 } 245 login.LoginResult = loginResult 246 247 login.UserInfo, err = login.getUserInfo() 248 if err != nil { 249 return err 250 } 251 252 err = login.UpdateConfigFile() 253 if err != nil { 254 return err 255 } 256 257 return err 258 } 259 260 // S3Config() returns a new `S3Config` with the values from the `DeviceLogin` 261 func (login *DeviceLogin) GetS3Config() (*S3Config, error) { 262 if login.LoginResult.AccessToken == "" { 263 264 return nil, errors.New("no login token available for config") 265 } 266 267 return &S3Config{ 268 AccessKey: login.UserInfo.Sub, 269 SecretKey: login.UserInfo.Sub, 270 AccessToken: login.LoginResult.AccessToken, 271 HostBucket: login.S3Target, 272 HostBase: login.S3Target, 273 PublicKey: login.PublicKey, 274 MultipartChunkSizeMb: 512, 275 GuessMimeType: false, 276 Encoding: "UTF-8", 277 CheckSslCertificate: false, 278 CheckSslHostname: false, 279 UseHTTPS: true, 280 SocketTimeout: 30, 281 HumanReadableSizes: true, 282 }, nil 283 } 284 285 func (login *DeviceLogin) getUserInfo() (*UserInfo, error) { 286 287 if login.LoginResult.AccessToken == "" { 288 return nil, errors.New("login token required to fetch userinfo") 289 } 290 291 req, err := http.NewRequest("GET", login.BaseURL+"/userinfo", nil) 292 if err != nil { 293 return nil, err 294 } 295 296 req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", login.LoginResult.AccessToken)) 297 298 resp, err := http.DefaultClient.Do(req) 299 if err != nil { 300 return nil, err 301 } 302 defer resp.Body.Close() 303 body, err := io.ReadAll(resp.Body) 304 if err != nil { 305 return nil, err 306 } 307 308 if resp.StatusCode < 200 || resp.StatusCode >= 400 { 309 err = fmt.Errorf("status code: %v", resp.StatusCode) 310 311 return nil, fmt.Errorf("request failed: %v", err) 312 } 313 314 var userinfo *UserInfo 315 err = json.Unmarshal(body, &userinfo) 316 317 return userinfo, err 318 } 319 320 // getWellKnown() makes a GET request to the `.well-known/openid-configuration` 321 // endpoint of BaseURL and returns the result as `OIDCWellKnown`. 322 func (login *DeviceLogin) getWellKnown() (*OIDCWellKnown, error) { 323 324 wellKnownURL := fmt.Sprintf("%v/.well-known/openid-configuration", login.BaseURL) 325 resp, err := http.Get(wellKnownURL) 326 if err != nil { 327 return nil, err 328 } 329 330 defer resp.Body.Close() 331 body, err := io.ReadAll(resp.Body) 332 if err != nil { 333 return nil, err 334 } 335 336 var wellKnownConfig *OIDCWellKnown 337 err = json.Unmarshal(body, &wellKnownConfig) 338 339 return wellKnownConfig, err 340 } 341 342 // startDeviceLogin() starts a device login towards the URLs in login.wellKnown 343 // and sets the login.deviceLogin 344 func (login *DeviceLogin) startDeviceLogin() (*DeviceLoginResponse, error) { 345 346 var ( 347 err error 348 codeChallenge string 349 ) 350 login.CodeVerifier, codeChallenge, err = generatePKCE(128) 351 if err != nil { 352 return nil, fmt.Errorf("could not create pkce: %v", err) 353 } 354 355 loginBody := fmt.Sprintf("response_type=device_code&client_id=%v"+ 356 "&scope=openid ga4gh_passport_v1 profile email&code_challenge_method=S256&code_challenge=%v", login.ClientID, codeChallenge) 357 358 req, err := http.NewRequest("POST", 359 login.wellKnown.DeviceAuthorizationEndpoint, strings.NewReader(loginBody)) 360 if err != nil { 361 return nil, err 362 } 363 364 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 365 366 resp, err := http.DefaultClient.Do(req) 367 if err != nil { 368 return nil, err 369 } 370 371 defer resp.Body.Close() 372 body, err := io.ReadAll(resp.Body) 373 if err != nil { 374 return nil, err 375 } 376 377 if resp.StatusCode < 200 || resp.StatusCode >= 400 { 378 err = fmt.Errorf("status code: %v", resp.StatusCode) 379 380 return nil, fmt.Errorf("request failed: %v", err) 381 } 382 383 var loginResponse *DeviceLoginResponse 384 err = json.Unmarshal(body, &loginResponse) 385 386 return loginResponse, err 387 } 388 389 // waitForLogin() waits for the remote OIDC server to verify the completed login 390 // by polling 391 func (login *DeviceLogin) waitForLogin() (*Result, error) { 392 393 body := fmt.Sprintf("grant_type=urn:ietf:params:oauth:grant-type:device_code"+ 394 "&client_id=%v&device_code=%v&code_verifier=%v", login.ClientID, login.deviceLogin.DeviceCode, login.CodeVerifier) 395 396 expirationTime := time.Now().Unix() + int64(login.deviceLogin.ExpiresIn) 397 398 for { 399 time.Sleep(time.Duration(login.PollingInterval) * time.Second) 400 401 req, err := http.NewRequest("POST", login.wellKnown.TokenEndpoint, 402 strings.NewReader(body)) 403 if err != nil { 404 return nil, err 405 } 406 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 407 408 resp, err := http.DefaultClient.Do(req) 409 if err != nil { 410 return nil, fmt.Errorf("failure to fetch login token: %v", err) 411 } 412 413 if resp.StatusCode == 200 { 414 defer resp.Body.Close() 415 respBody, err := io.ReadAll(resp.Body) 416 if err != nil { 417 return nil, err 418 } 419 420 var loginResult *Result 421 err = json.Unmarshal(respBody, &loginResult) 422 if err != nil { 423 return nil, err 424 } 425 426 return loginResult, nil 427 } 428 429 if expirationTime <= time.Now().Unix() { 430 431 break 432 } 433 } 434 435 return nil, errors.New("login timed out") 436 } 437 438 func generatePKCE(count int) (string, string, error) { 439 440 // generate code verifier 441 buf := make([]byte, count) 442 _, err := io.ReadFull(rand.Reader, buf) 443 if err != nil { 444 return "", "", err 445 } 446 verifier := hex.EncodeToString(buf) 447 448 // generate code challenge 449 sha2 := sha256.New() 450 _, err = io.WriteString(sha2, verifier) 451 if err != nil { 452 return "", "", err 453 } 454 challenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) 455 456 return verifier, challenge, nil 457 }