github.com/anycable/anycable-go@v1.5.1/identity/jwt.go (about)

     1  package identity
     2  
     3  import (
     4  	"fmt"
     5  	"log/slog"
     6  	"net/url"
     7  	"strings"
     8  
     9  	"github.com/anycable/anycable-go/common"
    10  	"github.com/golang-jwt/jwt"
    11  )
    12  
    13  const (
    14  	expiredMessage = "{\"type\":\"disconnect\",\"reason\":\"token_expired\",\"reconnect\":false}"
    15  )
    16  
    17  type JWTConfig struct {
    18  	Secret string
    19  	Param  string
    20  	Algo   jwt.SigningMethod
    21  	Force  bool
    22  }
    23  
    24  var (
    25  	defaultJWTAlgo = jwt.SigningMethodHS256
    26  )
    27  
    28  func NewJWTConfig(secret string) JWTConfig {
    29  	return JWTConfig{Secret: secret, Param: "jid", Algo: defaultJWTAlgo}
    30  }
    31  
    32  func (c JWTConfig) Enabled() bool {
    33  	return c.Secret != ""
    34  }
    35  
    36  type JWTIdentifier struct {
    37  	secret     []byte
    38  	paramName  string
    39  	headerName string
    40  	required   bool
    41  	log        *slog.Logger
    42  }
    43  
    44  var _ Identifier = (*JWTIdentifier)(nil)
    45  
    46  func NewJWTIdentifier(config *JWTConfig, l *slog.Logger) *JWTIdentifier {
    47  	return &JWTIdentifier{
    48  		secret:     []byte(config.Secret),
    49  		paramName:  config.Param,
    50  		headerName: strings.ToLower(fmt.Sprintf("x-%s", config.Param)),
    51  		required:   config.Force,
    52  		log:        l.With("context", "jwt"),
    53  	}
    54  }
    55  
    56  func (i *JWTIdentifier) Identify(sid string, env *common.SessionEnv) (*common.ConnectResult, error) {
    57  	var rawToken string
    58  
    59  	if env.Headers != nil {
    60  		if v, ok := (*env.Headers)[i.headerName]; ok {
    61  			rawToken = v
    62  		}
    63  	}
    64  
    65  	if rawToken == "" {
    66  		u, err := url.Parse(env.URL)
    67  
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  
    72  		m, err := url.ParseQuery(u.RawQuery)
    73  
    74  		if err != nil {
    75  			return nil, err
    76  		}
    77  
    78  		if v, ok := m[i.paramName]; ok {
    79  			rawToken = v[0]
    80  		}
    81  	}
    82  
    83  	if rawToken == "" {
    84  		i.log.Debug("no token is found", "url", env.URL, "headers", env.Headers)
    85  
    86  		if i.required {
    87  			return unauthorizedResponse(), nil
    88  		}
    89  
    90  		return nil, nil
    91  	}
    92  
    93  	token, err := jwt.Parse(rawToken, func(token *jwt.Token) (interface{}, error) {
    94  		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
    95  			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
    96  		}
    97  
    98  		return i.secret, nil
    99  	})
   100  
   101  	if err != nil {
   102  		if ve, ok := err.(*jwt.ValidationError); ok {
   103  			if ve.Errors&(jwt.ValidationErrorExpired) != 0 {
   104  				i.log.Debug("token has expired")
   105  
   106  				return expiredResponse(), nil
   107  			}
   108  		}
   109  
   110  		i.log.Debug("invalid token", "error", err)
   111  		return unauthorizedResponse(), nil
   112  	}
   113  
   114  	var ids string
   115  
   116  	if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
   117  		if v, ok := claims["ext"].(string); ok {
   118  			ids = v
   119  		} else {
   120  			return nil, fmt.Errorf("JWT token doesn't contain identifiers: %v", claims)
   121  		}
   122  	} else {
   123  		return nil, err
   124  	}
   125  
   126  	return &common.ConnectResult{
   127  		Identifier:    ids,
   128  		Transmissions: []string{actionCableWelcomeMessage(sid)},
   129  		Status:        common.SUCCESS,
   130  	}, nil
   131  }
   132  
   133  func unauthorizedResponse() *common.ConnectResult {
   134  	return &common.ConnectResult{Status: common.FAILURE, Transmissions: []string{actionCableDisconnectUnauthorizedMessage}}
   135  }
   136  
   137  func expiredResponse() *common.ConnectResult {
   138  	return &common.ConnectResult{Status: common.FAILURE, Transmissions: []string{expiredMessage}}
   139  }