github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/api/session_token_login_provider_test.go (about)

     1  // Copyright 2024 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package api_test
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  
    10  	"github.com/juju/errors"
    11  	"github.com/juju/names/v5"
    12  	jc "github.com/juju/testing/checkers"
    13  	gc "gopkg.in/check.v1"
    14  
    15  	"github.com/juju/juju/api"
    16  	"github.com/juju/juju/api/base"
    17  	jujutesting "github.com/juju/juju/juju/testing"
    18  	"github.com/juju/juju/rpc/params"
    19  )
    20  
    21  type sessionTokenLoginProviderProviderSuite struct {
    22  	jujutesting.JujuConnSuite
    23  }
    24  
    25  var _ = gc.Suite(&sessionTokenLoginProviderProviderSuite{})
    26  
    27  func (s *sessionTokenLoginProviderProviderSuite) Test(c *gc.C) {
    28  	info := s.APIInfo(c)
    29  
    30  	sessionToken := "test-session-token"
    31  	userCode := "1234567"
    32  	verificationURI := "http://localhost:8080/test-verification"
    33  
    34  	var loginDetails string
    35  	var obtainedSessionToken string
    36  
    37  	s.PatchValue(api.LoginDeviceAPICall, func(_ base.APICaller, request interface{}, response interface{}) error {
    38  		lr := struct {
    39  			UserCode        string `json:"user-code"`
    40  			VerificationURI string `json:"verification-uri"`
    41  		}{
    42  			UserCode:        userCode,
    43  			VerificationURI: verificationURI,
    44  		}
    45  
    46  		data, err := json.Marshal(lr)
    47  		if err != nil {
    48  			return errors.Trace(err)
    49  		}
    50  
    51  		return json.Unmarshal(data, response)
    52  	})
    53  
    54  	s.PatchValue(api.GetDeviceSessionTokenAPICall, func(_ base.APICaller, request interface{}, response interface{}) error {
    55  		lr := struct {
    56  			SessionToken string `json:"session-token"`
    57  		}{
    58  			SessionToken: sessionToken,
    59  		}
    60  
    61  		data, err := json.Marshal(lr)
    62  		if err != nil {
    63  			return errors.Trace(err)
    64  		}
    65  
    66  		return json.Unmarshal(data, response)
    67  	})
    68  
    69  	s.PatchValue(api.LoginWithSessionTokenAPICall, func(_ base.APICaller, request interface{}, response interface{}) error {
    70  		data, err := json.Marshal(request)
    71  		if err != nil {
    72  			return errors.Trace(err)
    73  		}
    74  
    75  		var lr struct {
    76  			SessionToken string `json:"session-token"`
    77  		}
    78  
    79  		err = json.Unmarshal(data, &lr)
    80  		if err != nil {
    81  			return errors.Trace(err)
    82  		}
    83  
    84  		if lr.SessionToken != sessionToken {
    85  			return &params.Error{
    86  				Message: "unauthorized",
    87  				Code:    params.CodeUnauthorized,
    88  			}
    89  		}
    90  
    91  		loginResult, ok := response.(*params.LoginResult)
    92  		if !ok {
    93  			return errors.Errorf("expected %T, received %T for response type", loginResult, response)
    94  		}
    95  		loginResult.ControllerTag = names.NewControllerTag(info.ControllerUUID).String()
    96  		loginResult.ServerVersion = "3.4.0"
    97  		loginResult.UserInfo = &params.AuthUserInfo{
    98  			DisplayName:      "alice@external",
    99  			Identity:         names.NewUserTag("alice@external").String(),
   100  			ControllerAccess: "superuser",
   101  		}
   102  		return nil
   103  	})
   104  
   105  	apiState, err := api.Open(&api.Info{
   106  		Addrs:          info.Addrs,
   107  		ControllerUUID: info.ControllerUUID,
   108  		CACert:         info.CACert,
   109  	}, api.DialOpts{
   110  		LoginProvider: api.NewSessionTokenLoginProvider(
   111  			"expired-token",
   112  			func(s string, a ...any) error {
   113  				loginDetails = fmt.Sprintf(s, a...)
   114  				return nil
   115  			},
   116  			func(sessionToken string) error {
   117  				obtainedSessionToken = sessionToken
   118  				return nil
   119  			},
   120  		),
   121  	})
   122  	c.Assert(err, jc.ErrorIsNil)
   123  
   124  	c.Assert(loginDetails, gc.Equals, "Please visit http://localhost:8080/test-verification and enter code 1234567 to log in.")
   125  	c.Assert(obtainedSessionToken, gc.Equals, sessionToken)
   126  	defer func() { _ = apiState.Close() }()
   127  }