github.com/avenga/couper@v1.12.2/accesscontrol/saml2_test.go (about)

     1  package accesscontrol_test
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"encoding/xml"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"net/url"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/google/go-cmp/cmp"
    14  	saml2 "github.com/russellhaering/gosaml2"
    15  	"github.com/russellhaering/gosaml2/types"
    16  
    17  	ac "github.com/avenga/couper/accesscontrol"
    18  	"github.com/avenga/couper/config/reader"
    19  	"github.com/avenga/couper/errors"
    20  	"github.com/avenga/couper/internal/test"
    21  )
    22  
    23  func Test_NewSAML2ACS(t *testing.T) {
    24  	helper := test.New(t)
    25  
    26  	type testCase struct {
    27  		metadataFile, acsURL, spEntityID string
    28  		arrayAttributes                  []string
    29  		expErrMsg                        string
    30  		shouldFail                       bool
    31  	}
    32  
    33  	for _, tc := range []testCase{
    34  		{"testdata/idp-metadata.xml", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{}, "", false},
    35  		{"not-there.xml", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{}, "not-there.xml: no such file or directory", true},
    36  	} {
    37  		metadata, err := reader.ReadFromAttrFile("saml2", "", tc.metadataFile)
    38  		if err != nil {
    39  			readErr := err.(errors.GoError)
    40  			if tc.shouldFail {
    41  				if !strings.HasSuffix(readErr.LogError(), tc.expErrMsg) {
    42  					t.Errorf("Want: %q, got: %q", tc.expErrMsg, readErr.LogError())
    43  				}
    44  				continue
    45  			}
    46  			t.Error(err)
    47  			continue
    48  		}
    49  
    50  		_, err = ac.NewSAML2ACS(metadata, "test", tc.acsURL, tc.spEntityID, tc.arrayAttributes)
    51  		helper.Must(err)
    52  	}
    53  }
    54  
    55  func Test_SAML2ACS_Validate(t *testing.T) {
    56  	metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml")
    57  	if err != nil || metadata == nil {
    58  		t.Fatal("Expected a metadata object")
    59  	}
    60  	sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
    61  	if err != nil || sa == nil {
    62  		t.Fatal("Expected a saml acs object")
    63  	}
    64  
    65  	type testCase struct {
    66  		name    string
    67  		payload string
    68  		wantErr bool
    69  	}
    70  	for _, tc := range []testCase{
    71  		{
    72  			"invalid body",
    73  			"1qp4ghn1pin",
    74  			true,
    75  		},
    76  		{
    77  			"invalid SAMLResponse",
    78  			"SAMLResponse=1qp4ghn1pin",
    79  			true,
    80  		},
    81  		{
    82  			"invalid url-encoded SAMLResponse",
    83  			"SAMLResponse=" + url.QueryEscape("abcde"),
    84  			true,
    85  		},
    86  		{
    87  			"invalid base64- and url-encoded SAMLResponse",
    88  			"SAMLResponse=" + url.QueryEscape(base64.StdEncoding.EncodeToString([]byte("abcde"))),
    89  			true,
    90  		},
    91  		// TODO how to make test for valid SAMLResponse?
    92  	} {
    93  		t.Run(tc.name, func(subT *testing.T) {
    94  			req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(tc.payload))
    95  			if err := sa.Validate(req); (err != nil) != tc.wantErr {
    96  				subT.Errorf("%s: Validate() error = %v, wantErr %v", tc.name, err, tc.wantErr)
    97  			}
    98  		})
    99  	}
   100  }
   101  
   102  func Test_SAML2ACS_ValidateAssertionInfo(t *testing.T) {
   103  	metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml")
   104  	if err != nil {
   105  		t.Fatal(err)
   106  	}
   107  	sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
   108  	if err != nil || sa == nil {
   109  		t.Fatal("Expected a saml acs object")
   110  	}
   111  
   112  	type testCase struct {
   113  		name          string
   114  		assertionInfo *saml2.AssertionInfo
   115  		wantErr       bool
   116  	}
   117  	for _, tc := range []testCase{
   118  		{
   119  			"assertion mismatch",
   120  			&saml2.AssertionInfo{
   121  				WarningInfo: &saml2.WarningInfo{
   122  					NotInAudience: true,
   123  				},
   124  			},
   125  			true,
   126  		},
   127  		{
   128  			"assertion match",
   129  			&saml2.AssertionInfo{
   130  				WarningInfo: &saml2.WarningInfo{},
   131  			},
   132  			false,
   133  		},
   134  	} {
   135  		t.Run(tc.name, func(subT *testing.T) {
   136  			if err = sa.ValidateAssertionInfo(tc.assertionInfo); (err != nil) != tc.wantErr {
   137  				subT.Errorf("%s: ValidateAssertionInfo() error = %v, wantErr %v", tc.name, err, tc.wantErr)
   138  			}
   139  		})
   140  	}
   141  }
   142  
   143  func Test_SAML2ACS_GetAssertionData(t *testing.T) {
   144  	metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml")
   145  	if err != nil || metadata == nil {
   146  		t.Fatal("Expected a metadata object")
   147  	}
   148  	sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
   149  	if err != nil || sa == nil {
   150  		t.Fatal("Expected a saml acs object")
   151  	}
   152  
   153  	valuesWith2MemberOf := saml2.Values{
   154  		"displayName": types.Attribute{
   155  			Name: "displayName",
   156  			Values: []types.AttributeValue{
   157  				{
   158  					Value: "John Doe",
   159  				},
   160  				{
   161  					Value: "Jane Doe",
   162  				},
   163  			},
   164  		},
   165  		"memberOf": types.Attribute{
   166  			Name: "memberOf",
   167  			Values: []types.AttributeValue{
   168  				{
   169  					Value: "group1",
   170  				},
   171  				{
   172  					Value: "group2",
   173  				},
   174  			},
   175  		},
   176  	}
   177  	valuesWith1MemberOf := saml2.Values{
   178  		"displayName": types.Attribute{
   179  			Name: "displayName",
   180  			Values: []types.AttributeValue{
   181  				{
   182  					Value: "Jane Doe",
   183  				},
   184  			},
   185  		},
   186  		"memberOf": types.Attribute{
   187  			Name: "memberOf",
   188  			Values: []types.AttributeValue{
   189  				{
   190  					Value: "group1",
   191  				},
   192  			},
   193  		},
   194  	}
   195  	valuesEmptyMemberOf := saml2.Values{
   196  		"displayName": types.Attribute{
   197  			Name: "displayName",
   198  			Values: []types.AttributeValue{
   199  				{
   200  					Value: "Jane Doe",
   201  				},
   202  			},
   203  		},
   204  		"memberOf": types.Attribute{
   205  			Name:   "memberOf",
   206  			Values: []types.AttributeValue{},
   207  		},
   208  	}
   209  	valuesMissingMemberOf := saml2.Values{
   210  		"displayName": types.Attribute{
   211  			Name: "displayName",
   212  			Values: []types.AttributeValue{
   213  				{
   214  					Value: "Jane Doe",
   215  				},
   216  			},
   217  		},
   218  	}
   219  	var authnStatement types.AuthnStatement
   220  	err = xml.Unmarshal([]byte(`<AuthnStatement xmlns="urn:oasis:names:tc:SAML:2.0:assertion" SessionNotOnOrAfter="2020-11-13T17:06:00Z"/>`), &authnStatement)
   221  	if err != nil {
   222  		t.Fatal(err)
   223  	}
   224  
   225  	type testCase struct {
   226  		name          string
   227  		assertionInfo *saml2.AssertionInfo
   228  		want          map[string]interface{}
   229  	}
   230  	for _, tc := range []testCase{
   231  		{
   232  			"without exp, with 2 memberOf",
   233  			&saml2.AssertionInfo{
   234  				NameID: "abc12345",
   235  				Values: valuesWith2MemberOf,
   236  			},
   237  			map[string]interface{}{
   238  				"sub": "abc12345",
   239  				"attributes": map[string]interface{}{
   240  					"displayName": "Jane Doe",
   241  					"memberOf": []string{
   242  						"group1",
   243  						"group2",
   244  					},
   245  				},
   246  			},
   247  		},
   248  		{
   249  			"without exp, with 1 memberOf",
   250  			&saml2.AssertionInfo{
   251  				NameID: "abc12345",
   252  				Values: valuesWith1MemberOf,
   253  			},
   254  			map[string]interface{}{
   255  				"sub": "abc12345",
   256  				"attributes": map[string]interface{}{
   257  					"displayName": "Jane Doe",
   258  					"memberOf": []string{
   259  						"group1",
   260  					},
   261  				},
   262  			},
   263  		},
   264  		{
   265  			"with exp, with memberOf",
   266  			&saml2.AssertionInfo{
   267  				NameID:              "abc12345",
   268  				SessionNotOnOrAfter: authnStatement.SessionNotOnOrAfter,
   269  				Values:              valuesWith2MemberOf,
   270  			},
   271  			map[string]interface{}{
   272  				"sub": "abc12345",
   273  				"exp": int64(1605287160),
   274  				"attributes": map[string]interface{}{
   275  					"displayName": "Jane Doe",
   276  					"memberOf": []string{
   277  						"group1",
   278  						"group2",
   279  					},
   280  				},
   281  			},
   282  		},
   283  		{
   284  			"without exp, empty memberOf",
   285  			&saml2.AssertionInfo{
   286  				NameID: "abc12345",
   287  				Values: valuesEmptyMemberOf,
   288  			},
   289  			map[string]interface{}{
   290  				"sub": "abc12345",
   291  				"attributes": map[string]interface{}{
   292  					"displayName": "Jane Doe",
   293  					"memberOf":    []string{},
   294  				},
   295  			},
   296  		},
   297  		{
   298  			"without exp, without memberOf",
   299  			&saml2.AssertionInfo{
   300  				NameID: "abc12345",
   301  				Values: valuesMissingMemberOf,
   302  			},
   303  			map[string]interface{}{
   304  				"sub": "abc12345",
   305  				"attributes": map[string]interface{}{
   306  					"displayName": "Jane Doe",
   307  					"memberOf":    []string{},
   308  				},
   309  			},
   310  		},
   311  	} {
   312  		t.Run(tc.name, func(subT *testing.T) {
   313  			assertionData := sa.GetAssertionData(tc.assertionInfo)
   314  			if !cmp.Equal(tc.want, assertionData) {
   315  				subT.Errorf(cmp.Diff(tc.want, assertionData))
   316  			}
   317  		})
   318  	}
   319  }