github.com/greenpau/go-authcrunch@v1.1.4/pkg/authz/authenticate_test.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package authz
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/tls"
    20  	"crypto/x509"
    21  	"encoding/json"
    22  	"fmt"
    23  	"github.com/greenpau/go-authcrunch/internal/tests"
    24  	"github.com/greenpau/go-authcrunch/internal/testutils"
    25  	"github.com/greenpau/go-authcrunch/pkg/acl"
    26  	"github.com/greenpau/go-authcrunch/pkg/requests"
    27  	logutil "github.com/greenpau/go-authcrunch/pkg/util/log"
    28  	"io/ioutil"
    29  	"net"
    30  	"net/http"
    31  	"net/http/cookiejar"
    32  	"net/url"
    33  	"strings"
    34  	"time"
    35  
    36  	"net/http/httptest"
    37  	"testing"
    38  )
    39  
    40  type testRequest struct {
    41  	id          string
    42  	roles       []string
    43  	method      string
    44  	path        string
    45  	headers     map[string]string
    46  	query       map[string]string
    47  	contentType string
    48  	token       string
    49  }
    50  
    51  func TestAuthenticate(t *testing.T) {
    52  	logger := logutil.NewLogger()
    53  
    54  	cfg := &PolicyConfig{
    55  		Name:        "mygatekeeper",
    56  		AuthURLPath: "/auth",
    57  		AccessListRules: []*acl.RuleConfiguration{
    58  			{
    59  				Conditions: []string{
    60  					"match roles authp/admin authp/user",
    61  				},
    62  				Action: "allow stop",
    63  			},
    64  		},
    65  		cryptoRawConfigs: []string{"key verify " + testutils.GetSharedKey()},
    66  	}
    67  
    68  	gatekeeper, err := NewGatekeeper(cfg, logger)
    69  	if err != nil {
    70  		t.Fatal(err)
    71  	}
    72  
    73  	var testcases = []struct {
    74  		name      string
    75  		want      map[string]interface{}
    76  		shouldErr bool
    77  		err       error
    78  		disabled  bool
    79  		req       *testRequest
    80  	}{
    81  		{
    82  			name: "admin accesses version with get",
    83  			req: &testRequest{
    84  				roles:  []string{"authp/admin"},
    85  				method: "GET",
    86  				path:   "/version",
    87  			},
    88  			want: map[string]interface{}{
    89  				"response": map[string]interface{}{
    90  					"authorized": true,
    91  				},
    92  				"status_code":  200,
    93  				"content_type": "text/plain; charset=utf-8",
    94  			},
    95  		},
    96  	}
    97  
    98  	// Initialize HTTP server.
    99  	ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   100  		rr := requests.NewAuthorizationRequest()
   101  		err := gatekeeper.Authenticate(w, r, rr)
   102  		resp := make(map[string]interface{})
   103  		if err != nil {
   104  			resp["error"] = err
   105  		}
   106  		resp["response"] = rr.Response
   107  		b, err := json.Marshal(resp)
   108  		if err != nil {
   109  			t.Fatalf("failed to marshal %T: %v", resp, err)
   110  		}
   111  		fmt.Fprintln(w, string(b))
   112  	}))
   113  	defer ts.Close()
   114  
   115  	for _, tc := range testcases {
   116  		t.Run(tc.name, func(t *testing.T) {
   117  			got := make(map[string]interface{})
   118  			if tc.req.method == "" {
   119  				tc.req.method = "GET"
   120  			}
   121  			if tc.disabled {
   122  				return
   123  			}
   124  			msgs := []string{fmt.Sprintf("test name: %s", tc.name)}
   125  			msgs = append(msgs, fmt.Sprintf("HTTP %s %s", tc.req.method, ts.URL+tc.req.path))
   126  
   127  			client := buildClient(t, ts, tc.req)
   128  			if len(tc.req.roles) > 0 {
   129  				msgs = append(msgs, fmt.Sprintf("roles: %s", tc.req.roles))
   130  			}
   131  			if tc.req.token != "" {
   132  				msgs = append(msgs, fmt.Sprintf("token: %s", tc.req.token))
   133  			}
   134  
   135  			req := buildRequest(t, ts, tc.req)
   136  
   137  			resp, err := client.Do(req)
   138  			if tests.EvalErrWithLog(t, err, "response error", tc.shouldErr, tc.err, msgs) {
   139  				return
   140  			}
   141  
   142  			body, err := ioutil.ReadAll(resp.Body)
   143  			resp.Body.Close()
   144  			if err != nil {
   145  				t.Fatal(err)
   146  			}
   147  
   148  			got["status_code"] = resp.StatusCode
   149  			got["content_type"] = resp.Header.Get("Content-Type")
   150  			switch resp.Header.Get("Content-Type") {
   151  			case "image/png":
   152  			default:
   153  				msgs = append(msgs, fmt.Sprintf("response body: %s", body))
   154  			}
   155  
   156  			switch {
   157  			case bytes.HasPrefix(body, []byte(`{`)):
   158  				var decodedResponse map[string]interface{}
   159  				json.Unmarshal(body, &decodedResponse)
   160  				for k, v := range decodedResponse {
   161  					got[k] = v
   162  				}
   163  			default:
   164  				t.Fatalf("detected non-JSON body: %s", strings.Join(msgs, "\n"))
   165  			}
   166  			tests.EvalObjectsWithLog(t, "response body", tc.want, got, msgs)
   167  		})
   168  	}
   169  }
   170  
   171  func buildClient(t *testing.T, ts *httptest.Server, req *testRequest) http.Client {
   172  	cert, err := x509.ParseCertificate(ts.TLS.Certificates[0].Certificate[0])
   173  	if err != nil {
   174  		t.Fatalf("failed extracting server certs: %v", err)
   175  	}
   176  	cp := x509.NewCertPool()
   177  	cp.AddCert(cert)
   178  
   179  	cj, err := cookiejar.New(nil)
   180  	if err != nil {
   181  		t.Fatalf("failed adding cookie jar: %v", err)
   182  	}
   183  
   184  	if len(req.roles) > 0 {
   185  		usr := testutils.NewTestUser()
   186  		usr.SetRolesClaim(req.roles)
   187  
   188  		ks := testutils.NewTestCryptoKeyStore()
   189  		if err := ks.SignToken("access_token", "HS512", usr); err != nil {
   190  			t.Fatalf("Failed to get JWT token for %v: %v", usr.AsMap(), err)
   191  		}
   192  		cookies := []*http.Cookie{
   193  			&http.Cookie{Name: "access_token", Value: usr.Token},
   194  		}
   195  		req.token = usr.Token
   196  
   197  		tsURL, _ := url.Parse(ts.URL)
   198  		cj.SetCookies(tsURL, cookies)
   199  	}
   200  
   201  	return http.Client{
   202  		Jar:     cj,
   203  		Timeout: time.Second * 10,
   204  		Transport: &http.Transport{
   205  			Dial: (&net.Dialer{
   206  				Timeout: 5 * time.Second,
   207  			}).Dial,
   208  			TLSHandshakeTimeout: 5 * time.Second,
   209  			TLSClientConfig: &tls.Config{
   210  				RootCAs: cp,
   211  			},
   212  		},
   213  		CheckRedirect: func(r *http.Request, via []*http.Request) error {
   214  			// Do not follow redirects.
   215  			return http.ErrUseLastResponse
   216  		},
   217  	}
   218  }
   219  
   220  func buildRequest(t *testing.T, ts *httptest.Server, req *testRequest) *http.Request {
   221  	r, err := http.NewRequest(req.method, ts.URL+req.path, nil)
   222  	if err != nil {
   223  		t.Fatal(err)
   224  	}
   225  
   226  	if len(req.headers) > 0 {
   227  		for k, v := range req.headers {
   228  			r.Header.Add(k, v)
   229  		}
   230  	}
   231  
   232  	if len(req.query) > 0 {
   233  		q := r.URL.Query()
   234  		for k, v := range req.query {
   235  			q.Set(k, v)
   236  		}
   237  		r.URL.RawQuery = q.Encode()
   238  	}
   239  	return r
   240  }