github.com/anycable/anycable-go@v1.5.1/identity/jwt_test.go (about) 1 package identity 2 3 import ( 4 "fmt" 5 "log/slog" 6 "testing" 7 "time" 8 9 "github.com/anycable/anycable-go/common" 10 "github.com/golang-jwt/jwt" 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13 ) 14 15 func TestJWTIdentifierIdentify(t *testing.T) { 16 secret := "ruby-to-go" 17 algo := defaultJWTAlgo 18 ids := "{\"user_id\":\"15\"}" 19 20 config := NewJWTConfig(secret) 21 subject := NewJWTIdentifier(&config, slog.Default()) 22 23 t.Run("with valid token passed as query param", func(t *testing.T) { 24 token := jwt.NewWithClaims(algo, jwt.MapClaims{ 25 "ext": ids, 26 "exp": time.Now().Local().Add(time.Hour * time.Duration(1)).Unix(), 27 }) 28 29 tokenString, err := token.SignedString([]byte(secret)) 30 31 require.Nil(t, err) 32 33 env := common.NewSessionEnv(fmt.Sprintf("ws://demo.anycable.io/cable?jid=%s", tokenString), nil) 34 35 res, err := subject.Identify("12", env) 36 37 require.Nil(t, err) 38 require.NotNil(t, res) 39 assert.Equal(t, ids, res.Identifier) 40 assert.Equal(t, common.SUCCESS, res.Status) 41 assert.Equal(t, []string{`{"type":"welcome","sid":"12"}`}, res.Transmissions) 42 }) 43 44 t.Run("with valid token passed as a header", func(t *testing.T) { 45 token := jwt.NewWithClaims(algo, jwt.MapClaims{ 46 "ext": ids, 47 "exp": time.Now().Local().Add(time.Hour * time.Duration(1)).Unix(), 48 }) 49 50 tokenString, err := token.SignedString([]byte(secret)) 51 52 require.Nil(t, err) 53 54 env := common.NewSessionEnv("ws://demo.anycable.io/cable", &map[string]string{"x-jid": tokenString}) 55 56 res, err := subject.Identify("12", env) 57 58 require.Nil(t, err) 59 require.NotNil(t, res) 60 assert.Equal(t, ids, res.Identifier) 61 assert.Equal(t, common.SUCCESS, res.Status) 62 assert.Equal(t, []string{`{"type":"welcome","sid":"12"}`}, res.Transmissions) 63 }) 64 65 t.Run("with invalid token", func(t *testing.T) { 66 tokenString := "secret-token-not-a-jwt-at-all" 67 env := common.NewSessionEnv(fmt.Sprintf("ws://demo.anycable.io/cable?jid=%s", tokenString), nil) 68 69 res, err := subject.Identify("12", env) 70 71 require.Nil(t, err) 72 require.NotNil(t, res) 73 assert.Equal(t, "", res.Identifier) 74 assert.Equal(t, common.FAILURE, res.Status) 75 assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions) 76 }) 77 78 t.Run("with invalid algo", func(t *testing.T) { 79 token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ 80 "ext": ids, 81 "exp": time.Now().Local().Add(time.Hour * time.Duration(1)).Unix(), 82 }) 83 84 tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 85 86 require.Nil(t, err) 87 88 env := common.NewSessionEnv(fmt.Sprintf("ws://demo.anycable.io/cable?jid=%s", tokenString), nil) 89 90 res, err := subject.Identify("12", env) 91 92 require.Nil(t, err) 93 require.NotNil(t, res) 94 assert.Equal(t, "", res.Identifier) 95 assert.Equal(t, common.FAILURE, res.Status) 96 assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions) 97 }) 98 99 t.Run("with invalid secret", func(t *testing.T) { 100 token := jwt.NewWithClaims(algo, jwt.MapClaims{ 101 "ext": ids, 102 "exp": time.Now().Local().Add(time.Hour * time.Duration(1)).Unix(), 103 }) 104 105 tokenString, err := token.SignedString([]byte("not-a-valid-secret")) 106 107 require.Nil(t, err) 108 109 env := common.NewSessionEnv(fmt.Sprintf("ws://demo.anycable.io/cable?jid=%s", tokenString), nil) 110 111 res, err := subject.Identify("12", env) 112 113 require.Nil(t, err) 114 require.NotNil(t, res) 115 assert.Equal(t, "", res.Identifier) 116 assert.Equal(t, common.FAILURE, res.Status) 117 assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions) 118 }) 119 120 t.Run("when token expired", func(t *testing.T) { 121 token := jwt.NewWithClaims(algo, jwt.MapClaims{ 122 "ext": ids, 123 "exp": time.Now().Local().Add(-time.Hour * time.Duration(1)).Unix(), 124 }) 125 126 tokenString, err := token.SignedString([]byte(secret)) 127 128 require.Nil(t, err) 129 130 env := common.NewSessionEnv("ws://demo.anycable.io/cable", &map[string]string{"x-jid": tokenString}) 131 132 res, err := subject.Identify("12", env) 133 134 require.Nil(t, err) 135 require.NotNil(t, res) 136 assert.Equal(t, "", res.Identifier) 137 assert.Equal(t, common.FAILURE, res.Status) 138 assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"token_expired\",\"reconnect\":false}"}, res.Transmissions) 139 }) 140 141 t.Run("when token is missing and not required", func(t *testing.T) { 142 env := common.NewSessionEnv("ws://demo.anycable.io/cable", nil) 143 144 res, err := subject.Identify("12", env) 145 146 assert.Nil(t, err) 147 assert.Nil(t, res) 148 }) 149 150 t.Run("when token is missing and required", func(t *testing.T) { 151 config := NewJWTConfig(secret) 152 config.Force = true 153 154 enforced := NewJWTIdentifier(&config, slog.Default()) 155 156 env := common.NewSessionEnv("ws://demo.anycable.io/cable", nil) 157 158 res, err := enforced.Identify("12", env) 159 160 require.Nil(t, err) 161 require.NotNil(t, res) 162 assert.Equal(t, "", res.Identifier) 163 assert.Equal(t, common.FAILURE, res.Status) 164 assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions) 165 }) 166 }