github.com/nats-io/jwt/v2@v2.5.6/operator_claims_test.go (about)

     1  /*
     2   * Copyright 2018 The NATS Authors
     3   * Licensed under the Apache License, Version 2.0 (the "License");
     4   * you may not use this file except in compliance with the License.
     5   * You may obtain a copy of the License at
     6   *
     7   * http://www.apache.org/licenses/LICENSE-2.0
     8   *
     9   * Unless required by applicable law or agreed to in writing, software
    10   * distributed under the License is distributed on an "AS IS" BASIS,
    11   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12   * See the License for the specific language governing permissions and
    13   * limitations under the License.
    14   */
    15  
    16  package jwt
    17  
    18  import (
    19  	"fmt"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/nats-io/nkeys"
    24  )
    25  
    26  func TestNewOperatorClaims(t *testing.T) {
    27  	ckp := createOperatorNKey(t)
    28  
    29  	uc := NewOperatorClaims(publicKey(ckp, t))
    30  	uc.Expires = time.Now().Add(time.Hour).Unix()
    31  	uJwt := encode(uc, ckp, t)
    32  
    33  	uc2, err := DecodeOperatorClaims(uJwt)
    34  	if err != nil {
    35  		t.Fatal("failed to decode", err)
    36  	}
    37  
    38  	AssertEquals(uc.String(), uc2.String(), t)
    39  
    40  	AssertEquals(uc.Claims() != nil, true, t)
    41  	AssertEquals(uc.Payload() != nil, true, t)
    42  }
    43  
    44  func TestOperatorSubjects(t *testing.T) {
    45  	type kpInputs struct {
    46  		name string
    47  		kp   nkeys.KeyPair
    48  		ok   bool
    49  	}
    50  
    51  	inputs := []kpInputs{
    52  		{"account", createAccountNKey(t), false},
    53  		{"cluster", createClusterNKey(t), false},
    54  		{"operator", createOperatorNKey(t), true},
    55  		{"server", createServerNKey(t), false},
    56  		{"user", createUserNKey(t), false},
    57  	}
    58  
    59  	for _, i := range inputs {
    60  		c := NewOperatorClaims(publicKey(i.kp, t))
    61  		_, err := c.Encode(createOperatorNKey(t))
    62  		if i.ok && err != nil {
    63  			t.Fatalf("unexpected error for %q: %v", i.name, err)
    64  		}
    65  		if !i.ok && err == nil {
    66  			t.Logf("should have failed to encode server with with %q subject", i.name)
    67  			t.Fail()
    68  		}
    69  	}
    70  }
    71  
    72  func TestInvalidOperatorClaimIssuer(t *testing.T) {
    73  	akp := createOperatorNKey(t)
    74  	ac := NewOperatorClaims(publicKey(akp, t))
    75  	ac.Expires = time.Now().Add(time.Hour).Unix()
    76  	aJwt := encode(ac, akp, t)
    77  
    78  	temp, err := DecodeGeneric(aJwt)
    79  	if err != nil {
    80  		t.Fatal("failed to decode", err)
    81  	}
    82  
    83  	type kpInputs struct {
    84  		name string
    85  		kp   nkeys.KeyPair
    86  		ok   bool
    87  	}
    88  
    89  	inputs := []kpInputs{
    90  		{"account", createAccountNKey(t), false},
    91  		{"user", createUserNKey(t), false},
    92  		{"operator", createOperatorNKey(t), true},
    93  		{"server", createServerNKey(t), false},
    94  		{"cluster", createClusterNKey(t), false},
    95  	}
    96  
    97  	for _, i := range inputs {
    98  		bad := encode(temp, i.kp, t)
    99  		_, err = DecodeOperatorClaims(bad)
   100  		if i.ok && err != nil {
   101  			t.Fatalf("unexpected error for %q: %v", i.name, err)
   102  		}
   103  		if !i.ok && err == nil {
   104  			t.Logf("should have failed to decode account signed by %q", i.name)
   105  			t.Fail()
   106  		}
   107  	}
   108  }
   109  
   110  func TestNewNilOperatorClaims(t *testing.T) {
   111  	v := NewOperatorClaims("")
   112  	if v != nil {
   113  		t.Fatal("expected nil user claim")
   114  	}
   115  }
   116  
   117  func TestOperatorType(t *testing.T) {
   118  	c := NewOperatorClaims(publicKey(createOperatorNKey(t), t))
   119  	s := encode(c, createOperatorNKey(t), t)
   120  	u, err := DecodeOperatorClaims(s)
   121  	if err != nil {
   122  		t.Fatalf("failed to decode operator claim: %v", err)
   123  	}
   124  
   125  	if OperatorClaim != u.Type {
   126  		t.Fatalf("type is unexpected %q (wanted operator)", u.Type)
   127  	}
   128  
   129  }
   130  
   131  func TestSigningKeyValidation(t *testing.T) {
   132  	ckp := createOperatorNKey(t)
   133  	ckp2 := createOperatorNKey(t)
   134  
   135  	uc := NewOperatorClaims(publicKey(ckp, t))
   136  	uc.Expires = time.Now().Add(time.Hour).Unix()
   137  	uc.SigningKeys.Add(publicKey(ckp2, t))
   138  	uJwt := encode(uc, ckp, t)
   139  
   140  	uc2, err := DecodeOperatorClaims(uJwt)
   141  	if err != nil {
   142  		t.Fatal("failed to decode", err)
   143  	}
   144  
   145  	AssertEquals(len(uc2.SigningKeys), 1, t)
   146  	AssertEquals(uc2.SigningKeys[0] == publicKey(ckp2, t), true, t)
   147  
   148  	vr := &ValidationResults{}
   149  	uc.Validate(vr)
   150  
   151  	if len(vr.Issues) != 0 {
   152  		t.Fatal("valid operator key should have no validation issues")
   153  	}
   154  
   155  	uc.SigningKeys.Add("") // add an invalid one
   156  
   157  	vr = &ValidationResults{}
   158  	uc.Validate(vr)
   159  	if len(vr.Issues) != 0 {
   160  		t.Fatal("should not be able to add empty values")
   161  	}
   162  }
   163  
   164  func TestSignedBy(t *testing.T) {
   165  	ckp := createOperatorNKey(t)
   166  	ckp2 := createOperatorNKey(t)
   167  
   168  	uc := NewOperatorClaims(publicKey(ckp, t))
   169  	uc2 := NewOperatorClaims(publicKey(ckp2, t))
   170  
   171  	akp := createAccountNKey(t)
   172  	ac := NewAccountClaims(publicKey(akp, t))
   173  	enc, err := ac.Encode(ckp) // sign with the operator key
   174  	if err != nil {
   175  		t.Fatal("failed to encode", err)
   176  	}
   177  	ac, err = DecodeAccountClaims(enc)
   178  	if err != nil {
   179  		t.Fatal("failed to decode", err)
   180  	}
   181  
   182  	AssertEquals(uc.DidSign(ac), true, t)
   183  	AssertEquals(uc2.DidSign(ac), false, t)
   184  
   185  	enc, err = ac.Encode(ckp2) // sign with the other operator key
   186  	if err != nil {
   187  		t.Fatal("failed to encode", err)
   188  	}
   189  	ac, err = DecodeAccountClaims(enc)
   190  	if err != nil {
   191  		t.Fatal("failed to decode", err)
   192  	}
   193  
   194  	AssertEquals(uc.DidSign(ac), false, t) // no signing key
   195  	AssertEquals(uc2.DidSign(ac), true, t) // actual key
   196  	uc.SigningKeys.Add(publicKey(ckp2, t))
   197  	AssertEquals(uc.DidSign(ac), true, t) // signing key
   198  	uc.StrictSigningKeyUsage = true
   199  	AssertEquals(uc.DidSign(uc), true, t)
   200  	AssertEquals(uc.DidSign(ac), true, t)
   201  	uc2.StrictSigningKeyUsage = true
   202  	AssertEquals(uc2.DidSign(uc2), true, t)
   203  	AssertEquals(uc2.DidSign(ac), false, t)
   204  }
   205  
   206  func testAccountWithAccountServerURL(t *testing.T, u string) error {
   207  	kp := createOperatorNKey(t)
   208  	pk := publicKey(kp, t)
   209  	oc := NewOperatorClaims(pk)
   210  	oc.AccountServerURL = u
   211  
   212  	s, err := oc.Encode(kp)
   213  	if err != nil {
   214  		return err
   215  	}
   216  	oc, err = DecodeOperatorClaims(s)
   217  	if err != nil {
   218  		t.Fatal(err)
   219  	}
   220  	AssertEquals(oc.AccountServerURL, u, t)
   221  	vr := ValidationResults{}
   222  	oc.Validate(&vr)
   223  	if !vr.IsEmpty() {
   224  		errs := vr.Errors()
   225  		return errs[0]
   226  	}
   227  	return nil
   228  }
   229  
   230  func Test_AccountServerURL(t *testing.T) {
   231  	var asuTests = []struct {
   232  		u          string
   233  		shouldFail bool
   234  	}{
   235  		{"", false},
   236  		{"HTTP://foo.bar.com", false},
   237  		{"http://foo.bar.com/foo/bar", false},
   238  		{"http://user:pass@foo.bar.com/foo/bar", false},
   239  		{"https://foo.bar.com", false},
   240  		{"nats://foo.bar.com", false},
   241  		{"/hello", true},
   242  	}
   243  
   244  	for i, tt := range asuTests {
   245  		err := testAccountWithAccountServerURL(t, tt.u)
   246  		if err != nil && tt.shouldFail == false {
   247  			t.Fatalf("expected not to fail: %v", err)
   248  		} else if err == nil && tt.shouldFail {
   249  			t.Fatalf("test %s expected to fail but didn't", asuTests[i].u)
   250  		}
   251  	}
   252  }
   253  
   254  func Test_SystemAccount(t *testing.T) {
   255  	operatorWithSystemAcc := func(t *testing.T, u string) error {
   256  		kp := createOperatorNKey(t)
   257  		pk := publicKey(kp, t)
   258  		oc := NewOperatorClaims(pk)
   259  		oc.SystemAccount = u
   260  		s, err := oc.Encode(kp)
   261  		if err != nil {
   262  			return err
   263  		}
   264  		oc, err = DecodeOperatorClaims(s)
   265  		if err != nil {
   266  			t.Fatal(err)
   267  		}
   268  		AssertEquals(oc.SystemAccount, u, t)
   269  		vr := ValidationResults{}
   270  		oc.Validate(&vr)
   271  		if !vr.IsEmpty() {
   272  			return fmt.Errorf("%s", vr.Errors()[0])
   273  		}
   274  		return nil
   275  	}
   276  	var asuTests = []struct {
   277  		accKey     string
   278  		shouldFail bool
   279  	}{
   280  		{"", false},
   281  		{"x", true},
   282  		{"ADZ547B24WHPLWOK7TMLNBSA7FQFXR6UM2NZ4HHNIB7RDFVZQFOZ4GQQ", false},
   283  		{"ADZ547B24WHPLWOK7TMLNBSA7FQFXR6UM2NZ4HHNIB7RDFVZQFOZ4777", true},
   284  	}
   285  	for i, tt := range asuTests {
   286  		err := operatorWithSystemAcc(t, tt.accKey)
   287  		if err != nil && tt.shouldFail == false {
   288  			t.Fatalf("expected not to fail: %v", err)
   289  		} else if err == nil && tt.shouldFail {
   290  			t.Fatalf("test %s expected to fail but didn't", asuTests[i].accKey)
   291  		}
   292  	}
   293  }
   294  
   295  func Test_AssertServerVersion(t *testing.T) {
   296  	operatorWithAssertServerVer := func(t *testing.T, v string) error {
   297  		kp := createOperatorNKey(t)
   298  		pk := publicKey(kp, t)
   299  		oc := NewOperatorClaims(pk)
   300  		oc.AssertServerVersion = v
   301  		s, err := oc.Encode(kp)
   302  		if err != nil {
   303  			return err
   304  		}
   305  		oc, err = DecodeOperatorClaims(s)
   306  		if err != nil {
   307  			t.Fatal(err)
   308  		}
   309  		AssertEquals(oc.AssertServerVersion, v, t)
   310  		vr := ValidationResults{}
   311  		oc.Validate(&vr)
   312  		if !vr.IsEmpty() {
   313  			return fmt.Errorf("%s", vr.Errors()[0])
   314  		}
   315  		return nil
   316  	}
   317  	var asuTests = []struct {
   318  		assertVer  string
   319  		shouldFail bool
   320  	}{
   321  		{"1.2.3", false},
   322  		{"10.2.3", false},
   323  		{"1.20.3", false},
   324  		{"1.2.30", false},
   325  		{"10.20.30", false},
   326  		{"0.0.0", false},
   327  		{"0.0", true},
   328  		{"0", true},
   329  		{"a", true},
   330  		{"a.b.c", true},
   331  		{"1..1", true},
   332  		{"1a.b.c", true},
   333  		{"-1.0.0", true},
   334  		{"1.-1.0", true},
   335  		{"1.0.-1", true},
   336  	}
   337  	for i, tt := range asuTests {
   338  		err := operatorWithAssertServerVer(t, tt.assertVer)
   339  		if err != nil && tt.shouldFail == false {
   340  			t.Fatalf("expected not to fail: %v", err)
   341  		} else if err == nil && tt.shouldFail {
   342  			t.Fatalf("test %s expected to fail but didn't", asuTests[i].assertVer)
   343  		}
   344  	}
   345  }
   346  
   347  func testOperatorWithOperatorServiceURL(t *testing.T, u string) error {
   348  	kp := createOperatorNKey(t)
   349  	pk := publicKey(kp, t)
   350  	oc := NewOperatorClaims(pk)
   351  	oc.OperatorServiceURLs.Add(u)
   352  
   353  	s, err := oc.Encode(kp)
   354  	if err != nil {
   355  		return err
   356  	}
   357  	oc, err = DecodeOperatorClaims(s)
   358  	if err != nil {
   359  		t.Fatal(err)
   360  	}
   361  	if u != "" {
   362  		AssertEquals(oc.OperatorServiceURLs[0], u, t)
   363  	}
   364  	vr := ValidationResults{}
   365  	oc.Validate(&vr)
   366  	if !vr.IsEmpty() {
   367  		errs := vr.Errors()
   368  		return errs[0]
   369  	}
   370  	return nil
   371  }
   372  
   373  func Test_OperatorServiceURL(t *testing.T) {
   374  	var asuTests = []struct {
   375  		u          string
   376  		shouldFail bool
   377  	}{
   378  		{"", false},
   379  		{"HTTP://foo.bar.com", true},
   380  		{"http://foo.bar.com/foo/bar", true},
   381  		{"nats://user:pass@foo.bar.com", true},
   382  		{"NATS://user:pass@foo.bar.com", true},
   383  		{"NATS://user@foo.bar.com", true},
   384  		{"nats://foo.bar.com/path", true},
   385  		{"tls://foo.bar.com/path", true},
   386  		{"/hello", true},
   387  		{"NATS://foo.bar.com", false},
   388  		{"TLS://foo.bar.com", false},
   389  		{"nats://foo.bar.com", false},
   390  		{"tls://foo.bar.com", false},
   391  	}
   392  
   393  	for i, tt := range asuTests {
   394  		err := testOperatorWithOperatorServiceURL(t, tt.u)
   395  		if err != nil && tt.shouldFail == false {
   396  			t.Fatalf("expected not to fail: %v", err)
   397  		} else if err == nil && tt.shouldFail {
   398  			t.Fatalf("test %s expected to fail but didn't", asuTests[i].u)
   399  		}
   400  	}
   401  
   402  	// now test all of them in a single jwt
   403  	kp := createOperatorNKey(t)
   404  	pk := publicKey(kp, t)
   405  	oc := NewOperatorClaims(pk)
   406  
   407  	encoded := 0
   408  	shouldFail := 0
   409  	for _, v := range asuTests {
   410  		oc.OperatorServiceURLs.Add(v.u)
   411  		// list won't encode empty strings
   412  		if v.u != "" {
   413  			encoded++
   414  		}
   415  		if v.shouldFail {
   416  			shouldFail++
   417  		}
   418  	}
   419  
   420  	s, err := oc.Encode(kp)
   421  	if err != nil {
   422  		t.Fatal(err)
   423  	}
   424  	oc, err = DecodeOperatorClaims(s)
   425  	if err != nil {
   426  		t.Fatal(err)
   427  	}
   428  
   429  	AssertEquals(len(oc.OperatorServiceURLs), encoded, t)
   430  
   431  	vr := ValidationResults{}
   432  	oc.Validate(&vr)
   433  	if vr.IsEmpty() {
   434  		t.Fatal("should have had errors")
   435  	}
   436  
   437  	errs := vr.Errors()
   438  	AssertEquals(len(errs), shouldFail, t)
   439  }
   440  
   441  func TestTags(t *testing.T) {
   442  	okp := createOperatorNKey(t)
   443  	opk := publicKey(okp, t)
   444  
   445  	oc := NewOperatorClaims(opk)
   446  	oc.Tags.Add("one")
   447  	oc.Tags.Add("one") // duplicated tags should be ignored
   448  	oc.Tags.Add("TWO") // should become lower case
   449  	oc.Tags.Add("three")
   450  
   451  	oJwt := encode(oc, okp, t)
   452  
   453  	oc2, err := DecodeOperatorClaims(oJwt)
   454  	if err != nil {
   455  		t.Fatal(err)
   456  	}
   457  	if len(oc2.GenericFields.Tags) != 3 {
   458  		t.Fatal("expected 3 tags")
   459  	}
   460  	for _, v := range oc.GenericFields.Tags {
   461  		AssertFalse(v == "TWO", t)
   462  	}
   463  
   464  	AssertTrue(oc.GenericFields.Tags.Contains("one"), t)
   465  	AssertTrue(oc.GenericFields.Tags.Contains("two"), t)
   466  	AssertTrue(oc.GenericFields.Tags.Contains("three"), t)
   467  }
   468  
   469  func TestOperatorClaims_GetTags(t *testing.T) {
   470  	okp := createOperatorNKey(t)
   471  	opk := publicKey(okp, t)
   472  
   473  	oc := NewOperatorClaims(opk)
   474  	oc.Operator.Tags.Add("foo", "bar")
   475  	tags := oc.GetTags()
   476  	if len(tags) != 2 {
   477  		t.Fatal("expected 2 tags")
   478  	}
   479  	if tags[0] != "foo" {
   480  		t.Fatal("expected tag foo")
   481  	}
   482  	if tags[1] != "bar" {
   483  		t.Fatal("expected tag bar")
   484  	}
   485  
   486  	token, err := oc.Encode(okp)
   487  	if err != nil {
   488  		t.Fatal("error encoding")
   489  	}
   490  	oc, err = DecodeOperatorClaims(token)
   491  	if err != nil {
   492  		t.Fatal("error decoding")
   493  	}
   494  	tags = oc.GetTags()
   495  	if len(tags) != 2 {
   496  		t.Fatal("expected 2 tags")
   497  	}
   498  	if tags[0] != "foo" {
   499  		t.Fatal("expected tag foo")
   500  	}
   501  	if tags[1] != "bar" {
   502  		t.Fatal("expected tag bar")
   503  	}
   504  }