github.com/greenpau/go-authcrunch@v1.1.4/pkg/idp/saml/authenticate.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package saml
    16  
    17  import (
    18  	"encoding/base64"
    19  	"fmt"
    20  	"github.com/greenpau/go-authcrunch/pkg/requests"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  
    25  	"go.uber.org/zap"
    26  )
    27  
    28  // Authenticate performs authentication.
    29  func (b *IdentityProvider) Authenticate(r *requests.Request) error {
    30  	r.Response.Code = 400
    31  	if r.Upstream.Request.Method != "POST" {
    32  		r.Response.Code = 302
    33  		r.Response.RedirectURL = b.loginURL
    34  		return nil
    35  	}
    36  
    37  	if 500 > r.Upstream.Request.ContentLength || r.Upstream.Request.ContentLength > 30000 {
    38  		return fmt.Errorf("request payload is not 500 to 300000 bytes: %d", r.Upstream.Request.ContentLength)
    39  	}
    40  	contentType := r.Upstream.Request.Header.Get("Content-Type")
    41  	if contentType != "application/x-www-form-urlencoded" {
    42  		return fmt.Errorf("request content type is not application/x-www-form-urlencoded")
    43  	}
    44  	if err := r.Upstream.Request.ParseForm(); err != nil {
    45  		return fmt.Errorf("failed to parse form: %v", err)
    46  	}
    47  	if r.Upstream.Request.FormValue("SAMLResponse") == "" {
    48  		return fmt.Errorf("request from has no SAMLResponse field")
    49  	}
    50  	samlResponseBytes, err := base64.StdEncoding.DecodeString(r.Upstream.Request.FormValue("SAMLResponse"))
    51  	if err != nil {
    52  		return fmt.Errorf("failed to decode SAMLResponse: %v", err)
    53  	}
    54  	acsURL := ""
    55  	s := string(samlResponseBytes)
    56  	for _, elem := range []string{"Destination=\""} {
    57  		i := strings.Index(s, elem)
    58  		if i < 0 {
    59  			continue
    60  		}
    61  		j := strings.Index(s[i+len(elem):], "\"")
    62  		if j < 0 {
    63  			continue
    64  		}
    65  		acsURL = s[i+len(elem) : i+len(elem)+j]
    66  	}
    67  
    68  	if acsURL == "" {
    69  		return fmt.Errorf("failed to parse ACS URL")
    70  	}
    71  
    72  	if b.config.Driver == "azure" {
    73  		if !strings.Contains(r.Upstream.Request.Header.Get("Origin"), "login.microsoftonline.com") && !strings.Contains(r.Upstream.Request.Header.Get("Referer"), "windowsazure.com") {
    74  			return fmt.Errorf("Origin does not contain login.microsoftonline.com and Referer is not windowsazure.com")
    75  		}
    76  	}
    77  
    78  	sp, serviceProviderExists := b.serviceProviders[acsURL]
    79  	if !serviceProviderExists {
    80  		return fmt.Errorf("unsupported ACS URL %s", acsURL)
    81  	}
    82  
    83  	samlAssertions, err := sp.ParseXMLResponse(samlResponseBytes, []string{""})
    84  	if err != nil {
    85  		return fmt.Errorf("failed to ParseXMLResponse: %s", err)
    86  	}
    87  
    88  	m := make(map[string]interface{})
    89  	metadata := make(map[string]interface{})
    90  	for _, attrStatement := range samlAssertions.AttributeStatements {
    91  
    92  		for _, attrEntry := range attrStatement.Attributes {
    93  			if len(attrEntry.Values) == 0 {
    94  				continue
    95  			}
    96  			switch {
    97  			case strings.HasSuffix(attrEntry.Name, "Attributes/MaxSessionDuration"):
    98  				multiplier, err := strconv.Atoi(attrEntry.Values[0].Value)
    99  				if err != nil {
   100  					b.logger.Error(
   101  						"Failed parsing Attributes/MaxSessionDuration",
   102  						zap.String("request_id", r.ID),
   103  						zap.String("error", err.Error()),
   104  					)
   105  					continue
   106  				}
   107  				m["exp"] = time.Now().Add(time.Duration(multiplier) * time.Second).Unix()
   108  			case strings.HasSuffix(attrEntry.Name, "identity/claims/displayname"):
   109  				if attrEntry.Values[0].Value != "" {
   110  					m["name"] = attrEntry.Values[0].Value
   111  				}
   112  			case strings.HasSuffix(attrEntry.Name, "identity/claims/emailaddress"):
   113  				if attrEntry.Values[0].Value != "" {
   114  					m["email"] = attrEntry.Values[0].Value
   115  				}
   116  			case strings.HasSuffix(attrEntry.Name, "identity/claims/identityprovider"):
   117  				if attrEntry.Values[0].Value != "" {
   118  					m["origin"] = attrEntry.Values[0].Value
   119  				}
   120  			case strings.HasSuffix(attrEntry.Name, "schemas.microsoft.com/identity/claims/objectidentifier"):
   121  				if attrEntry.Values[0].Value != "" {
   122  					metadata["oid"] = attrEntry.Values[0].Value
   123  				}
   124  			case strings.HasSuffix(attrEntry.Name, "schemas.xmlsoap.org/ws/2005/05/identity/claims/upn"):
   125  				if attrEntry.Values[0].Value != "" {
   126  					metadata["upn"] = attrEntry.Values[0].Value
   127  				}
   128  			case strings.HasSuffix(attrEntry.Name, "identity/claims/name"):
   129  				if attrEntry.Values[0].Value != "" {
   130  					m["sub"] = attrEntry.Values[0].Value
   131  				}
   132  			case strings.HasSuffix(attrEntry.Name, "Attributes/Role"):
   133  				roles := []string{}
   134  				for _, attrEntryElement := range attrEntry.Values {
   135  					roles = append(roles, attrEntryElement.Value)
   136  				}
   137  				if len(roles) > 0 {
   138  					m["roles"] = roles
   139  				}
   140  			}
   141  		}
   142  	}
   143  
   144  	for _, k := range []string{"email", "name"} {
   145  		if _, exists := m[k]; !exists {
   146  			return fmt.Errorf("SAML authorization failed, mandatory %s attribute not found: %v", k, m)
   147  		}
   148  	}
   149  
   150  	if len(metadata) > 0 {
   151  		m["metadata"] = metadata
   152  	}
   153  
   154  	r.Response.Code = 200
   155  	r.Response.Payload = m
   156  	return nil
   157  }