github.com/avenga/couper@v1.12.2/eval/lib/saml.go (about)

     1  package lib
     2  
     3  import (
     4  	"encoding/xml"
     5  	"fmt"
     6  	"net/url"
     7  
     8  	saml2 "github.com/russellhaering/gosaml2"
     9  	"github.com/russellhaering/gosaml2/types"
    10  	"github.com/zclconf/go-cty/cty"
    11  	"github.com/zclconf/go-cty/cty/function"
    12  
    13  	"github.com/avenga/couper/config"
    14  )
    15  
    16  const (
    17  	FnSamlSsoURL            = "saml_sso_url"
    18  	NameIDFormatUnspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"
    19  )
    20  
    21  var NoOpSamlSsoURLFunction = function.New(&function.Spec{
    22  	Params: []function.Parameter{
    23  		{
    24  			Name: "saml_label",
    25  			Type: cty.String,
    26  		},
    27  	},
    28  	Type: function.StaticReturnType(cty.String),
    29  	Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) {
    30  		if len(args) > 0 {
    31  			return cty.StringVal(""), fmt.Errorf("missing saml block with referenced label %q", args[0].AsString())
    32  		}
    33  		return cty.StringVal(""), fmt.Errorf("missing saml definitions")
    34  	},
    35  })
    36  
    37  func NewSamlSsoURLFunction(configs []*config.SAML, origin *url.URL) function.Function {
    38  	type entity struct {
    39  		config     *config.SAML
    40  		descriptor *types.EntityDescriptor
    41  		err        error
    42  	}
    43  
    44  	samlEntities := make(map[string]*entity)
    45  	for _, conf := range configs {
    46  		metadata := &types.EntityDescriptor{}
    47  		err := xml.Unmarshal(conf.MetadataBytes, metadata)
    48  		samlEntities[conf.Name] = &entity{
    49  			config:     conf,
    50  			descriptor: metadata,
    51  			err:        err,
    52  		}
    53  	}
    54  
    55  	return function.New(&function.Spec{
    56  		Params: []function.Parameter{
    57  			{
    58  				Name: "saml_label",
    59  				Type: cty.String,
    60  			},
    61  		},
    62  		Type: function.StaticReturnType(cty.String),
    63  		Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) {
    64  			label := args[0].AsString()
    65  			ent, exist := samlEntities[label]
    66  			if !exist {
    67  				return NoOpSamlSsoURLFunction.Call(args)
    68  			}
    69  
    70  			metadata := ent.descriptor
    71  			var ssoURL string
    72  			for _, ssoService := range metadata.IDPSSODescriptor.SingleSignOnServices {
    73  				if ssoService.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" {
    74  					ssoURL = ssoService.Location
    75  					continue
    76  				}
    77  			}
    78  
    79  			nameIDFormat := getNameIDFormat(metadata.IDPSSODescriptor.NameIDFormats)
    80  
    81  			absAcsURL, err := AbsoluteURL(ent.config.SpAcsURL, origin)
    82  			if err != nil {
    83  				return cty.StringVal(""), err
    84  			}
    85  
    86  			sp := &saml2.SAMLServiceProvider{
    87  				AssertionConsumerServiceURL: absAcsURL,
    88  				IdentityProviderSSOURL:      ssoURL,
    89  				ServiceProviderIssuer:       ent.config.SpEntityID,
    90  				SignAuthnRequests:           false,
    91  			}
    92  			if nameIDFormat != "" {
    93  				sp.NameIdFormat = nameIDFormat
    94  			}
    95  
    96  			samlSsoURL, err := sp.BuildAuthURL("")
    97  			if err != nil {
    98  				return cty.StringVal(""), err
    99  			}
   100  
   101  			return cty.StringVal(samlSsoURL), nil
   102  		},
   103  	})
   104  }
   105  
   106  func getNameIDFormat(supportedNameIDFormats []types.NameIDFormat) string {
   107  	nameIDFormat := ""
   108  	if isSupportedNameIDFormat(supportedNameIDFormats, NameIDFormatUnspecified) {
   109  		nameIDFormat = NameIDFormatUnspecified
   110  	} else if len(supportedNameIDFormats) > 0 {
   111  		nameIDFormat = supportedNameIDFormats[0].Value
   112  	}
   113  	return nameIDFormat
   114  }
   115  
   116  func isSupportedNameIDFormat(supportedNameIDFormats []types.NameIDFormat, nameIDFormat string) bool {
   117  	for _, n := range supportedNameIDFormats {
   118  		if n.Value == nameIDFormat {
   119  			return true
   120  		}
   121  	}
   122  	return false
   123  }