github.com/zooyer/miskit@v1.0.71/sdk/sso/sso.go (about) 1 package sso 2 3 import ( 4 "context" 5 "fmt" 6 "net/http" 7 "net/url" 8 "path" 9 "strings" 10 "time" 11 12 "github.com/gin-gonic/gin" 13 "github.com/zooyer/miskit/log" 14 "github.com/zooyer/miskit/zrpc" 15 ) 16 17 type Option struct { 18 ClientID string 19 ClientSecret string 20 Scope []string 21 Addr string 22 Retry int 23 Timeout time.Duration 24 Logger *log.Logger 25 } 26 27 type Client struct { 28 option Option 29 client *zrpc.Client 30 } 31 32 type Pages Client 33 34 type cookie struct { 35 Cookie string `json:"cookie"` 36 MaxAge int `json:"max_age"` 37 Path string `json:"path"` 38 Domain string `json:"domain"` 39 Secure bool `json:"secure"` 40 HttpOnly bool `json:"http_only"` 41 } 42 43 type sessionResp struct { 44 Cookie *cookie `json:"cookie"` 45 Userinfo *Userinfo `json:"userinfo"` 46 } 47 48 type Token struct { 49 AccessToken string `json:"access_token,omitempty"` 50 TokenType string `json:"token_type,omitempty"` 51 ExpiresIn int64 `json:"expires_in,omitempty"` 52 RefreshToken string `json:"refresh_token,omitempty"` 53 Scope string `json:"scope,omitempty"` 54 IDToken string `json:"id_token,omitempty"` 55 } 56 57 type Userinfo struct { 58 UserID int64 `json:"user_id"` 59 Username string `json:"username,omitempty"` 60 Nickname string `json:"nickname,omitempty"` 61 Realname string `json:"realname,omitempty"` 62 UserPhone string `json:"user_phone,omitempty"` 63 UserEmail string `json:"user_email,omitempty"` 64 UserGender int `json:"user_gender,omitempty"` 65 UserAddress string `json:"user_address,omitempty"` 66 UserAvatar string `json:"user_avatar,omitempty"` 67 UserAccess string `json:"user_access,omitempty"` 68 UserSource string `json:"user_source,omitempty"` 69 UserStatus int `json:"user_status,omitempty"` 70 UserExpiredAt int64 `json:"user_expired_at,omitempty"` 71 } 72 73 type Router interface { 74 gin.IRouter 75 BasePath() string 76 } 77 78 func New(option Option) *Client { 79 return &Client{ 80 option: option, 81 client: zrpc.New("sso", option.Retry, option.Timeout, option.Logger), 82 } 83 } 84 85 func (c *Client) AuthorizeCodeURL(ctx context.Context, redirectURI string) string { 86 var params = url.Values{ 87 "response_type": {"code"}, 88 "client_id": {c.option.ClientID}, 89 } 90 91 if redirectURI != "" { 92 params.Set("redirect_uri", redirectURI) 93 } 94 95 if len(c.option.Scope) > 0 { 96 params.Set("scope", strings.Join(c.option.Scope, " ")) 97 } 98 99 // TODO 后续考虑增加state、code_challenge、code_challenge_method 100 return fmt.Sprintf("%v/sso/authorize?%s", c.option.Addr, params.Encode()) 101 } 102 103 func (c *Client) Token(ctx context.Context, code string) (_ *Token, err error) { 104 var req = map[string]string{ 105 "grant_type": "authorization_code", 106 "code": code, 107 "client_id": c.option.ClientID, 108 "client_secret": c.option.ClientSecret, 109 // TODO 后续增加code_verifier,对code_challenge做校验 110 } 111 112 var ( 113 uri = fmt.Sprintf("%v/sso/api/v1/oauth/token", c.option.Addr) 114 token Token 115 ) 116 117 if _, _, err = c.client.PostJSON(ctx, uri, req, &token); err != nil { 118 return 119 } 120 121 return &token, nil 122 } 123 124 func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (_ *Token, err error) { 125 var req = map[string]string{ 126 "grant_type": "refresh_token", 127 "refresh_token": refreshToken, 128 "client_id": c.option.ClientID, 129 "client_secret": c.option.ClientSecret, 130 } 131 132 var ( 133 uri = fmt.Sprintf("%v/sso/api/v1/oauth/token", c.option.Addr) 134 token Token 135 ) 136 137 if _, _, err = c.client.PostJSON(ctx, uri, req, &token); err != nil { 138 return 139 } 140 141 return &token, nil 142 } 143 144 func (c *Client) Verify(ctx context.Context, accessToken string) (err error) { 145 var ( 146 uri = fmt.Sprintf("%v/sso/api/v1/oauth/verify", c.option.Addr) 147 req = map[string]interface{}{ 148 "client_id": c.option.ClientID, 149 "access_token": accessToken, 150 } 151 resp interface{} 152 ) 153 154 if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil { 155 return 156 } 157 158 return nil 159 } 160 161 func (c *Client) Userinfo(ctx *gin.Context, accessToken string) (userinfo *Userinfo, err error) { 162 var ( 163 uri = fmt.Sprintf("%v/sso/api/v1/oauth/userinfo", c.option.Addr) 164 req = map[string]interface{}{ 165 "client_id": c.option.ClientID, 166 "access_token": accessToken, 167 } 168 resp Userinfo 169 ) 170 171 if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil { 172 return 173 } 174 175 return c.userinfo(ctx, &resp), nil 176 } 177 178 func (c *Client) cookieName() string { 179 return fmt.Sprintf("sso-cookie-%v", c.option.ClientID) 180 } 181 182 func (c *Client) contextKey() string { 183 return fmt.Sprintf("sso-session-%v", c.option.ClientID) 184 } 185 186 // setCookie 设置cookie 187 func (c *Client) setCookie(ctx *gin.Context, cookie *cookie) { 188 if cookie == nil { 189 return 190 } 191 192 ctx.SetCookie( 193 c.cookieName(), 194 cookie.Cookie, 195 cookie.MaxAge, 196 cookie.Path, 197 cookie.Domain, 198 cookie.Secure, 199 cookie.HttpOnly, 200 ) 201 } 202 203 // getSession 获取session 204 func (c *Client) getSession(ctx *gin.Context, cookie string) (session *sessionResp, err error) { 205 var ( 206 uri = fmt.Sprintf("%v/sso/api/v1/oauth/session/get", c.option.Addr) 207 req = map[string]interface{}{ 208 "client_id": c.option.ClientID, 209 "cookie": cookie, 210 } 211 resp sessionResp 212 ) 213 214 if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil { 215 return 216 } 217 218 return &resp, nil 219 } 220 221 // newSession 创建session 222 func (c *Client) newSession(ctx *gin.Context, code string) (session *sessionResp, err error) { 223 token, err := c.Token(ctx, code) 224 if err != nil { 225 return 226 } 227 228 var ( 229 uri = fmt.Sprintf("%v/sso/api/v1/oauth/session/new", c.option.Addr) 230 req = map[string]interface{}{ 231 "client_id": c.option.ClientID, 232 "access_token": token.AccessToken, 233 } 234 resp sessionResp 235 ) 236 237 if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil { 238 return 239 } 240 241 return &resp, nil 242 } 243 244 // delSession 删除session 245 func (c *Client) delSession(ctx *gin.Context, cookie string) (session *sessionResp, err error) { 246 var ( 247 uri = fmt.Sprintf("%v/sso/api/v1/oauth/session/del", c.option.Addr) 248 req = map[string]interface{}{ 249 "client_id": c.option.ClientID, 250 "cookie": cookie, 251 } 252 resp sessionResp 253 ) 254 255 if _, _, err = c.client.PostJSON(ctx, uri, req, &resp); err != nil { 256 return 257 } 258 259 return &resp, nil 260 } 261 262 type sessionOptions struct { 263 // 用户未登录重定向到登录页,默认未登录则会302重定向到登录页 264 RedirectFunc func(ctx *gin.Context, uri string, err error) 265 // 用户授权登录后回调,默认失败会返回403状态码 266 CallbackFunc func(ctx *gin.Context, userinfo *Userinfo, err error) 267 // 用户注销登录,默认失败会返回 268 LogoutFunc func(ctx *gin.Context, err error) 269 } 270 271 type SessionOption func(options *sessionOptions) 272 273 func WithRedirect(redirect func(ctx *gin.Context, uri string, err error)) SessionOption { 274 return func(options *sessionOptions) { 275 options.RedirectFunc = redirect 276 } 277 } 278 279 func WithCallback(callback func(ctx *gin.Context, userinfo *Userinfo, err error)) SessionOption { 280 return func(options *sessionOptions) { 281 options.CallbackFunc = callback 282 } 283 } 284 285 func WithLogout(logout func(ctx *gin.Context, err error)) SessionOption { 286 return func(options *sessionOptions) { 287 options.LogoutFunc = logout 288 } 289 } 290 291 func (c *Client) userinfo(ctx *gin.Context, userinfo *Userinfo) *Userinfo { 292 if userinfo == nil { 293 return nil 294 } 295 296 if userinfo.UserAvatar != "" { 297 userinfo.UserAvatar = fmt.Sprintf("%s%s", c.option.Addr, userinfo.UserAvatar) 298 } 299 300 return userinfo 301 } 302 303 func (c *Client) oauth(options sessionOptions) gin.HandlerFunc { 304 return func(ctx *gin.Context) { 305 var ( 306 err error 307 req struct { 308 Code string `form:"code" json:"code" binding:"required"` 309 } 310 resp *sessionResp 311 userinfo *Userinfo 312 ) 313 314 defer func() { 315 if options.CallbackFunc != nil { 316 options.CallbackFunc(ctx, userinfo, err) 317 } else { 318 if err != nil { 319 ctx.AbortWithStatus(http.StatusForbidden) 320 } 321 } 322 }() 323 324 if err = ctx.Bind(&req); err != nil { 325 return 326 } 327 328 if resp, err = c.newSession(ctx, req.Code); err != nil { 329 return 330 } 331 332 if userinfo = resp.Userinfo; userinfo != nil { 333 userinfo = c.userinfo(ctx, userinfo) 334 } 335 336 if resp.Cookie != nil { 337 c.setCookie(ctx, resp.Cookie) 338 } 339 } 340 } 341 342 func (c *Client) middleware(loginPath string, options sessionOptions) gin.HandlerFunc { 343 return func(ctx *gin.Context) { 344 var ( 345 err error 346 cookie string 347 ) 348 349 defer func() { 350 if err != nil { 351 if options.RedirectFunc != nil { 352 options.RedirectFunc(ctx, loginPath, err) 353 } else { 354 ctx.Redirect(http.StatusFound, loginPath) 355 } 356 ctx.Abort() 357 } 358 }() 359 360 if cookie, err = ctx.Cookie(c.cookieName()); err != nil { 361 return 362 } 363 364 session, err := c.getSession(ctx, cookie) 365 if err != nil { 366 return 367 } 368 369 ctx.Set(c.contextKey(), c.userinfo(ctx, session.Userinfo)) 370 371 if session.Cookie != nil { 372 c.setCookie(ctx, session.Cookie) 373 } 374 } 375 } 376 377 func (c *Client) login() gin.HandlerFunc { 378 return func(ctx *gin.Context) { 379 var authCodeURL = c.AuthorizeCodeURL(ctx, ctx.Query("redirect_uri")) 380 uri, err := url.Parse(authCodeURL) 381 if err != nil { 382 ctx.Redirect(http.StatusFound, authCodeURL) 383 return 384 } 385 386 var query = uri.Query() 387 for key, values := range ctx.Request.URL.Query() { 388 for _, value := range values { 389 query.Set(key, value) 390 } 391 } 392 393 uri.RawQuery = query.Encode() 394 uri.Fragment = ctx.Request.URL.Fragment 395 396 ctx.Redirect(http.StatusFound, uri.String()) 397 } 398 } 399 400 func (c *Client) logout(options sessionOptions) gin.HandlerFunc { 401 return func(ctx *gin.Context) { 402 var ( 403 err error 404 cookie string 405 ) 406 407 defer func() { 408 if options.LogoutFunc != nil { 409 options.LogoutFunc(ctx, err) 410 } else { 411 if err != nil { 412 ctx.AbortWithStatus(http.StatusForbidden) 413 } 414 } 415 }() 416 417 if cookie, err = ctx.Cookie(c.cookieName()); err != nil { 418 return 419 } 420 421 session, err := c.delSession(ctx, cookie) 422 if err != nil { 423 return 424 } 425 426 if session.Cookie != nil { 427 c.setCookie(ctx, session.Cookie) 428 } 429 } 430 } 431 432 func (c *Client) Session(router Router, loginPath, oauthPath, logoutPath string, options ...SessionOption) (middleware gin.HandlerFunc) { 433 var opt sessionOptions 434 for _, fn := range options { 435 fn(&opt) 436 } 437 438 middleware = c.middleware(path.Join(router.BasePath(), loginPath), opt) 439 440 router.GET(loginPath, c.login()) 441 router.HEAD(loginPath, c.login()) 442 router.GET(oauthPath, c.oauth(opt)) 443 router.POST(oauthPath, c.oauth(opt)) 444 router.GET(logoutPath, middleware, c.logout(opt)) 445 router.POST(logoutPath, middleware, c.logout(opt)) 446 447 return middleware 448 } 449 450 func (c *Client) SessionUserinfo(ctx context.Context) *Userinfo { 451 if userinfo, ok := ctx.Value(c.contextKey()).(*Userinfo); ok { 452 return userinfo 453 } 454 455 return nil 456 } 457 458 func (c *Client) Pages() *Pages { 459 return (*Pages)(c) 460 } 461 462 func (p *Pages) Login() string { 463 return fmt.Sprintf("%v/login", p.option.Addr) 464 } 465 466 func (p *Pages) Home() string { 467 return fmt.Sprintf("%v", p.option.Addr) 468 } 469 470 func (p *Pages) Dashboard() string { 471 return fmt.Sprintf("%v/dashboard", p.option.Addr) 472 } 473 474 func (p *Pages) Profile() string { 475 return fmt.Sprintf("%v/profile", p.option.Addr) 476 }