github.com/in4it/ecs-deploy@v0.0.42-0.20240508120354-ed77ff16df25/api/saml.go (about)

     1  package api
     2  
     3  // saml gin-gonic implementation
     4  // uses saml/samlsp with parts rewritten to make it work with gin-gonic and gin-jwt
     5  
     6  import (
     7  	"context"
     8  
     9  	jwt "github.com/golang-jwt/jwt/v4"
    10  
    11  	"github.com/crewjam/saml"
    12  	"github.com/crewjam/saml/samlsp"
    13  	"github.com/gin-contrib/location"
    14  	"github.com/gin-gonic/gin"
    15  	"github.com/in4it/ecs-deploy/util"
    16  	"github.com/juju/loggo"
    17  
    18  	"crypto/rsa"
    19  	"crypto/tls"
    20  	"crypto/x509"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"net/http"
    24  	"net/url"
    25  	"strings"
    26  	"time"
    27  )
    28  
    29  // logging
    30  var samlLogger = loggo.GetLogger("saml")
    31  
    32  // jwt signing method
    33  var jwtSigningMethod = jwt.SigningMethodHS256
    34  
    35  type SAML struct {
    36  	idpMetadataURL    *url.URL
    37  	sp                saml.ServiceProvider
    38  	AllowIDPInitiated bool
    39  	TimeFunc          func() time.Time
    40  }
    41  
    42  func randomBytes(n int) []byte {
    43  	rv := make([]byte, n)
    44  	if _, err := saml.RandReader.Read(rv); err != nil {
    45  		panic(err)
    46  	}
    47  	return rv
    48  }
    49  
    50  func newSAML(strIdpMetadataURL string, X509KeyPair, keyPEMBlock []byte) (*SAML, error) {
    51  	s := SAML{}
    52  	var err error
    53  
    54  	if s.TimeFunc == nil {
    55  		s.TimeFunc = time.Now
    56  	}
    57  
    58  	keyPair, err := tls.X509KeyPair(X509KeyPair, keyPEMBlock)
    59  	if err != nil {
    60  		// try to fix AWS paramstore newlines missing
    61  		str := string(keyPEMBlock)
    62  		str = strings.Replace(str, "-----BEGIN PRIVATE KEY----- ", "", -1)
    63  		str = strings.Replace(str, " -----END PRIVATE KEY-----", "", -1)
    64  		str = strings.Replace(str, " ", "\n", -1)
    65  		str = "-----BEGIN PRIVATE KEY-----\n" + str + "\n-----END PRIVATE KEY-----"
    66  		keyPair, err = tls.X509KeyPair(X509KeyPair, []byte(str))
    67  		if err != nil {
    68  			return nil, err
    69  		}
    70  	}
    71  	keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	s.idpMetadataURL, err = url.Parse(strIdpMetadataURL)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	rootURL, err := url.Parse(util.GetEnv("SAML_ACS_URL", ""))
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	idpMetadataURL := *s.idpMetadataURL
    87  
    88  	idpMetadata, err := samlsp.FetchMetadata(
    89  		context.Background(),
    90  		http.DefaultClient,
    91  		idpMetadataURL)
    92  
    93  	samlSP, _ := samlsp.New(samlsp.Options{
    94  		URL:         *rootURL,
    95  		Key:         keyPair.PrivateKey.(*rsa.PrivateKey),
    96  		Certificate: keyPair.Leaf,
    97  		IDPMetadata: idpMetadata,
    98  	})
    99  
   100  	s.sp = samlSP.ServiceProvider
   101  	s.AllowIDPInitiated = true
   102  
   103  	return &s, nil
   104  }
   105  
   106  func (s *SAML) getIDPSSOURL() string {
   107  	return s.idpMetadataURL.String()
   108  }
   109  
   110  func (s *SAML) getIDPCert() string {
   111  	return "cert"
   112  }
   113  
   114  func (s *SAML) samlEnabledHandler(c *gin.Context) {
   115  	if util.GetEnv("SAML_ENABLED", "") == "yes" {
   116  		c.JSON(200, gin.H{
   117  			"saml": "enabled",
   118  		})
   119  	} else {
   120  		c.JSON(200, gin.H{
   121  			"saml": "disabled",
   122  		})
   123  	}
   124  }
   125  func (s *SAML) samlResponseHandler(c *gin.Context) {
   126  	assertion, err := s.sp.ParseResponse(c.Request, s.getPossibleRequestIDs(c))
   127  	if err != nil {
   128  		if parseErr, ok := err.(*saml.InvalidResponseError); ok {
   129  			samlLogger.Errorf("RESPONSE: ===\n%s\n===\nNOW: %s\nERROR: %s", parseErr.Response, parseErr.Now, parseErr.PrivateErr)
   130  		}
   131  		c.JSON(http.StatusForbidden, gin.H{
   132  			"error": http.StatusText(http.StatusForbidden),
   133  		})
   134  		return
   135  	}
   136  	// auth OK, create jwt token
   137  	token := jwt.New(jwtSigningMethod)
   138  	claims := token.Claims.(jwt.MapClaims)
   139  	expire := s.TimeFunc().UTC().Add(time.Hour)
   140  	claims["id"] = assertion.Subject.NameID.Value
   141  	claims["exp"] = expire.Unix()
   142  	claims["orig_iat"] = s.TimeFunc().Unix()
   143  
   144  	tokenString, err := token.SignedString([]byte(util.GetEnv("JWT_SECRET", "unsecure secret key 8a045eb")))
   145  
   146  	if err != nil {
   147  		c.JSON(http.StatusInternalServerError, gin.H{
   148  			"error": err.Error(),
   149  		})
   150  		return
   151  	}
   152  	// redirect to UI with jwt token
   153  	c.Redirect(http.StatusFound, util.GetEnv("URL_PREFIX", "")+"/webapp/saml?token="+tokenString)
   154  }
   155  
   156  // samlsp/middleware.go adapted for gin gonic
   157  func (s *SAML) samlInitHandler(c *gin.Context) {
   158  	if c.PostForm("SAMLResponse") != "" {
   159  		s.samlResponseHandler(c)
   160  		return
   161  	}
   162  	url := location.Get(c)
   163  
   164  	binding := saml.HTTPRedirectBinding
   165  	bindingLocation := s.sp.GetSSOBindingLocation(binding)
   166  	if bindingLocation == "" {
   167  		binding = saml.HTTPPostBinding
   168  		bindingLocation = s.sp.GetSSOBindingLocation(binding)
   169  	}
   170  
   171  	req, err := s.sp.MakeAuthenticationRequest(bindingLocation, binding, saml.HTTPPostBinding)
   172  	if err != nil {
   173  		c.JSON(http.StatusInternalServerError, gin.H{
   174  			"error": err.Error(),
   175  		})
   176  		return
   177  	}
   178  	relayState := base64.URLEncoding.EncodeToString(randomBytes(42))
   179  
   180  	secretBlock := x509.MarshalPKCS1PrivateKey(s.sp.Key)
   181  	state := jwt.New(jwtSigningMethod)
   182  	claims := state.Claims.(jwt.MapClaims)
   183  	claims["id"] = req.ID
   184  	claims["uri"] = url.Scheme + url.Host + url.Path
   185  	signedState, err := state.SignedString(secretBlock)
   186  	if err != nil {
   187  		c.JSON(http.StatusInternalServerError, gin.H{
   188  			"error": err.Error(),
   189  		})
   190  		return
   191  	}
   192  
   193  	http.SetCookie(c.Writer, &http.Cookie{
   194  		Name:     fmt.Sprintf("saml_%s", relayState),
   195  		Value:    signedState,
   196  		MaxAge:   int(saml.MaxIssueDelay.Seconds()),
   197  		HttpOnly: true,
   198  		Secure:   url.Scheme == "https",
   199  		Path:     s.sp.AcsURL.Path,
   200  	})
   201  
   202  	if binding == saml.HTTPRedirectBinding {
   203  		redirectURL, err := req.Redirect(relayState, &s.sp)
   204  		if err != nil {
   205  			c.JSON(http.StatusInternalServerError, gin.H{
   206  				"error": err.Error(),
   207  			})
   208  			return
   209  		}
   210  		c.Redirect(http.StatusFound, redirectURL.String())
   211  		return
   212  	}
   213  	if binding == saml.HTTPPostBinding {
   214  		c.Writer.Header().Add("Content-Security-Policy", ""+
   215  			"default-src; "+
   216  			"script-src 'sha256-AjPdJSbZmeWHnEc5ykvJFay8FTWeTeRbs9dutfZ0HqE='; "+
   217  			"reflected-xss block; referrer no-referrer;")
   218  		c.Writer.Header().Add("Content-type", "text/html")
   219  		c.Writer.Write([]byte(`<!DOCTYPE html><html><body>`))
   220  		c.Writer.Write(req.Post(relayState))
   221  		c.Writer.Write([]byte(`</body></html>`))
   222  		return
   223  	}
   224  	panic("no saml binding found")
   225  }
   226  
   227  func (s *SAML) getPossibleRequestIDs(c *gin.Context) []string {
   228  	rv := []string{}
   229  	for _, cookie := range c.Request.Cookies() {
   230  		if !strings.HasPrefix(cookie.Name, "saml_") {
   231  			continue
   232  		}
   233  		samlLogger.Debugf("getPossibleRequestIDs: cookie: %s", cookie.String())
   234  
   235  		jwtParser := jwt.Parser{
   236  			ValidMethods: []string{jwtSigningMethod.Name},
   237  		}
   238  		token, err := jwtParser.Parse(cookie.Value, func(t *jwt.Token) (interface{}, error) {
   239  			secretBlock := x509.MarshalPKCS1PrivateKey(s.sp.Key)
   240  			return secretBlock, nil
   241  		})
   242  		if err != nil || !token.Valid {
   243  			samlLogger.Debugf("... invalid token %s", err)
   244  			continue
   245  		}
   246  		claims := token.Claims.(jwt.MapClaims)
   247  		rv = append(rv, claims["id"].(string))
   248  	}
   249  
   250  	// If IDP initiated requests are allowed, then we can expect an empty response ID.
   251  	if s.AllowIDPInitiated {
   252  		rv = append(rv, "")
   253  	}
   254  
   255  	return rv
   256  }