github.com/greenpau/go-authcrunch@v1.0.50/pkg/authz/validator/sources_test.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package validator
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/greenpau/go-authcrunch/internal/tests"
    26  	"github.com/greenpau/go-authcrunch/internal/testutils"
    27  	"github.com/greenpau/go-authcrunch/pkg/authz/options"
    28  	"github.com/greenpau/go-authcrunch/pkg/errors"
    29  	"github.com/greenpau/go-authcrunch/pkg/requests"
    30  )
    31  
    32  func TestAuthorizationSources(t *testing.T) {
    33  	var testcases = []struct {
    34  		name                         string
    35  		allowedTokenNames            []string
    36  		allowedTokenSources          []string
    37  		enableQueryViolations        bool
    38  		enableCookieViolations       bool
    39  		enableHeaderViolations       bool
    40  		enableBearerHeaderViolations bool
    41  		// The name of the token.
    42  		entries   []*testutils.InjectedTestToken
    43  		want      map[string]interface{}
    44  		shouldErr bool
    45  		err       error
    46  	}{
    47  		{
    48  			name: "default token sources and names with auth header claim injection",
    49  			entries: []*testutils.InjectedTestToken{
    50  				testutils.NewInjectedTestToken("access_token", tokenSourceHeader, `"name": "foo",`),
    51  			},
    52  			want: map[string]interface{}{
    53  				"token_name": "access_token",
    54  				"claim_name": "foo",
    55  			},
    56  			shouldErr: false,
    57  		},
    58  		{
    59  			name: "default token sources and names with cookie claim injection",
    60  			entries: []*testutils.InjectedTestToken{
    61  				testutils.NewInjectedTestToken("access_token", tokenSourceCookie, `"name": "foo",`),
    62  			},
    63  			want: map[string]interface{}{
    64  				"token_name": "access_token",
    65  				"claim_name": "foo",
    66  			},
    67  			shouldErr: false,
    68  		},
    69  		{
    70  			name: "default token sources and names with query parameter claim injection",
    71  			entries: []*testutils.InjectedTestToken{
    72  				testutils.NewInjectedTestToken("", tokenSourceQuery, `"name": "foo",`),
    73  			},
    74  			want: map[string]interface{}{
    75  				"token_name": "access_token",
    76  				"claim_name": "foo",
    77  			},
    78  			shouldErr: false,
    79  		},
    80  		{
    81  			name: "default token source priorities, same token name, different entries injected in query parameter and auth header",
    82  			entries: []*testutils.InjectedTestToken{
    83  				testutils.NewInjectedTestToken("access_token", tokenSourceHeader, `"name": "foo",`),
    84  				testutils.NewInjectedTestToken("access_token", tokenSourceQuery, `"name": "bar",`),
    85  			},
    86  			want: map[string]interface{}{
    87  				"token_name": "access_token",
    88  				"claim_name": "foo",
    89  			},
    90  			shouldErr: false,
    91  		},
    92  		{
    93  			name:                "custom token source priorities, same token name, different entries injected in query parameter and auth header",
    94  			allowedTokenSources: []string{tokenSourceQuery, tokenSourceCookie, tokenSourceHeader},
    95  			entries: []*testutils.InjectedTestToken{
    96  				testutils.NewInjectedTestToken("access_token", tokenSourceHeader, `"name": "foo",`),
    97  				testutils.NewInjectedTestToken("access_token", tokenSourceQuery, `"name": "bar",`),
    98  			},
    99  			want: map[string]interface{}{
   100  				"token_name": "access_token",
   101  				"claim_name": "bar",
   102  			},
   103  			shouldErr: false,
   104  		},
   105  		{
   106  			name:              "default token source priorities, different token name, different entries injected in query parameter and auth header",
   107  			allowedTokenNames: []string{"jwt_access_token"},
   108  			entries: []*testutils.InjectedTestToken{
   109  				testutils.NewInjectedTestToken("", tokenSourceHeader, `"name": "foo",`),
   110  				testutils.NewInjectedTestToken("jwt_access_token", tokenSourceQuery, `"name": "bar",`),
   111  			},
   112  			want: map[string]interface{}{
   113  				"token_name": "jwt_access_token",
   114  				"claim_name": "bar",
   115  			},
   116  			shouldErr: false,
   117  		},
   118  		{
   119  			name: "default token sources and names with custom token name injection",
   120  			entries: []*testutils.InjectedTestToken{
   121  				testutils.NewInjectedTestToken("foobar", tokenSourceHeader, `"name": "foo",`),
   122  			},
   123  			shouldErr: true,
   124  			err:       errors.ErrNoTokenFound,
   125  		},
   126  		{
   127  			name:              "custom token names with standard token name injection",
   128  			allowedTokenNames: []string{"foobar_token"},
   129  			entries: []*testutils.InjectedTestToken{
   130  				testutils.NewInjectedTestToken("access_token", tokenSourceHeader, `"name": "foo",`),
   131  			},
   132  			shouldErr: true,
   133  			err:       errors.ErrNoTokenFound,
   134  		},
   135  		{
   136  			name:                "cookie token source with auth header token injection",
   137  			allowedTokenSources: []string{tokenSourceCookie},
   138  			entries: []*testutils.InjectedTestToken{
   139  				testutils.NewInjectedTestToken("access_token", tokenSourceHeader, `"name": "foo",`),
   140  			},
   141  			shouldErr: true,
   142  			err:       errors.ErrNoTokenFound,
   143  		},
   144  		{
   145  			name:                  "query paramater token source violations",
   146  			enableQueryViolations: true,
   147  			shouldErr:             true,
   148  			err:                   errors.ErrNoTokenFound,
   149  		},
   150  		{
   151  			name:                   "cookie token source violations",
   152  			enableCookieViolations: true,
   153  			shouldErr:              true,
   154  			err:                    errors.ErrNoTokenFound,
   155  		},
   156  		{
   157  			name:                   "header token source violations",
   158  			enableHeaderViolations: true,
   159  			shouldErr:              true,
   160  			err:                    errors.ErrNoTokenFound,
   161  		},
   162  		{
   163  			name:                         "bearer header token source violations",
   164  			enableBearerHeaderViolations: true,
   165  			shouldErr:                    true,
   166  			err:                          errors.ErrNoTokenFound,
   167  		},
   168  	}
   169  
   170  	for _, tc := range testcases {
   171  		t.Run(tc.name, func(t *testing.T) {
   172  			ctx := context.Background()
   173  			ks := testutils.NewTestCryptoKeyStore()
   174  			keys := ks.GetKeys()
   175  			signingKey := keys[0]
   176  			opts := options.NewTokenValidatorOptions()
   177  			if tc.enableBearerHeaderViolations {
   178  				opts.ValidateBearerHeader = true
   179  			}
   180  
   181  			validator := NewTokenValidator()
   182  			accessList := testutils.NewTestGuestAccessList()
   183  
   184  			if err := validator.Configure(ctx, keys, accessList, opts); err != nil {
   185  				t.Fatal(err)
   186  			}
   187  
   188  			if len(tc.allowedTokenSources) > 0 {
   189  				if err := validator.SetSourcePriority(tc.allowedTokenSources); err != nil {
   190  					t.Fatal(err)
   191  				}
   192  			}
   193  
   194  			if len(tc.allowedTokenNames) > 0 {
   195  				if err := validator.setAllowedTokenNames(tc.allowedTokenNames); err != nil {
   196  					t.Fatal(err)
   197  				}
   198  			}
   199  
   200  			handler := func(w http.ResponseWriter, r *http.Request) {
   201  				ctx := context.Background()
   202  				var msgs []string
   203  				msgs = append(msgs, fmt.Sprintf("test name: %s", tc.name))
   204  				if len(tc.allowedTokenNames) > 0 {
   205  					msgs = append(msgs, fmt.Sprintf("allowed token names: %s", tc.allowedTokenNames))
   206  				}
   207  				for i, tkn := range tc.entries {
   208  					msgs = append(msgs, fmt.Sprintf("token %d, name: %s, location: %s", i, tkn.Name, tkn.Location))
   209  				}
   210  				ar := requests.NewAuthorizationRequest()
   211  				ar.ID = "TEST_REQUEST_ID"
   212  				ar.SessionID = "TEST_SESSION_ID"
   213  				usr, err := validator.Authorize(ctx, r, ar)
   214  				if tests.EvalErrWithLog(t, err, tc.want, tc.shouldErr, tc.err, msgs) {
   215  					return
   216  				}
   217  				got := make(map[string]interface{})
   218  				got["token_name"] = usr.TokenName
   219  				got["claim_name"] = usr.Claims.Name
   220  				tests.EvalObjectsWithLog(t, "response", tc.want, got, msgs)
   221  			}
   222  
   223  			reqURI := "/protected/path"
   224  			if tc.enableQueryViolations {
   225  				reqURI += "?access_token=foobarfoo"
   226  			}
   227  
   228  			req, err := http.NewRequest("GET", reqURI, nil)
   229  			if err != nil {
   230  				t.Fatal(err)
   231  			}
   232  
   233  			if tc.enableCookieViolations {
   234  				req.AddCookie(&http.Cookie{
   235  					Name:    "foobar",
   236  					Value:   "foobar",
   237  					Expires: time.Now().Add(time.Minute * time.Duration(30)),
   238  				})
   239  				req.AddCookie(&http.Cookie{
   240  					Name:    "access_token",
   241  					Value:   "foobar",
   242  					Expires: time.Now().Add(time.Minute * time.Duration(30)),
   243  				})
   244  			}
   245  
   246  			if tc.enableBearerHeaderViolations {
   247  				req.Header.Add("Authorization", "Bearer")
   248  			}
   249  
   250  			if tc.enableHeaderViolations {
   251  				req.Header.Add("Authorization", "access_token")
   252  			}
   253  
   254  			for _, entry := range tc.entries {
   255  				tokenName := entry.Name
   256  				if tokenName == "" {
   257  					tokenName = "access_token"
   258  				}
   259  				if err := signingKey.SignToken("HS512", entry.User); err != nil {
   260  					t.Fatal(err)
   261  				}
   262  				switch entry.Location {
   263  				case tokenSourceCookie:
   264  					req.AddCookie(testutils.GetCookie(tokenName, entry.User.Token, 10))
   265  				case tokenSourceHeader:
   266  					req.Header.Set("Authorization", fmt.Sprintf("%s=%s", tokenName, entry.User.Token))
   267  				case tokenSourceQuery:
   268  					q := req.URL.Query()
   269  					q.Set(tokenName, entry.User.Token)
   270  					req.URL.RawQuery = q.Encode()
   271  				case "":
   272  					t.Fatal("malformed test: token injection location is empty")
   273  				default:
   274  					t.Fatalf("malformed test: token injection location %s is not supported", entry.Location)
   275  				}
   276  			}
   277  
   278  			w := httptest.NewRecorder()
   279  			handler(w, req)
   280  			w.Result()
   281  		})
   282  	}
   283  }