github.com/anycable/anycable-go@v1.5.1/identity/jwt.go (about) 1 package identity 2 3 import ( 4 "fmt" 5 "log/slog" 6 "net/url" 7 "strings" 8 9 "github.com/anycable/anycable-go/common" 10 "github.com/golang-jwt/jwt" 11 ) 12 13 const ( 14 expiredMessage = "{\"type\":\"disconnect\",\"reason\":\"token_expired\",\"reconnect\":false}" 15 ) 16 17 type JWTConfig struct { 18 Secret string 19 Param string 20 Algo jwt.SigningMethod 21 Force bool 22 } 23 24 var ( 25 defaultJWTAlgo = jwt.SigningMethodHS256 26 ) 27 28 func NewJWTConfig(secret string) JWTConfig { 29 return JWTConfig{Secret: secret, Param: "jid", Algo: defaultJWTAlgo} 30 } 31 32 func (c JWTConfig) Enabled() bool { 33 return c.Secret != "" 34 } 35 36 type JWTIdentifier struct { 37 secret []byte 38 paramName string 39 headerName string 40 required bool 41 log *slog.Logger 42 } 43 44 var _ Identifier = (*JWTIdentifier)(nil) 45 46 func NewJWTIdentifier(config *JWTConfig, l *slog.Logger) *JWTIdentifier { 47 return &JWTIdentifier{ 48 secret: []byte(config.Secret), 49 paramName: config.Param, 50 headerName: strings.ToLower(fmt.Sprintf("x-%s", config.Param)), 51 required: config.Force, 52 log: l.With("context", "jwt"), 53 } 54 } 55 56 func (i *JWTIdentifier) Identify(sid string, env *common.SessionEnv) (*common.ConnectResult, error) { 57 var rawToken string 58 59 if env.Headers != nil { 60 if v, ok := (*env.Headers)[i.headerName]; ok { 61 rawToken = v 62 } 63 } 64 65 if rawToken == "" { 66 u, err := url.Parse(env.URL) 67 68 if err != nil { 69 return nil, err 70 } 71 72 m, err := url.ParseQuery(u.RawQuery) 73 74 if err != nil { 75 return nil, err 76 } 77 78 if v, ok := m[i.paramName]; ok { 79 rawToken = v[0] 80 } 81 } 82 83 if rawToken == "" { 84 i.log.Debug("no token is found", "url", env.URL, "headers", env.Headers) 85 86 if i.required { 87 return unauthorizedResponse(), nil 88 } 89 90 return nil, nil 91 } 92 93 token, err := jwt.Parse(rawToken, func(token *jwt.Token) (interface{}, error) { 94 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { 95 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 96 } 97 98 return i.secret, nil 99 }) 100 101 if err != nil { 102 if ve, ok := err.(*jwt.ValidationError); ok { 103 if ve.Errors&(jwt.ValidationErrorExpired) != 0 { 104 i.log.Debug("token has expired") 105 106 return expiredResponse(), nil 107 } 108 } 109 110 i.log.Debug("invalid token", "error", err) 111 return unauthorizedResponse(), nil 112 } 113 114 var ids string 115 116 if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { 117 if v, ok := claims["ext"].(string); ok { 118 ids = v 119 } else { 120 return nil, fmt.Errorf("JWT token doesn't contain identifiers: %v", claims) 121 } 122 } else { 123 return nil, err 124 } 125 126 return &common.ConnectResult{ 127 Identifier: ids, 128 Transmissions: []string{actionCableWelcomeMessage(sid)}, 129 Status: common.SUCCESS, 130 }, nil 131 } 132 133 func unauthorizedResponse() *common.ConnectResult { 134 return &common.ConnectResult{Status: common.FAILURE, Transmissions: []string{actionCableDisconnectUnauthorizedMessage}} 135 } 136 137 func expiredResponse() *common.ConnectResult { 138 return &common.ConnectResult{Status: common.FAILURE, Transmissions: []string{expiredMessage}} 139 }