github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/accounts/oauth.go (about)

     1  package accounts
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"net/http"
     7  	"net/url"
     8  
     9  	"github.com/cozy/cozy-stack/model/account"
    10  	"github.com/cozy/cozy-stack/model/instance"
    11  	"github.com/cozy/cozy-stack/model/instance/lifecycle"
    12  	"github.com/cozy/cozy-stack/model/permission"
    13  	"github.com/cozy/cozy-stack/model/session"
    14  	"github.com/cozy/cozy-stack/pkg/config/config"
    15  	"github.com/cozy/cozy-stack/pkg/consts"
    16  	"github.com/cozy/cozy-stack/pkg/couchdb"
    17  	"github.com/cozy/cozy-stack/pkg/jsonapi"
    18  	"github.com/cozy/cozy-stack/pkg/logger"
    19  	"github.com/cozy/cozy-stack/web/auth"
    20  	"github.com/cozy/cozy-stack/web/middlewares"
    21  	"github.com/cozy/cozy-stack/web/oidc"
    22  	jwt "github.com/golang-jwt/jwt/v5"
    23  	"github.com/labstack/echo/v4"
    24  )
    25  
    26  type apiAccount struct {
    27  	*account.Account
    28  }
    29  
    30  func (a *apiAccount) MarshalJSON() ([]byte, error)           { return json.Marshal(a.Account) }
    31  func (a *apiAccount) Relationships() jsonapi.RelationshipMap { return nil }
    32  func (a *apiAccount) Included() []jsonapi.Object             { return nil }
    33  func (a *apiAccount) Links() *jsonapi.LinksList {
    34  	return &jsonapi.LinksList{Self: "/data/" + consts.Accounts + "/" + a.ID()}
    35  }
    36  
    37  func start(c echo.Context) error {
    38  	instance := middlewares.GetInstance(c)
    39  
    40  	accountTypeID := c.Param("accountType")
    41  	accountType, err := account.TypeInfo(accountTypeID, instance.ContextName)
    42  	if err != nil {
    43  		return err
    44  	}
    45  
    46  	state, err := getStorage().Add(&stateHolder{
    47  		InstanceDomain: instance.Domain,
    48  		AccountType:    accountType.ServiceID(),
    49  		ClientState:    c.QueryParam("state"),
    50  		Nonce:          c.QueryParam("nonce"),
    51  		Slug:           c.QueryParam("slug"),
    52  	})
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	url, err := accountType.MakeOauthStartURL(instance, state, c.QueryParams())
    58  	if err != nil {
    59  		return err
    60  	}
    61  
    62  	return c.Redirect(http.StatusSeeOther, url)
    63  }
    64  
    65  func redirectToApp(
    66  	c echo.Context,
    67  	inst *instance.Instance,
    68  	acc *account.Account,
    69  	clientState, slug, connID, connDeleted, errorMessage string,
    70  ) error {
    71  	if slug == "" {
    72  		slug = consts.HomeSlug
    73  	}
    74  	u := inst.SubDomain(slug)
    75  	vv := &url.Values{}
    76  	if acc != nil {
    77  		vv.Add("account", acc.ID())
    78  	}
    79  	if clientState != "" {
    80  		vv.Add("state", clientState)
    81  	}
    82  	if connID != "" {
    83  		vv.Add("connection_id", connID)
    84  	}
    85  	if connDeleted != "" {
    86  		vv.Add("connection_deleted", connDeleted)
    87  	}
    88  	if errorMessage != "" {
    89  		vv.Add("error", errorMessage)
    90  	}
    91  	u.RawQuery = vv.Encode()
    92  	return c.Redirect(http.StatusSeeOther, u.String())
    93  }
    94  
    95  // redirect is the redirect_uri endpoint passed to oauth services
    96  // it should create the account.
    97  // middlewares.NeedInstance is not applied before this handler
    98  // it needs to handle both
    99  // - with instance redirect
   100  // - without instance redirect
   101  func redirect(c echo.Context) error {
   102  	accessCode := c.QueryParam("code")
   103  	accessToken := c.QueryParam("access_token")
   104  	accountTypeID := c.Param("accountType")
   105  
   106  	i, _ := lifecycle.GetInstance(c.Request().Host)
   107  	var clientState, connID, connDeleted, slug string
   108  	var acc *account.Account
   109  
   110  	connID = c.QueryParam("connection_id")
   111  	connDeleted = c.QueryParam("connection_deleted")
   112  
   113  	if accessToken != "" {
   114  		if i == nil {
   115  			return echo.NewHTTPError(http.StatusBadRequest,
   116  				"using ?access_token with instance-less redirect")
   117  		}
   118  
   119  		acc = &account.Account{
   120  			AccountType: accountTypeID,
   121  			Oauth: &account.OauthInfo{
   122  				AccessToken: accessToken,
   123  			},
   124  		}
   125  	} else {
   126  		stateCode := c.QueryParam("state")
   127  		state := getStorage().Find(stateCode)
   128  		if state == nil ||
   129  			state.AccountType != accountTypeID ||
   130  			(i != nil && state.InstanceDomain != i.Domain) {
   131  			return errors.New("bad state")
   132  		}
   133  		if i == nil {
   134  			var err error
   135  			i, err = lifecycle.GetInstance(state.InstanceDomain)
   136  			if err != nil {
   137  				return errors.New("bad state")
   138  			}
   139  		}
   140  
   141  		clientState = state.ClientState
   142  		slug = state.Slug
   143  
   144  		// https://developers.google.com/identity/protocols/oauth2/web-server?hl=en#handlingresponse
   145  		if c.QueryParam("error") == "access_denied" {
   146  			return redirectToApp(c, i, nil, clientState, slug, connID, connDeleted, "access_denied")
   147  		}
   148  
   149  		accountType, err := account.TypeInfo(accountTypeID, i.ContextName)
   150  		if err != nil {
   151  			return err
   152  		}
   153  
   154  		if state.WebviewFlow {
   155  			return redirectToApp(c, i, nil, clientState, slug, connID, connDeleted, "")
   156  		}
   157  
   158  		if accountType.TokenEndpoint == "" {
   159  			params := c.QueryParams()
   160  			params.Del("state")
   161  			acc = &account.Account{
   162  				AccountType: accountTypeID,
   163  				Oauth: &account.OauthInfo{
   164  					ClientID:     accountType.ClientID,
   165  					ClientSecret: accountType.ClientSecret,
   166  					Query:        &params,
   167  				},
   168  			}
   169  		} else {
   170  			acc, err = accountType.RequestAccessToken(i, accessCode, stateCode, state.Nonce)
   171  			if err != nil {
   172  				return err
   173  			}
   174  		}
   175  	}
   176  
   177  	if connID != "" {
   178  		if existingAccount, err := findAccountWithSameConnectionID(i, connID); err == nil {
   179  			acc = existingAccount
   180  		}
   181  	}
   182  
   183  	if acc.ID() == "" {
   184  		if err := couchdb.CreateDoc(i, acc); err != nil {
   185  			return err
   186  		}
   187  	}
   188  
   189  	c.Set("instance", i.WithContextualDomain(c.Request().Host))
   190  	return redirectToApp(c, i, acc, clientState, slug, connID, connDeleted, "")
   191  }
   192  
   193  func findAccountWithSameConnectionID(inst *instance.Instance, connectionID string) (*account.Account, error) {
   194  	var accounts []*account.Account
   195  	req := &couchdb.AllDocsRequest{Limit: 1000}
   196  	err := couchdb.GetAllDocs(inst, consts.Accounts, req, &accounts)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  
   201  	for _, a := range accounts {
   202  		if a.Oauth == nil || a.Oauth.Query == nil {
   203  			continue
   204  		}
   205  		connID := a.Oauth.Query.Get("connection_id")
   206  		if connID == connectionID {
   207  			return a, nil
   208  		}
   209  	}
   210  
   211  	return nil, errors.New("not found")
   212  }
   213  
   214  // refresh is an internal route used by konnectors to refresh accounts
   215  // it requires permissions GET:io.cozy.accounts:accountid
   216  func refresh(c echo.Context) error {
   217  	instance := middlewares.GetInstance(c)
   218  	accountid := c.Param("accountid")
   219  
   220  	var acc account.Account
   221  	if err := couchdb.GetDoc(instance, consts.Accounts, accountid, &acc); err != nil {
   222  		return err
   223  	}
   224  
   225  	if err := middlewares.Allow(c, permission.GET, &acc); err != nil {
   226  		return err
   227  	}
   228  
   229  	accountType, err := account.TypeInfo(acc.AccountType, instance.ContextName)
   230  	if err != nil {
   231  		return err
   232  	}
   233  
   234  	err = accountType.RefreshAccount(acc)
   235  	if err != nil {
   236  		return err
   237  	}
   238  
   239  	err = couchdb.UpdateDoc(instance, &acc)
   240  	if err != nil {
   241  		return err
   242  	}
   243  
   244  	return jsonapi.Data(c, http.StatusOK, &apiAccount{&acc}, nil)
   245  }
   246  
   247  // manage redirects the user to the BI webview allowing them to manage their
   248  // bank connections
   249  func manage(c echo.Context) error {
   250  	instance := middlewares.GetInstance(c)
   251  	accountid := c.Param("accountid")
   252  
   253  	var acc account.Account
   254  	if err := couchdb.GetDoc(instance, consts.Accounts, accountid, &acc); err != nil {
   255  		return err
   256  	}
   257  
   258  	accountType, err := account.TypeInfo(acc.AccountType, instance.ContextName)
   259  	if err != nil {
   260  		return err
   261  	}
   262  
   263  	state, err := getStorage().Add(&stateHolder{
   264  		InstanceDomain: instance.Domain,
   265  		AccountType:    accountType.ServiceID(),
   266  		ClientState:    c.QueryParam("state"),
   267  		Slug:           c.QueryParam("slug"),
   268  		WebviewFlow:    true,
   269  	})
   270  	if err != nil {
   271  		return err
   272  	}
   273  
   274  	url, err := accountType.MakeManageURL(instance, state, c.QueryParams())
   275  	if err != nil {
   276  		return err
   277  	}
   278  
   279  	return c.Redirect(http.StatusSeeOther, url)
   280  }
   281  
   282  // reconnect can be used to reconnect a user from BI
   283  func reconnect(c echo.Context) error {
   284  	instance := middlewares.GetInstance(c)
   285  	accountid := c.Param("accountid")
   286  
   287  	var acc account.Account
   288  	if err := couchdb.GetDoc(instance, consts.Accounts, accountid, &acc); err != nil {
   289  		return err
   290  	}
   291  
   292  	accountType, err := account.TypeInfo(acc.AccountType, instance.ContextName)
   293  	if err != nil {
   294  		return err
   295  	}
   296  
   297  	state, err := getStorage().Add(&stateHolder{
   298  		InstanceDomain: instance.Domain,
   299  		AccountType:    accountType.ServiceID(),
   300  		ClientState:    c.QueryParam("state"),
   301  		Slug:           c.QueryParam("slug"),
   302  		WebviewFlow:    true,
   303  	})
   304  	if err != nil {
   305  		return err
   306  	}
   307  
   308  	url, err := accountType.MakeReconnectURL(instance, state, c.QueryParams())
   309  	if err != nil {
   310  		return err
   311  	}
   312  
   313  	return c.Redirect(http.StatusSeeOther, url)
   314  }
   315  
   316  func checkLogin(next echo.HandlerFunc) echo.HandlerFunc {
   317  	return func(c echo.Context) error {
   318  		inst := middlewares.GetInstance(c)
   319  		sess, isLoggedIn := middlewares.GetSession(c)
   320  		wasLoggedIn := isLoggedIn
   321  
   322  		if sess != nil && sess.ShortRun {
   323  			// XXX it's better to create a new session in that case, as the
   324  			// existing short session can easily timeout between now and when
   325  			// the user will come back.
   326  			wasLoggedIn = false
   327  		}
   328  
   329  		if code := c.QueryParam("session_code"); code != "" {
   330  			// XXX we should always clear the session code to avoid it being
   331  			// reused, even if the user is already logged in and we don't want to
   332  			// create a new session
   333  			if checked := inst.CheckAndClearSessionCode(code); checked {
   334  				isLoggedIn = true
   335  			}
   336  		}
   337  
   338  		if !isLoggedIn && checkIDToken(c) {
   339  			isLoggedIn = true
   340  		}
   341  
   342  		if !isLoggedIn {
   343  			return echo.NewHTTPError(http.StatusForbidden)
   344  		}
   345  
   346  		if !wasLoggedIn {
   347  			sessionID, err := auth.SetCookieForNewSession(c, session.ShortRun)
   348  			req := c.Request()
   349  			if err == nil {
   350  				if err = session.StoreNewLoginEntry(inst, sessionID, "", req, "session_code", false); err != nil {
   351  					inst.Logger().Errorf("Could not store session history %q: %s", sessionID, err)
   352  				}
   353  			}
   354  		}
   355  
   356  		return next(c)
   357  	}
   358  }
   359  
   360  func checkIDToken(c echo.Context) bool {
   361  	inst := middlewares.GetInstance(c)
   362  	cfg, ok := config.GetOIDC(inst.ContextName)
   363  	if !ok {
   364  		return false
   365  	}
   366  	allowOAuthToken, _ := cfg["allow_oauth_token"].(bool)
   367  	if !allowOAuthToken {
   368  		return false
   369  	}
   370  	idTokenKeyURL, _ := cfg["id_token_jwk_url"].(string)
   371  	if idTokenKeyURL == "" {
   372  		return false
   373  	}
   374  
   375  	keys, err := oidc.GetIDTokenKeys(idTokenKeyURL)
   376  	if err != nil {
   377  		return false
   378  	}
   379  	idToken := c.QueryParam("id_token")
   380  	token, err := jwt.Parse(idToken, func(token *jwt.Token) (interface{}, error) {
   381  		return oidc.ChooseKeyForIDToken(keys, token)
   382  	})
   383  	if err != nil {
   384  		logger.WithNamespace("oidc").Errorf("Error on jwt.Parse: %s", err)
   385  		return false
   386  	}
   387  	if !token.Valid {
   388  		logger.WithNamespace("oidc").Errorf("Invalid token: %#v", token)
   389  		return false
   390  	}
   391  	claims := token.Claims.(jwt.MapClaims)
   392  	if claims["sub"] == "" || claims["sub"] != inst.OIDCID {
   393  		inst.Logger().WithNamespace("oidc").Errorf("Invalid sub: %s != %s", claims["sub"], inst.OIDCID)
   394  		return false
   395  	}
   396  
   397  	return true
   398  }
   399  
   400  // Routes setups routing for cozy-as-oauth-client routes
   401  // Careful, the normal middlewares NeedInstance and LoadSession are not applied
   402  // to this group in web/routing
   403  func Routes(router *echo.Group) {
   404  	router.GET("/:accountType/start", start, middlewares.NeedInstance, middlewares.LoadSession, checkLogin)
   405  	router.GET("/:accountType/redirect", redirect)
   406  	router.GET("/:accountType/:accountid/manage", manage, middlewares.NeedInstance, middlewares.LoadSession, checkLogin)
   407  	router.POST("/:accountType/:accountid/refresh", refresh, middlewares.NeedInstance)
   408  	router.GET("/:accountType/:accountid/reconnect", reconnect, middlewares.NeedInstance, middlewares.LoadSession, checkLogin)
   409  }