github.com/lestrrat-go/jwx/v2@v2.0.21/jwt/token_test.go (about)

     1  package jwt_test
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/lestrrat-go/jwx/v2/internal/json"
    10  
    11  	"github.com/lestrrat-go/jwx/v2/jwt"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  const (
    16  	tokenTime = 233431200
    17  )
    18  
    19  var zeroval reflect.Value
    20  var expectedTokenTime = time.Unix(tokenTime, 0).UTC()
    21  
    22  func TestHeader(t *testing.T) {
    23  	t.Parallel()
    24  	values := map[string]interface{}{
    25  		jwt.AudienceKey:   []string{"developers", "secops", "tac"},
    26  		jwt.ExpirationKey: expectedTokenTime,
    27  		jwt.IssuedAtKey:   expectedTokenTime,
    28  		jwt.IssuerKey:     "http://www.example.com",
    29  		jwt.JwtIDKey:      "e9bc097a-ce51-4036-9562-d2ade882db0d",
    30  		jwt.NotBeforeKey:  expectedTokenTime,
    31  		jwt.SubjectKey:    "unit test",
    32  	}
    33  
    34  	t.Run("Roundtrip", func(t *testing.T) {
    35  		t.Parallel()
    36  		h := jwt.New()
    37  		for k, v := range values {
    38  			if !assert.NoError(t, h.Set(k, v), `h.Set should succeed for key %#v`, k) {
    39  				return
    40  			}
    41  			got, ok := h.Get(k)
    42  			if !assert.True(t, ok, `h.Get should succeed for key %#v`, k) {
    43  				return
    44  			}
    45  			if !reflect.DeepEqual(v, got) {
    46  				t.Fatalf("Values do not match: (%v, %v)", v, got)
    47  			}
    48  		}
    49  	})
    50  
    51  	t.Run("RoundtripError", func(t *testing.T) {
    52  		t.Parallel()
    53  		type dummyStruct struct {
    54  			dummy1 int
    55  			dummy2 float64
    56  		}
    57  		dummy := &dummyStruct{1, 3.4}
    58  
    59  		values := map[string]interface{}{
    60  			jwt.AudienceKey:   dummy,
    61  			jwt.ExpirationKey: dummy,
    62  			jwt.IssuedAtKey:   dummy,
    63  			jwt.IssuerKey:     dummy,
    64  			jwt.JwtIDKey:      dummy,
    65  			jwt.NotBeforeKey:  dummy,
    66  			jwt.SubjectKey:    dummy,
    67  		}
    68  
    69  		h := jwt.New()
    70  		for k, v := range values {
    71  			err := h.Set(k, v)
    72  			if err == nil {
    73  				t.Fatalf("Setting %s value should have failed", k)
    74  			}
    75  		}
    76  		err := h.Set("default", dummy) // private params
    77  		if err != nil {
    78  			t.Fatalf("Setting %s value failed", "default")
    79  		}
    80  		for k := range values {
    81  			_, ok := h.Get(k)
    82  			if ok {
    83  				t.Fatalf("Getting %s value should have failed", k)
    84  			}
    85  		}
    86  		_, ok := h.Get("default")
    87  		if !ok {
    88  			t.Fatal("Failed to get default value")
    89  		}
    90  	})
    91  
    92  	t.Run("GetError", func(t *testing.T) {
    93  		t.Parallel()
    94  		h := jwt.New()
    95  		issuer := h.Issuer()
    96  		if issuer != "" {
    97  			t.Fatalf("Get Issuer should return empty string")
    98  		}
    99  		jwtID := h.JwtID()
   100  		if jwtID != "" {
   101  			t.Fatalf("Get JWT Id should return empty string")
   102  		}
   103  	})
   104  }
   105  
   106  func TestTokenMarshal(t *testing.T) {
   107  	t.Parallel()
   108  	t1 := jwt.New()
   109  	err := t1.Set(jwt.JwtIDKey, "AbCdEfG")
   110  	if err != nil {
   111  		t.Fatalf("Failed to set JWT ID: %s", err.Error())
   112  	}
   113  	err = t1.Set(jwt.SubjectKey, "foobar@example.com")
   114  	if err != nil {
   115  		t.Fatalf("Failed to set Subject: %s", err.Error())
   116  	}
   117  
   118  	// Silly fix to remove monotonic element from time.Time obtained
   119  	// from time.Now(). Without this, the equality comparison goes
   120  	// ga-ga for golang tip (1.9)
   121  	now := time.Unix(time.Now().Unix(), 0)
   122  	err = t1.Set(jwt.IssuedAtKey, now.Unix())
   123  	if err != nil {
   124  		t.Fatalf("Failed to set IssuedAt: %s", err.Error())
   125  	}
   126  	err = t1.Set(jwt.NotBeforeKey, now.Add(5*time.Second))
   127  	if err != nil {
   128  		t.Fatalf("Failed to set NotBefore: %s", err.Error())
   129  	}
   130  	err = t1.Set(jwt.ExpirationKey, now.Add(10*time.Second).Unix())
   131  	if err != nil {
   132  		t.Fatalf("Failed to set Expiration: %s", err.Error())
   133  	}
   134  	err = t1.Set(jwt.AudienceKey, []string{"devops", "secops", "tac"})
   135  	if err != nil {
   136  		t.Fatalf("Failed to set audience: %s", err.Error())
   137  	}
   138  	err = t1.Set("custom", "MyValue")
   139  	if err != nil {
   140  		t.Fatalf(`Failed to set private claim "custom": %s`, err.Error())
   141  	}
   142  	jsonbuf1, err := json.MarshalIndent(t1, "", "  ")
   143  	if err != nil {
   144  		t.Fatalf("JSON Marshal failed: %s", err.Error())
   145  	}
   146  
   147  	t2 := jwt.New()
   148  	if !assert.NoError(t, json.Unmarshal(jsonbuf1, t2), `json.Unmarshal should succeed`) {
   149  		return
   150  	}
   151  
   152  	if !assert.Equal(t, t1, t2, "tokens should match") {
   153  		return
   154  	}
   155  
   156  	_, err = json.MarshalIndent(t2, "", "  ")
   157  	if err != nil {
   158  		t.Fatalf("JSON marshal error: %s", err.Error())
   159  	}
   160  }
   161  
   162  func TestToken(t *testing.T) {
   163  	tok := jwt.New()
   164  
   165  	def := map[string]struct {
   166  		Value  interface{}
   167  		Method string
   168  	}{
   169  		jwt.AudienceKey: {
   170  			Method: "Audience",
   171  			Value:  []string{"developers", "secops", "tac"},
   172  		},
   173  		jwt.ExpirationKey: {
   174  			Method: "Expiration",
   175  			Value:  expectedTokenTime,
   176  		},
   177  		jwt.IssuedAtKey: {
   178  			Method: "IssuedAt",
   179  			Value:  expectedTokenTime,
   180  		},
   181  		jwt.IssuerKey: {
   182  			Method: "Issuer",
   183  			Value:  "http://www.example.com",
   184  		},
   185  		jwt.JwtIDKey: {
   186  			Method: "JwtID",
   187  			Value:  "e9bc097a-ce51-4036-9562-d2ade882db0d",
   188  		},
   189  		jwt.NotBeforeKey: {
   190  			Method: "NotBefore",
   191  			Value:  expectedTokenTime,
   192  		},
   193  		jwt.SubjectKey: {
   194  			Method: "Subject",
   195  			Value:  "unit test",
   196  		},
   197  		"myClaim": {
   198  			Value: "hello, world",
   199  		},
   200  	}
   201  
   202  	t.Run("Set", func(t *testing.T) {
   203  		for k, kdef := range def {
   204  			if !assert.NoError(t, tok.Set(k, kdef.Value), `tok.Set(%s) should succeed`, k) {
   205  				return
   206  			}
   207  		}
   208  	})
   209  	t.Run("Get", func(t *testing.T) {
   210  		rv := reflect.ValueOf(tok)
   211  		for k, kdef := range def {
   212  			getval, ok := tok.Get(k)
   213  			if !assert.True(t, ok, `tok.Get(%s) should succeed`, k) {
   214  				return
   215  			}
   216  
   217  			if mname := kdef.Method; mname != "" {
   218  				method := rv.MethodByName(mname)
   219  				if !assert.NotEqual(t, zeroval, method, `method %s should not be zero value`, mname) {
   220  					return
   221  				}
   222  
   223  				retvals := method.Call(nil)
   224  				if !assert.Len(t, retvals, 1, `should have exactly one return value`) {
   225  					return
   226  				}
   227  
   228  				if !assert.Equal(t, getval, retvals[0].Interface(), `values should match`) {
   229  					return
   230  				}
   231  			}
   232  		}
   233  	})
   234  	t.Run("Roundtrip", func(t *testing.T) {
   235  		buf, err := json.Marshal(tok)
   236  		if !assert.NoError(t, err, `json.Marshal should succeed`) {
   237  			return
   238  		}
   239  
   240  		newtok, err := jwt.ParseInsecure(buf)
   241  		if !assert.NoError(t, err, `jwt.Parse should succeed`) {
   242  			return
   243  		}
   244  
   245  		m1, err := tok.AsMap(context.TODO())
   246  		if !assert.NoError(t, err, `tok.AsMap should succeed`) {
   247  			return
   248  		}
   249  
   250  		m2, err := newtok.AsMap(context.TODO())
   251  		if !assert.NoError(t, err, `tok.AsMap should succeed`) {
   252  			return
   253  		}
   254  
   255  		if !assert.Equal(t, m1, m2, `tokens should match`) {
   256  			return
   257  		}
   258  	})
   259  	t.Run("Set/Remove", func(t *testing.T) {
   260  		ctx := context.TODO()
   261  
   262  		newtok, err := tok.Clone()
   263  		if !assert.NoError(t, err, `tok.Clone should succeed`) {
   264  			return
   265  		}
   266  
   267  		for iter := tok.Iterate(ctx); iter.Next(ctx); {
   268  			pair := iter.Pair()
   269  			newtok.Remove(pair.Key.(string))
   270  		}
   271  
   272  		m, err := newtok.AsMap(ctx)
   273  		if !assert.NoError(t, err, `tok.AsMap should succeed`) {
   274  			return
   275  		}
   276  
   277  		if !assert.Len(t, m, 0, `toks should have 0 tok`) {
   278  			return
   279  		}
   280  
   281  		for iter := tok.Iterate(ctx); iter.Next(ctx); {
   282  			pair := iter.Pair()
   283  			if !assert.NoError(t, newtok.Set(pair.Key.(string), pair.Value), `newtok.Set should succeed`) {
   284  				return
   285  			}
   286  		}
   287  	})
   288  }