github.com/versent/saml2aws@v2.17.0+incompatible/saml.go (about)

     1  package saml2aws
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  
     7  	"github.com/beevik/etree"
     8  )
     9  
    10  const (
    11  	assertionTag          = "Assertion"
    12  	attributeStatementTag = "AttributeStatement"
    13  	attributeTag          = "Attribute"
    14  	attributeValueTag     = "AttributeValue"
    15  )
    16  
    17  //ErrMissingElement is the error type that indicates an element and/or attribute is
    18  //missing. It provides a structured error that can be more appropriately acted
    19  //upon.
    20  type ErrMissingElement struct {
    21  	Tag, Attribute string
    22  }
    23  
    24  //ErrMissingAssertion indicates that an appropriate assertion element could not
    25  //be found in the SAML Response
    26  var (
    27  	ErrMissingAssertion = ErrMissingElement{Tag: assertionTag}
    28  )
    29  
    30  func (e ErrMissingElement) Error() string {
    31  	if e.Attribute != "" {
    32  		return fmt.Sprintf("missing %s attribute on %s element", e.Attribute, e.Tag)
    33  	}
    34  	return fmt.Sprintf("missing %s element", e.Tag)
    35  }
    36  
    37  // ExtractSessionDuration this will attempt to extract a session duration from the assertion
    38  // see https://aws.amazon.com/SAML/Attributes/SessionDuration
    39  func ExtractSessionDuration(data []byte) (int64, error) {
    40  
    41  	doc := etree.NewDocument()
    42  	if err := doc.ReadFromBytes(data); err != nil {
    43  		return 0, err
    44  	}
    45  
    46  	assertionElement := doc.FindElement(".//Assertion")
    47  	if assertionElement == nil {
    48  		return 0, ErrMissingAssertion
    49  	}
    50  
    51  	// log.Printf("tag: %s", assertionElement.Tag)
    52  
    53  	//Get the actual assertion attributes
    54  	attributeStatement := assertionElement.FindElement(childPath(assertionElement.Space, attributeStatementTag))
    55  	if attributeStatement == nil {
    56  		return 0, ErrMissingElement{Tag: attributeStatementTag}
    57  	}
    58  
    59  	attributes := attributeStatement.FindElements(childPath(assertionElement.Space, attributeTag))
    60  
    61  	for _, attribute := range attributes {
    62  		if attribute.SelectAttrValue("Name", "") != "https://aws.amazon.com/SAML/Attributes/SessionDuration" {
    63  			continue
    64  		}
    65  		atributeValues := attribute.FindElements(childPath(assertionElement.Space, attributeValueTag))
    66  		for _, attrValue := range atributeValues {
    67  			return strconv.ParseInt(attrValue.Text(), 10, 64)
    68  		}
    69  	}
    70  
    71  	return 0, nil
    72  }
    73  
    74  // ExtractAwsRoles given an assertion document extract the aws roles
    75  func ExtractAwsRoles(data []byte) ([]string, error) {
    76  
    77  	awsroles := []string{}
    78  
    79  	doc := etree.NewDocument()
    80  	if err := doc.ReadFromBytes(data); err != nil {
    81  		return awsroles, err
    82  	}
    83  
    84  	// log.Printf("root tag: %s", doc.Root().Tag)
    85  
    86  	assertionElement := doc.FindElement(".//Assertion")
    87  	if assertionElement == nil {
    88  		return nil, ErrMissingAssertion
    89  	}
    90  
    91  	// log.Printf("tag: %s", assertionElement.Tag)
    92  
    93  	//Get the actual assertion attributes
    94  	attributeStatement := assertionElement.FindElement(childPath(assertionElement.Space, attributeStatementTag))
    95  	if attributeStatement == nil {
    96  		return nil, ErrMissingElement{Tag: attributeStatementTag}
    97  	}
    98  
    99  	// log.Printf("tag: %s", attributeStatement.Tag)
   100  
   101  	attributes := attributeStatement.FindElements(childPath(assertionElement.Space, attributeTag))
   102  	for _, attribute := range attributes {
   103  		if attribute.SelectAttrValue("Name", "") != "https://aws.amazon.com/SAML/Attributes/Role" {
   104  			continue
   105  		}
   106  		atributeValues := attribute.FindElements(childPath(assertionElement.Space, attributeValueTag))
   107  		for _, attrValue := range atributeValues {
   108  			awsroles = append(awsroles, attrValue.Text())
   109  		}
   110  	}
   111  
   112  	return awsroles, nil
   113  }
   114  
   115  func childPath(space, tag string) string {
   116  	if space == "" {
   117  		return "./" + tag
   118  	}
   119  	//log.Printf("query = %s", "./"+space+":"+tag)
   120  	return "./" + space + ":" + tag
   121  }