github.com/zooyer/miskit@v1.0.71/sdk/sso/sso.go (about)

     1  package sso
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"net/url"
     8  	"path"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/gin-gonic/gin"
    13  	"github.com/zooyer/miskit/log"
    14  	"github.com/zooyer/miskit/zrpc"
    15  )
    16  
    17  type Option struct {
    18  	ClientID     string
    19  	ClientSecret string
    20  	Scope        []string
    21  	Addr         string
    22  	Retry        int
    23  	Timeout      time.Duration
    24  	Logger       *log.Logger
    25  }
    26  
    27  type Client struct {
    28  	option Option
    29  	client *zrpc.Client
    30  }
    31  
    32  type Pages Client
    33  
    34  type cookie struct {
    35  	Cookie   string `json:"cookie"`
    36  	MaxAge   int    `json:"max_age"`
    37  	Path     string `json:"path"`
    38  	Domain   string `json:"domain"`
    39  	Secure   bool   `json:"secure"`
    40  	HttpOnly bool   `json:"http_only"`
    41  }
    42  
    43  type sessionResp struct {
    44  	Cookie   *cookie   `json:"cookie"`
    45  	Userinfo *Userinfo `json:"userinfo"`
    46  }
    47  
    48  type Token struct {
    49  	AccessToken  string `json:"access_token,omitempty"`
    50  	TokenType    string `json:"token_type,omitempty"`
    51  	ExpiresIn    int64  `json:"expires_in,omitempty"`
    52  	RefreshToken string `json:"refresh_token,omitempty"`
    53  	Scope        string `json:"scope,omitempty"`
    54  	IDToken      string `json:"id_token,omitempty"`
    55  }
    56  
    57  type Userinfo struct {
    58  	UserID        int64  `json:"user_id"`
    59  	Username      string `json:"username,omitempty"`
    60  	Nickname      string `json:"nickname,omitempty"`
    61  	Realname      string `json:"realname,omitempty"`
    62  	UserPhone     string `json:"user_phone,omitempty"`
    63  	UserEmail     string `json:"user_email,omitempty"`
    64  	UserGender    int    `json:"user_gender,omitempty"`
    65  	UserAddress   string `json:"user_address,omitempty"`
    66  	UserAvatar    string `json:"user_avatar,omitempty"`
    67  	UserAccess    string `json:"user_access,omitempty"`
    68  	UserSource    string `json:"user_source,omitempty"`
    69  	UserStatus    int    `json:"user_status,omitempty"`
    70  	UserExpiredAt int64  `json:"user_expired_at,omitempty"`
    71  }
    72  
    73  type Router interface {
    74  	gin.IRouter
    75  	BasePath() string
    76  }
    77  
    78  func New(option Option) *Client {
    79  	return &Client{
    80  		option: option,
    81  		client: zrpc.New("sso", option.Retry, option.Timeout, option.Logger),
    82  	}
    83  }
    84  
    85  func (c *Client) AuthorizeCodeURL(ctx context.Context, redirectURI string) string {
    86  	var params = url.Values{
    87  		"response_type": {"code"},
    88  		"client_id":     {c.option.ClientID},
    89  	}
    90  
    91  	if redirectURI != "" {
    92  		params.Set("redirect_uri", redirectURI)
    93  	}
    94  
    95  	if len(c.option.Scope) > 0 {
    96  		params.Set("scope", strings.Join(c.option.Scope, " "))
    97  	}
    98  
    99  	// TODO 后续考虑增加state、code_challenge、code_challenge_method
   100  	return fmt.Sprintf("%v/sso/authorize?%s", c.option.Addr, params.Encode())
   101  }
   102  
   103  func (c *Client) Token(ctx context.Context, code string) (_ *Token, err error) {
   104  	var req = map[string]string{
   105  		"grant_type":    "authorization_code",
   106  		"code":          code,
   107  		"client_id":     c.option.ClientID,
   108  		"client_secret": c.option.ClientSecret,
   109  		// TODO 后续增加code_verifier,对code_challenge做校验
   110  	}
   111  
   112  	var (
   113  		uri   = fmt.Sprintf("%v/sso/api/v1/oauth/token", c.option.Addr)
   114  		token Token
   115  	)
   116  
   117  	if _, _, err = c.client.PostJSON(ctx, uri, req, &token); err != nil {
   118  		return
   119  	}
   120  
   121  	return &token, nil
   122  }
   123  
   124  func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (_ *Token, err error) {
   125  	var req = map[string]string{
   126  		"grant_type":    "refresh_token",
   127  		"refresh_token": refreshToken,
   128  		"client_id":     c.option.ClientID,
   129  		"client_secret": c.option.ClientSecret,
   130  	}
   131  
   132  	var (
   133  		uri   = fmt.Sprintf("%v/sso/api/v1/oauth/token", c.option.Addr)
   134  		token Token
   135  	)
   136  
   137  	if _, _, err = c.client.PostJSON(ctx, uri, req, &token); err != nil {
   138  		return
   139  	}
   140  
   141  	return &token, nil
   142  }
   143  
   144  func (c *Client) Verify(ctx context.Context, accessToken string) (err error) {
   145  	var (
   146  		uri = fmt.Sprintf("%v/sso/api/v1/oauth/verify", c.option.Addr)
   147  		req = map[string]interface{}{
   148  			"client_id":    c.option.ClientID,
   149  			"access_token": accessToken,
   150  		}
   151  		resp interface{}
   152  	)
   153  
   154  	if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil {
   155  		return
   156  	}
   157  
   158  	return nil
   159  }
   160  
   161  func (c *Client) Userinfo(ctx *gin.Context, accessToken string) (userinfo *Userinfo, err error) {
   162  	var (
   163  		uri = fmt.Sprintf("%v/sso/api/v1/oauth/userinfo", c.option.Addr)
   164  		req = map[string]interface{}{
   165  			"client_id":    c.option.ClientID,
   166  			"access_token": accessToken,
   167  		}
   168  		resp Userinfo
   169  	)
   170  
   171  	if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil {
   172  		return
   173  	}
   174  
   175  	return c.userinfo(ctx, &resp), nil
   176  }
   177  
   178  func (c *Client) cookieName() string {
   179  	return fmt.Sprintf("sso-cookie-%v", c.option.ClientID)
   180  }
   181  
   182  func (c *Client) contextKey() string {
   183  	return fmt.Sprintf("sso-session-%v", c.option.ClientID)
   184  }
   185  
   186  // setCookie 设置cookie
   187  func (c *Client) setCookie(ctx *gin.Context, cookie *cookie) {
   188  	if cookie == nil {
   189  		return
   190  	}
   191  
   192  	ctx.SetCookie(
   193  		c.cookieName(),
   194  		cookie.Cookie,
   195  		cookie.MaxAge,
   196  		cookie.Path,
   197  		cookie.Domain,
   198  		cookie.Secure,
   199  		cookie.HttpOnly,
   200  	)
   201  }
   202  
   203  // getSession 获取session
   204  func (c *Client) getSession(ctx *gin.Context, cookie string) (session *sessionResp, err error) {
   205  	var (
   206  		uri = fmt.Sprintf("%v/sso/api/v1/oauth/session/get", c.option.Addr)
   207  		req = map[string]interface{}{
   208  			"client_id": c.option.ClientID,
   209  			"cookie":    cookie,
   210  		}
   211  		resp sessionResp
   212  	)
   213  
   214  	if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil {
   215  		return
   216  	}
   217  
   218  	return &resp, nil
   219  }
   220  
   221  // newSession 创建session
   222  func (c *Client) newSession(ctx *gin.Context, code string) (session *sessionResp, err error) {
   223  	token, err := c.Token(ctx, code)
   224  	if err != nil {
   225  		return
   226  	}
   227  
   228  	var (
   229  		uri = fmt.Sprintf("%v/sso/api/v1/oauth/session/new", c.option.Addr)
   230  		req = map[string]interface{}{
   231  			"client_id":    c.option.ClientID,
   232  			"access_token": token.AccessToken,
   233  		}
   234  		resp sessionResp
   235  	)
   236  
   237  	if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil {
   238  		return
   239  	}
   240  
   241  	return &resp, nil
   242  }
   243  
   244  // delSession 删除session
   245  func (c *Client) delSession(ctx *gin.Context, cookie string) (session *sessionResp, err error) {
   246  	var (
   247  		uri = fmt.Sprintf("%v/sso/api/v1/oauth/session/del", c.option.Addr)
   248  		req = map[string]interface{}{
   249  			"client_id": c.option.ClientID,
   250  			"cookie":    cookie,
   251  		}
   252  		resp sessionResp
   253  	)
   254  
   255  	if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil {
   256  		return
   257  	}
   258  
   259  	return &resp, nil
   260  }
   261  
   262  type sessionOptions struct {
   263  	// 用户未登录重定向到登录页,默认未登录则会302重定向到登录页
   264  	RedirectFunc func(ctx *gin.Context, uri string, err error)
   265  	// 用户授权登录后回调,默认失败会返回403状态码
   266  	CallbackFunc func(ctx *gin.Context, userinfo *Userinfo, err error)
   267  	// 用户注销登录,默认失败会返回
   268  	LogoutFunc func(ctx *gin.Context, err error)
   269  }
   270  
   271  type SessionOption func(options *sessionOptions)
   272  
   273  func WithRedirect(redirect func(ctx *gin.Context, uri string, err error)) SessionOption {
   274  	return func(options *sessionOptions) {
   275  		options.RedirectFunc = redirect
   276  	}
   277  }
   278  
   279  func WithCallback(callback func(ctx *gin.Context, userinfo *Userinfo, err error)) SessionOption {
   280  	return func(options *sessionOptions) {
   281  		options.CallbackFunc = callback
   282  	}
   283  }
   284  
   285  func WithLogout(logout func(ctx *gin.Context, err error)) SessionOption {
   286  	return func(options *sessionOptions) {
   287  		options.LogoutFunc = logout
   288  	}
   289  }
   290  
   291  func (c *Client) userinfo(ctx *gin.Context, userinfo *Userinfo) *Userinfo {
   292  	if userinfo == nil {
   293  		return nil
   294  	}
   295  
   296  	if userinfo.UserAvatar != "" {
   297  		userinfo.UserAvatar = fmt.Sprintf("%s%s", c.option.Addr, userinfo.UserAvatar)
   298  	}
   299  
   300  	return userinfo
   301  }
   302  
   303  func (c *Client) oauth(options sessionOptions) gin.HandlerFunc {
   304  	return func(ctx *gin.Context) {
   305  		var (
   306  			err error
   307  			req struct {
   308  				Code string `form:"code" json:"code" binding:"required"`
   309  			}
   310  			resp     *sessionResp
   311  			userinfo *Userinfo
   312  		)
   313  
   314  		defer func() {
   315  			if options.CallbackFunc != nil {
   316  				options.CallbackFunc(ctx, userinfo, err)
   317  			} else {
   318  				if err != nil {
   319  					ctx.AbortWithStatus(http.StatusForbidden)
   320  				}
   321  			}
   322  		}()
   323  
   324  		if err = ctx.Bind(&req); err != nil {
   325  			return
   326  		}
   327  
   328  		if resp, err = c.newSession(ctx, req.Code); err != nil {
   329  			return
   330  		}
   331  
   332  		if userinfo = resp.Userinfo; userinfo != nil {
   333  			userinfo = c.userinfo(ctx, userinfo)
   334  		}
   335  
   336  		if resp.Cookie != nil {
   337  			c.setCookie(ctx, resp.Cookie)
   338  		}
   339  	}
   340  }
   341  
   342  func (c *Client) middleware(loginPath string, options sessionOptions) gin.HandlerFunc {
   343  	return func(ctx *gin.Context) {
   344  		var (
   345  			err    error
   346  			cookie string
   347  		)
   348  
   349  		defer func() {
   350  			if err != nil {
   351  				if options.RedirectFunc != nil {
   352  					options.RedirectFunc(ctx, loginPath, err)
   353  				} else {
   354  					ctx.Redirect(http.StatusFound, loginPath)
   355  				}
   356  				ctx.Abort()
   357  			}
   358  		}()
   359  
   360  		if cookie, err = ctx.Cookie(c.cookieName()); err != nil {
   361  			return
   362  		}
   363  
   364  		session, err := c.getSession(ctx, cookie)
   365  		if err != nil {
   366  			return
   367  		}
   368  
   369  		ctx.Set(c.contextKey(), c.userinfo(ctx, session.Userinfo))
   370  
   371  		if session.Cookie != nil {
   372  			c.setCookie(ctx, session.Cookie)
   373  		}
   374  	}
   375  }
   376  
   377  func (c *Client) login() gin.HandlerFunc {
   378  	return func(ctx *gin.Context) {
   379  		var authCodeURL = c.AuthorizeCodeURL(ctx, ctx.Query("redirect_uri"))
   380  		uri, err := url.Parse(authCodeURL)
   381  		if err != nil {
   382  			ctx.Redirect(http.StatusFound, authCodeURL)
   383  			return
   384  		}
   385  
   386  		var query = uri.Query()
   387  		for key, values := range ctx.Request.URL.Query() {
   388  			for _, value := range values {
   389  				query.Set(key, value)
   390  			}
   391  		}
   392  
   393  		uri.RawQuery = query.Encode()
   394  		uri.Fragment = ctx.Request.URL.Fragment
   395  
   396  		ctx.Redirect(http.StatusFound, uri.String())
   397  	}
   398  }
   399  
   400  func (c *Client) logout(options sessionOptions) gin.HandlerFunc {
   401  	return func(ctx *gin.Context) {
   402  		var (
   403  			err    error
   404  			cookie string
   405  		)
   406  
   407  		defer func() {
   408  			if options.LogoutFunc != nil {
   409  				options.LogoutFunc(ctx, err)
   410  			} else {
   411  				if err != nil {
   412  					ctx.AbortWithStatus(http.StatusForbidden)
   413  				}
   414  			}
   415  		}()
   416  
   417  		if cookie, err = ctx.Cookie(c.cookieName()); err != nil {
   418  			return
   419  		}
   420  
   421  		session, err := c.delSession(ctx, cookie)
   422  		if err != nil {
   423  			return
   424  		}
   425  
   426  		if session.Cookie != nil {
   427  			c.setCookie(ctx, session.Cookie)
   428  		}
   429  	}
   430  }
   431  
   432  func (c *Client) Session(router Router, loginPath, oauthPath, logoutPath string, options ...SessionOption) (middleware gin.HandlerFunc) {
   433  	var opt sessionOptions
   434  	for _, fn := range options {
   435  		fn(&opt)
   436  	}
   437  
   438  	middleware = c.middleware(path.Join(router.BasePath(), loginPath), opt)
   439  
   440  	router.GET(loginPath, c.login())
   441  	router.HEAD(loginPath, c.login())
   442  	router.GET(oauthPath, c.oauth(opt))
   443  	router.POST(oauthPath, c.oauth(opt))
   444  	router.GET(logoutPath, middleware, c.logout(opt))
   445  	router.POST(logoutPath, middleware, c.logout(opt))
   446  
   447  	return middleware
   448  }
   449  
   450  func (c *Client) SessionUserinfo(ctx context.Context) *Userinfo {
   451  	if userinfo, ok := ctx.Value(c.contextKey()).(*Userinfo); ok {
   452  		return userinfo
   453  	}
   454  
   455  	return nil
   456  }
   457  
   458  func (c *Client) Pages() *Pages {
   459  	return (*Pages)(c)
   460  }
   461  
   462  func (p *Pages) Login() string {
   463  	return fmt.Sprintf("%v/login", p.option.Addr)
   464  }
   465  
   466  func (p *Pages) Home() string {
   467  	return fmt.Sprintf("%v", p.option.Addr)
   468  }
   469  
   470  func (p *Pages) Dashboard() string {
   471  	return fmt.Sprintf("%v/dashboard", p.option.Addr)
   472  }
   473  
   474  func (p *Pages) Profile() string {
   475  	return fmt.Sprintf("%v/profile", p.option.Addr)
   476  }