github.com/avenga/couper@v1.12.2/eval/lib/saml_test.go (about) 1 package lib_test 2 3 import ( 4 "bytes" 5 "compress/flate" 6 "context" 7 "encoding/base64" 8 "encoding/xml" 9 "io" 10 "net/http" 11 "net/url" 12 "strings" 13 "testing" 14 15 "github.com/zclconf/go-cty/cty" 16 17 "github.com/avenga/couper/config/configload" 18 "github.com/avenga/couper/config/request" 19 "github.com/avenga/couper/errors" 20 "github.com/avenga/couper/eval" 21 "github.com/avenga/couper/eval/lib" 22 "github.com/avenga/couper/internal/test" 23 ) 24 25 func Test_SamlSsoURL(t *testing.T) { 26 tests := []struct { 27 name string 28 hcl string 29 samlLabel string 30 wantPfx string 31 }{ 32 { 33 "metadata found", 34 ` 35 server "test" { 36 } 37 definitions { 38 saml "MySAML" { 39 idp_metadata_file = "testdata/idp-metadata.xml" 40 sp_entity_id = "the-sp" 41 sp_acs_url = "https://sp.example.com/saml/acs" 42 array_attributes = ["memberOf"] 43 } 44 } 45 `, 46 "MySAML", 47 "https://idp.example.org/saml/SSOService", 48 }, 49 } 50 for _, tt := range tests { 51 t.Run(tt.name, func(subT *testing.T) { 52 h := test.New(subT) 53 cf, err := configload.LoadBytes([]byte(tt.hcl), "couper.hcl") 54 if err != nil { 55 h.Must(err) 56 } 57 58 evalContext := cf.Context.Value(request.ContextType).(*eval.Context) 59 req, err := http.NewRequest(http.MethodGet, "https://www.example.com/foo", nil) 60 h.Must(err) 61 evalContext = evalContext.WithClientRequest(req) 62 63 ssoURL, err := evalContext.HCLContext().Functions[lib.FnSamlSsoURL].Call([]cty.Value{cty.StringVal(tt.samlLabel)}) 64 h.Must(err) 65 66 if !strings.HasPrefix(ssoURL.AsString(), tt.wantPfx) { 67 subT.Errorf("Expected to start with %q, got: %#v", tt.wantPfx, ssoURL.AsString()) 68 } 69 70 u, err := url.Parse(ssoURL.AsString()) 71 h.Must(err) 72 73 q := u.Query() 74 samlRequest := q.Get("SAMLRequest") 75 if samlRequest == "" { 76 subT.Fatal("Expected SAMLRequest query param") 77 } 78 79 b64Decoded, err := base64.StdEncoding.DecodeString(samlRequest) 80 h.Must(err) 81 82 fr := flate.NewReader(bytes.NewReader(b64Decoded)) 83 deflated, err := io.ReadAll(fr) 84 h.Must(err) 85 86 var x interface{} 87 err = xml.Unmarshal(deflated, &x) 88 h.Must(err) 89 }) 90 } 91 } 92 93 func TestSamlConfigError(t *testing.T) { 94 tests := []struct { 95 name string 96 config string 97 label string 98 wantErr string 99 }{ 100 { 101 "missing referenced saml IdP metadata", 102 ` 103 server {} 104 definitions { 105 saml "MySAML" { 106 idp_metadata_file = "/not/there" 107 sp_entity_id = "the-sp" 108 sp_acs_url = "https://sp.example.com/saml/acs" 109 } 110 } 111 `, 112 "MyLabel", 113 "configuration error: MySAML: saml2 idp_metadata_file: read error: open /not/there: no such file or directory", 114 }, 115 } 116 117 for _, tt := range tests { 118 t.Run(tt.name, func(subT *testing.T) { 119 _, err := configload.LoadBytes([]byte(tt.config), "test.hcl") 120 if err == nil { 121 subT.Error("expected an error, got nothing") 122 return 123 } 124 gErr := err.(errors.GoError) 125 if gErr.LogError() != tt.wantErr { 126 subT.Errorf("\nWant:\t%q\nGot:\t%q", tt.wantErr, gErr.LogError()) 127 } 128 }) 129 } 130 } 131 132 func TestSamlSsoURLError(t *testing.T) { 133 tests := []struct { 134 name string 135 config string 136 label string 137 wantErr string 138 }{ 139 { 140 "missing saml definitions", 141 ` 142 server {} 143 definitions { 144 } 145 `, 146 "MyLabel", 147 `missing saml block with referenced label "MyLabel"`, 148 }, 149 { 150 "missing referenced saml", 151 ` 152 server {} 153 definitions { 154 saml "MySAML" { 155 idp_metadata_file = "testdata/idp-metadata.xml" 156 sp_entity_id = "the-sp" 157 sp_acs_url = "https://sp.example.com/saml/acs" 158 } 159 } 160 `, 161 "MyLabel", 162 `missing saml block with referenced label "MyLabel"`, 163 }, 164 } 165 166 for _, tt := range tests { 167 t.Run(tt.name, func(subT *testing.T) { 168 h := test.New(subT) 169 couperConf, err := configload.LoadBytes([]byte(tt.config), "test.hcl") 170 h.Must(err) 171 172 ctx, cancel := context.WithCancel(couperConf.Context) 173 couperConf.Context = ctx 174 defer cancel() 175 176 evalContext := couperConf.Context.Value(request.ContextType).(*eval.Context) 177 req, err := http.NewRequest(http.MethodGet, "https://www.example.com/foo", nil) 178 h.Must(err) 179 evalContext = evalContext.WithClientRequest(req) 180 181 _, err = evalContext.HCLContext().Functions[lib.FnSamlSsoURL].Call([]cty.Value{cty.StringVal(tt.label)}) 182 if err == nil { 183 subT.Error("expected an error, got nothing") 184 return 185 } 186 if err.Error() != tt.wantErr { 187 subT.Errorf("\nWant:\t%q\nGot:\t%q", tt.wantErr, err.Error()) 188 } 189 }) 190 } 191 }