github.com/NBISweden/sda-cli@v0.1.2-0.20240506070033-4c8af88918df/login/login.go (about)

     1  package login
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/sha256"
     6  	"encoding/base64"
     7  	"encoding/hex"
     8  	"encoding/json"
     9  	"errors"
    10  	"flag"
    11  	"fmt"
    12  	"io"
    13  	"net/http"
    14  	"os"
    15  	"os/exec"
    16  	"runtime"
    17  	"strings"
    18  	"time"
    19  
    20  	"gopkg.in/ini.v1"
    21  )
    22  
    23  // Help text and command line flags.
    24  
    25  // Usage text that will be displayed as command line help text when using the
    26  // `help login` command
    27  var Usage = `
    28  
    29  USAGE: %s login <login-target>
    30  
    31  login:
    32      logs in to the SDA using the provided login target.
    33  `
    34  
    35  // ArgHelp is the suffix text that will be displayed after the argument list in
    36  // the module help
    37  var ArgHelp = `
    38      [login-target]
    39          The login target is the base URL of the service.`
    40  
    41  // Args is a flagset that needs to be exported so that it can be written to the
    42  // main program help
    43  var Args = flag.NewFlagSet("login", flag.ExitOnError)
    44  
    45  type S3Config struct {
    46  	AccessKey            string `ini:"access_key"`
    47  	SecretKey            string `ini:"secret_key"`
    48  	AccessToken          string `ini:"access_token"`
    49  	HostBucket           string `ini:"host_bucket"`
    50  	HostBase             string `ini:"host_base"`
    51  	MultipartChunkSizeMb int64  `ini:"multipart_chunk_size_mb"`
    52  	GuessMimeType        bool   `ini:"guess_mime_type"`
    53  	Encoding             string `ini:"encoding"`
    54  	CheckSslCertificate  bool   `ini:"check_ssl_certificate"`
    55  	CheckSslHostname     bool   `ini:"check_ssl_hostname"`
    56  	UseHTTPS             bool   `ini:"use_https"`
    57  	SocketTimeout        int    `ini:"socket_timeout"`
    58  	HumanReadableSizes   bool   `ini:"human_readable_sizes"`
    59  	PublicKey            string `ini:"public_key"`
    60  }
    61  
    62  type OIDCWellKnown struct {
    63  	TokenEndpoint               string `json:"token_endpoint"`
    64  	DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"`
    65  }
    66  
    67  type DeviceLoginResponse struct {
    68  	VerificationURL string `json:"verification_uri_complete"`
    69  	DeviceCode      string `json:"device_code"`
    70  	ExpiresIn       int    `json:"expires_in"`
    71  }
    72  
    73  type Result struct {
    74  	AccessToken      string `json:"access_token"`
    75  	IDToken          string `json:"id_token"`
    76  	Scope            string `json:"scope"`
    77  	TokenType        string `json:"token_type"`
    78  	ExpiresIn        int    `json:"expires_in"`
    79  	Error            string `json:"error"`
    80  	ErrorDescription string `json:"error_description"`
    81  }
    82  
    83  type UserInfo struct {
    84  	Sub               string   `json:"sub"`
    85  	Name              string   `json:"name"`
    86  	PreferredUsername string   `json:"preferred_username"`
    87  	GivenName         string   `json:"given_name"`
    88  	FamilyName        string   `json:"family_name"`
    89  	Email             string   `json:"email"`
    90  	EmailVerified     bool     `json:"email_verified"`
    91  	Ga4ghPassportV1   []string `json:"ga4gh_passport_v1"`
    92  }
    93  
    94  type DeviceLogin struct {
    95  	BaseURL         string
    96  	ClientID        string
    97  	S3Target        string
    98  	PublicKey       string
    99  	PollingInterval int
   100  	LoginResult     *Result
   101  	UserInfo        *UserInfo
   102  	wellKnown       *OIDCWellKnown
   103  	deviceLogin     *DeviceLoginResponse
   104  	CodeVerifier    string
   105  }
   106  
   107  type AuthInfo struct {
   108  	ClientID  string `json:"client_id"`
   109  	OidcURI   string `json:"oidc_uri"`
   110  	PublicKey string `json:"public_key"`
   111  	InboxURI  string `json:"inbox_uri"`
   112  }
   113  
   114  // requests the /info endpoint to fetch the parameters needed for login
   115  func GetAuthInfo(baseURL string) (*AuthInfo, error) {
   116  	url := strings.TrimSuffix(baseURL, "/") + "/info"
   117  	resp, err := http.Get(url)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	defer resp.Body.Close()
   122  	var result AuthInfo
   123  	body, err := io.ReadAll(resp.Body)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	err = json.Unmarshal(body, &result)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	return &result, nil
   133  }
   134  
   135  // creates a .sda-cli-session file and updates its values
   136  func (login *DeviceLogin) UpdateConfigFile() error {
   137  
   138  	out, err := os.Create(".sda-cli-session")
   139  	if err != nil {
   140  		return err
   141  	}
   142  
   143  	cfg, err := ini.Load(".sda-cli-session")
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	s3Config, err := login.GetS3Config()
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	err = ini.ReflectFrom(cfg, s3Config)
   154  	if err != nil {
   155  		return err
   156  	}
   157  	err = cfg.SaveTo(".sda-cli-session")
   158  	if err != nil {
   159  		return err
   160  	}
   161  	defer out.Close()
   162  
   163  	return nil
   164  }
   165  
   166  func NewLogin(args []string) error {
   167  	deviceLogin, err := NewDeviceLogin(args)
   168  	if err != nil {
   169  		return fmt.Errorf("failed to contact authentication service: %v", err)
   170  	}
   171  	err = deviceLogin.Login()
   172  	if err != nil {
   173  		return err
   174  	}
   175  	fmt.Printf("Logged in as %v\n", deviceLogin.UserInfo.Name)
   176  
   177  	return err
   178  }
   179  
   180  // NewDeviceLogin() returns a new `DeviceLogin` with the given `url` and
   181  // `clientID` set.
   182  func NewDeviceLogin(args []string) (DeviceLogin, error) {
   183  
   184  	var loginURL string
   185  	err := Args.Parse(args[1:])
   186  	if err != nil {
   187  		return DeviceLogin{}, fmt.Errorf("failed parsing arguments: %v", err)
   188  	}
   189  	if len(Args.Args()) == 1 {
   190  		loginURL = Args.Args()[0]
   191  	}
   192  	info, err := GetAuthInfo(loginURL)
   193  	if err != nil {
   194  		return DeviceLogin{}, fmt.Errorf("failed to get auth Info: %v", err)
   195  	}
   196  
   197  	return DeviceLogin{BaseURL: info.OidcURI, ClientID: info.ClientID, PollingInterval: 2, S3Target: info.InboxURI, PublicKey: info.PublicKey}, nil
   198  }
   199  
   200  // open opens the specified URL in the default browser of the user.
   201  func open(url string) error {
   202  	var cmd string
   203  	var args []string
   204  
   205  	switch runtime.GOOS {
   206  	case "windows":
   207  		cmd = "cmd"
   208  		args = []string{"/c", "start"}
   209  	case "darwin":
   210  		cmd = "open"
   211  	default: // "linux", "freebsd", "openbsd", "netbsd"
   212  		cmd = "xdg-open"
   213  	}
   214  	args = append(args, url)
   215  
   216  	return exec.Command(cmd, args...).Start()
   217  }
   218  
   219  // Login() does a full login by fetching the remote configuration, starting the
   220  // login procedure, and then waiting for the user to complete login.
   221  func (login *DeviceLogin) Login() error {
   222  
   223  	var err error
   224  	login.wellKnown, err = login.getWellKnown()
   225  	if err != nil {
   226  		return fmt.Errorf("failed to fetch .well-known configuration: %v", err)
   227  	}
   228  
   229  	login.deviceLogin, err = login.startDeviceLogin()
   230  	if err != nil {
   231  		return fmt.Errorf("failed to start device login: %v", err)
   232  	}
   233  	expires := time.Duration(login.deviceLogin.ExpiresIn * int(time.Second))
   234  	fmt.Printf("Login started (expires in %v minutes)\n", expires.Minutes())
   235  
   236  	err = open(login.deviceLogin.VerificationURL)
   237  	if err != nil {
   238  		return fmt.Errorf("failed to open login URL: %v", err)
   239  	}
   240  
   241  	loginResult, err := login.waitForLogin()
   242  	if err != nil {
   243  		return err
   244  	}
   245  	login.LoginResult = loginResult
   246  
   247  	login.UserInfo, err = login.getUserInfo()
   248  	if err != nil {
   249  		return err
   250  	}
   251  
   252  	err = login.UpdateConfigFile()
   253  	if err != nil {
   254  		return err
   255  	}
   256  
   257  	return err
   258  }
   259  
   260  // S3Config() returns a new `S3Config` with the values from the `DeviceLogin`
   261  func (login *DeviceLogin) GetS3Config() (*S3Config, error) {
   262  	if login.LoginResult.AccessToken == "" {
   263  
   264  		return nil, errors.New("no login token available for config")
   265  	}
   266  
   267  	return &S3Config{
   268  		AccessKey:            login.UserInfo.Sub,
   269  		SecretKey:            login.UserInfo.Sub,
   270  		AccessToken:          login.LoginResult.AccessToken,
   271  		HostBucket:           login.S3Target,
   272  		HostBase:             login.S3Target,
   273  		PublicKey:            login.PublicKey,
   274  		MultipartChunkSizeMb: 512,
   275  		GuessMimeType:        false,
   276  		Encoding:             "UTF-8",
   277  		CheckSslCertificate:  false,
   278  		CheckSslHostname:     false,
   279  		UseHTTPS:             true,
   280  		SocketTimeout:        30,
   281  		HumanReadableSizes:   true,
   282  	}, nil
   283  }
   284  
   285  func (login *DeviceLogin) getUserInfo() (*UserInfo, error) {
   286  
   287  	if login.LoginResult.AccessToken == "" {
   288  		return nil, errors.New("login token required to fetch userinfo")
   289  	}
   290  
   291  	req, err := http.NewRequest("GET", login.BaseURL+"/userinfo", nil)
   292  	if err != nil {
   293  		return nil, err
   294  	}
   295  
   296  	req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", login.LoginResult.AccessToken))
   297  
   298  	resp, err := http.DefaultClient.Do(req)
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  	defer resp.Body.Close()
   303  	body, err := io.ReadAll(resp.Body)
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  
   308  	if resp.StatusCode < 200 || resp.StatusCode >= 400 {
   309  		err = fmt.Errorf("status code: %v", resp.StatusCode)
   310  
   311  		return nil, fmt.Errorf("request failed: %v", err)
   312  	}
   313  
   314  	var userinfo *UserInfo
   315  	err = json.Unmarshal(body, &userinfo)
   316  
   317  	return userinfo, err
   318  }
   319  
   320  // getWellKnown() makes a GET request to the `.well-known/openid-configuration`
   321  // endpoint of BaseURL and returns the result as `OIDCWellKnown`.
   322  func (login *DeviceLogin) getWellKnown() (*OIDCWellKnown, error) {
   323  
   324  	wellKnownURL := fmt.Sprintf("%v/.well-known/openid-configuration", login.BaseURL)
   325  	resp, err := http.Get(wellKnownURL)
   326  	if err != nil {
   327  		return nil, err
   328  	}
   329  
   330  	defer resp.Body.Close()
   331  	body, err := io.ReadAll(resp.Body)
   332  	if err != nil {
   333  		return nil, err
   334  	}
   335  
   336  	var wellKnownConfig *OIDCWellKnown
   337  	err = json.Unmarshal(body, &wellKnownConfig)
   338  
   339  	return wellKnownConfig, err
   340  }
   341  
   342  // startDeviceLogin() starts a device login towards the URLs in login.wellKnown
   343  // and sets the login.deviceLogin
   344  func (login *DeviceLogin) startDeviceLogin() (*DeviceLoginResponse, error) {
   345  
   346  	var (
   347  		err           error
   348  		codeChallenge string
   349  	)
   350  	login.CodeVerifier, codeChallenge, err = generatePKCE(128)
   351  	if err != nil {
   352  		return nil, fmt.Errorf("could not create pkce: %v", err)
   353  	}
   354  
   355  	loginBody := fmt.Sprintf("response_type=device_code&client_id=%v"+
   356  		"&scope=openid ga4gh_passport_v1 profile email&code_challenge_method=S256&code_challenge=%v", login.ClientID, codeChallenge)
   357  
   358  	req, err := http.NewRequest("POST",
   359  		login.wellKnown.DeviceAuthorizationEndpoint, strings.NewReader(loginBody))
   360  	if err != nil {
   361  		return nil, err
   362  	}
   363  
   364  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   365  
   366  	resp, err := http.DefaultClient.Do(req)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  
   371  	defer resp.Body.Close()
   372  	body, err := io.ReadAll(resp.Body)
   373  	if err != nil {
   374  		return nil, err
   375  	}
   376  
   377  	if resp.StatusCode < 200 || resp.StatusCode >= 400 {
   378  		err = fmt.Errorf("status code: %v", resp.StatusCode)
   379  
   380  		return nil, fmt.Errorf("request failed: %v", err)
   381  	}
   382  
   383  	var loginResponse *DeviceLoginResponse
   384  	err = json.Unmarshal(body, &loginResponse)
   385  
   386  	return loginResponse, err
   387  }
   388  
   389  // waitForLogin() waits for the remote OIDC server to verify the completed login
   390  // by polling
   391  func (login *DeviceLogin) waitForLogin() (*Result, error) {
   392  
   393  	body := fmt.Sprintf("grant_type=urn:ietf:params:oauth:grant-type:device_code"+
   394  		"&client_id=%v&device_code=%v&code_verifier=%v", login.ClientID, login.deviceLogin.DeviceCode, login.CodeVerifier)
   395  
   396  	expirationTime := time.Now().Unix() + int64(login.deviceLogin.ExpiresIn)
   397  
   398  	for {
   399  		time.Sleep(time.Duration(login.PollingInterval) * time.Second)
   400  
   401  		req, err := http.NewRequest("POST", login.wellKnown.TokenEndpoint,
   402  			strings.NewReader(body))
   403  		if err != nil {
   404  			return nil, err
   405  		}
   406  		req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   407  
   408  		resp, err := http.DefaultClient.Do(req)
   409  		if err != nil {
   410  			return nil, fmt.Errorf("failure to fetch login token: %v", err)
   411  		}
   412  
   413  		if resp.StatusCode == 200 {
   414  			defer resp.Body.Close()
   415  			respBody, err := io.ReadAll(resp.Body)
   416  			if err != nil {
   417  				return nil, err
   418  			}
   419  
   420  			var loginResult *Result
   421  			err = json.Unmarshal(respBody, &loginResult)
   422  			if err != nil {
   423  				return nil, err
   424  			}
   425  
   426  			return loginResult, nil
   427  		}
   428  
   429  		if expirationTime <= time.Now().Unix() {
   430  
   431  			break
   432  		}
   433  	}
   434  
   435  	return nil, errors.New("login timed out")
   436  }
   437  
   438  func generatePKCE(count int) (string, string, error) {
   439  
   440  	// generate code verifier
   441  	buf := make([]byte, count)
   442  	_, err := io.ReadFull(rand.Reader, buf)
   443  	if err != nil {
   444  		return "", "", err
   445  	}
   446  	verifier := hex.EncodeToString(buf)
   447  
   448  	// generate code challenge
   449  	sha2 := sha256.New()
   450  	_, err = io.WriteString(sha2, verifier)
   451  	if err != nil {
   452  		return "", "", err
   453  	}
   454  	challenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))
   455  
   456  	return verifier, challenge, nil
   457  }