github.com/crewjam/saml@v0.4.14/samlsp/fetch_metadata.go (about)

     1  package samlsp
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/xml"
     7  	"errors"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  
    12  	"github.com/crewjam/httperr"
    13  	xrv "github.com/mattermost/xml-roundtrip-validator"
    14  
    15  	"github.com/crewjam/saml/logger"
    16  
    17  	"github.com/crewjam/saml"
    18  )
    19  
    20  // ParseMetadata parses arbitrary SAML IDP metadata.
    21  //
    22  // Note: this is needed because IDP metadata is sometimes wrapped in
    23  // an <EntitiesDescriptor>, and sometimes the top level element is an
    24  // <EntityDescriptor>.
    25  func ParseMetadata(data []byte) (*saml.EntityDescriptor, error) {
    26  	entity := &saml.EntityDescriptor{}
    27  
    28  	if err := xrv.Validate(bytes.NewBuffer(data)); err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	err := xml.Unmarshal(data, entity)
    33  
    34  	// this comparison is ugly, but it is how the error is generated in encoding/xml
    35  	if err != nil && err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
    36  		entities := &saml.EntitiesDescriptor{}
    37  		if err := xml.Unmarshal(data, entities); err != nil {
    38  			return nil, err
    39  		}
    40  
    41  		for i, e := range entities.EntityDescriptors {
    42  			if len(e.IDPSSODescriptors) > 0 {
    43  				return &entities.EntityDescriptors[i], nil
    44  			}
    45  		}
    46  		return nil, errors.New("no entity found with IDPSSODescriptor")
    47  	}
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  	return entity, nil
    52  }
    53  
    54  // FetchMetadata returns metadata from an IDP metadata URL.
    55  func FetchMetadata(ctx context.Context, httpClient *http.Client, metadataURL url.URL) (*saml.EntityDescriptor, error) {
    56  	req, err := http.NewRequest("GET", metadataURL.String(), nil)
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	req = req.WithContext(ctx)
    61  
    62  	resp, err := httpClient.Do(req)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	defer func() {
    67  		if err := resp.Body.Close(); err != nil {
    68  			logger.DefaultLogger.Printf("Error while closing response body during fetch metadata: %v", err)
    69  		}
    70  	}()
    71  	if resp.StatusCode >= 400 {
    72  		return nil, httperr.Response(*resp)
    73  	}
    74  
    75  	data, err := io.ReadAll(resp.Body)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	return ParseMetadata(data)
    81  }