github.com/anycable/anycable-go@v1.5.1/identity/jwt_test.go (about)

     1  package identity
     2  
     3  import (
     4  	"fmt"
     5  	"log/slog"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/anycable/anycable-go/common"
    10  	"github.com/golang-jwt/jwt"
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func TestJWTIdentifierIdentify(t *testing.T) {
    16  	secret := "ruby-to-go"
    17  	algo := defaultJWTAlgo
    18  	ids := "{\"user_id\":\"15\"}"
    19  
    20  	config := NewJWTConfig(secret)
    21  	subject := NewJWTIdentifier(&config, slog.Default())
    22  
    23  	t.Run("with valid token passed as query param", func(t *testing.T) {
    24  		token := jwt.NewWithClaims(algo, jwt.MapClaims{
    25  			"ext": ids,
    26  			"exp": time.Now().Local().Add(time.Hour * time.Duration(1)).Unix(),
    27  		})
    28  
    29  		tokenString, err := token.SignedString([]byte(secret))
    30  
    31  		require.Nil(t, err)
    32  
    33  		env := common.NewSessionEnv(fmt.Sprintf("ws://demo.anycable.io/cable?jid=%s", tokenString), nil)
    34  
    35  		res, err := subject.Identify("12", env)
    36  
    37  		require.Nil(t, err)
    38  		require.NotNil(t, res)
    39  		assert.Equal(t, ids, res.Identifier)
    40  		assert.Equal(t, common.SUCCESS, res.Status)
    41  		assert.Equal(t, []string{`{"type":"welcome","sid":"12"}`}, res.Transmissions)
    42  	})
    43  
    44  	t.Run("with valid token passed as a header", func(t *testing.T) {
    45  		token := jwt.NewWithClaims(algo, jwt.MapClaims{
    46  			"ext": ids,
    47  			"exp": time.Now().Local().Add(time.Hour * time.Duration(1)).Unix(),
    48  		})
    49  
    50  		tokenString, err := token.SignedString([]byte(secret))
    51  
    52  		require.Nil(t, err)
    53  
    54  		env := common.NewSessionEnv("ws://demo.anycable.io/cable", &map[string]string{"x-jid": tokenString})
    55  
    56  		res, err := subject.Identify("12", env)
    57  
    58  		require.Nil(t, err)
    59  		require.NotNil(t, res)
    60  		assert.Equal(t, ids, res.Identifier)
    61  		assert.Equal(t, common.SUCCESS, res.Status)
    62  		assert.Equal(t, []string{`{"type":"welcome","sid":"12"}`}, res.Transmissions)
    63  	})
    64  
    65  	t.Run("with invalid token", func(t *testing.T) {
    66  		tokenString := "secret-token-not-a-jwt-at-all"
    67  		env := common.NewSessionEnv(fmt.Sprintf("ws://demo.anycable.io/cable?jid=%s", tokenString), nil)
    68  
    69  		res, err := subject.Identify("12", env)
    70  
    71  		require.Nil(t, err)
    72  		require.NotNil(t, res)
    73  		assert.Equal(t, "", res.Identifier)
    74  		assert.Equal(t, common.FAILURE, res.Status)
    75  		assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions)
    76  	})
    77  
    78  	t.Run("with invalid algo", func(t *testing.T) {
    79  		token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{
    80  			"ext": ids,
    81  			"exp": time.Now().Local().Add(time.Hour * time.Duration(1)).Unix(),
    82  		})
    83  
    84  		tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
    85  
    86  		require.Nil(t, err)
    87  
    88  		env := common.NewSessionEnv(fmt.Sprintf("ws://demo.anycable.io/cable?jid=%s", tokenString), nil)
    89  
    90  		res, err := subject.Identify("12", env)
    91  
    92  		require.Nil(t, err)
    93  		require.NotNil(t, res)
    94  		assert.Equal(t, "", res.Identifier)
    95  		assert.Equal(t, common.FAILURE, res.Status)
    96  		assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions)
    97  	})
    98  
    99  	t.Run("with invalid secret", func(t *testing.T) {
   100  		token := jwt.NewWithClaims(algo, jwt.MapClaims{
   101  			"ext": ids,
   102  			"exp": time.Now().Local().Add(time.Hour * time.Duration(1)).Unix(),
   103  		})
   104  
   105  		tokenString, err := token.SignedString([]byte("not-a-valid-secret"))
   106  
   107  		require.Nil(t, err)
   108  
   109  		env := common.NewSessionEnv(fmt.Sprintf("ws://demo.anycable.io/cable?jid=%s", tokenString), nil)
   110  
   111  		res, err := subject.Identify("12", env)
   112  
   113  		require.Nil(t, err)
   114  		require.NotNil(t, res)
   115  		assert.Equal(t, "", res.Identifier)
   116  		assert.Equal(t, common.FAILURE, res.Status)
   117  		assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions)
   118  	})
   119  
   120  	t.Run("when token expired", func(t *testing.T) {
   121  		token := jwt.NewWithClaims(algo, jwt.MapClaims{
   122  			"ext": ids,
   123  			"exp": time.Now().Local().Add(-time.Hour * time.Duration(1)).Unix(),
   124  		})
   125  
   126  		tokenString, err := token.SignedString([]byte(secret))
   127  
   128  		require.Nil(t, err)
   129  
   130  		env := common.NewSessionEnv("ws://demo.anycable.io/cable", &map[string]string{"x-jid": tokenString})
   131  
   132  		res, err := subject.Identify("12", env)
   133  
   134  		require.Nil(t, err)
   135  		require.NotNil(t, res)
   136  		assert.Equal(t, "", res.Identifier)
   137  		assert.Equal(t, common.FAILURE, res.Status)
   138  		assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"token_expired\",\"reconnect\":false}"}, res.Transmissions)
   139  	})
   140  
   141  	t.Run("when token is missing and not required", func(t *testing.T) {
   142  		env := common.NewSessionEnv("ws://demo.anycable.io/cable", nil)
   143  
   144  		res, err := subject.Identify("12", env)
   145  
   146  		assert.Nil(t, err)
   147  		assert.Nil(t, res)
   148  	})
   149  
   150  	t.Run("when token is missing and required", func(t *testing.T) {
   151  		config := NewJWTConfig(secret)
   152  		config.Force = true
   153  
   154  		enforced := NewJWTIdentifier(&config, slog.Default())
   155  
   156  		env := common.NewSessionEnv("ws://demo.anycable.io/cable", nil)
   157  
   158  		res, err := enforced.Identify("12", env)
   159  
   160  		require.Nil(t, err)
   161  		require.NotNil(t, res)
   162  		assert.Equal(t, "", res.Identifier)
   163  		assert.Equal(t, common.FAILURE, res.Status)
   164  		assert.Equal(t, []string{"{\"type\":\"disconnect\",\"reason\":\"unauthorized\",\"reconnect\":false}"}, res.Transmissions)
   165  	})
   166  }