github.com/greenpau/go-authcrunch@v1.0.50/pkg/idp/oauth/provider_test.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package oauth
    16  
    17  import (
    18  	"crypto/rand"
    19  	"crypto/rsa"
    20  	"encoding/json"
    21  	"fmt"
    22  	"github.com/greenpau/go-authcrunch/internal/tests"
    23  	"github.com/greenpau/go-authcrunch/pkg/errors"
    24  	logutil "github.com/greenpau/go-authcrunch/pkg/util/log"
    25  	"go.uber.org/zap"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"testing"
    30  )
    31  
    32  func TestNewIdentityProvider(t *testing.T) {
    33  	// Generate JWKS keys from RSA key-pairs.
    34  	pk1, err := rsa.GenerateKey(rand.Reader, 4096)
    35  	if err != nil {
    36  		t.Fatal(err)
    37  	}
    38  	pk2, err := rsa.GenerateKey(rand.Reader, 4096)
    39  	if err != nil {
    40  		t.Fatal(err)
    41  	}
    42  
    43  	jpk1, err := NewJwksKeyFromRSAPrivateKey(pk1)
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  
    48  	jpk2, err := NewJwksKeyFromRSAPrivateKey(pk2)
    49  	if err != nil {
    50  		t.Fatal(err)
    51  	}
    52  
    53  	jwksKeys := []*JwksKey{jpk1, jpk2}
    54  
    55  	// Initialize HTTP server.
    56  	ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    57  		resp := make(map[string]interface{})
    58  		switch r.URL.Path {
    59  		case "/oauth/.well-known/openid-configuration":
    60  			resp["authorization_endpoint"] = "https://" + r.Host + "/oauth/authorize"
    61  			resp["token_endpoint"] = "https://" + r.Host + "/oauth/access_token"
    62  			resp["jwks_uri"] = "https://" + r.Host + "/oauth/jwks.json"
    63  		case "/oauth/jwks.json":
    64  			resp["keys"] = jwksKeys
    65  		default:
    66  			t.Fatalf("unsupported path: %v", r.URL.Path)
    67  		}
    68  
    69  		b, err := json.Marshal(resp)
    70  		if err != nil {
    71  			t.Fatalf("failed to marshal %T: %v", resp, err)
    72  		}
    73  
    74  		fmt.Fprintln(w, string(b))
    75  	}))
    76  	defer ts.Close()
    77  
    78  	tsURL, _ := url.Parse(ts.URL)
    79  	// t.Logf("Server: %s", ts.URL)
    80  
    81  	testcases := []struct {
    82  		name      string
    83  		config    *Config
    84  		logger    *zap.Logger
    85  		want      map[string]interface{}
    86  		shouldErr bool
    87  		errPhase  string
    88  		err       error
    89  	}{
    90  		{
    91  			name: "generic oauth provider",
    92  			config: &Config{
    93  				Name:                  "contoso",
    94  				Realm:                 "contoso",
    95  				Driver:                "generic",
    96  				ClientID:              "foo",
    97  				ClientSecret:          "bar",
    98  				BaseAuthURL:           ts.URL + "/oauth",
    99  				MetadataURL:           ts.URL + "/oauth/.well-known/openid-configuration",
   100  				TLSInsecureSkipVerify: true,
   101  			},
   102  			logger: logutil.NewLogger(),
   103  			want: map[string]interface{}{
   104  				"kind":  "oauth",
   105  				"name":  "contoso",
   106  				"realm": "contoso",
   107  				"config": map[string]interface{}{
   108  					"base_auth_url":            ts.URL + "/oauth",
   109  					"client_id":                "foo",
   110  					"client_secret":            "bar",
   111  					"driver":                   "generic",
   112  					"identity_token_name":      "id_token",
   113  					"metadata_url":             ts.URL + "/oauth/.well-known/openid-configuration",
   114  					"name":                     "contoso",
   115  					"realm":                    "contoso",
   116  					"required_token_fields":    []interface{}{"access_token", "id_token"},
   117  					"response_type":            []interface{}{"code"},
   118  					"scopes":                   []interface{}{"openid", "email", "profile"},
   119  					"server_name":              tsURL.Host,
   120  					"tls_insecure_skip_verify": bool(true),
   121  					"login_icon": map[string]interface{}{
   122  						"background_color": string("#324960"),
   123  						"class_name":       string("lab la-codepen la-2x"),
   124  						"color":            string("white"),
   125  						"text_color":       string("#37474f"),
   126  					},
   127  				},
   128  			},
   129  		},
   130  		{
   131  			name: "generic oauth provider with static jwks keys",
   132  			config: &Config{
   133  				Name:                "contoso",
   134  				Realm:               "contoso",
   135  				Driver:              "generic",
   136  				ClientID:            "foo",
   137  				ClientSecret:        "bar",
   138  				BaseAuthURL:         "https://localhost/oauth",
   139  				ResponseType:        []string{"code"},
   140  				RequiredTokenFields: []string{"access_token"},
   141  				AuthorizationURL:    "https://localhost/oauth/authorize",
   142  				TokenURL:            "https://localhost/oauth/access_token",
   143  				JwksKeys: map[string]string{
   144  					"87329db33bf": "../../../testdata/oauth/87329db33bf_pub.pem",
   145  				},
   146  				KeyVerificationDisabled: true,
   147  				TLSInsecureSkipVerify:   true,
   148  			},
   149  			logger: logutil.NewLogger(),
   150  			want: map[string]interface{}{
   151  				"kind":  "oauth",
   152  				"name":  "contoso",
   153  				"realm": "contoso",
   154  				"config": map[string]interface{}{
   155  					"base_auth_url":             "https://localhost/oauth",
   156  					"token_url":                 "https://localhost/oauth/access_token",
   157  					"authorization_url":         "https://localhost/oauth/authorize",
   158  					"client_id":                 "foo",
   159  					"client_secret":             "bar",
   160  					"driver":                    "generic",
   161  					"identity_token_name":       "id_token",
   162  					"name":                      "contoso",
   163  					"realm":                     "contoso",
   164  					"required_token_fields":     []interface{}{"access_token"},
   165  					"response_type":             []interface{}{"code"},
   166  					"scopes":                    []interface{}{"openid", "email", "profile"},
   167  					"server_name":               "localhost",
   168  					"tls_insecure_skip_verify":  true,
   169  					"key_verification_disabled": true,
   170  					"jwks_keys": map[string]interface{}{
   171  						"87329db33bf": "../../../testdata/oauth/87329db33bf_pub.pem",
   172  					},
   173  					"login_icon": map[string]interface{}{
   174  						"background_color": string("#324960"),
   175  						"class_name":       string("lab la-codepen la-2x"),
   176  						"color":            string("white"),
   177  						"text_color":       string("#37474f"),
   178  					},
   179  				},
   180  			},
   181  		},
   182  		{
   183  			name: "test nil logger",
   184  			config: &Config{
   185  				Realm: "azure",
   186  			},
   187  			shouldErr: true,
   188  			errPhase:  "initialize",
   189  			err:       errors.ErrIdentityProviderConfigureLoggerNotFound,
   190  		},
   191  		{
   192  			name: "test invalid config",
   193  			config: &Config{
   194  				Realm: "azure",
   195  			},
   196  			logger:    logutil.NewLogger(),
   197  			shouldErr: true,
   198  			errPhase:  "initialize",
   199  			err:       errors.ErrIdentityProviderConfigureNameEmpty,
   200  		},
   201  	}
   202  	for _, tc := range testcases {
   203  		t.Run(tc.name, func(t *testing.T) {
   204  			got := make(map[string]interface{})
   205  			msgs := []string{fmt.Sprintf("test name: %s", tc.name)}
   206  			msgs = append(msgs, fmt.Sprintf("config:\n%v", tc.config))
   207  
   208  			prv, err := NewIdentityProvider(tc.config, tc.logger)
   209  			if tc.errPhase == "initialize" {
   210  				if tests.EvalErrWithLog(t, err, "NewIdentityProvider", tc.shouldErr, tc.err, msgs) {
   211  					return
   212  				}
   213  			} else {
   214  				if tests.EvalErrWithLog(t, err, "NewIdentityProvider", false, nil, msgs) {
   215  					return
   216  				}
   217  			}
   218  
   219  			err = prv.Configure()
   220  			if tc.errPhase == "configure" {
   221  				if tests.EvalErrWithLog(t, err, "IdentityProvider.Configure", tc.shouldErr, tc.err, msgs) {
   222  					return
   223  				}
   224  			} else {
   225  				if tests.EvalErrWithLog(t, err, "IdentityProvider.Configure", false, nil, msgs) {
   226  					return
   227  				}
   228  			}
   229  
   230  			got["name"] = prv.GetName()
   231  			got["realm"] = prv.GetRealm()
   232  			got["kind"] = prv.GetKind()
   233  			got["config"] = prv.GetConfig()
   234  
   235  			tests.EvalObjectsWithLog(t, "IdentityProvider", tc.want, got, msgs)
   236  		})
   237  	}
   238  }