github.com/avenga/couper@v1.12.2/eval/lib/saml_test.go (about)

     1  package lib_test
     2  
     3  import (
     4  	"bytes"
     5  	"compress/flate"
     6  	"context"
     7  	"encoding/base64"
     8  	"encoding/xml"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/zclconf/go-cty/cty"
    16  
    17  	"github.com/avenga/couper/config/configload"
    18  	"github.com/avenga/couper/config/request"
    19  	"github.com/avenga/couper/errors"
    20  	"github.com/avenga/couper/eval"
    21  	"github.com/avenga/couper/eval/lib"
    22  	"github.com/avenga/couper/internal/test"
    23  )
    24  
    25  func Test_SamlSsoURL(t *testing.T) {
    26  	tests := []struct {
    27  		name      string
    28  		hcl       string
    29  		samlLabel string
    30  		wantPfx   string
    31  	}{
    32  		{
    33  			"metadata found",
    34  			`
    35  			server "test" {
    36  			}
    37  			definitions {
    38  				saml "MySAML" {
    39  					idp_metadata_file = "testdata/idp-metadata.xml"
    40  					sp_entity_id = "the-sp"
    41  					sp_acs_url = "https://sp.example.com/saml/acs"
    42  					array_attributes = ["memberOf"]
    43  				}
    44  			}
    45  			`,
    46  			"MySAML",
    47  			"https://idp.example.org/saml/SSOService",
    48  		},
    49  	}
    50  	for _, tt := range tests {
    51  		t.Run(tt.name, func(subT *testing.T) {
    52  			h := test.New(subT)
    53  			cf, err := configload.LoadBytes([]byte(tt.hcl), "couper.hcl")
    54  			if err != nil {
    55  				h.Must(err)
    56  			}
    57  
    58  			evalContext := cf.Context.Value(request.ContextType).(*eval.Context)
    59  			req, err := http.NewRequest(http.MethodGet, "https://www.example.com/foo", nil)
    60  			h.Must(err)
    61  			evalContext = evalContext.WithClientRequest(req)
    62  
    63  			ssoURL, err := evalContext.HCLContext().Functions[lib.FnSamlSsoURL].Call([]cty.Value{cty.StringVal(tt.samlLabel)})
    64  			h.Must(err)
    65  
    66  			if !strings.HasPrefix(ssoURL.AsString(), tt.wantPfx) {
    67  				subT.Errorf("Expected to start with %q, got: %#v", tt.wantPfx, ssoURL.AsString())
    68  			}
    69  
    70  			u, err := url.Parse(ssoURL.AsString())
    71  			h.Must(err)
    72  
    73  			q := u.Query()
    74  			samlRequest := q.Get("SAMLRequest")
    75  			if samlRequest == "" {
    76  				subT.Fatal("Expected SAMLRequest query param")
    77  			}
    78  
    79  			b64Decoded, err := base64.StdEncoding.DecodeString(samlRequest)
    80  			h.Must(err)
    81  
    82  			fr := flate.NewReader(bytes.NewReader(b64Decoded))
    83  			deflated, err := io.ReadAll(fr)
    84  			h.Must(err)
    85  
    86  			var x interface{}
    87  			err = xml.Unmarshal(deflated, &x)
    88  			h.Must(err)
    89  		})
    90  	}
    91  }
    92  
    93  func TestSamlConfigError(t *testing.T) {
    94  	tests := []struct {
    95  		name    string
    96  		config  string
    97  		label   string
    98  		wantErr string
    99  	}{
   100  		{
   101  			"missing referenced saml IdP metadata",
   102  			`
   103  			server {}
   104  			definitions {
   105  			  saml "MySAML" {
   106  			    idp_metadata_file = "/not/there"
   107  			    sp_entity_id = "the-sp"
   108  			    sp_acs_url = "https://sp.example.com/saml/acs"
   109  			  }
   110  			}
   111  			`,
   112  			"MyLabel",
   113  			"configuration error: MySAML: saml2 idp_metadata_file: read error: open /not/there: no such file or directory",
   114  		},
   115  	}
   116  
   117  	for _, tt := range tests {
   118  		t.Run(tt.name, func(subT *testing.T) {
   119  			_, err := configload.LoadBytes([]byte(tt.config), "test.hcl")
   120  			if err == nil {
   121  				subT.Error("expected an error, got nothing")
   122  				return
   123  			}
   124  			gErr := err.(errors.GoError)
   125  			if gErr.LogError() != tt.wantErr {
   126  				subT.Errorf("\nWant:\t%q\nGot:\t%q", tt.wantErr, gErr.LogError())
   127  			}
   128  		})
   129  	}
   130  }
   131  
   132  func TestSamlSsoURLError(t *testing.T) {
   133  	tests := []struct {
   134  		name    string
   135  		config  string
   136  		label   string
   137  		wantErr string
   138  	}{
   139  		{
   140  			"missing saml definitions",
   141  			`
   142  			server {}
   143  			definitions {
   144  			}
   145  			`,
   146  			"MyLabel",
   147  			`missing saml block with referenced label "MyLabel"`,
   148  		},
   149  		{
   150  			"missing referenced saml",
   151  			`
   152  			server {}
   153  			definitions {
   154  			  saml "MySAML" {
   155  			    idp_metadata_file = "testdata/idp-metadata.xml"
   156  			    sp_entity_id = "the-sp"
   157  			    sp_acs_url = "https://sp.example.com/saml/acs"
   158  			  }
   159  			}
   160  			`,
   161  			"MyLabel",
   162  			`missing saml block with referenced label "MyLabel"`,
   163  		},
   164  	}
   165  
   166  	for _, tt := range tests {
   167  		t.Run(tt.name, func(subT *testing.T) {
   168  			h := test.New(subT)
   169  			couperConf, err := configload.LoadBytes([]byte(tt.config), "test.hcl")
   170  			h.Must(err)
   171  
   172  			ctx, cancel := context.WithCancel(couperConf.Context)
   173  			couperConf.Context = ctx
   174  			defer cancel()
   175  
   176  			evalContext := couperConf.Context.Value(request.ContextType).(*eval.Context)
   177  			req, err := http.NewRequest(http.MethodGet, "https://www.example.com/foo", nil)
   178  			h.Must(err)
   179  			evalContext = evalContext.WithClientRequest(req)
   180  
   181  			_, err = evalContext.HCLContext().Functions[lib.FnSamlSsoURL].Call([]cty.Value{cty.StringVal(tt.label)})
   182  			if err == nil {
   183  				subT.Error("expected an error, got nothing")
   184  				return
   185  			}
   186  			if err.Error() != tt.wantErr {
   187  				subT.Errorf("\nWant:\t%q\nGot:\t%q", tt.wantErr, err.Error())
   188  			}
   189  		})
   190  	}
   191  }