github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/api/session_token_login_provider_test.go (about) 1 // Copyright 2024 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package api_test 5 6 import ( 7 "encoding/json" 8 "fmt" 9 10 "github.com/juju/errors" 11 "github.com/juju/names/v5" 12 jc "github.com/juju/testing/checkers" 13 gc "gopkg.in/check.v1" 14 15 "github.com/juju/juju/api" 16 "github.com/juju/juju/api/base" 17 jujutesting "github.com/juju/juju/juju/testing" 18 "github.com/juju/juju/rpc/params" 19 ) 20 21 type sessionTokenLoginProviderProviderSuite struct { 22 jujutesting.JujuConnSuite 23 } 24 25 var _ = gc.Suite(&sessionTokenLoginProviderProviderSuite{}) 26 27 func (s *sessionTokenLoginProviderProviderSuite) Test(c *gc.C) { 28 info := s.APIInfo(c) 29 30 sessionToken := "test-session-token" 31 userCode := "1234567" 32 verificationURI := "http://localhost:8080/test-verification" 33 34 var loginDetails string 35 var obtainedSessionToken string 36 37 s.PatchValue(api.LoginDeviceAPICall, func(_ base.APICaller, request interface{}, response interface{}) error { 38 lr := struct { 39 UserCode string `json:"user-code"` 40 VerificationURI string `json:"verification-uri"` 41 }{ 42 UserCode: userCode, 43 VerificationURI: verificationURI, 44 } 45 46 data, err := json.Marshal(lr) 47 if err != nil { 48 return errors.Trace(err) 49 } 50 51 return json.Unmarshal(data, response) 52 }) 53 54 s.PatchValue(api.GetDeviceSessionTokenAPICall, func(_ base.APICaller, request interface{}, response interface{}) error { 55 lr := struct { 56 SessionToken string `json:"session-token"` 57 }{ 58 SessionToken: sessionToken, 59 } 60 61 data, err := json.Marshal(lr) 62 if err != nil { 63 return errors.Trace(err) 64 } 65 66 return json.Unmarshal(data, response) 67 }) 68 69 s.PatchValue(api.LoginWithSessionTokenAPICall, func(_ base.APICaller, request interface{}, response interface{}) error { 70 data, err := json.Marshal(request) 71 if err != nil { 72 return errors.Trace(err) 73 } 74 75 var lr struct { 76 SessionToken string `json:"session-token"` 77 } 78 79 err = json.Unmarshal(data, &lr) 80 if err != nil { 81 return errors.Trace(err) 82 } 83 84 if lr.SessionToken != sessionToken { 85 return ¶ms.Error{ 86 Message: "unauthorized", 87 Code: params.CodeUnauthorized, 88 } 89 } 90 91 loginResult, ok := response.(*params.LoginResult) 92 if !ok { 93 return errors.Errorf("expected %T, received %T for response type", loginResult, response) 94 } 95 loginResult.ControllerTag = names.NewControllerTag(info.ControllerUUID).String() 96 loginResult.ServerVersion = "3.4.0" 97 loginResult.UserInfo = ¶ms.AuthUserInfo{ 98 DisplayName: "alice@external", 99 Identity: names.NewUserTag("alice@external").String(), 100 ControllerAccess: "superuser", 101 } 102 return nil 103 }) 104 105 apiState, err := api.Open(&api.Info{ 106 Addrs: info.Addrs, 107 ControllerUUID: info.ControllerUUID, 108 CACert: info.CACert, 109 }, api.DialOpts{ 110 LoginProvider: api.NewSessionTokenLoginProvider( 111 "expired-token", 112 func(s string, a ...any) error { 113 loginDetails = fmt.Sprintf(s, a...) 114 return nil 115 }, 116 func(sessionToken string) error { 117 obtainedSessionToken = sessionToken 118 return nil 119 }, 120 ), 121 }) 122 c.Assert(err, jc.ErrorIsNil) 123 124 c.Assert(loginDetails, gc.Equals, "Please visit http://localhost:8080/test-verification and enter code 1234567 to log in.") 125 c.Assert(obtainedSessionToken, gc.Equals, sessionToken) 126 defer func() { _ = apiState.Close() }() 127 }